Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Vector Search #1639

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,17 @@ public <T extends Value<?>> List<T> getList(String name) {
return (List<T>) getValue(name).get();
}

/**
* Returns the property value as a vector.
*
* @throws DatastoreException if no such property
* @throws ClassCastException if value is not a vector
*/
@SuppressWarnings("unchecked")
public List<DoubleValue> getVector(String name) {
return (List<DoubleValue>) getValue(name).get();
}

/**
* Returns the property value as a blob.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.cloud.datastore;

import com.google.common.base.MoreObjects;
import com.google.common.base.MoreObjects.ToStringHelper;
import com.google.protobuf.DoubleValue;
import com.google.protobuf.Int32Value;
import java.io.Serializable;
import java.util.Objects;
import javax.annotation.Nullable;

/**
* A query that finds the entities whose vector fields are closest to a certain query vector. Create
* an instance of `FindNearest` with {@link Query#findNearest}.
*/
public final class FindNearest implements Serializable {

private final String vectorProperty;
private final VectorValue queryVector;
private final DistanceMeasure measure;
private final int limit;
/*
Optional. Optional name of the field to output the result of the vector
* distance calculation.
*/
private final @Nullable String distanceResultField;
private final @Nullable Double distanceThreshold;

private static final long serialVersionUID = 4688656124180403551L;

/** Creates a VectorQuery */
public FindNearest(
String vectorProperty,
VectorValue queryVector,
DistanceMeasure measure,
int limit,
@Nullable String distanceResultField,
@Nullable Double distanceThreshold) {
this.vectorProperty = vectorProperty;
this.queryVector = queryVector;
this.measure = measure;
this.limit = limit;
this.distanceResultField = distanceResultField;
this.distanceThreshold = distanceThreshold;
}

public FindNearest(
String vectorProperty, VectorValue queryVector, DistanceMeasure measure, int limit) {
this(vectorProperty, queryVector, measure, limit, null, null);
}

public FindNearest(
String vectorProperty,
VectorValue queryVector,
DistanceMeasure measure,
int limit,
@Nullable String distanceResultField) {
this(vectorProperty, queryVector, measure, limit, distanceResultField, null);
}

public FindNearest(
String vectorProperty,
VectorValue queryVector,
DistanceMeasure measure,
int limit,
@Nullable Double distanceThreshold) {
this(vectorProperty, queryVector, measure, limit, null, distanceThreshold);
}

@Override
public int hashCode() {
return Objects.hash(
vectorProperty, queryVector, measure, limit, distanceResultField, distanceThreshold);
}

/**
* Returns true if this VectorQuery is equal to the provided object.
*
* @param obj The object to compare against.
* @return Whether this VectorQuery is equal to the provided object.
*/
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || !(obj instanceof FindNearest)) {
return false;
}
FindNearest otherQuery = (FindNearest) obj;
return Objects.equals(vectorProperty, otherQuery.vectorProperty)
&& Objects.equals(queryVector, otherQuery.queryVector)
&& Objects.equals(distanceResultField, otherQuery.distanceResultField)
&& Objects.equals(distanceThreshold, otherQuery.distanceThreshold)
&& limit == otherQuery.limit
&& measure == otherQuery.measure;
}

@Override
public String toString() {
ToStringHelper toStringHelper = MoreObjects.toStringHelper(this);
toStringHelper.add("vectorProperty", vectorProperty);
toStringHelper.add("queryVector", queryVector);
toStringHelper.add("measure", measure);
toStringHelper.add("limit", limit);
toStringHelper.add("distanceResultField", distanceResultField);
toStringHelper.add("distanceThreshold", distanceThreshold);
return toStringHelper.toString();
}

