This is an automated email from the ASF dual-hosted git repository.

adelapena pushed a commit to branch cassandra-5.0
in repository https://gitbox.apache.org/repos/asf/cassandra.git


The following commit(s) were added to refs/heads/cassandra-5.0 by this push:
     new c76b32492f Add support of vector type to cqlsh COPY command
c76b32492f is described below

commit c76b32492f08c4af56846518488ae0b191e077e8
Author: Szymon Miężał <[email protected]>
AuthorDate: Thu Nov 30 17:56:48 2023 +0100

    Add support of vector type to cqlsh COPY command
    
    This patch adds a converter that allows parsing vector literals
    passed via csv files to the COPY command.
    
    patch by Szymon Miezal; reviewed by Andrés de la Peña, Stefan Miklosovic 
and Maxwell Guo for CASSANDRA-19118
---
 CHANGES.txt                                        |   1 +
 pylib/cqlshlib/copyutil.py                         |   9 +-
 .../apache/cassandra/tools/cqlsh/CqlshTest.java    | 126 ++++++++++++++++++++-
 3 files changed, 132 insertions(+), 4 deletions(-)

diff --git a/CHANGES.txt b/CHANGES.txt
index a7859ee9ec..1d71cb52c3 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -1,4 +1,5 @@
 5.0-beta2
+ * Add support of vector type to cqlsh COPY command (CASSANDRA-19118)
  * Make CQLSSTableWriter to support building of SAI indexes (CASSANDRA-18714)
  * Append additional JVM options when using JDK17+ (CASSANDRA-19001)
  * Upgrade Python driver to 3.29.0 (CASSANDRA-19245)
diff --git a/pylib/cqlshlib/copyutil.py b/pylib/cqlshlib/copyutil.py
index 2a8a11d1bf..af35731005 100644
--- a/pylib/cqlshlib/copyutil.py
+++ b/pylib/cqlshlib/copyutil.py
@@ -46,7 +46,7 @@ from queue import Queue
 
 from cassandra import OperationTimedOut
 from cassandra.cluster import Cluster, DefaultConnection
-from cassandra.cqltypes import ReversedType, UserType, VarcharType
+from cassandra.cqltypes import ReversedType, UserType, VarcharType, VectorType
 from cassandra.metadata import protect_name, protect_names, protect_value
 from cassandra.policies import RetryPolicy, WhiteListRoundRobinPolicy, 
DCAwareRoundRobinPolicy, FallthroughRetryPolicy
 from cassandra.query import BatchStatement, BatchType, SimpleStatement, 
tuple_factory
@@ -2074,6 +2074,12 @@ class ImportConversion(object):
             return ImmutableDict(frozenset((convert_mandatory(ct.subtypes[0], 
v[0]), convert(ct.subtypes[1], v[1]))
                                  for v in [split(split_format_str % vv, 
sep=sep) for vv in split(val)]))
 
