This is an automated email from the ASF dual-hosted git repository.
absurdfarce pushed a commit to branch 4.x
in repository https://gitbox.apache.org/repos/asf/cassandra-java-driver.git
The following commit(s) were added to refs/heads/4.x by this push:
new 7689c5a81 JAVA-3118: Add support for vector data type in Schema
Builder, QueryBuilder patch by Jane He; reviewed by Mick Semb Wever and Bret
McGuire for JAVA-3118 reference: #1931
7689c5a81 is described below
commit 7689c5a81e89bce5598d9976a041068a1f5e2a7f
Author: SiyaoIsHiding <[email protected]>
AuthorDate: Tue Jan 7 14:30:58 2025 +0800
JAVA-3118: Add support for vector data type in Schema Builder, QueryBuilder
patch by Jane He; reviewed by Mick Semb Wever and Bret McGuire for JAVA-3118
reference: #1931
---
manual/query_builder/select/README.md | 23 ++++++
query-builder/revapi.json | 10 +++
.../oss/driver/api/querybuilder/select/Select.java | 11 +++
.../querybuilder/select/DefaultSelect.java | 86 ++++++++++++++++++++--
.../querybuilder/delete/DeleteSelectorTest.java | 11 +++
.../api/querybuilder/insert/RegularInsertTest.java | 7 ++
.../api/querybuilder/schema/AlterTableTest.java | 6 ++
.../api/querybuilder/schema/AlterTypeTest.java | 6 ++
.../api/querybuilder/schema/CreateTableTest.java | 9 +++
.../api/querybuilder/schema/CreateTypeTest.java | 9 +++
.../querybuilder/select/SelectOrderingTest.java | 20 +++++
.../querybuilder/select/SelectSelectorTest.java | 43 +++++++++++
12 files changed, 233 insertions(+), 8 deletions(-)
diff --git a/manual/query_builder/select/README.md
b/manual/query_builder/select/README.md
index 92c058608..0425423a4 100644
--- a/manual/query_builder/select/README.md
+++ b/manual/query_builder/select/README.md
@@ -387,6 +387,29 @@ selectFrom("sensor_data")
// SELECT reading FROM sensor_data WHERE id=? ORDER BY date DESC
```
+Vector Search:
+
+```java
+
+import com.datastax.oss.driver.api.core.data.CqlVector;
+
+selectFrom("foo")
+ .all()
+ .where(Relation.column("k").isEqualTo(literal(1)))
+ .orderByAnnOf("c1", CqlVector.newInstance(0.1, 0.2, 0.3));
+// SELECT * FROM foo WHERE k=1 ORDER BY c1 ANN OF [0.1, 0.2, 0.3]
+
+selectFrom("cycling", "comments_vs")
+ .column("comment")
+ .function(
+ "similarity_cosine",
+ Selector.column("comment_vector"),
+ literal(CqlVector.newInstance(0.2, 0.15, 0.3, 0.2, 0.05)))
+ .orderByAnnOf("comment_vector", CqlVector.newInstance(0.1, 0.15, 0.3,
0.12, 0.05))
+ .limit(1);
+// SELECT comment,similarity_cosine(comment_vector,[0.2, 0.15, 0.3, 0.2,
0.05]) FROM cycling.comments_vs ORDER BY comment_vector ANN OF [0.1, 0.15, 0.3,
0.12, 0.05] LIMIT 1
+```
+
Limits:
```java
diff --git a/query-builder/revapi.json b/query-builder/revapi.json
index 9d0163b48..c4d8aa272 100644
--- a/query-builder/revapi.json
+++ b/query-builder/revapi.json
@@ -2772,6 +2772,16 @@
"code": "java.method.addedToInterface",
"new": "method
com.datastax.oss.driver.api.querybuilder.update.UpdateStart
com.datastax.oss.driver.api.querybuilder.update.UpdateStart::usingTtl(int)",
"justification": "JAVA-2210: Add ability to set TTL for modification
queries"
+ },
+ {
+ "code": "java.method.addedToInterface",
+ "new": "method com.datastax.oss.driver.api.querybuilder.select.Select
com.datastax.oss.driver.api.querybuilder.select.Select::orderByAnnOf(java.lang.String,
com.datastax.oss.driver.api.core.data.CqlVector<?>)",
+ "justification": "JAVA-3118: Add support for vector data type in
Schema Builder, QueryBuilder"
+ },
+ {
+ "code": "java.method.addedToInterface",
+ "new": "method com.datastax.oss.driver.api.querybuilder.select.Select
com.datastax.oss.driver.api.querybuilder.select.Select::orderByAnnOf(com.datastax.oss.driver.api.core.CqlIdentifier,
com.datastax.oss.driver.api.core.data.CqlVector<?>)",
+ "justification": "JAVA-3118: Add support for vector data type in
Schema Builder, QueryBuilder"
}
]
}
diff --git
a/query-builder/src/main/java/com/datastax/oss/driver/api/querybuilder/select/Select.java
b/query-builder/src/main/java/com/datastax/oss/driver/api/querybuilder/select/Select.java
index a22b45c35..159657989 100644
---
a/query-builder/src/main/java/com/datastax/oss/driver/api/querybuilder/select/Select.java
+++
b/query-builder/src/main/java/com/datastax/oss/driver/api/querybuilder/select/Select.java
@@ -18,6 +18,7 @@
package com.datastax.oss.driver.api.querybuilder.select;
import com.datastax.oss.driver.api.core.CqlIdentifier;
+import com.datastax.oss.driver.api.core.data.CqlVector;
import com.datastax.oss.driver.api.core.metadata.schema.ClusteringOrder;
import com.datastax.oss.driver.api.querybuilder.BindMarker;
import com.datastax.oss.driver.api.querybuilder.BuildableQuery;
@@ -146,6 +147,16 @@ public interface Select extends OngoingSelection,
OngoingWhereClause<Select>, Bu
return orderBy(CqlIdentifier.fromCql(columnName), order);
}
+ /**
+ * Shortcut for {@link #orderByAnnOf(CqlIdentifier, CqlVector)}, adding an
ORDER BY ... ANN OF ...
+ * clause
+ */
+ @NonNull
+ Select orderByAnnOf(@NonNull String columnName, @NonNull CqlVector<?> ann);
+
+ /** Adds the ORDER BY ... ANN OF ... clause, usually used for vector search
*/
+ @NonNull
+ Select orderByAnnOf(@NonNull CqlIdentifier columnId, @NonNull CqlVector<?>
ann);
/**
* Adds a LIMIT clause to this query with a literal value.
*
diff --git
a/query-builder/src/main/java/com/datastax/oss/driver/internal/querybuilder/select/DefaultSelect.java
b/query-builder/src/main/java/com/datastax/oss/driver/internal/querybuilder/select/DefaultSelect.java
index 86a2a07a3..5daf252a9 100644
---
a/query-builder/src/main/java/com/datastax/oss/driver/internal/querybuilder/select/DefaultSelect.java
+++
b/query-builder/src/main/java/com/datastax/oss/driver/internal/querybuilder/select/DefaultSelect.java
@@ -20,8 +20,10 @@ package com.datastax.oss.driver.internal.querybuilder.select;
import com.datastax.oss.driver.api.core.CqlIdentifier;
import com.datastax.oss.driver.api.core.cql.SimpleStatement;
import com.datastax.oss.driver.api.core.cql.SimpleStatementBuilder;
+import com.datastax.oss.driver.api.core.data.CqlVector;
import com.datastax.oss.driver.api.core.metadata.schema.ClusteringOrder;
import com.datastax.oss.driver.api.querybuilder.BindMarker;
+import com.datastax.oss.driver.api.querybuilder.QueryBuilder;
import com.datastax.oss.driver.api.querybuilder.relation.Relation;
import com.datastax.oss.driver.api.querybuilder.select.Select;
import com.datastax.oss.driver.api.querybuilder.select.SelectFrom;
@@ -49,6 +51,7 @@ public class DefaultSelect implements SelectFrom, Select {
private final ImmutableList<Relation> relations;
private final ImmutableList<Selector> groupByClauses;
private final ImmutableMap<CqlIdentifier, ClusteringOrder> orderings;
+ private final Ann ann;
private final Object limit;
private final Object perPartitionLimit;
private final boolean allowsFiltering;
@@ -65,6 +68,7 @@ public class DefaultSelect implements SelectFrom, Select {
ImmutableMap.of(),
null,
null,
+ null,
false);
}
@@ -74,6 +78,8 @@ public class DefaultSelect implements SelectFrom, Select {
* @param selectors if it contains {@link AllSelector#INSTANCE}, that must
be the only element.
* This isn't re-checked because methods that call this constructor
internally already do it,
* make sure you do it yourself.
+ * @param ann Approximate nearest neighbor. ANN ordering does not support
secondary ordering or
+ * ASC order.
*/
public DefaultSelect(
@Nullable CqlIdentifier keyspace,
@@ -84,6 +90,7 @@ public class DefaultSelect implements SelectFrom, Select {
@NonNull ImmutableList<Relation> relations,
@NonNull ImmutableList<Selector> groupByClauses,
@NonNull ImmutableMap<CqlIdentifier, ClusteringOrder> orderings,
+ @Nullable Ann ann,
@Nullable Object limit,
@Nullable Object perPartitionLimit,
boolean allowsFiltering) {
@@ -94,6 +101,9 @@ public class DefaultSelect implements SelectFrom, Select {
|| (limit instanceof Integer && (Integer) limit > 0)
|| limit instanceof BindMarker,
"limit must be a strictly positive integer or a bind marker");
+ Preconditions.checkArgument(
+ orderings.isEmpty() || ann == null, "ANN ordering does not support
secondary ordering");
+ this.ann = ann;
this.keyspace = keyspace;
this.table = table;
this.isJson = isJson;
@@ -117,6 +127,7 @@ public class DefaultSelect implements SelectFrom, Select {
relations,
groupByClauses,
orderings,
+ ann,
limit,
perPartitionLimit,
allowsFiltering);
@@ -134,6 +145,7 @@ public class DefaultSelect implements SelectFrom, Select {
relations,
groupByClauses,
orderings,
+ ann,
limit,
perPartitionLimit,
allowsFiltering);
@@ -193,6 +205,7 @@ public class DefaultSelect implements SelectFrom, Select {
relations,
groupByClauses,
orderings,
+ ann,
limit,
perPartitionLimit,
allowsFiltering);
@@ -221,6 +234,7 @@ public class DefaultSelect implements SelectFrom, Select {
newRelations,
groupByClauses,
orderings,
+ ann,
limit,
perPartitionLimit,
allowsFiltering);
@@ -249,6 +263,7 @@ public class DefaultSelect implements SelectFrom, Select {
relations,
newGroupByClauses,
orderings,
+ ann,
limit,
perPartitionLimit,
allowsFiltering);
@@ -260,6 +275,18 @@ public class DefaultSelect implements SelectFrom, Select {
return withOrderings(ImmutableCollections.append(orderings, columnId,
order));
}
+ @NonNull
+ @Override
+ public Select orderByAnnOf(@NonNull String columnName, @NonNull CqlVector<?>
ann) {
+ return withAnn(new Ann(CqlIdentifier.fromCql(columnName), ann));
+ }
+
+ @NonNull
+ @Override
+ public Select orderByAnnOf(@NonNull CqlIdentifier columnId, @NonNull
CqlVector<?> ann) {
+ return withAnn(new Ann(columnId, ann));
+ }
+
@NonNull
@Override
public Select orderByIds(@NonNull Map<CqlIdentifier, ClusteringOrder>
newOrderings) {
@@ -277,6 +304,24 @@ public class DefaultSelect implements SelectFrom, Select {
relations,
groupByClauses,
newOrderings,
+ ann,
+ limit,
+ perPartitionLimit,
+ allowsFiltering);
+ }
+
+ @NonNull
+ Select withAnn(@NonNull Ann ann) {
+ return new DefaultSelect(
+ keyspace,
+ table,
+ isJson,
+ isDistinct,
+ selectors,
+ relations,
+ groupByClauses,
+ orderings,
+ ann,
limit,
perPartitionLimit,
allowsFiltering);
@@ -295,6 +340,7 @@ public class DefaultSelect implements SelectFrom, Select {
relations,
groupByClauses,
orderings,
+ ann,
limit,
perPartitionLimit,
allowsFiltering);
@@ -312,6 +358,7 @@ public class DefaultSelect implements SelectFrom, Select {
relations,
groupByClauses,
orderings,
+ ann,
bindMarker,
perPartitionLimit,
allowsFiltering);
@@ -331,6 +378,7 @@ public class DefaultSelect implements SelectFrom, Select {
relations,
groupByClauses,
orderings,
+ ann,
limit,
perPartitionLimit,
allowsFiltering);
@@ -348,6 +396,7 @@ public class DefaultSelect implements SelectFrom, Select {
relations,
groupByClauses,
orderings,
+ ann,
limit,
bindMarker,
allowsFiltering);
@@ -365,6 +414,7 @@ public class DefaultSelect implements SelectFrom, Select {
relations,
groupByClauses,
orderings,
+ ann,
limit,
perPartitionLimit,
true);
@@ -391,15 +441,20 @@ public class DefaultSelect implements SelectFrom, Select {
CqlHelper.append(relations, builder, " WHERE ", " AND ", null);
CqlHelper.append(groupByClauses, builder, " GROUP BY ", ",", null);
- boolean first = true;
- for (Map.Entry<CqlIdentifier, ClusteringOrder> entry :
orderings.entrySet()) {
- if (first) {
- builder.append(" ORDER BY ");
- first = false;
- } else {
- builder.append(",");
+ if (ann != null) {
+ builder.append(" ORDER BY
").append(this.ann.columnId.asCql(true)).append(" ANN OF ");
+ QueryBuilder.literal(ann.vector).appendTo(builder);
+ } else {
+ boolean first = true;
+ for (Map.Entry<CqlIdentifier, ClusteringOrder> entry :
orderings.entrySet()) {
+ if (first) {
+ builder.append(" ORDER BY ");
+ first = false;
+ } else {
+ builder.append(",");
+ }
+ builder.append(entry.getKey().asCql(true)).append("
").append(entry.getValue().name());
}
- builder.append(entry.getKey().asCql(true)).append("
").append(entry.getValue().name());
}
if (limit != null) {
@@ -499,6 +554,11 @@ public class DefaultSelect implements SelectFrom, Select {
return limit;
}
+ @Nullable
+ public Ann getAnn() {
+ return ann;
+ }
+
@Nullable
public Object getPerPartitionLimit() {
return perPartitionLimit;
@@ -512,4 +572,14 @@ public class DefaultSelect implements SelectFrom, Select {
public String toString() {
return asCql();
}
+
+ public static class Ann {
+ private final CqlVector<?> vector;
+ private final CqlIdentifier columnId;
+
+ private Ann(CqlIdentifier columnId, CqlVector<?> vector) {
+ this.vector = vector;
+ this.columnId = columnId;
+ }
+ }
}
diff --git
a/query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/delete/DeleteSelectorTest.java
b/query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/delete/DeleteSelectorTest.java
index 23210971b..cce4cf51a 100644
---
a/query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/delete/DeleteSelectorTest.java
+++
b/query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/delete/DeleteSelectorTest.java
@@ -22,6 +22,7 @@ import static
com.datastax.oss.driver.api.querybuilder.QueryBuilder.bindMarker;
import static com.datastax.oss.driver.api.querybuilder.QueryBuilder.deleteFrom;
import static com.datastax.oss.driver.api.querybuilder.QueryBuilder.literal;
+import com.datastax.oss.driver.api.core.data.CqlVector;
import org.junit.Test;
public class DeleteSelectorTest {
@@ -34,6 +35,16 @@ public class DeleteSelectorTest {
.hasCql("DELETE v FROM ks.foo WHERE k=?");
}
+ @Test
+ public void should_generate_vector_deletion() {
+ assertThat(
+ deleteFrom("foo")
+ .column("v")
+ .whereColumn("k")
+ .isEqualTo(literal(CqlVector.newInstance(0.1, 0.2))))
+ .hasCql("DELETE v FROM foo WHERE k=[0.1, 0.2]");
+ }
+
@Test
public void should_generate_field_deletion() {
assertThat(
diff --git
a/query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/insert/RegularInsertTest.java
b/query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/insert/RegularInsertTest.java
index 36133445b..89c833ff1 100644
---
a/query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/insert/RegularInsertTest.java
+++
b/query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/insert/RegularInsertTest.java
@@ -23,6 +23,7 @@ import static
com.datastax.oss.driver.api.querybuilder.QueryBuilder.insertInto;
import static com.datastax.oss.driver.api.querybuilder.QueryBuilder.literal;
import static org.assertj.core.api.Assertions.catchThrowable;
+import com.datastax.oss.driver.api.core.data.CqlVector;
import com.datastax.oss.driver.api.querybuilder.term.Term;
import com.datastax.oss.driver.internal.querybuilder.insert.DefaultInsert;
import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableMap;
@@ -41,6 +42,12 @@ public class RegularInsertTest {
.hasCql("INSERT INTO foo (a,b) VALUES (?,?)");
}
+ @Test
+ public void should_generate_vector_literals() {
+ assertThat(insertInto("foo").value("a", literal(CqlVector.newInstance(0.1,
0.2, 0.3))))
+ .hasCql("INSERT INTO foo (a) VALUES ([0.1, 0.2, 0.3])");
+ }
+
@Test
public void should_keep_last_assignment_if_column_listed_twice() {
assertThat(
diff --git
a/query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/schema/AlterTableTest.java
b/query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/schema/AlterTableTest.java
index 1567b0848..2c99b154b 100644
---
a/query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/schema/AlterTableTest.java
+++
b/query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/schema/AlterTableTest.java
@@ -108,4 +108,10 @@ public class AlterTableTest {
assertThat(alterTable("bar").withNoCompression())
.hasCql("ALTER TABLE bar WITH compression={'sstable_compression':''}");
}
+
+ @Test
+ public void should_generate_alter_table_with_vector() {
+ assertThat(alterTable("bar").alterColumn("v",
DataTypes.vectorOf(DataTypes.FLOAT, 3)))
+ .hasCql("ALTER TABLE bar ALTER v TYPE vector<float, 3>");
+ }
}
diff --git
a/query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/schema/AlterTypeTest.java
b/query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/schema/AlterTypeTest.java
index 2becb9338..14bec0a6c 100644
---
a/query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/schema/AlterTypeTest.java
+++
b/query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/schema/AlterTypeTest.java
@@ -53,4 +53,10 @@ public class AlterTypeTest {
assertThat(alterType("bar").renameField("x", "y").renameField("u",
"v").renameField("b", "a"))
.hasCql("ALTER TYPE bar RENAME x TO y AND u TO v AND b TO a");
}
+
+ @Test
+ public void should_generate_alter_type_with_vector() {
+ assertThat(alterType("foo", "bar").alterField("vec",
DataTypes.vectorOf(DataTypes.FLOAT, 3)))
+ .hasCql("ALTER TYPE foo.bar ALTER vec TYPE vector<float, 3>");
+ }
}
diff --git
a/query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/schema/CreateTableTest.java
b/query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/schema/CreateTableTest.java
index 7a5542c51..15cd12c75 100644
---
a/query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/schema/CreateTableTest.java
+++
b/query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/schema/CreateTableTest.java
@@ -314,4 +314,13 @@ public class CreateTableTest {
.hasCql(
"CREATE TABLE bar (k int PRIMARY KEY,v text) WITH
compaction={'class':'TimeWindowCompactionStrategy','compaction_window_size':10,'compaction_window_unit':'DAYS','timestamp_resolution':'MICROSECONDS','unsafe_aggressive_sstable_expiration':false}");
}
+
+ @Test
+ public void should_generate_vector_column() {
+ assertThat(
+ createTable("foo")
+ .withPartitionKey("k", DataTypes.INT)
+ .withColumn("v", DataTypes.vectorOf(DataTypes.FLOAT, 3)))
+ .hasCql("CREATE TABLE foo (k int PRIMARY KEY,v vector<float, 3>)");
+ }
}
diff --git
a/query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/schema/CreateTypeTest.java
b/query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/schema/CreateTypeTest.java
index d881a0500..f7c15788a 100644
---
a/query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/schema/CreateTypeTest.java
+++
b/query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/schema/CreateTypeTest.java
@@ -83,4 +83,13 @@ public class CreateTypeTest {
.withField("map", DataTypes.mapOf(DataTypes.INT,
DataTypes.TEXT)))
.hasCql("CREATE TYPE ks1.type (map map<int, text>)");
}
+
+ @Test
+ public void should_create_type_with_vector() {
+ assertThat(
+ createType("ks1", "type")
+ .withField("c1", DataTypes.INT)
+ .withField("vec", DataTypes.vectorOf(DataTypes.FLOAT, 3)))
+ .hasCql("CREATE TYPE ks1.type (c1 int,vec vector<float, 3>)");
+ }
}
diff --git
a/query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/select/SelectOrderingTest.java
b/query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/select/SelectOrderingTest.java
index ff27fde4f..a9c618e95 100644
---
a/query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/select/SelectOrderingTest.java
+++
b/query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/select/SelectOrderingTest.java
@@ -23,6 +23,7 @@ import static
com.datastax.oss.driver.api.querybuilder.Assertions.assertThat;
import static com.datastax.oss.driver.api.querybuilder.QueryBuilder.literal;
import static com.datastax.oss.driver.api.querybuilder.QueryBuilder.selectFrom;
+import com.datastax.oss.driver.api.core.data.CqlVector;
import com.datastax.oss.driver.api.querybuilder.relation.Relation;
import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableMap;
import org.junit.Test;
@@ -74,4 +75,23 @@ public class SelectOrderingTest {
.orderBy(ImmutableMap.of("c1", DESC, "c2", ASC)))
.hasCql("SELECT * FROM foo WHERE k=1 ORDER BY c3 ASC,c1 DESC,c2 ASC");
}
+
+ @Test
+ public void should_generate_ann_clause() {
+ assertThat(
+ selectFrom("foo")
+ .all()
+ .where(Relation.column("k").isEqualTo(literal(1)))
+ .orderByAnnOf("c1", CqlVector.newInstance(0.1, 0.2, 0.3)))
+ .hasCql("SELECT * FROM foo WHERE k=1 ORDER BY c1 ANN OF [0.1, 0.2,
0.3]");
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void should_fail_when_provided_ann_with_other_orderings() {
+ selectFrom("foo")
+ .all()
+ .where(Relation.column("k").isEqualTo(literal(1)))
+ .orderBy("c1", ASC)
+ .orderByAnnOf("c2", CqlVector.newInstance(0.1, 0.2, 0.3));
+ }
}
diff --git
a/query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/select/SelectSelectorTest.java
b/query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/select/SelectSelectorTest.java
index dc7cc98c6..7e03627d4 100644
---
a/query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/select/SelectSelectorTest.java
+++
b/query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/select/SelectSelectorTest.java
@@ -22,6 +22,7 @@ import static
com.datastax.oss.driver.api.querybuilder.QueryBuilder.literal;
import static com.datastax.oss.driver.api.querybuilder.QueryBuilder.raw;
import static com.datastax.oss.driver.api.querybuilder.QueryBuilder.selectFrom;
+import com.datastax.oss.driver.api.core.data.CqlVector;
import com.datastax.oss.driver.api.core.type.DataTypes;
import com.datastax.oss.driver.api.core.type.codec.CodecNotFoundException;
import com.datastax.oss.driver.api.querybuilder.CharsetCodec;
@@ -230,6 +231,48 @@ public class SelectSelectorTest {
.hasCql("SELECT bar,baz FROM foo");
}
+ @Test
+ public void should_generate_similarity_functions() {
+ Select similarity_cosine_clause =
+ selectFrom("cycling", "comments_vs")
+ .column("comment")
+ .function(
+ "similarity_cosine",
+ Selector.column("comment_vector"),
+ literal(CqlVector.newInstance(0.2, 0.15, 0.3, 0.2, 0.05)))
+ .orderByAnnOf("comment_vector", CqlVector.newInstance(0.1, 0.15,
0.3, 0.12, 0.05))
+ .limit(1);
+ assertThat(similarity_cosine_clause)
+ .hasCql(
+ "SELECT comment,similarity_cosine(comment_vector,[0.2, 0.15, 0.3,
0.2, 0.05]) FROM cycling.comments_vs ORDER BY comment_vector ANN OF [0.1, 0.15,
0.3, 0.12, 0.05] LIMIT 1");
+
+ Select similarity_euclidean_clause =
+ selectFrom("cycling", "comments_vs")
+ .column("comment")
+ .function(
+ "similarity_euclidean",
+ Selector.column("comment_vector"),
+ literal(CqlVector.newInstance(0.2, 0.15, 0.3, 0.2, 0.05)))
+ .orderByAnnOf("comment_vector", CqlVector.newInstance(0.1, 0.15,
0.3, 0.12, 0.05))
+ .limit(1);
+ assertThat(similarity_euclidean_clause)
+ .hasCql(
+ "SELECT comment,similarity_euclidean(comment_vector,[0.2, 0.15,
0.3, 0.2, 0.05]) FROM cycling.comments_vs ORDER BY comment_vector ANN OF [0.1,
0.15, 0.3, 0.12, 0.05] LIMIT 1");
+
+ Select similarity_dot_product_clause =
+ selectFrom("cycling", "comments_vs")
+ .column("comment")
+ .function(
+ "similarity_dot_product",
+ Selector.column("comment_vector"),
+ literal(CqlVector.newInstance(0.2, 0.15, 0.3, 0.2, 0.05)))
+ .orderByAnnOf("comment_vector", CqlVector.newInstance(0.1, 0.15,
0.3, 0.12, 0.05))
+ .limit(1);
+ assertThat(similarity_dot_product_clause)
+ .hasCql(
+ "SELECT comment,similarity_dot_product(comment_vector,[0.2, 0.15,
0.3, 0.2, 0.05]) FROM cycling.comments_vs ORDER BY comment_vector ANN OF [0.1,
0.15, 0.3, 0.12, 0.05] LIMIT 1");
+ }
+
@Test
public void should_alias_selectors() {
assertThat(selectFrom("foo").column("bar").as("baz")).hasCql("SELECT bar
AS baz FROM foo");
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]