static FindNearest fromPb(com.google.datastore.v1.FindNearest findNearestPb) {
String vectorProperty = findNearestPb.getVectorProperty().getName();
VectorValue queryVector =
VectorValue.MARSHALLER.fromProto(findNearestPb.getQueryVector()).build();
DistanceMeasure distanceMeasure =
DistanceMeasure.valueOf(findNearestPb.getDistanceMeasure().toString());
int limit = findNearestPb.getLimit().getValue();
String distanceResultField =
findNearestPb.getDistanceResultProperty() == null
|| findNearestPb.getDistanceResultProperty().isEmpty()
? null
: findNearestPb.getDistanceResultProperty();
Double distanceThreshold =
findNearestPb.getDistanceThreshold() == null
|| findNearestPb.getDistanceThreshold() == DoubleValue.getDefaultInstance()
? null
: findNearestPb.getDistanceThreshold().getValue();
return new FindNearest(
vectorProperty,
queryVector,
distanceMeasure,
limit,
distanceResultField,
distanceThreshold);
}

com.google.datastore.v1.FindNearest toPb() {
com.google.datastore.v1.FindNearest.Builder findNearestPb =
com.google.datastore.v1.FindNearest.newBuilder();
findNearestPb.getVectorPropertyBuilder().setName(vectorProperty);
findNearestPb.setQueryVector(queryVector.toPb());
findNearestPb.setDistanceMeasure(toProto(measure));
findNearestPb.setLimit(Int32Value.of(limit));
if (distanceResultField != null) {
findNearestPb.setDistanceResultProperty(distanceResultField);
}
if (distanceThreshold != null) {
findNearestPb.setDistanceThreshold(DoubleValue.of(distanceThreshold));
}
return findNearestPb.build();
}

protected static com.google.datastore.v1.FindNearest.DistanceMeasure toProto(
DistanceMeasure distanceMeasure) {
switch (distanceMeasure) {
case COSINE:
return com.google.datastore.v1.FindNearest.DistanceMeasure.COSINE;
case EUCLIDEAN:
return com.google.datastore.v1.FindNearest.DistanceMeasure.EUCLIDEAN;
case DOT_PRODUCT:
return com.google.datastore.v1.FindNearest.DistanceMeasure.DOT_PRODUCT;
default:
return com.google.datastore.v1.FindNearest.DistanceMeasure.UNRECOGNIZED;
}
}

/**
* The distance measure to use when comparing vectors in a {@link FindNearest query}.
*
* @see com.google.cloud.datastore.Query#findNearest
*/
public enum DistanceMeasure {
/**
* COSINE distance compares vectors based on the angle between them, which allows you to measure
* similarity that isn't based on the vectors' magnitude. We recommend using DOT_PRODUCT with
* unit normalized vectors instead of COSINE distance, which is mathematically equivalent with
* better performance.
*/
COSINE,
/** Measures the EUCLIDEAN distance between the vectors. */
EUCLIDEAN,
/** Similar to cosine but is affected by the magnitude of the vectors. */
DOT_PRODUCT
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ public abstract class StructuredQuery<V> extends Query<V> implements RecordQuery
private final Cursor endCursor;
private final int offset;
private final Integer limit;
private final FindNearest findNearest;

private final ResultType<V> resultType;

Expand Down Expand Up @@ -731,6 +732,9 @@ public interface Builder<V> {
/** Adds settings to the existing order by clause. */
Builder<V> addOrderBy(OrderBy orderBy, OrderBy... others);

/** Sets the find_nearest for the query. */
Builder<V> setFindNearest(FindNearest findNearest);

StructuredQuery<V> build();
}

Expand All @@ -753,6 +757,7 @@ abstract static class BuilderImpl<V, B extends BuilderImpl<V, B>> implements Bui
private Cursor endCursor;
private int offset;
private Integer limit;
private FindNearest findNearest;

BuilderImpl(ResultType<V> resultType) {
this.resultType = resultType;
Expand All @@ -770,6 +775,7 @@ abstract static class BuilderImpl<V, B extends BuilderImpl<V, B>> implements Bui
endCursor = query.endCursor;
offset = query.offset;
limit = query.limit;
findNearest = query.findNearest;
}

@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -841,6 +847,13 @@ public B addOrderBy(OrderBy orderBy, OrderBy... others) {
return self();
}

@Override
public B setFindNearest(FindNearest findNearest) {
Preconditions.checkArgument(findNearest != null, "vector query must not be null");
this.findNearest = findNearest;
return self();
}

B clearProjection() {
projection.clear();
return self();
Expand Down Expand Up @@ -904,6 +917,9 @@ B mergeFrom(com.google.datastore.v1.Query queryPb) {
for (com.google.datastore.v1.PropertyReference distinctOnPb : queryPb.getDistinctOnList()) {
addDistinctOn(distinctOnPb.getName());
}
if (queryPb.getFindNearest() != null) {
setFindNearest(FindNearest.fromPb(queryPb.getFindNearest()));
}
return self();
}
}
Expand All @@ -920,6 +936,7 @@ B mergeFrom(com.google.datastore.v1.Query queryPb) {
endCursor = builder.endCursor;
offset = builder.offset;
limit = builder.limit;
findNearest = builder.findNearest;
}

@Override
Expand All @@ -935,6 +952,7 @@ public String toString() {
.add("orderBy", orderBy)
.add("projection", projection)
.add("distinctOn", distinctOn)
.add("findNearest", findNearest)
.toString();
}

Expand All @@ -950,7 +968,8 @@ public int hashCode() {
filter,
orderBy,
projection,
distinctOn);
distinctOn,
findNearest);
}