+        def convert_vector(val, ct=cql_type):
+            string_coordinates = split(val)
+            if len(string_coordinates) != ct.vector_size:
+                raise ParseError("The length of given vector value '%d' is not 
equal to the vector size from the type definition '%d'" % 
(len(string_coordinates), ct.vector_size))
+            return [convert_mandatory(ct.subtype, v) for v in 
string_coordinates]
+
         def convert_user_type(val, ct=cql_type):
             """
             A user type is a dictionary except that we must convert each key 
into
@@ -2130,6 +2136,7 @@ class ImportConversion(object):
             'map': convert_map,
             'tuple': convert_tuple,
             'frozen': convert_single_subtype,
+            VectorType.typename: convert_vector,
         }
 
         return converters.get(cql_type.typename, convert_unknown)
diff --git a/test/unit/org/apache/cassandra/tools/cqlsh/CqlshTest.java 
b/test/unit/org/apache/cassandra/tools/cqlsh/CqlshTest.java
index 4e6dd2088b..356769b840 100644
--- a/test/unit/org/apache/cassandra/tools/cqlsh/CqlshTest.java
+++ b/test/unit/org/apache/cassandra/tools/cqlsh/CqlshTest.java
@@ -18,16 +18,24 @@
 
 package org.apache.cassandra.tools.cqlsh;
 
+import java.io.IOException;
+import java.io.Writer;
+import java.nio.charset.StandardCharsets;
+import java.nio.file.Files;
+import java.nio.file.Path;
+
 import org.junit.BeforeClass;
 import org.junit.Test;
 
 import org.apache.cassandra.cql3.CQLTester;
+import org.apache.cassandra.cql3.UntypedResultSet;
 import org.apache.cassandra.tools.ToolRunner;
 import org.apache.cassandra.tools.ToolRunner.ToolResult;
-import org.hamcrest.CoreMatchers;
 
+import static java.lang.String.format;
+import static org.assertj.core.api.Assertions.assertThat;
 import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertTrue;
 
 public class CqlshTest extends CQLTester
 {
@@ -41,7 +49,119 @@ public class CqlshTest extends CQLTester
     public void testKeyspaceRequired()
     {
         ToolResult tool = ToolRunner.invokeCqlsh("SELECT * FROM test");
-        assertThat(tool.getCleanedStderr(), 
CoreMatchers.containsStringIgnoringCase("No keyspace has been specified"));
+        tool.asserts().errorContains("No keyspace has been specified");
         assertEquals(2, tool.getExitCode());
     }
+
+    @Test
+    public void testCopyFloatVector() throws IOException
+    {
+        assertCopyOfVectorTypeSucceeds("float", 6, new Object[][] {
+            row(1, vector(0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f)),
+            row(2, vector(-0.1f, -0.2f, -0.3f, -0.4f, -0.5f, -0.6f)),
+            row(3, vector(0.9f, 0.8f, 0.7f, 0.6f, 0.5f, 0.4f))
+        });
+
+        assertCopyOfVectorTypeSucceeds("float", 3, new Object[][] {
+            row(1, vector(0.1f, 0.2f, 0.3f)),
+            row(2, vector(-0.4f, -0.5f, -0.6f)),
+            row(3, vector(0.7f, 0.8f, 0.9f))
+        });
+    }
+
+    @Test
+    public void testCopyIntVector() throws IOException
+    {
+        assertCopyOfVectorTypeSucceeds("int", 6, new Object[][] {
+            row(1, vector(1, 2, 3, 4, 5, 6)),
+            row(2, vector(-1, -2, -3, -4, -5, -6)),
+            row(3, vector(9, 8, 7, 6, 5, 4))
+        });
+
+        assertCopyOfVectorTypeSucceeds("int", 3, new Object[][] {
+            row(1, vector(1, 2, 3)),
+            row(2, vector(-4, -5, -6)),
+            row(3, vector(7, 8, 9))
+        });
+    }
+
+    private void assertCopyOfVectorTypeSucceeds(String vectorType, int 
vectorSize, Object[][] rows) throws IOException
+    {
+        // given a table with a vector column
+        createTable(KEYSPACE, format("CREATE TABLE %%s (id int PRIMARY KEY, 
embedding_vector vector<%s, %d>)", vectorType, vectorSize));
+        assertTrue("table should be initially empty", execute("SELECT * FROM 
%s").isEmpty());
+
+        // write the rows into the table
+        for (Object[] row : rows)
+            execute("INSERT INTO %s (id, embedding_vector) VALUES (?, ?)", 
row);
+
+        // when running COPY TO CSV via cqlsh
+        Path csv = createTempFile("test_copy_to_vector");
+        ToolRunner.ToolResult copyToResult = 
ToolRunner.invokeCqlsh(format("COPY %s.%s TO '%s'", KEYSPACE, currentTable(), 
csv.toAbsolutePath()));
+
+        // then all rows should be exported
+        copyToResult.asserts().success();
+        // verify that the exported CSV contains the expected rows
+        assertThat(csv).hasSameTextualContentAs(prepareCSVFile(rows));
+
+        // truncate the table
+        execute("TRUNCATE %s");
+        assertTrue("table should be empty", execute("SELECT * FROM 
%s").isEmpty());
+
+        // when running COPY FROM via cqlsh
+        ToolRunner.ToolResult copyFromResult = 
ToolRunner.invokeCqlsh(format("COPY %s.%s FROM '%s'", KEYSPACE, currentTable(), 
csv.toAbsolutePath()));
+
+        // then all rows should be imported
+        copyFromResult.asserts().success();
+        UntypedResultSet importedRows = execute("SELECT * FROM %s");
+        assertRowsIgnoringOrder(importedRows, rows);
+    }
+
+    @Test
+    public void testCopyOnlyThoseRowsThatMatchVectorTypeSize() throws 
IOException
+    {
+        // given a table with a vector column and a file containing vector 
literals
+        createTable(KEYSPACE, "CREATE TABLE %s (id int PRIMARY KEY, 
embedding_vector vector<int, 6>)");
+        assertTrue("table should be initially empty", execute("SELECT * FROM 
%s").isEmpty());
+
+        Object[][] rows = {
+            row(1, vector(1, 2, 3, 4, 5, 6)),
+            row(2, vector(1, 2, 3, 4, 5)),
+            row(3, vector(1, 2, 3, 4, 6, 7))
+        };
+
+        Path csv = prepareCSVFile(rows);
+
+        // when running COPY via cqlsh
+        ToolRunner.ToolResult result = ToolRunner.invokeCqlsh(format("COPY 
%s.%s FROM '%s'", KEYSPACE, currentTable(), csv.toAbsolutePath()));
+
+        // then only rows that match type size should be imported
+        result.asserts().failure();
+        result.asserts().errorContains("The length of given vector value '5' 
is not equal to the vector size from the type definition '6'");
+        UntypedResultSet importedRows = execute("SELECT * FROM %s");
+        assertRowsIgnoringOrder(importedRows, row(1, vector(1, 2, 3, 4, 5, 6)),
+                                row(3, vector(1, 2, 3, 4, 6, 7)));
+    }
+
+    private static Path prepareCSVFile(Object[][] rows) throws IOException
+    {
+        Path csv = createTempFile("test_copy_from_vector");
+
+        try (Writer out = Files.newBufferedWriter(csv, StandardCharsets.UTF_8))
+        {
+            for (Object[] row : rows)
+            {
+                out.write(String.format("%s,\"%s\"\n", row[0], row[1]));
+            }
+        }
+
+        return csv;
+    }
+
+    private static Path createTempFile(String prefix) throws IOException
+    {
+        Path csv = Files.createTempFile(prefix, ".csv");
+        csv.toFile().deleteOnExit();
+        return csv;
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to