Maxwell-Guo commented on code in PR #2957:
URL: https://github.com/apache/cassandra/pull/2957#discussion_r1411883520


##########
test/unit/org/apache/cassandra/cql3/validation/operations/CQLVectorTest.java:
##########
@@ -554,4 +562,97 @@ public void udf() throws Throwable
         // make sure the function referencing the UDT is dropped before 
dropping the UDT at cleanup
         execute("DROP FUNCTION " + f);
     }
+
+    @Test
+    public void testCopyFloatVectorFromFile() throws IOException
+    {
+        assertCopyOfVectorLiteralsFromFileSucceeds("float", 6, new Object[][] {
+            row(1, vector(0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f)),
+            row(2, vector(-0.1f, -0.2f, -0.3f, -0.4f, -0.5f, -0.6f)),
+            row(3, vector(0.9f, 0.8f, 0.7f, 0.6f, 0.5f, 0.4f))
+        });
+
+        assertCopyOfVectorLiteralsFromFileSucceeds("float", 3, new Object[][] {
+            row(1, vector(0.1f, 0.2f, 0.3f)),
+            row(2, vector(-0.4f, -0.5f, -0.6f)),
+            row(3, vector(0.7f, 0.8f, 0.9f))
+        });
+    }
+
+    @Test
+    public void testCopyIntVectorFromFile() throws IOException
+    {
+        assertCopyOfVectorLiteralsFromFileSucceeds("int", 6, new Object[][] {
+            row(1, vector(1, 2, 3, 4, 5, 6)),
+            row(2, vector(-1, -2, -3, -4, -5, -6)),
+            row(3, vector(9, 8, 7, 6, 5, 4))
+        });
+
+        assertCopyOfVectorLiteralsFromFileSucceeds("int", 3, new Object[][] {
+            row(1, vector(1, 2, 3)),
+            row(2, vector(-4, -5, -6)),
+            row(3, vector(7, 8, 9))
+        });
+    }
+
+    private void assertCopyOfVectorLiteralsFromFileSucceeds(String vectorType, 
int vectorSize, Object[][] rows) throws IOException

Review Comment:
   I would recommend to move this test cases to 
https://github.com/apache/cassandra/tree/trunk/test/unit/org/apache/cassandra/tools/cqlsh



##########
pylib/cqlshlib/copyutil.py:
##########
@@ -2130,6 +2136,7 @@ def convert_unknown(val, ct=cql_type):
             'map': convert_map,
             'tuple': convert_tuple,
             'frozen': convert_single_subtype,
+            VectorType.typename: convert_vector,

Review Comment:
   I would recommend use 'vector' to maintain a consistent coding style



##########
test/unit/org/apache/cassandra/cql3/validation/operations/CQLVectorTest.java:
##########
@@ -554,4 +562,97 @@ public void udf() throws Throwable
         // make sure the function referencing the UDT is dropped before 
dropping the UDT at cleanup
         execute("DROP FUNCTION " + f);
     }
+
+    @Test
+    public void testCopyFloatVectorFromFile() throws IOException
+    {
+        assertCopyOfVectorLiteralsFromFileSucceeds("float", 6, new Object[][] {
+            row(1, vector(0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f)),
+            row(2, vector(-0.1f, -0.2f, -0.3f, -0.4f, -0.5f, -0.6f)),
+            row(3, vector(0.9f, 0.8f, 0.7f, 0.6f, 0.5f, 0.4f))
+        });
+
+        assertCopyOfVectorLiteralsFromFileSucceeds("float", 3, new Object[][] {
+            row(1, vector(0.1f, 0.2f, 0.3f)),
+            row(2, vector(-0.4f, -0.5f, -0.6f)),
+            row(3, vector(0.7f, 0.8f, 0.9f))
+        });
+    }
+
+    @Test
+    public void testCopyIntVectorFromFile() throws IOException
+    {
+        assertCopyOfVectorLiteralsFromFileSucceeds("int", 6, new Object[][] {
+            row(1, vector(1, 2, 3, 4, 5, 6)),
+            row(2, vector(-1, -2, -3, -4, -5, -6)),
+            row(3, vector(9, 8, 7, 6, 5, 4))
+        });
+
+        assertCopyOfVectorLiteralsFromFileSucceeds("int", 3, new Object[][] {
+            row(1, vector(1, 2, 3)),
+            row(2, vector(-4, -5, -6)),
+            row(3, vector(7, 8, 9))
+        });
+    }
+
+    private void assertCopyOfVectorLiteralsFromFileSucceeds(String vectorType, 
int vectorSize, Object[][] rows) throws IOException
+    {
+        // given a table with a vector column and a file containing vector 
literals
+        requireNetwork();
+        createTable(KEYSPACE, format("CREATE TABLE %%s (id int PRIMARY KEY, 
embedding_vector vector<%s, %d>)", vectorType, vectorSize));
+        assertTrue("table should be initially empty", execute("SELECT * FROM 
%s").isEmpty());
+
+        Path csv = prepareCSVFile(rows);
+
+        // when running COPY via cqlsh
+        ToolRunner.ToolResult result = ToolRunner.invokeCqlsh(format("COPY 
%s.%s FROM '%s'", KEYSPACE, currentTable(), csv.toAbsolutePath()));
+        UntypedResultSet importedRows = execute("SELECT * FROM %s");
+
+        // then all rows should be imported
+        result.asserts().success();
+        assertRowsIgnoringOrder(importedRows, rows);
+    }
+
+    private Path prepareCSVFile(Object[][] rows) throws IOException
+    {
+        Path csv = Files.createTempFile("test_copy_vector", ".csv");
+        csv.toFile().deleteOnExit();
+
+        try (Writer out = Files.newBufferedWriter(csv, StandardCharsets.UTF_8))
+        {
+            for (Object[] row : rows)
+            {
+                out.write(String.format("%s,\"%s\"\n", row[0], row[1]));
+            }
+        }
+
+        return csv;
+    }
+
+    @Test
+    public void testCopyOnlyThoseRowsThatMatchVectorTypeSize() throws 
IOException
+    {
+        // given a table with a vector column and a file containing vector 
literals
+        requireNetwork();
+        createTable(KEYSPACE, "CREATE TABLE %s (id int PRIMARY KEY, 
embedding_vector vector<int, 6>)");
+        assertTrue("table should be initially empty", execute("SELECT * FROM 
%s").isEmpty());
+
+        Object[][] rows = {
+            row(1, vector(1, 2, 3, 4, 5, 6)),
+            row(2, vector(1, 2, 3, 4, 5)),
+            row(3, vector(1, 2, 3, 4, 6, 7))
+        };
+
+        Path csv = prepareCSVFile(rows);
+
+        // when running COPY via cqlsh
+        ToolRunner.ToolResult result = ToolRunner.invokeCqlsh(format("COPY 
%s.%s FROM '%s'", KEYSPACE, currentTable(), csv.toAbsolutePath()));
+        UntypedResultSet importedRows = execute("SELECT * FROM %s");
+
+        // then only rows that match type size should be imported
+        result.asserts().failure();
+        result.asserts().errorContains("The length of given vector value '5' 
is not equal to the vector size from the type definition '6'");
+        assertRowsIgnoringOrder(importedRows, row(1, vector(1, 2, 3, 4, 5, 6)),

Review Comment:
   row should keep aligned with importedRows



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to