Skip to content

Commit

Permalink
issue #957: adds vector index template
Browse files Browse the repository at this point in the history
  • Loading branch information
mrk-vi committed Jul 8, 2024
1 parent 9c5535d commit fa750e7
Showing 1 changed file with 184 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
import org.opensearch.client.opensearch.core.bulk.BulkOperation;
import org.opensearch.client.opensearch.core.bulk.BulkResponseItem;
import org.opensearch.client.opensearch.core.bulk.IndexOperation;
import org.opensearch.client.opensearch.indices.PutIndexTemplateResponse;
import org.opensearch.client.transport.endpoints.BooleanResponse;

import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -48,12 +50,15 @@

public class VectorIndexWriter extends AbstractBehavior<Writer.Command> {

private final static org.jboss.logging.Logger logger =
private final static org.jboss.logging.Logger log =
Logger.getLogger(VectorIndexWriter.class);

private final OpenSearchAsyncClient asyncClient;
private final String vectorIndexName;
private final ActorRef<Writer.Response> replyTo;
private final String templateName;
private boolean indexTemplateCreated = false;
private int vectorSize;

public VectorIndexWriter(
ActorContext<Writer.Command> context,
Expand All @@ -63,7 +68,9 @@ public VectorIndexWriter(
super(context);
this.asyncClient = CDI.current().select(OpenSearchAsyncClient.class).get();
this.vectorIndexName = scheduler.getVectorIndexName();
this.templateName = vectorIndexName + "-template";
this.replyTo = replyTo;

}

public static Behavior<Writer.Command> create(
Expand All @@ -76,7 +83,11 @@ public static Behavior<Writer.Command> create(
public Receive<Writer.Command> createReceive() {
return newReceiveBuilder()
.onMessage(Writer.Start.class, this::onStart)
.onMessage(AsyncResponse.class, this::onAsyncResponse)
.onMessage(CheckIndexTemplate.class, this::onCheckIndexTemplate)
.onMessage(CheckIndexTemplateResponse.class, this::onCheckIndexTemplateResponse)
.onMessage(PutTemplateResponse.class, this::onPutTemplateResponse)
.onMessage(IndexDocument.class, this::onIndexDocument)
.onMessage(IndexDocumentResponse.class, this::onIndexDocumentResponse)
.build();
}

Expand All @@ -88,9 +99,143 @@ private Behavior<Writer.Command> onStart(Writer.Start start) {
EmbeddingService.EmbeddedChunks dataPayload = Json.decodeValue(
Buffer.buffer(data), EmbeddingService.EmbeddedChunks.class);

getContext().getSelf().tell(new CheckIndexTemplate(dataPayload, heldMessage));

return this;
}

private Behavior<Writer.Command> onCheckIndexTemplate(CheckIndexTemplate checkIndexTemplate) {

var embeddedChunks = checkIndexTemplate.embeddedChunks();
var heldMessage = checkIndexTemplate.heldMessage();

if (indexTemplateCreated) {

getContext().getSelf().tell(new IndexDocument(
embeddedChunks, heldMessage)
);

return this;
}

try {

getContext().pipeToSelf(
asyncClient.indices().existsIndexTemplate(req -> req.name(templateName)),
(r, t) -> new CheckIndexTemplateResponse(
embeddedChunks, heldMessage, r, t)
);

}
catch (IOException e) {

replyTo.tell(new Writer.Failure(e, heldMessage));

}

return this;
}

private Behavior<Writer.Command> onCheckIndexTemplateResponse(CheckIndexTemplateResponse response) {

var exists = response.exists();
var throwable = response.throwable();
var heldMessage = response.heldMessage();

if (throwable != null) {

replyTo.tell(new Writer.Failure((Exception) throwable, heldMessage));

return this;

}

if (exists.value()) {

this.indexTemplateCreated = true;

getContext().getSelf().tell(new IndexDocument(response.embeddedChunks(), heldMessage));

}
else {

var chunks = response.embeddedChunks().list().iterator();

if (chunks.hasNext()) {
this.vectorSize = chunks.next().vector().size();
}

try {

getContext().pipeToSelf(
asyncClient.indices().putIndexTemplate(req -> req
.name(templateName)
.indexPatterns(vectorIndexName)
.template(template -> template
.settings(settings -> settings
.knn(true))
.mappings(mapping -> mapping
.properties("indexName", p -> p
.text(text -> text))
.properties("contentId", p -> p
.text(text -> text))
.properties("number", p -> p
.integer(int_ -> int_))
.properties("total", p -> p
.integer(int_ -> int_))
.properties("text", p -> p
.text(text -> text))
.properties("vector", p -> p
.knnVector(knn -> knn.dimension(vectorSize)))
)
)
),
(r, t) -> new PutTemplateResponse(
response.embeddedChunks(),
heldMessage,
r,
t
)
);

}
catch (IOException e) {

replyTo.tell(new Writer.Failure(e, heldMessage));

}

}

return this;
}

private Behavior<Writer.Command> onPutTemplateResponse(PutTemplateResponse putTemplateResponse) {

var embeddedChunks = putTemplateResponse.embeddedChunks();
var heldMessage = putTemplateResponse.heldMessage();
var throwable = putTemplateResponse.throwable();

if (throwable != null) {

replyTo.tell(new Writer.Failure((Exception) throwable, heldMessage));

return this;
}

getContext().getSelf().tell(new IndexDocument(embeddedChunks, heldMessage));

return this;
}

private Behavior<Writer.Command> onIndexDocument(IndexDocument indexDocument) {

var embeddedChunks = indexDocument.embeddedChunks();
var heldMessage = indexDocument.heldMessage();

var bulkOperations = new ArrayList<BulkOperation>();

for (EmbeddingService.EmbeddedChunk chunk : dataPayload.list()) {
for (EmbeddingService.EmbeddedChunk chunk : embeddedChunks.list()) {
var document = JsonObject.mapFrom(chunk);

Map<?, ?> acl = (Map<?, ?>) document.getValue("acl");
Expand All @@ -115,21 +260,24 @@ private Behavior<Writer.Command> onStart(Writer.Start start) {
.build();

try {

getContext().pipeToSelf(
asyncClient.bulk(bulkRequest),
(bulkResponse, exception) -> new AsyncResponse(
dataPayload, heldMessage, bulkResponse, (Exception) exception)
(bulkResponse, exception) -> new IndexDocumentResponse(
embeddedChunks, heldMessage, bulkResponse, (Exception) exception)
);

}
catch (IOException e) {
replyTo.tell(new Writer.Failure(e, heldMessage));
}

return this;

}

private Behavior<Writer.Command> onAsyncResponse(
AsyncResponse brc) {
private Behavior<Writer.Command> onIndexDocumentResponse(
IndexDocumentResponse brc) {

var bulkResponse = brc.bulkResponse;
var heldMessage = brc.heldMessage;
Expand All @@ -146,7 +294,7 @@ private Behavior<Writer.Command> onAsyncResponse(
.map(ErrorCause::reason)
.collect(Collectors.joining());

logger.error("Bulk request error: " + reasons);
log.error("Bulk request error: " + reasons);
replyTo.tell(new Writer.Failure(
new RuntimeException(reasons),
heldMessage
Expand All @@ -156,7 +304,7 @@ private Behavior<Writer.Command> onAsyncResponse(
}

if (throwable != null) {
logger.error("Error on bulk request", throwable);
log.error("Error on bulk request", throwable);
replyTo.tell(new Writer.Failure(throwable, heldMessage));
}
else {
Expand All @@ -167,8 +315,34 @@ private Behavior<Writer.Command> onAsyncResponse(
return this;
}

private record CheckIndexTemplate(
EmbeddingService.EmbeddedChunks embeddedChunks,
HeldMessage heldMessage
) implements Writer.Command {}

private record CheckIndexTemplateResponse(
EmbeddingService.EmbeddedChunks embeddedChunks,
HeldMessage heldMessage,
BooleanResponse exists,
Throwable throwable
)
implements Writer.Command {}

private record PutTemplateResponse(
EmbeddingService.EmbeddedChunks embeddedChunks,
HeldMessage heldMessage,
PutIndexTemplateResponse response,
Throwable throwable
)
implements Writer.Command {}

private record IndexDocument(
EmbeddingService.EmbeddedChunks embeddedChunks,
HeldMessage heldMessage
)
implements Writer.Command {}

private record AsyncResponse(
private record IndexDocumentResponse(
EmbeddingService.EmbeddedChunks embeddedChunks,
HeldMessage heldMessage,
BulkResponse bulkResponse,
Expand Down

0 comments on commit fa750e7

Please sign in to comment.