diff --git a/google-cloud-datastore/src/main/java/com/google/cloud/datastore/BaseEntity.java b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/BaseEntity.java index 608dc7187..d50770215 100644 --- a/google-cloud-datastore/src/main/java/com/google/cloud/datastore/BaseEntity.java +++ b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/BaseEntity.java @@ -615,6 +615,17 @@ public > List getList(String name) { return (List) getValue(name).get(); } + /** + * Returns the property value as a vector. + * + * @throws DatastoreException if no such property + * @throws ClassCastException if value is not a vector + */ + @SuppressWarnings("unchecked") + public List getVector(String name) { + return (List) getValue(name).get(); + } + /** * Returns the property value as a blob. * diff --git a/google-cloud-datastore/src/main/java/com/google/cloud/datastore/FindNearest.java b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/FindNearest.java new file mode 100644 index 000000000..855c62d07 --- /dev/null +++ b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/FindNearest.java @@ -0,0 +1,200 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.datastore; + +import com.google.common.base.MoreObjects; +import com.google.common.base.MoreObjects.ToStringHelper; +import com.google.protobuf.DoubleValue; +import com.google.protobuf.Int32Value; +import java.io.Serializable; +import java.util.Objects; +import javax.annotation.Nullable; + +/** + * A query that finds the entities whose vector fields are closest to a certain query vector. Create + * an instance of `FindNearest` with {@link Query#findNearest}. + */ +public final class FindNearest implements Serializable { + + private final String vectorProperty; + private final VectorValue queryVector; + private final DistanceMeasure measure; + private final int limit; + /* + Optional. Optional name of the field to output the result of the vector + * distance calculation. + */ + private final @Nullable String distanceResultField; + private final @Nullable Double distanceThreshold; + + private static final long serialVersionUID = 4688656124180403551L; + + /** Creates a VectorQuery */ + public FindNearest( + String vectorProperty, + VectorValue queryVector, + DistanceMeasure measure, + int limit, + @Nullable String distanceResultField, + @Nullable Double distanceThreshold) { + this.vectorProperty = vectorProperty; + this.queryVector = queryVector; + this.measure = measure; + this.limit = limit; + this.distanceResultField = distanceResultField; + this.distanceThreshold = distanceThreshold; + } + + public FindNearest( + String vectorProperty, VectorValue queryVector, DistanceMeasure measure, int limit) { + this(vectorProperty, queryVector, measure, limit, null, null); + } + + public FindNearest( + String vectorProperty, + VectorValue queryVector, + DistanceMeasure measure, + int limit, + @Nullable String distanceResultField) { + this(vectorProperty, queryVector, measure, limit, distanceResultField, null); + } + + public FindNearest( + String vectorProperty, + VectorValue queryVector, + DistanceMeasure measure, + int limit, + @Nullable Double distanceThreshold) { + this(vectorProperty, queryVector, measure, limit, null, distanceThreshold); + } + + @Override + public int hashCode() { + return Objects.hash( + vectorProperty, queryVector, measure, limit, distanceResultField, distanceThreshold); + } + + /** + * Returns true if this VectorQuery is equal to the provided object. + * + * @param obj The object to compare against. + * @return Whether this VectorQuery is equal to the provided object. + */ + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || !(obj instanceof FindNearest)) { + return false; + } + FindNearest otherQuery = (FindNearest) obj; + return Objects.equals(vectorProperty, otherQuery.vectorProperty) + && Objects.equals(queryVector, otherQuery.queryVector) + && Objects.equals(distanceResultField, otherQuery.distanceResultField) + && Objects.equals(distanceThreshold, otherQuery.distanceThreshold) + && limit == otherQuery.limit + && measure == otherQuery.measure; + } + + @Override + public String toString() { + ToStringHelper toStringHelper = MoreObjects.toStringHelper(this); + toStringHelper.add("vectorProperty", vectorProperty); + toStringHelper.add("queryVector", queryVector); + toStringHelper.add("measure", measure); + toStringHelper.add("limit", limit); + toStringHelper.add("distanceResultField", distanceResultField); + toStringHelper.add("distanceThreshold", distanceThreshold); + return toStringHelper.toString(); + } + + static FindNearest fromPb(com.google.datastore.v1.FindNearest findNearestPb) { + String vectorProperty = findNearestPb.getVectorProperty().getName(); + VectorValue queryVector = + VectorValue.MARSHALLER.fromProto(findNearestPb.getQueryVector()).build(); + DistanceMeasure distanceMeasure = + DistanceMeasure.valueOf(findNearestPb.getDistanceMeasure().toString()); + int limit = findNearestPb.getLimit().getValue(); + String distanceResultField = + findNearestPb.getDistanceResultProperty() == null + || findNearestPb.getDistanceResultProperty().isEmpty() + ? null + : findNearestPb.getDistanceResultProperty(); + Double distanceThreshold = + findNearestPb.getDistanceThreshold() == null + || findNearestPb.getDistanceThreshold() == DoubleValue.getDefaultInstance() + ? null + : findNearestPb.getDistanceThreshold().getValue(); + return new FindNearest( + vectorProperty, + queryVector, + distanceMeasure, + limit, + distanceResultField, + distanceThreshold); + } + + com.google.datastore.v1.FindNearest toPb() { + com.google.datastore.v1.FindNearest.Builder findNearestPb = + com.google.datastore.v1.FindNearest.newBuilder(); + findNearestPb.getVectorPropertyBuilder().setName(vectorProperty); + findNearestPb.setQueryVector(queryVector.toPb()); + findNearestPb.setDistanceMeasure(toProto(measure)); + findNearestPb.setLimit(Int32Value.of(limit)); + if (distanceResultField != null) { + findNearestPb.setDistanceResultProperty(distanceResultField); + } + if (distanceThreshold != null) { + findNearestPb.setDistanceThreshold(DoubleValue.of(distanceThreshold)); + } + return findNearestPb.build(); + } + + protected static com.google.datastore.v1.FindNearest.DistanceMeasure toProto( + DistanceMeasure distanceMeasure) { + switch (distanceMeasure) { + case COSINE: + return com.google.datastore.v1.FindNearest.DistanceMeasure.COSINE; + case EUCLIDEAN: + return com.google.datastore.v1.FindNearest.DistanceMeasure.EUCLIDEAN; + case DOT_PRODUCT: + return com.google.datastore.v1.FindNearest.DistanceMeasure.DOT_PRODUCT; + default: + return com.google.datastore.v1.FindNearest.DistanceMeasure.UNRECOGNIZED; + } + } + + /** + * The distance measure to use when comparing vectors in a {@link FindNearest query}. + * + * @see com.google.cloud.datastore.Query#findNearest + */ + public enum DistanceMeasure { + /** + * COSINE distance compares vectors based on the angle between them, which allows you to measure + * similarity that isn't based on the vectors' magnitude. We recommend using DOT_PRODUCT with + * unit normalized vectors instead of COSINE distance, which is mathematically equivalent with + * better performance. + */ + COSINE, + /** Measures the EUCLIDEAN distance between the vectors. */ + EUCLIDEAN, + /** Similar to cosine but is affected by the magnitude of the vectors. */ + DOT_PRODUCT + } +} diff --git a/google-cloud-datastore/src/main/java/com/google/cloud/datastore/StructuredQuery.java b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/StructuredQuery.java index 30cd05759..7541dbd75 100644 --- a/google-cloud-datastore/src/main/java/com/google/cloud/datastore/StructuredQuery.java +++ b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/StructuredQuery.java @@ -101,6 +101,7 @@ public abstract class StructuredQuery extends Query implements RecordQuery private final Cursor endCursor; private final int offset; private final Integer limit; + private final FindNearest findNearest; private final ResultType resultType; @@ -731,6 +732,9 @@ public interface Builder { /** Adds settings to the existing order by clause. */ Builder addOrderBy(OrderBy orderBy, OrderBy... others); + /** Sets the find_nearest for the query. */ + Builder setFindNearest(FindNearest findNearest); + StructuredQuery build(); } @@ -753,6 +757,7 @@ abstract static class BuilderImpl> implements Bui private Cursor endCursor; private int offset; private Integer limit; + private FindNearest findNearest; BuilderImpl(ResultType resultType) { this.resultType = resultType; @@ -770,6 +775,7 @@ abstract static class BuilderImpl> implements Bui endCursor = query.endCursor; offset = query.offset; limit = query.limit; + findNearest = query.findNearest; } @SuppressWarnings("unchecked") @@ -841,6 +847,13 @@ public B addOrderBy(OrderBy orderBy, OrderBy... others) { return self(); } + @Override + public B setFindNearest(FindNearest findNearest) { + Preconditions.checkArgument(findNearest != null, "vector query must not be null"); + this.findNearest = findNearest; + return self(); + } + B clearProjection() { projection.clear(); return self(); @@ -904,6 +917,9 @@ B mergeFrom(com.google.datastore.v1.Query queryPb) { for (com.google.datastore.v1.PropertyReference distinctOnPb : queryPb.getDistinctOnList()) { addDistinctOn(distinctOnPb.getName()); } + if (queryPb.getFindNearest() != null) { + setFindNearest(FindNearest.fromPb(queryPb.getFindNearest())); + } return self(); } } @@ -920,6 +936,7 @@ B mergeFrom(com.google.datastore.v1.Query queryPb) { endCursor = builder.endCursor; offset = builder.offset; limit = builder.limit; + findNearest = builder.findNearest; } @Override @@ -935,6 +952,7 @@ public String toString() { .add("orderBy", orderBy) .add("projection", projection) .add("distinctOn", distinctOn) + .add("findNearest", findNearest) .toString(); } @@ -950,7 +968,8 @@ public int hashCode() { filter, orderBy, projection, - distinctOn); + distinctOn, + findNearest); } @Override @@ -971,7 +990,8 @@ public boolean equals(Object obj) { && Objects.equals(filter, other.filter) && Objects.equals(orderBy, other.orderBy) && Objects.equals(projection, other.projection) - && Objects.equals(distinctOn, other.distinctOn); + && Objects.equals(distinctOn, other.distinctOn) + && Objects.equals(findNearest, other.findNearest); } /** Returns the kind for this query. */ @@ -1023,6 +1043,11 @@ public Integer getLimit() { return limit; } + /** Returns the vector query for this query. */ + public FindNearest getFindNearest() { + return findNearest; + } + public abstract Builder toBuilder(); @InternalApi diff --git a/google-cloud-datastore/src/main/java/com/google/cloud/datastore/StructuredQueryProtoPreparer.java b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/StructuredQueryProtoPreparer.java index fda6f8f4a..c7e39f3d4 100644 --- a/google-cloud-datastore/src/main/java/com/google/cloud/datastore/StructuredQueryProtoPreparer.java +++ b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/StructuredQueryProtoPreparer.java @@ -60,6 +60,9 @@ public Query prepare(StructuredQuery query) { .build(); queryPb.addProjection(expressionPb); } + if (query.getFindNearest() != null) { + queryPb.setFindNearest(query.getFindNearest().toPb()); + } return queryPb.build(); } diff --git a/google-cloud-datastore/src/main/java/com/google/cloud/datastore/Value.java b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/Value.java index 4bd0a5133..93bd4298c 100644 --- a/google-cloud-datastore/src/main/java/com/google/cloud/datastore/Value.java +++ b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/Value.java @@ -16,6 +16,7 @@ package com.google.cloud.datastore; +import static com.google.cloud.datastore.VectorValue.VECTOR_MEANING; import static com.google.common.base.Preconditions.checkNotNull; import com.google.cloud.GcpLaunchStage; @@ -214,8 +215,18 @@ com.google.datastore.v1.Value toPb() { public static Value fromPb(com.google.datastore.v1.Value proto) { ValueTypeCase descriptorId = proto.getValueTypeCase(); ValueType valueType = ValueType.getByDescriptorId(descriptorId.getNumber()); - return valueType == null - ? RawValue.MARSHALLER.fromProto(proto).build() - : valueType.getMarshaller().fromProto(proto).build(); + if (valueType == null) return RawValue.MARSHALLER.fromProto(proto).build(); + + Value returnValue = valueType.getMarshaller().fromProto(proto).build(); + if (valueType == ValueType.LIST && proto.getMeaning() == VECTOR_MEANING) { + for (com.google.datastore.v1.Value item : proto.getArrayValue().getValuesList()) { + if (item.getValueTypeCase() != ValueTypeCase.DOUBLE_VALUE) { + return returnValue; + } + } + returnValue = VectorValue.MARSHALLER.fromProto(proto).build(); + } + + return returnValue; } } diff --git a/google-cloud-datastore/src/main/java/com/google/cloud/datastore/ValueType.java b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/ValueType.java index 13e3c7af6..df2fb4099 100644 --- a/google-cloud-datastore/src/main/java/com/google/cloud/datastore/ValueType.java +++ b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/ValueType.java @@ -19,7 +19,7 @@ import com.google.common.collect.ImmutableMap; /** - * The type of a Datastore property. + * The type of Datastore property. * * @see Google @@ -61,7 +61,10 @@ public enum ValueType { RAW_VALUE(RawValue.MARSHALLER), /** Represents a {@link LatLng} value. */ - LAT_LNG(LatLngValue.MARSHALLER); + LAT_LNG(LatLngValue.MARSHALLER), + + /** Represents a {@link Vector} value. */ + VECTOR(VectorValue.MARSHALLER); private static final ImmutableMap DESCRIPTOR_TO_TYPE_MAP; diff --git a/google-cloud-datastore/src/main/java/com/google/cloud/datastore/VectorValue.java b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/VectorValue.java new file mode 100644 index 000000000..38a2903dc --- /dev/null +++ b/google-cloud-datastore/src/main/java/com/google/cloud/datastore/VectorValue.java @@ -0,0 +1,154 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.datastore; + +import com.google.common.collect.ImmutableList; +import java.util.ArrayList; +import java.util.List; + +/** A Google Cloud Datastore Vector value. A list value is a list of {@link Value} objects. */ +public final class VectorValue extends Value>> { + + private static final long serialVersionUID = -5121887228607148859L; + + public static final int VECTOR_MEANING = 31; + + static final BaseMarshaller>, VectorValue, Builder> MARSHALLER = + new BaseMarshaller>, VectorValue, Builder>() { + private static final long serialVersionUID = 7720473855548179943L; + + @Override + public int getProtoFieldId() { + return -1; + } + + @Override + public Builder newBuilder(List> values) { + return VectorValue.newBuilder().set(values); + } + + @Override + protected List> getValue(com.google.datastore.v1.Value from) { + List> properties = new ArrayList<>(from.getArrayValue().getValuesCount()); + for (com.google.datastore.v1.Value valuePb : from.getArrayValue().getValuesList()) { + properties.add((Value) Value.fromPb(valuePb)); + } + return properties; + } + + @Override + protected void setValue(VectorValue from, com.google.datastore.v1.Value.Builder to) { + List propertiesPb = new ArrayList<>(); + for (Value property : from.get()) { + propertiesPb.add(property.toPb()); + } + to.setArrayValue( + com.google.datastore.v1.ArrayValue.newBuilder().addAllValues(propertiesPb)); + } + }; + + public static final class Builder + extends Value.BaseBuilder>, VectorValue, Builder> { + private ImmutableList.Builder> vectorBuilder = ImmutableList.builder(); + + private Builder() { + super(ValueType.VECTOR); + } + + /** Adds the provided double values to the {@code VectorValue} builder. */ + public VectorValue.Builder addValue(Value first, Value... other) { + vectorBuilder.add(first); + for (Value value : other) { + vectorBuilder.add(value); + } + return this; + } + + public VectorValue.Builder addValue(double first, double... other) { + vectorBuilder.add(DoubleValue.of(first)); + for (double value : other) { + vectorBuilder.add(DoubleValue.of(value)); + } + return this; + } + + /** + * Sets the list of values of this {@code ListValue} builder to {@code values}. The provided + * list is copied. + * + * @see com.google.cloud.datastore.Value.BaseBuilder#set(java.lang.Object) + */ + @Override + public Builder set(List> values) { + vectorBuilder = ImmutableList.builder(); + for (Value value : values) { + addValue(value); + } + return this; + } + + @Override + public List> get() { + return vectorBuilder.build(); + } + + /** Creates a {@code ListValue} object. */ + @Override + public VectorValue build() { + return new VectorValue(this); + } + } + + public VectorValue(List> values) { + this(newBuilder().set(values)); + } + + private VectorValue(Builder builder) { + super(builder); + } + + /** Returns a builder for the list value object. */ + @Override + public Builder toBuilder() { + return new Builder().mergeFrom(this); + } + + /** Creates a {@code VectorValue} object given a number of double values. */ + public static VectorValue of(double first, double... other) { + return newBuilder().addValue(first, other).build(); + } + + /** Creates a {@code VectorValue} object given a list of {@code Value} objects. */ + public static VectorValue of(List> values) { + return new VectorValue(values); + } + + /** Returns a builder for {@code ListValue} objects. */ + public static Builder newBuilder() { + Builder builder = new VectorValue.Builder(); + builder.setExcludeFromIndexes(true); + builder.setMeaning(VECTOR_MEANING); + return builder; + } + + public static Builder newBuilder(double first, double... other) { + VectorValue.Builder builder = new VectorValue.Builder(); + builder.setExcludeFromIndexes(true); + builder.setMeaning(VECTOR_MEANING); + return builder.addValue(first, other); + } +} diff --git a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/BaseEntityTest.java b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/BaseEntityTest.java index 1b5380ab9..fc33c64b0 100644 --- a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/BaseEntityTest.java +++ b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/BaseEntityTest.java @@ -16,6 +16,7 @@ package com.google.cloud.datastore; +import static com.google.cloud.datastore.VectorValue.VECTOR_MEANING; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @@ -35,6 +36,11 @@ public class BaseEntityTest { private static final Blob BLOB = Blob.copyFrom(new byte[] {1, 2}); private static final Timestamp TIMESTAMP = Timestamp.now(); private static final LatLng LAT_LNG = new LatLng(37.422035, -122.084124); + private static final VectorValue VECTOR = + VectorValue.newBuilder(1.78, 2.56, 3.88) + .setMeaning(VECTOR_MEANING) + .setExcludeFromIndexes(true) + .build(); private static final Key KEY = Key.newBuilder("ds1", "k1", "n1").build(); private static final Entity ENTITY = Entity.newBuilder(KEY).set("name", "foo").build(); private static final IncompleteKey INCOMPLETE_KEY = IncompleteKey.newBuilder("ds1", "k1").build(); @@ -76,6 +82,7 @@ public void setUp() { builder.set("stringList", "s1", "s2", "s3"); builder.set("longList", 1, 23, 456); builder.set("latLngList", LAT_LNG, LAT_LNG); + builder.set("vector", VECTOR); } @Test @@ -182,6 +189,16 @@ public void testGetEntity() { assertEquals(PARTIAL_ENTITY, entity.getEntity("entity")); } + @Test + public void testGetVector() { + BaseEntity entity = builder.build(); + List vectorList = entity.getVector("vector"); + assertEquals(3, vectorList.size()); + assertEquals(Double.valueOf(1.78), vectorList.get(0).get()); + assertEquals(Double.valueOf(2.56), vectorList.get(1).get()); + assertEquals(Double.valueOf(3.88), vectorList.get(2).get()); + } + @Test public void testGetList() { BaseEntity entity = builder.build(); @@ -229,7 +246,7 @@ public void testNames() { .add("entity", "partialEntity", "null", "timestamp", "blob", "key", "blobList") .add( "booleanList", "timestampList", "doubleList", "keyList", "entityList", "stringList") - .add("longList", "latLng", "latLngList") + .add("longList", "latLng", "latLngList", "vector") .build(); BaseEntity entity = builder.build(); assertEquals(names, entity.getNames()); diff --git a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/ProtoTestData.java b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/ProtoTestData.java index 8e2ba890a..57f71039c 100644 --- a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/ProtoTestData.java +++ b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/ProtoTestData.java @@ -17,6 +17,7 @@ import static com.google.datastore.v1.PropertyOrder.Direction.ASCENDING; +import com.google.cloud.datastore.FindNearest.DistanceMeasure; import com.google.datastore.v1.AggregationQuery.Aggregation; import com.google.datastore.v1.AggregationQuery.Aggregation.Count; import com.google.datastore.v1.Filter; @@ -27,6 +28,9 @@ import com.google.datastore.v1.PropertyOrder; import com.google.datastore.v1.PropertyReference; import com.google.datastore.v1.Value; +import com.google.protobuf.DoubleValue; +import com.google.protobuf.Int32Value; +import javax.annotation.Nullable; public class ProtoTestData { @@ -83,4 +87,33 @@ public static PropertyOrder propertyOrder(String value) { public static Projection projection(String value) { return Projection.newBuilder().setProperty(propertyReference(value)).build(); } + + public static com.google.datastore.v1.FindNearest FindNearest( + String vectorProperty, VectorValue queryVector, DistanceMeasure measure, int limit) { + return FindNearest(vectorProperty, queryVector, measure, limit, null, null); + } + + public static com.google.datastore.v1.FindNearest FindNearest( + String vectorProperty, + VectorValue queryVector, + DistanceMeasure measure, + int limit, + @Nullable String distanceResultField, + @Nullable Double distanceThreshold) { + com.google.datastore.v1.FindNearest.Builder builder = + com.google.datastore.v1.FindNearest.newBuilder() + .setVectorProperty(propertyReference(vectorProperty)) + .setQueryVector(queryVector.toPb()) + .setDistanceMeasure(FindNearest.toProto(measure)) + .setLimit(Int32Value.of(limit)); + + if (distanceResultField != null) { + builder.setDistanceResultProperty(distanceResultField); + } + if (distanceThreshold != null) { + builder.setDistanceThreshold(DoubleValue.of(distanceThreshold)); + } + + return builder.build(); + } } diff --git a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/StructuredQueryProtoPreparerTest.java b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/StructuredQueryProtoPreparerTest.java index 60937fc28..549a8876b 100644 --- a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/StructuredQueryProtoPreparerTest.java +++ b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/StructuredQueryProtoPreparerTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2022 Google LLC + * Copyright 2024 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ */ package com.google.cloud.datastore; +import static com.google.cloud.datastore.ProtoTestData.FindNearest; import static com.google.cloud.datastore.ProtoTestData.booleanValue; import static com.google.cloud.datastore.ProtoTestData.projection; import static com.google.cloud.datastore.ProtoTestData.propertyFilter; @@ -86,6 +87,18 @@ public void testFilter() { assertThat(queryProto.getFilter()).isEqualTo(propertyFilter("done", EQUAL, booleanValue(true))); } + @Test + public void testFindNearest() { + VectorValue VECTOR_VALUE = VectorValue.newBuilder(1.78, 2.56, 3.88).build(); + FindNearest FIND_NEAREST = + new FindNearest("vector_property", VECTOR_VALUE, FindNearest.DistanceMeasure.COSINE, 1); + Query queryProto = + protoPreparer.prepare(newEntityQueryBuilder().setFindNearest(FIND_NEAREST).build()); + assertThat(queryProto.getFindNearest()) + .isEqualTo( + FindNearest("vector_property", VECTOR_VALUE, FindNearest.DistanceMeasure.COSINE, 1)); + } + @Test public void testOrderBy() { Query queryProto = diff --git a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/StructuredQueryTest.java b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/StructuredQueryTest.java index c59337586..c23bd3a9c 100644 --- a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/StructuredQueryTest.java +++ b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/StructuredQueryTest.java @@ -50,6 +50,9 @@ public class StructuredQueryTest { private static final String DISTINCT_ON1 = "p6"; private static final String DISTINCT_ON2 = "p7"; private static final List DISTINCT_ON = ImmutableList.of(DISTINCT_ON1, DISTINCT_ON2); + private static final VectorValue VECTOR_VALUE = VectorValue.newBuilder(1.78, 2.56, 3.88).build(); + private static final FindNearest FIND_NEAREST = + new FindNearest("vector_property", VECTOR_VALUE, FindNearest.DistanceMeasure.COSINE, 1); private static final EntityQuery ENTITY_QUERY = Query.newEntityQueryBuilder() .setNamespace(NAMESPACE) @@ -60,6 +63,7 @@ public class StructuredQueryTest { .setLimit(LIMIT) .setFilter(AND_FILTER) .setOrderBy(ORDER_BY_1, ORDER_BY_2) + .setFindNearest(FIND_NEAREST) .build(); private static final KeyQuery KEY_QUERY = Query.newKeyQueryBuilder() @@ -71,6 +75,7 @@ public class StructuredQueryTest { .setLimit(LIMIT) .setFilter(OR_FILTER) .setOrderBy(ORDER_BY_1, ORDER_BY_2) + .setFindNearest(FIND_NEAREST) .build(); private static final ProjectionEntityQuery PROJECTION_QUERY = Query.newProjectionEntityQueryBuilder() @@ -82,6 +87,7 @@ public class StructuredQueryTest { .setLimit(LIMIT) .setFilter(AND_FILTER) .setOrderBy(ORDER_BY_1, ORDER_BY_2) + .setFindNearest(FIND_NEAREST) .setProjection(PROJECTION1, PROJECTION2) .setDistinctOn(DISTINCT_ON1, DISTINCT_ON2) .build(); @@ -105,6 +111,7 @@ public void testKeyQueryBuilder() { assertEquals(ORDER_BY, KEY_QUERY.getOrderBy()); assertEquals(ImmutableList.of(StructuredQuery.KEY_PROPERTY_NAME), KEY_QUERY.getProjection()); assertTrue(KEY_QUERY.getDistinctOn().isEmpty()); + assertEquals(LIMIT, KEY_QUERY.getLimit()); } @Test @@ -123,6 +130,7 @@ private void compareBaseBuilderFields(StructuredQuery query) { assertEquals(LIMIT, query.getLimit()); assertEquals(AND_FILTER, query.getFilter()); assertEquals(ORDER_BY, query.getOrderBy()); + assertEquals(FIND_NEAREST, query.getFindNearest()); } @Test @@ -149,6 +157,9 @@ private void compareMergedQuery(StructuredQuery expected, StructuredQuery @Test public void testToAndFromPb() { + EntityQuery a = ENTITY_QUERY; + StructuredQuery pb = + StructuredQuery.fromPb(ResultType.ENTITY, ENTITY_QUERY.getNamespace(), ENTITY_QUERY.toPb()); assertEquals( ENTITY_QUERY, StructuredQuery.fromPb( diff --git a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/ValueTest.java b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/ValueTest.java index 8d53dc736..dd36919a1 100644 --- a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/ValueTest.java +++ b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/ValueTest.java @@ -42,6 +42,7 @@ public class ValueTest { private static final RawValue RAW_VALUE = RawValue.of(STRING_VALUE.toPb()); private static final LatLngValue LAT_LNG_VALUE = LatLngValue.of(new LatLng(37.422035, -122.084124)); + private static final VectorValue VECTOR_VALUE = VectorValue.newBuilder(1.78, 2.56, 3.88).build(); private static final ImmutableMap TYPES = ImmutableMap.builder() .put(ValueType.NULL, new Object[] {NullValue.class, NULL_VALUE.get()}) @@ -57,6 +58,7 @@ public class ValueTest { .put(ValueType.LONG, new Object[] {LongValue.class, 123L}) .put(ValueType.RAW_VALUE, new Object[] {RawValue.class, RAW_VALUE.get()}) .put(ValueType.LAT_LNG, new Object[] {LatLngValue.class, LAT_LNG_VALUE.get()}) + .put(ValueType.VECTOR, new Object[] {VectorValue.class, VECTOR_VALUE.get()}) .put(ValueType.STRING, new Object[] {StringValue.class, STRING_VALUE.get()}) .buildOrThrow(); @@ -123,7 +125,11 @@ public void testType() { @Test public void testExcludeFromIndexes() { for (Map.Entry> entry : typeToValue.entrySet()) { - assertFalse(entry.getValue().excludeFromIndexes()); + if (entry.getKey() == ValueType.VECTOR) { + assertTrue(entry.getValue().excludeFromIndexes()); + } else { + assertFalse(entry.getValue().excludeFromIndexes()); + } } TestBuilder builder = new TestBuilder(); assertFalse(builder.build().excludeFromIndexes()); diff --git a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/VectorValueTest.java b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/VectorValueTest.java new file mode 100644 index 000000000..d95e740a0 --- /dev/null +++ b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/VectorValueTest.java @@ -0,0 +1,62 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.datastore; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import com.google.common.collect.ImmutableList; +import java.util.List; +import org.junit.Test; + +public class VectorValueTest { + private static final List> vectorList = + ImmutableList.of(DoubleValue.of(1.2), DoubleValue.of(3.6)); + + @Test + public void testToBuilder() { + // StringValue value = StringValue.of(CONTENT); + VectorValue value = VectorValue.of(0.3, 4.2, 3.7); + assertEquals(value, value.toBuilder().build()); + } + + @Test + public void testOf() { + VectorValue value = VectorValue.of(0.3, 4.2, 3.7); + assertEquals( + ImmutableList.of(DoubleValue.of(0.3), DoubleValue.of(4.2), DoubleValue.of(3.7)), + value.get()); + assertTrue(value.excludeFromIndexes()); + assertEquals(31, value.getMeaning()); + VectorValue value1 = VectorValue.of(vectorList); + assertEquals(vectorList, value1.get()); + assertTrue(value1.excludeFromIndexes()); + assertEquals(31, value1.getMeaning()); + } + + @SuppressWarnings("deprecation") + @Test + public void testBuilder() { + VectorValue.Builder builder = VectorValue.newBuilder(0.3, 4.2, 3.7); + VectorValue value = builder.setExcludeFromIndexes(true).build(); + assertEquals( + ImmutableList.of(DoubleValue.of(0.3), DoubleValue.of(4.2), DoubleValue.of(3.7)), + value.get()); + assertEquals(31, value.getMeaning()); + assertTrue(value.excludeFromIndexes()); + } +} diff --git a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/it/ITDatastoreConceptsTest.java b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/it/ITDatastoreConceptsTest.java index 770065778..d96b083ef 100644 --- a/google-cloud-datastore/src/test/java/com/google/cloud/datastore/it/ITDatastoreConceptsTest.java +++ b/google-cloud-datastore/src/test/java/com/google/cloud/datastore/it/ITDatastoreConceptsTest.java @@ -23,29 +23,10 @@ import static org.junit.Assert.fail; import com.google.cloud.Timestamp; -import com.google.cloud.datastore.Cursor; -import com.google.cloud.datastore.Datastore; -import com.google.cloud.datastore.DatastoreException; -import com.google.cloud.datastore.DatastoreOptions; -import com.google.cloud.datastore.Entity; -import com.google.cloud.datastore.EntityQuery; -import com.google.cloud.datastore.FullEntity; -import com.google.cloud.datastore.IncompleteKey; -import com.google.cloud.datastore.Key; -import com.google.cloud.datastore.KeyFactory; -import com.google.cloud.datastore.KeyQuery; -import com.google.cloud.datastore.ListValue; -import com.google.cloud.datastore.PathElement; -import com.google.cloud.datastore.ProjectionEntity; -import com.google.cloud.datastore.Query; -import com.google.cloud.datastore.QueryResults; -import com.google.cloud.datastore.ReadOption; -import com.google.cloud.datastore.StringValue; -import com.google.cloud.datastore.StructuredQuery; +import com.google.cloud.datastore.*; import com.google.cloud.datastore.StructuredQuery.CompositeFilter; import com.google.cloud.datastore.StructuredQuery.OrderBy; import com.google.cloud.datastore.StructuredQuery.PropertyFilter; -import com.google.cloud.datastore.Transaction; import com.google.cloud.datastore.testing.RemoteDatastoreHelper; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -175,6 +156,9 @@ private void setUpQueryTests() { "description", StringValue.newBuilder("Learn Cloud Datastore").setExcludeFromIndexes(true).build()) .set("tag", "fun", "l", "programming", "learn") + .set( + "vector_property", + VectorValue.newBuilder(3.0, 1.0, 2.0).setExcludeFromIndexes(true).build()) .build()); } @@ -591,6 +575,35 @@ public void testEqualAndInequalityRange() { assertValidQuery(query); } + @Test + public void testVectorSearch() { + VectorValue vectorValue = VectorValue.newBuilder(1.78, 2.56, 3.88).build(); + FindNearest vectorQuery = + new FindNearest( + "vector_property", vectorValue, FindNearest.DistanceMeasure.COSINE, 1, "distance"); + + Query query = Query.newEntityQueryBuilder().setFindNearest(vectorQuery).build(); + assertValidQuery(query); + } + + @Test + public void testVectorSearchWithEmptyVector() { + VectorValue emptyVector = VectorValue.newBuilder().build(); + FindNearest vectorQuery = + new FindNearest("vector_property", emptyVector, FindNearest.DistanceMeasure.EUCLIDEAN, 1); + Query query = Query.newEntityQueryBuilder().setFindNearest(vectorQuery).build(); + assertInvalidQuery(query); + } + + @Test + public void testVectorSearchWithUnmatchedVectorSize() { + VectorValue vectorValue = VectorValue.newBuilder(1.78, 2.56, 3.88, 4.33).build(); + FindNearest vectorQuery = + new FindNearest("vector_property", vectorValue, FindNearest.DistanceMeasure.DOT_PRODUCT, 1); + Query query = Query.newEntityQueryBuilder().setFindNearest(vectorQuery).build(); + assertInvalidQuery(query); + } + @Test public void testInequalitySort() { Query query =