@Override
Expand All @@ -971,7 +990,8 @@ public boolean equals(Object obj) {
&& Objects.equals(filter, other.filter)
&& Objects.equals(orderBy, other.orderBy)
&& Objects.equals(projection, other.projection)
&& Objects.equals(distinctOn, other.distinctOn);
&& Objects.equals(distinctOn, other.distinctOn)
&& Objects.equals(findNearest, other.findNearest);
}

/** Returns the kind for this query. */
Expand Down Expand Up @@ -1023,6 +1043,11 @@ public Integer getLimit() {
return limit;
}

/** Returns the vector query for this query. */
public FindNearest getFindNearest() {
return findNearest;
}

public abstract Builder<V> toBuilder();

@InternalApi
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ public Query prepare(StructuredQuery<?> query) {
.build();
queryPb.addProjection(expressionPb);
}
if (query.getFindNearest() != null) {
queryPb.setFindNearest(query.getFindNearest().toPb());
}

return queryPb.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.google.cloud.datastore;

import static com.google.cloud.datastore.VectorValue.VECTOR_MEANING;
import static com.google.common.base.Preconditions.checkNotNull;

import com.google.cloud.GcpLaunchStage;
Expand Down Expand Up @@ -214,8 +215,18 @@ com.google.datastore.v1.Value toPb() {
public static Value<?> fromPb(com.google.datastore.v1.Value proto) {
ValueTypeCase descriptorId = proto.getValueTypeCase();
ValueType valueType = ValueType.getByDescriptorId(descriptorId.getNumber());
return valueType == null
? RawValue.MARSHALLER.fromProto(proto).build()
: valueType.getMarshaller().fromProto(proto).build();
if (valueType == null) return RawValue.MARSHALLER.fromProto(proto).build();

Value<?> returnValue = valueType.getMarshaller().fromProto(proto).build();
if (valueType == ValueType.LIST && proto.getMeaning() == VECTOR_MEANING) {
for (com.google.datastore.v1.Value item : proto.getArrayValue().getValuesList()) {
if (item.getValueTypeCase() != ValueTypeCase.DOUBLE_VALUE) {
return returnValue;
}
}
returnValue = VectorValue.MARSHALLER.fromProto(proto).build();
}

return returnValue;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import com.google.common.collect.ImmutableMap;

/**
* The type of a Datastore property.
* The type of Datastore property.
*
* @see <a
* href="http://cloud.google.com/datastore/docs/concepts/entities#Datastore_Properties_and_value_types">Google
Expand Down Expand Up @@ -61,7 +61,10 @@ public enum ValueType {
RAW_VALUE(RawValue.MARSHALLER),

/** Represents a {@link LatLng} value. */
LAT_LNG(LatLngValue.MARSHALLER);
LAT_LNG(LatLngValue.MARSHALLER),

/** Represents a {@link Vector} value. */
VECTOR(VectorValue.MARSHALLER);

private static final ImmutableMap<Integer, ValueType> DESCRIPTOR_TO_TYPE_MAP;

Expand Down
Loading
Loading