This is an automated email from the ASF dual-hosted git repository.
adelapena pushed a commit to branch cassandra-5.0
in repository https://gitbox.apache.org/repos/asf/cassandra.git
The following commit(s) were added to refs/heads/cassandra-5.0 by this push:
new c76b32492f Add support of vector type to cqlsh COPY command
c76b32492f is described below
commit c76b32492f08c4af56846518488ae0b191e077e8
Author: Szymon Miężał <[email protected]>
AuthorDate: Thu Nov 30 17:56:48 2023 +0100
Add support of vector type to cqlsh COPY command
This patch adds a converter that allows parsing vector literals
passed via csv files to the COPY command.
patch by Szymon Miezal; reviewed by Andrés de la Peña, Stefan Miklosovic
and Maxwell Guo for CASSANDRA-19118
---
CHANGES.txt | 1 +
pylib/cqlshlib/copyutil.py | 9 +-
.../apache/cassandra/tools/cqlsh/CqlshTest.java | 126 ++++++++++++++++++++-
3 files changed, 132 insertions(+), 4 deletions(-)
diff --git a/CHANGES.txt b/CHANGES.txt
index a7859ee9ec..1d71cb52c3 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -1,4 +1,5 @@
5.0-beta2
+ * Add support of vector type to cqlsh COPY command (CASSANDRA-19118)
* Make CQLSSTableWriter to support building of SAI indexes (CASSANDRA-18714)
* Append additional JVM options when using JDK17+ (CASSANDRA-19001)
* Upgrade Python driver to 3.29.0 (CASSANDRA-19245)
diff --git a/pylib/cqlshlib/copyutil.py b/pylib/cqlshlib/copyutil.py
index 2a8a11d1bf..af35731005 100644
--- a/pylib/cqlshlib/copyutil.py
+++ b/pylib/cqlshlib/copyutil.py
@@ -46,7 +46,7 @@ from queue import Queue
from cassandra import OperationTimedOut
from cassandra.cluster import Cluster, DefaultConnection
-from cassandra.cqltypes import ReversedType, UserType, VarcharType
+from cassandra.cqltypes import ReversedType, UserType, VarcharType, VectorType
from cassandra.metadata import protect_name, protect_names, protect_value
from cassandra.policies import RetryPolicy, WhiteListRoundRobinPolicy,
DCAwareRoundRobinPolicy, FallthroughRetryPolicy
from cassandra.query import BatchStatement, BatchType, SimpleStatement,
tuple_factory
@@ -2074,6 +2074,12 @@ class ImportConversion(object):
return ImmutableDict(frozenset((convert_mandatory(ct.subtypes[0],
v[0]), convert(ct.subtypes[1], v[1]))
for v in [split(split_format_str % vv,
sep=sep) for vv in split(val)]))
+ def convert_vector(val, ct=cql_type):
+ string_coordinates = split(val)
+ if len(string_coordinates) != ct.vector_size:
+ raise ParseError("The length of given vector value '%d' is not
equal to the vector size from the type definition '%d'" %
(len(string_coordinates), ct.vector_size))
+ return [convert_mandatory(ct.subtype, v) for v in
string_coordinates]
+
def convert_user_type(val, ct=cql_type):
"""
A user type is a dictionary except that we must convert each key
into
@@ -2130,6 +2136,7 @@ class ImportConversion(object):
'map': convert_map,
'tuple': convert_tuple,
'frozen': convert_single_subtype,
+ VectorType.typename: convert_vector,
}
return converters.get(cql_type.typename, convert_unknown)
diff --git a/test/unit/org/apache/cassandra/tools/cqlsh/CqlshTest.java
b/test/unit/org/apache/cassandra/tools/cqlsh/CqlshTest.java
index 4e6dd2088b..356769b840 100644
--- a/test/unit/org/apache/cassandra/tools/cqlsh/CqlshTest.java
+++ b/test/unit/org/apache/cassandra/tools/cqlsh/CqlshTest.java
@@ -18,16 +18,24 @@
package org.apache.cassandra.tools.cqlsh;
+import java.io.IOException;
+import java.io.Writer;
+import java.nio.charset.StandardCharsets;
+import java.nio.file.Files;
+import java.nio.file.Path;
+
import org.junit.BeforeClass;
import org.junit.Test;
import org.apache.cassandra.cql3.CQLTester;
+import org.apache.cassandra.cql3.UntypedResultSet;
import org.apache.cassandra.tools.ToolRunner;
import org.apache.cassandra.tools.ToolRunner.ToolResult;
-import org.hamcrest.CoreMatchers;
+import static java.lang.String.format;
+import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertTrue;
public class CqlshTest extends CQLTester
{
@@ -41,7 +49,119 @@ public class CqlshTest extends CQLTester
public void testKeyspaceRequired()
{
ToolResult tool = ToolRunner.invokeCqlsh("SELECT * FROM test");
- assertThat(tool.getCleanedStderr(),
CoreMatchers.containsStringIgnoringCase("No keyspace has been specified"));
+ tool.asserts().errorContains("No keyspace has been specified");
assertEquals(2, tool.getExitCode());
}
+
+ @Test
+ public void testCopyFloatVector() throws IOException
+ {
+ assertCopyOfVectorTypeSucceeds("float", 6, new Object[][] {
+ row(1, vector(0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f)),
+ row(2, vector(-0.1f, -0.2f, -0.3f, -0.4f, -0.5f, -0.6f)),
+ row(3, vector(0.9f, 0.8f, 0.7f, 0.6f, 0.5f, 0.4f))
+ });
+
+ assertCopyOfVectorTypeSucceeds("float", 3, new Object[][] {
+ row(1, vector(0.1f, 0.2f, 0.3f)),
+ row(2, vector(-0.4f, -0.5f, -0.6f)),
+ row(3, vector(0.7f, 0.8f, 0.9f))
+ });
+ }
+
+ @Test
+ public void testCopyIntVector() throws IOException
+ {
+ assertCopyOfVectorTypeSucceeds("int", 6, new Object[][] {
+ row(1, vector(1, 2, 3, 4, 5, 6)),
+ row(2, vector(-1, -2, -3, -4, -5, -6)),
+ row(3, vector(9, 8, 7, 6, 5, 4))
+ });
+
+ assertCopyOfVectorTypeSucceeds("int", 3, new Object[][] {
+ row(1, vector(1, 2, 3)),
+ row(2, vector(-4, -5, -6)),
+ row(3, vector(7, 8, 9))
+ });
+ }
+
+ private void assertCopyOfVectorTypeSucceeds(String vectorType, int
vectorSize, Object[][] rows) throws IOException
+ {
+ // given a table with a vector column
+ createTable(KEYSPACE, format("CREATE TABLE %%s (id int PRIMARY KEY,
embedding_vector vector<%s, %d>)", vectorType, vectorSize));
+ assertTrue("table should be initially empty", execute("SELECT * FROM
%s").isEmpty());
+
+ // write the rows into the table
+ for (Object[] row : rows)
+ execute("INSERT INTO %s (id, embedding_vector) VALUES (?, ?)",
row);
+
+ // when running COPY TO CSV via cqlsh
+ Path csv = createTempFile("test_copy_to_vector");
+ ToolRunner.ToolResult copyToResult =
ToolRunner.invokeCqlsh(format("COPY %s.%s TO '%s'", KEYSPACE, currentTable(),
csv.toAbsolutePath()));
+
+ // then all rows should be exported
+ copyToResult.asserts().success();
+ // verify that the exported CSV contains the expected rows
+ assertThat(csv).hasSameTextualContentAs(prepareCSVFile(rows));
+
+ // truncate the table
+ execute("TRUNCATE %s");
+ assertTrue("table should be empty", execute("SELECT * FROM
%s").isEmpty());
+
+ // when running COPY FROM via cqlsh
+ ToolRunner.ToolResult copyFromResult =
ToolRunner.invokeCqlsh(format("COPY %s.%s FROM '%s'", KEYSPACE, currentTable(),
csv.toAbsolutePath()));
+
+ // then all rows should be imported
+ copyFromResult.asserts().success();
+ UntypedResultSet importedRows = execute("SELECT * FROM %s");
+ assertRowsIgnoringOrder(importedRows, rows);
+ }
+
+ @Test
+ public void testCopyOnlyThoseRowsThatMatchVectorTypeSize() throws
IOException
+ {
+ // given a table with a vector column and a file containing vector
literals
+ createTable(KEYSPACE, "CREATE TABLE %s (id int PRIMARY KEY,
embedding_vector vector<int, 6>)");
+ assertTrue("table should be initially empty", execute("SELECT * FROM
%s").isEmpty());
+
+ Object[][] rows = {
+ row(1, vector(1, 2, 3, 4, 5, 6)),
+ row(2, vector(1, 2, 3, 4, 5)),
+ row(3, vector(1, 2, 3, 4, 6, 7))
+ };
+
+ Path csv = prepareCSVFile(rows);
+
+ // when running COPY via cqlsh
+ ToolRunner.ToolResult result = ToolRunner.invokeCqlsh(format("COPY
%s.%s FROM '%s'", KEYSPACE, currentTable(), csv.toAbsolutePath()));
+
+ // then only rows that match type size should be imported
+ result.asserts().failure();
+ result.asserts().errorContains("The length of given vector value '5'
is not equal to the vector size from the type definition '6'");
+ UntypedResultSet importedRows = execute("SELECT * FROM %s");
+ assertRowsIgnoringOrder(importedRows, row(1, vector(1, 2, 3, 4, 5, 6)),
+ row(3, vector(1, 2, 3, 4, 6, 7)));
+ }
+
+ private static Path prepareCSVFile(Object[][] rows) throws IOException
+ {
+ Path csv = createTempFile("test_copy_from_vector");
+
+ try (Writer out = Files.newBufferedWriter(csv, StandardCharsets.UTF_8))
+ {
+ for (Object[] row : rows)
+ {
+ out.write(String.format("%s,\"%s\"\n", row[0], row[1]));
+ }
+ }
+
+ return csv;
+ }
+
+ private static Path createTempFile(String prefix) throws IOException
+ {
+ Path csv = Files.createTempFile(prefix, ".csv");
+ csv.toFile().deleteOnExit();
+ return csv;
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]