adelapena commented on code in PR #2673: URL: https://github.com/apache/cassandra/pull/2673#discussion_r1348796706
########## src/java/org/apache/cassandra/index/sai/disk/v1/vector/CassandraDiskAnn.java: ########## @@ -0,0 +1,267 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.disk.v1.vector; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.PrimitiveIterator; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import io.github.jbellis.jvector.disk.CachingGraphIndex; +import io.github.jbellis.jvector.disk.CompressedVectors; +import io.github.jbellis.jvector.disk.OnDiskGraphIndex; +import io.github.jbellis.jvector.disk.ReaderSupplier; +import io.github.jbellis.jvector.graph.GraphIndex; +import io.github.jbellis.jvector.graph.GraphSearcher; +import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; +import io.github.jbellis.jvector.graph.NeighborSimilarity; +import io.github.jbellis.jvector.graph.SearchResult.NodeScore; +import io.github.jbellis.jvector.util.Bits; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import org.apache.cassandra.index.sai.IndexContext; +import org.apache.cassandra.index.sai.disk.format.IndexComponent; +import org.apache.cassandra.index.sai.disk.v1.PerColumnIndexFiles; +import org.apache.cassandra.index.sai.disk.v1.postings.ReorderingPostingList; +import org.apache.cassandra.index.sai.disk.v1.segment.SegmentMetadata; +import org.apache.cassandra.io.util.FileHandle; +import org.apache.cassandra.io.util.FileUtils; +import org.apache.cassandra.io.util.RandomAccessReader; + +public class CassandraDiskAnn implements AutoCloseable Review Comment: Starting the class name with "Cassandra" doesn't give a lot of information, I think. Maybe this could be named `DiskANN`, `OnDiskANN`, or something like that? Same for `CassandraOnHeapGraph`. ########## src/java/org/apache/cassandra/index/sai/VectorQueryContext.java: ########## @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashSet; +import java.util.NavigableSet; +import java.util.Set; +import java.util.TreeSet; + +import io.github.jbellis.jvector.util.Bits; +import org.apache.cassandra.db.ReadCommand; +import org.apache.cassandra.index.sai.disk.PrimaryKeyMap; +import org.apache.cassandra.index.sai.disk.v1.segment.SegmentMetadata; +import org.apache.cassandra.index.sai.disk.v1.vector.CassandraDiskAnn; +import org.apache.cassandra.index.sai.disk.v1.vector.CassandraOnHeapGraph; +import org.apache.cassandra.index.sai.utils.PrimaryKey; + + +/** + * This represents the state of a vector query. It is repsonsible for maintaining a list of any {@link PrimaryKey}s + * that have been updated or deleted during a search of the indexes. + * <p> + * The number of {@link #shadowedPrimaryKeys} is compared before and after a search is performed. If it changes, it + * means that a {@link PrimaryKey} was found to have been changed. In this case the whole search is repeated until the + * counts match. + * <p> + * When this process has completed, a {@link Bits} array is generated. This is used by the vector graph search to + * identify which nodes in the graph to include in the results. + */ +public class VectorQueryContext +{ + private TreeSet<PrimaryKey> shadowedPrimaryKeys; // allocate when needed + private int limit; Review Comment: Nit: can be `final` ########## src/java/org/apache/cassandra/utils/ReadWriteLockedList.java: ########## @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.utils; + +import java.util.AbstractList; +import java.util.List; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReadWriteLock; +import java.util.concurrent.locks.ReentrantReadWriteLock; + +public class ReadWriteLockedList<T> extends AbstractList<T> Review Comment: This entire class seems unused ########## test/unit/org/apache/cassandra/index/sai/cql/VectorSegmentationTest.java: ########## @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.cql; + +import java.util.ArrayList; +import java.util.List; + +import org.junit.Test; + +import org.apache.cassandra.cql3.UntypedResultSet; +import org.apache.cassandra.db.marshal.FloatType; +import org.apache.cassandra.db.marshal.VectorType; +import org.apache.cassandra.index.sai.disk.v1.segment.SegmentBuilder; + +import static org.assertj.core.api.Assertions.assertThat; + +public class VectorSegmentationTest extends VectorTester +{ + private static final int dimension = 100; + + @Test + public void testMultipleSegmentsForCreatingIndex() throws Throwable + { + createTable("CREATE TABLE %s (pk int, val vector<float, " + dimension + ">, PRIMARY KEY(pk))"); + + int vectorCount = 100; + List<float[]> vectors = new ArrayList<>(); + for (int row = 0; row < vectorCount; row++) + { + float[] vector = nextVector(); + vectors.add(vector); + execute("INSERT INTO %s (pk, val) VALUES (?, ?)", row, vector(vector)); + } + + flush(); + + SegmentBuilder.updateLastValidSegmentRowId(17); // 17 rows per segment + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex'"); + + int limit = 35; + float[] queryVector = nextVector(); + UntypedResultSet resultSet = execute("SELECT * FROM %s ORDER BY val ANN OF ? LIMIT " + limit, vector(queryVector)); + assertThat(resultSet.size()).isEqualTo(limit); + + List<float[]> resultVectors = getVectorsFromResult(resultSet); + double recall = rawIndexedRecall(vectors, queryVector, resultVectors, limit); + assertThat(recall).isGreaterThanOrEqualTo(0.99); + } + + @Test + public void testMultipleSegmentsForCompaction() throws Throwable + { + createTable("CREATE TABLE %s (pk int, val vector<float, " + dimension + ">, PRIMARY KEY(pk))"); + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex'"); + + List<float[]> vectors = new ArrayList<>(); + int rowsPerSSTable = 10; + int sstables = 5; + int pk = 0; + for (int i = 0; i < sstables; i++) + { + for (int row = 0; row < rowsPerSSTable; row++) + { + float[] vector = nextVector(); + execute("INSERT INTO %s (pk, val) VALUES (?, ?)", pk++, vector(vector)); + vectors.add(vector); + } + + flush(); + } + + int limit = 30; + float[] queryVector = nextVector(); + UntypedResultSet resultSet = execute("SELECT * FROM %s ORDER BY val ANN OF ? LIMIT " + limit, vector(queryVector)); + assertThat(resultSet.size()).isEqualTo(limit); + + List<float[]> resultVectors = getVectorsFromResult(resultSet); + double recall = rawIndexedRecall(vectors, queryVector, resultVectors, limit); + assertThat(recall).isGreaterThanOrEqualTo(0.99); + + + SegmentBuilder.updateLastValidSegmentRowId(11); // 11 rows per segment + compact(); + + queryVector = nextVector(); + resultSet = execute("SELECT * FROM %s ORDER BY val ANN OF ? LIMIT " + limit, vector(queryVector)); + assertThat(resultSet.size()).isEqualTo(limit); + + resultVectors = getVectorsFromResult(resultSet); + recall = rawIndexedRecall(vectors, queryVector, resultVectors, limit); + assertThat(recall).isGreaterThanOrEqualTo(0.99); + } + + protected Vector<Float> vector(float[] values) + { + Float[] floats = new Float[values.length]; + for (int i = 0; i < values.length; i++) + floats[i] = values[i]; + + return new Vector<>(floats); + } + + private float[] nextVector() Review Comment: Nit: can be `static` ########## test/unit/org/apache/cassandra/index/sai/cql/VectorTypeTest.java: ########## @@ -0,0 +1,673 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.cql; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.junit.Test; + +import org.apache.cassandra.config.CassandraRelevantProperties; +import org.apache.cassandra.cql3.UntypedResultSet; +import org.apache.cassandra.db.marshal.Int32Type; +import org.apache.cassandra.dht.IPartitioner; +import org.apache.cassandra.dht.Murmur3Partitioner; +import org.apache.cassandra.dht.Token; +import org.apache.cassandra.exceptions.InvalidRequestException; +import org.apache.cassandra.index.sai.StorageAttachedIndex; +import org.apache.cassandra.service.ClientWarn; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public class VectorTypeTest extends VectorTester +{ + private static final IPartitioner partitioner = Murmur3Partitioner.instance; + + @Test + public void endToEndTest() + { + createTable("CREATE TABLE %s (pk int, str_val text, val vector<float, 3>, PRIMARY KEY(pk))"); + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex'"); + + execute("INSERT INTO %s (pk, str_val, val) VALUES (0, 'A', [1.0, 2.0, 3.0])"); + execute("INSERT INTO %s (pk, str_val, val) VALUES (1, 'B', [2.0, 3.0, 4.0])"); + execute("INSERT INTO %s (pk, str_val, val) VALUES (2, 'C', [3.0, 4.0, 5.0])"); + execute("INSERT INTO %s (pk, str_val, val) VALUES (3, 'D', [4.0, 5.0, 6.0])"); + + UntypedResultSet result = execute("SELECT * FROM %s ORDER BY val ann of [2.5, 3.5, 4.5] LIMIT 3"); + assertThat(result).hasSize(3); + + flush(); + result = execute("SELECT * FROM %s ORDER BY val ann of [2.5, 3.5, 4.5] LIMIT 3"); + assertThat(result).hasSize(3); + + execute("INSERT INTO %s (pk, str_val, val) VALUES (4, 'E', [5.0, 2.0, 3.0])"); + execute("INSERT INTO %s (pk, str_val, val) VALUES (5, 'F', [6.0, 3.0, 4.0])"); + execute("INSERT INTO %s (pk, str_val, val) VALUES (6, 'G', [7.0, 4.0, 5.0])"); + execute("INSERT INTO %s (pk, str_val, val) VALUES (7, 'H', [8.0, 5.0, 6.0])"); + + flush(); + compact(); + + result = execute("SELECT * FROM %s ORDER BY val ann of [2.5, 3.5, 4.5] LIMIT 5"); + assertThat(result).hasSize(5); + + // some data that only lives in memtable + execute("INSERT INTO %s (pk, str_val, val) VALUES (8, 'I', [9.0, 5.0, 6.0])"); + execute("INSERT INTO %s (pk, str_val, val) VALUES (9, 'J', [10.0, 6.0, 7.0])"); + result = execute("SELECT * FROM %s ORDER BY val ann of [9.5, 5.5, 6.5] LIMIT 5"); + assertContainsInt(result, "pk", 8); + assertContainsInt(result, "pk", 9); + + // data from sstables + result = execute("SELECT * FROM %s ORDER BY val ann of [2.5, 3.5, 4.5] LIMIT 2"); + assertContainsInt(result, "pk", 1); + assertContainsInt(result, "pk", 2); + } + + @Test + public void warningIsIssuedOnIndexCreation() + { + ClientWarn.instance.captureWarnings(); + createTable("CREATE TABLE %s (pk int, val vector<float, 3>, PRIMARY KEY(pk))"); + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex'"); + List<String> warnings = ClientWarn.instance.getWarnings(); + + assertTrue(warnings.size() > 0); + assertEquals(StorageAttachedIndex.VECTOR_USAGE_WARNING, warnings.get(0)); + } + + @Test + public void createIndexAfterInsertTest() + { + createTable("CREATE TABLE %s (pk int, str_val text, val vector<float, 3>, PRIMARY KEY(pk))"); + + execute("INSERT INTO %s (pk, str_val, val) VALUES (0, 'A', [1.0, 2.0, 3.0])"); + execute("INSERT INTO %s (pk, str_val, val) VALUES (1, 'B', [2.0, 3.0, 4.0])"); + execute("INSERT INTO %s (pk, str_val, val) VALUES (2, 'C', [3.0, 4.0, 5.0])"); + execute("INSERT INTO %s (pk, str_val, val) VALUES (3, 'D', [4.0, 5.0, 6.0])"); + + flush(); + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex'"); + + UntypedResultSet result = execute("SELECT * FROM %s ORDER BY val ann of [2.5, 3.5, 4.5] LIMIT 3"); + assertThat(result).hasSize(3); + } + + public static void assertContainsInt(UntypedResultSet result, String columnName, int columnValue) + { + for (UntypedResultSet.Row row : result) + { + if (row.has(columnName)) + { + int value = row.getInt(columnName); + if (value == columnValue) + { + return; + } + } + } + throw new AssertionError("Result set does not contain a row with " + columnName + " = " + columnValue); + } + + @Test + public void testTwoPredicates() + { + createTable("CREATE TABLE %s (pk int, b boolean, v vector<float, 3>, PRIMARY KEY(pk))"); + createIndex("CREATE CUSTOM INDEX ON %s(b) USING 'StorageAttachedIndex'"); + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'StorageAttachedIndex'"); + + execute("INSERT INTO %s (pk, b, v) VALUES (0, true, [1.0, 2.0, 3.0])"); + execute("INSERT INTO %s (pk, b, v) VALUES (1, true, [2.0, 3.0, 4.0])"); + execute("INSERT INTO %s (pk, b, v) VALUES (2, false, [3.0, 4.0, 5.0])"); + + // the vector given is closest to row 2, but we exclude that row because b=false + var result = execute("SELECT * FROM %s WHERE b=true ORDER BY v ANN OF [3.1, 4.1, 5.1] LIMIT 2"); + // VSTODO assert specific row keys + assertThat(result).hasSize(2); + + flush(); + compact(); + + result = execute("SELECT * FROM %s WHERE b=true ORDER BY v ANN OF [3.1, 4.1, 5.1] LIMIT 2"); + assertThat(result).hasSize(2); + } + + @Test + public void testTwoPredicatesManyRows() + { + createTable("CREATE TABLE %s (pk int, b boolean, v vector<float, 3>, PRIMARY KEY(pk))"); + createIndex("CREATE CUSTOM INDEX ON %s(b) USING 'StorageAttachedIndex'"); + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'StorageAttachedIndex'"); + + for (int i = 0; i < 100; i++) + execute("INSERT INTO %s (pk, b, v) VALUES (?, true, ?)", + i, vector((float) i, (float) (i + 1), (float) (i + 2))); + + var result = execute("SELECT * FROM %s WHERE b=true ORDER BY v ANN OF [3.1, 4.1, 5.1] LIMIT 2"); + assertThat(result).hasSize(2); + + flush(); + compact(); + + result = execute("SELECT * FROM %s WHERE b=true ORDER BY v ANN OF [3.1, 4.1, 5.1] LIMIT 2"); + assertThat(result).hasSize(2); + } + + @Test + public void testThreePredicates() + { + createTable("CREATE TABLE %s (pk int, b boolean, v vector<float, 3>, str text, PRIMARY KEY(pk))"); + createIndex("CREATE CUSTOM INDEX ON %s(b) USING 'StorageAttachedIndex'"); + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'StorageAttachedIndex'"); + createIndex("CREATE CUSTOM INDEX ON %s(str) USING 'StorageAttachedIndex'"); + + execute("INSERT INTO %s (pk, b, v, str) VALUES (0, true, [1.0, 2.0, 3.0], 'A')"); + execute("INSERT INTO %s (pk, b, v, str) VALUES (1, true, [2.0, 3.0, 4.0], 'B')"); + execute("INSERT INTO %s (pk, b, v, str) VALUES (2, false, [3.0, 4.0, 5.0], 'C')"); + + // the vector given is closest to row 2, but we exclude that row because b=false and str!='B' + var result = execute("SELECT * FROM %s WHERE b=true AND str='B' ORDER BY v ANN OF [3.1, 4.1, 5.1] LIMIT 2"); + // VSTODO assert specific row keys + assertThat(result).hasSize(1); + + flush(); + compact(); + + result = execute("SELECT * FROM %s WHERE b=true AND str='B' ORDER BY v ANN OF [3.1, 4.1, 5.1] LIMIT 2"); + assertThat(result).hasSize(1); + } + + @Test + public void testSameVectorMultipleRows() + { + createTable("CREATE TABLE %s (pk int, str_val text, val vector<float, 3>, PRIMARY KEY(pk))"); + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex'"); + + execute("INSERT INTO %s (pk, str_val, val) VALUES (0, 'A', [1.0, 2.0, 3.0])"); + execute("INSERT INTO %s (pk, str_val, val) VALUES (1, 'A', [1.0, 2.0, 3.0])"); + execute("INSERT INTO %s (pk, str_val, val) VALUES (2, 'A', [1.0, 2.0, 3.0])"); + + var result = execute("SELECT * FROM %s ORDER BY val ann of [2.5, 3.5, 4.5] LIMIT 3"); + assertThat(result).hasSize(3); + + flush(); + compact(); + + result = execute("SELECT * FROM %s ORDER BY val ann of [2.5, 3.5, 4.5] LIMIT 3"); + assertThat(result).hasSize(3); + } + + @Test + public void testQueryEmptyTable() + { + createTable("CREATE TABLE %s (pk int, str_val text, val vector<float, 3>, PRIMARY KEY(pk))"); + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex'"); + + var result = execute("SELECT * FROM %s ORDER BY val ANN OF [2.5, 3.5, 4.5] LIMIT 1"); + assertThat(result).hasSize(0); + } + + @Test + public void testQueryTableWithNulls() + { + createTable("CREATE TABLE %s (pk int, str_val text, val vector<float, 3>, PRIMARY KEY(pk))"); + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex'"); + + execute("INSERT INTO %s (pk, str_val, val) VALUES (0, 'A', null)"); + var result = execute("SELECT * FROM %s ORDER BY val ANN OF [2.5, 3.5, 4.5] LIMIT 1"); + assertThat(result).hasSize(0); + + execute("INSERT INTO %s (pk, str_val, val) VALUES (1, 'B', [4.0, 5.0, 6.0])"); + result = execute("SELECT pk FROM %s ORDER BY val ANN OF [2.5, 3.5, 4.5] LIMIT 1"); + assertRows(result, row(1)); + } + + @Test + public void testLimitLessThanInsertedRowCount() + { + createTable("CREATE TABLE %s (pk int, str_val text, val vector<float, 3>, PRIMARY KEY(pk))"); + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex'"); + + // Insert more rows than the query limit + execute("INSERT INTO %s (pk, str_val, val) VALUES (0, 'A', [1.0, 2.0, 3.0])"); + execute("INSERT INTO %s (pk, str_val, val) VALUES (1, 'B', [4.0, 5.0, 6.0])"); + execute("INSERT INTO %s (pk, str_val, val) VALUES (2, 'C', [7.0, 8.0, 9.0])"); + + // Query with limit less than inserted row count + var result = execute("SELECT * FROM %s ORDER BY val ANN OF [2.5, 3.5, 4.5] LIMIT 2"); + assertThat(result).hasSize(2); + } + + @Test + public void testQueryMoreRowsThanInserted() + { + createTable("CREATE TABLE %s (pk int, str_val text, val vector<float, 3>, PRIMARY KEY(pk))"); + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex'"); + + execute("INSERT INTO %s (pk, str_val, val) VALUES (0, 'A', [1.0, 2.0, 3.0])"); + + var result = execute("SELECT * FROM %s ORDER BY val ANN OF [2.5, 3.5, 4.5] LIMIT 2"); + assertThat(result).hasSize(1); + } + + @Test + public void changingOptionsTest() + { + createTable("CREATE TABLE %s (pk int, str_val text, val vector<float, 3>, PRIMARY KEY(pk))"); + if (CassandraRelevantProperties.SAI_VECTOR_ALLOW_CUSTOM_PARAMETERS.getBoolean()) + { + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex' WITH OPTIONS = " + + "{'maximum_node_connections' : 10, 'construction_beam_width' : 200, 'similarity_function' : 'euclidean' }"); + } + else + { + assertThatThrownBy(() -> createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex' WITH OPTIONS = " + + "{'maximum_node_connections' : 10, 'construction_beam_width' : 200, 'similarity_function' : 'euclidean' }")) + .isInstanceOf(InvalidRequestException.class); + return; + } + + execute("INSERT INTO %s (pk, str_val, val) VALUES (0, 'A', [1.0, 2.0, 3.0])"); + execute("INSERT INTO %s (pk, str_val, val) VALUES (1, 'B', [2.0, 3.0, 4.0])"); + execute("INSERT INTO %s (pk, str_val, val) VALUES (2, 'C', [3.0, 4.0, 5.0])"); + execute("INSERT INTO %s (pk, str_val, val) VALUES (3, 'D', [4.0, 5.0, 6.0])"); + + UntypedResultSet result = execute("SELECT * FROM %s ORDER BY val ann of [2.5, 3.5, 4.5] LIMIT 3"); + assertThat(result).hasSize(3); + + flush(); + result = execute("SELECT * FROM %s ORDER BY val ann of [2.5, 3.5, 4.5] LIMIT 3"); + assertThat(result).hasSize(3); + + execute("INSERT INTO %s (pk, str_val, val) VALUES (4, 'E', [5.0, 2.0, 3.0])"); + execute("INSERT INTO %s (pk, str_val, val) VALUES (5, 'F', [6.0, 3.0, 4.0])"); + execute("INSERT INTO %s (pk, str_val, val) VALUES (6, 'G', [7.0, 4.0, 5.0])"); + execute("INSERT INTO %s (pk, str_val, val) VALUES (7, 'H', [8.0, 5.0, 6.0])"); + + flush(); + compact(); + + result = execute("SELECT * FROM %s ORDER BY val ann of [2.5, 3.5, 4.5] LIMIT 5"); + assertThat(result).hasSize(5); + } + + @Test + public void bindVariablesTest() + { + createTable("CREATE TABLE %s (pk int, str_val text, val vector<float, 3>, PRIMARY KEY(pk))"); + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex'"); + + execute("INSERT INTO %s (pk, str_val, val) VALUES (0, 'A', ?)", vector(1.0f, 2.0f ,3.0f)); + execute("INSERT INTO %s (pk, str_val, val) VALUES (1, 'B', ?)", vector(2.0f ,3.0f, 4.0f)); + execute("INSERT INTO %s (pk, str_val, val) VALUES (2, 'C', ?)", vector(3.0f, 4.0f, 5.0f)); + execute("INSERT INTO %s (pk, str_val, val) VALUES (3, 'D', ?)", vector(4.0f, 5.0f, 6.0f)); + + UntypedResultSet result = execute("SELECT * FROM %s ORDER BY val ann of ? LIMIT 3", vector(2.5f, 3.5f, 4.5f)); + assertThat(result).hasSize(3); + } + + @Test + public void intersectedSearcherTest() + { + // check that we correctly get back the two rows with str_val=B even when those are not + // the closest rows to the query vector + createTable("CREATE TABLE %s (pk int, str_val text, val vector<float, 3>, PRIMARY KEY(pk))"); + createIndex("CREATE CUSTOM INDEX ON %s(str_val) USING 'StorageAttachedIndex'"); + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex'"); + + execute("INSERT INTO %s (pk, str_val, val) VALUES (0, 'A', ?)", vector(1.0f, 2.0f ,3.0f)); + execute("INSERT INTO %s (pk, str_val, val) VALUES (1, 'B', ?)", vector(2.0f ,3.0f, 4.0f)); + execute("INSERT INTO %s (pk, str_val, val) VALUES (2, 'C', ?)", vector(3.0f, 4.0f, 5.0f)); + execute("INSERT INTO %s (pk, str_val, val) VALUES (3, 'B', ?)", vector(4.0f, 5.0f, 6.0f)); + execute("INSERT INTO %s (pk, str_val, val) VALUES (4, 'E', ?)", vector(5.0f, 6.0f, 7.0f)); + + UntypedResultSet result = execute("SELECT * FROM %s WHERE str_val = 'B' ORDER BY val ann of [2.5, 3.5, 4.5] LIMIT 2"); + assertThat(result).hasSize(2); + + flush(); + result = execute("SELECT * FROM %s WHERE str_val = 'B' ORDER BY val ann of [2.5, 3.5, 4.5] LIMIT 2"); + assertThat(result).hasSize(2); + } + + @Test + public void nullVectorTest() + { + createTable("CREATE TABLE %s (pk int, str_val text, val vector<float, 3>, PRIMARY KEY(pk))"); + createIndex("CREATE CUSTOM INDEX ON %s(str_val) USING 'StorageAttachedIndex'"); + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex'"); + + execute("INSERT INTO %s (pk, str_val, val) VALUES (0, 'A', ?)", vector(1.0f, 2.0f ,3.0f)); + execute("INSERT INTO %s (pk, str_val) VALUES (1, 'B')"); // no vector + execute("INSERT INTO %s (pk, str_val, val) VALUES (2, 'C', ?)", vector(3.0f, 4.0f, 5.0f)); + execute("INSERT INTO %s (pk, str_val) VALUES (3, 'D')"); // no vector + execute("INSERT INTO %s (pk, str_val, val) VALUES (4, 'E', ?)", vector(5.0f, 6.0f, 7.0f)); + + UntypedResultSet result = execute("SELECT * FROM %s WHERE str_val = 'B' ORDER BY val ann of [2.5, 3.5, 4.5] LIMIT 2"); + assertThat(result).hasSize(0); + + result = execute("SELECT * FROM %s WHERE str_val = 'A' ORDER BY val ann of [2.5, 3.5, 4.5] LIMIT 2"); + assertThat(result).hasSize(1); + + flush(); + + result = execute("SELECT * FROM %s WHERE str_val = 'B' ORDER BY val ann of [2.5, 3.5, 4.5] LIMIT 2"); + assertThat(result).hasSize(0); + + result = execute("SELECT * FROM %s WHERE str_val = 'A' ORDER BY val ann of [2.5, 3.5, 4.5] LIMIT 2"); + assertThat(result).hasSize(1); + } + + @Test + public void lwtTest() + { + createTable("CREATE TABLE %s (p int, c int, v text, vec vector<float, 2>, PRIMARY KEY(p, c))"); + createIndex("CREATE CUSTOM INDEX ON %s(vec) USING 'StorageAttachedIndex'"); + + execute("INSERT INTO %s (p, c, v) VALUES (?, ?, ?)", 0, 0, "test"); + execute("INSERT INTO %s (p, c, v) VALUES (?, ?, ?)", 0, 1, "00112233445566"); + + execute("UPDATE %s SET v='00112233', vec=[0.9, 0.7] WHERE p = 0 AND c = 0 IF v = 'test'"); + + UntypedResultSet result = execute("SELECT * FROM %s ORDER BY vec ANN OF [0.1, 0.9] LIMIT 100"); + + assertThat(result).hasSize(1); + } + + @Test + public void twoVectorFieldsTest() + { + createTable("CREATE TABLE %s (pk int, v2 vector<float, 2>, v3 vector<float, 3>, PRIMARY KEY(pk))"); + createIndex("CREATE CUSTOM INDEX ON %s(v2) USING 'StorageAttachedIndex'"); + createIndex("CREATE CUSTOM INDEX ON %s(v3) USING 'StorageAttachedIndex'"); + } + + @Test + public void primaryKeySearchTest() + { + createTable("CREATE TABLE %s (pk int, val vector<float, 3>, i int, PRIMARY KEY(pk))"); + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex'"); + + var N = 5; + for (int i = 0; i < N; i++) + execute("INSERT INTO %s (pk, val) VALUES (?, ?)", i, vector(1.0f + i, 2.0f + i, 3.0f + i)); + + for (int i = 0; i < N; i++) + { + UntypedResultSet result = execute("SELECT pk FROM %s WHERE pk = ? ORDER BY val ann of [2.5, 3.5, 4.5] LIMIT 2", i); + assertThat(result).hasSize(1); + assertRows(result, row(i)); + } + + flush(); + for (int i = 0; i < N; i++) + { + UntypedResultSet result = execute("SELECT pk FROM %s WHERE pk = ? ORDER BY val ann of [2.5, 3.5, 4.5] LIMIT 2", i); + assertThat(result).hasSize(1); + assertRows(result, row(i)); + } + } + + @Test + public void partitionKeySearchTest() + { + createTable("CREATE TABLE %s (partition int, row int, val vector<float, 2>, PRIMARY KEY(partition, row))"); + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex' WITH OPTIONS = {'similarity_function' : 'euclidean'}"); + + var nPartitions = 5; + var rowsPerPartition = 10; + Map<Integer, List<float[]>> vectorsByPartition = new HashMap<>(); + + for (int i = 1; i <= nPartitions; i++) + { + for (int j = 1; j <= rowsPerPartition; j++) + { + logger.debug("Inserting partition {} row {}: [{}, {}]", i, j, i, j); + execute("INSERT INTO %s (partition, row, val) VALUES (?, ?, ?)", i, j, vector((float) i, (float) j)); + float[] vector = {(float) i, (float) j}; + vectorsByPartition.computeIfAbsent(i, k -> new ArrayList<>()).add(vector); + } + } + + var queryVector = vector(new float[] { 1.5f, 1.5f }); + for (int i = 1; i <= nPartitions; i++) + { + UntypedResultSet result = execute("SELECT partition, row FROM %s WHERE partition = ? ORDER BY val ann of ? LIMIT 2", i, queryVector); + assertThat(result).hasSize(2); + assertRowsIgnoringOrder(result, + row(i, 1), + row(i, 2)); + } + + flush(); + for (int i = 1; i <= nPartitions; i++) + { + UntypedResultSet result = execute("SELECT partition, row FROM %s WHERE partition = ? ORDER BY val ann of ? LIMIT 2", i, queryVector); + assertThat(result).hasSize(2); + assertRowsIgnoringOrder(result, + row(i, 1), + row(i, 2)); + } + } + + @Test + public void clusteringKeyIndexTest() + { + createTable("CREATE TABLE %s (pk int, ck vector<float, 2>, PRIMARY KEY(pk, ck))"); + createIndex("CREATE CUSTOM INDEX ON %s(ck) USING 'StorageAttachedIndex'"); + + execute("INSERT INTO %s (pk, ck) VALUES (1, [1.0, 2.0])"); + + assertRows(execute("SELECT * FROM %s ORDER BY ck ANN OF [1.0, 2.0] LIMIT 1"), row(1, vector(1.0F, 2.0F))); + } + + @Test + public void rangeSearchTest() throws Throwable + { + createTable("CREATE TABLE %s (partition int, val vector<float, 2>, PRIMARY KEY(partition))"); + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex' WITH OPTIONS = {'similarity_function' : 'euclidean'}"); + + var nPartitions = 100; + Map<Integer, float[]> vectorsByKey = new HashMap<>(); + + for (int i = 1; i <= nPartitions; i++) + { + float[] vector = {(float) i, (float) i}; + execute("INSERT INTO %s (partition, val) VALUES (?, ?)", i, vector(vector)); + vectorsByKey.put(i, vector); + } + + var queryVector = vector(new float[] { 1.5f, 1.5f }); + CheckedFunction tester = () -> { + for (int i = 1; i <= nPartitions; i++) + { + UntypedResultSet result = execute("SELECT partition FROM %s WHERE token(partition) > token(?) ORDER BY val ann of ? LIMIT 1000", i, queryVector); + assertThat(keys(result)).containsExactlyInAnyOrderElementsOf(keysWithLowerBound(vectorsByKey.keySet(), i, false)); + + result = execute("SELECT partition FROM %s WHERE token(partition) >= token(?) ORDER BY val ann of ? LIMIT 1000", i, queryVector); + assertThat(keys(result)).containsExactlyInAnyOrderElementsOf(keysWithLowerBound(vectorsByKey.keySet(), i, true)); + + result = execute("SELECT partition FROM %s WHERE token(partition) < token(?) ORDER BY val ann of ? LIMIT 1000", i, queryVector); + assertThat(keys(result)).containsExactlyInAnyOrderElementsOf(keysWithUpperBound(vectorsByKey.keySet(), i, false)); + + result = execute("SELECT partition FROM %s WHERE token(partition) <= token(?) ORDER BY val ann of ? LIMIT 1000", i, queryVector); + assertThat(keys(result)).containsExactlyInAnyOrderElementsOf(keysWithUpperBound(vectorsByKey.keySet(), i, true)); + + for (int j = 1; j <= nPartitions; j++) + { + result = execute("SELECT partition FROM %s WHERE token(partition) >= token(?) AND token(partition) <= token(?) ORDER BY val ann of ? LIMIT 1000", i, j, queryVector); + assertThat(keys(result)).containsExactlyInAnyOrderElementsOf(keysInBounds(vectorsByKey.keySet(), i, true, j, true)); + + result = execute("SELECT partition FROM %s WHERE token(partition) > token(?) AND token(partition) <= token(?) ORDER BY val ann of ? LIMIT 1000", i, j, queryVector); + assertThat(keys(result)).containsExactlyInAnyOrderElementsOf(keysInBounds(vectorsByKey.keySet(), i, false, j, true)); + + result = execute("SELECT partition FROM %s WHERE token(partition) >= token(?) AND token(partition) < token(?) ORDER BY val ann of ? LIMIT 1000", i, j, queryVector); + assertThat(keys(result)).containsExactlyInAnyOrderElementsOf(keysInBounds(vectorsByKey.keySet(), i, true, j, false)); + + result = execute("SELECT partition FROM %s WHERE token(partition) > token(?) AND token(partition) < token(?) ORDER BY val ann of ? LIMIT 1000", i, j, queryVector); + assertThat(keys(result)).containsExactlyInAnyOrderElementsOf(keysInBounds(vectorsByKey.keySet(), i, false, j, false)); + } + } + }; + + tester.apply(); + + flush(); + + tester.apply(); + } + + private Collection<Integer> keys(UntypedResultSet result) + { + List<Integer> keys = new ArrayList<>(result.size()); + for (UntypedResultSet.Row row : result) + keys.add(row.getInt("partition")); + return keys; + } + + private Collection<Integer> keysWithLowerBound(Collection<Integer> keys, int leftKey, boolean leftInclusive) + { + return keysInTokenRange(keys, partitioner.getToken(Int32Type.instance.decompose(leftKey)), leftInclusive, + partitioner.getMaximumToken().getToken(), true); + } + + private Collection<Integer> keysWithUpperBound(Collection<Integer> keys, int rightKey, boolean rightInclusive) + { + return keysInTokenRange(keys, partitioner.getMinimumToken().getToken(), true, + partitioner.getToken(Int32Type.instance.decompose(rightKey)), rightInclusive); + } + + private Collection<Integer> keysInBounds(Collection<Integer> keys, int leftKey, boolean leftInclusive, int rightKey, boolean rightInclusive) + { + return keysInTokenRange(keys, partitioner.getToken(Int32Type.instance.decompose(leftKey)), leftInclusive, + partitioner.getToken(Int32Type.instance.decompose(rightKey)), rightInclusive); + } + + private Collection<Integer> keysInTokenRange(Collection<Integer> keys, Token leftToken, boolean leftInclusive, Token rightToken, boolean rightInclusive) + { + long left = leftToken.getLongValue(); + long right = rightToken.getLongValue(); + return keys.stream() + .filter(k -> { + long t = partitioner.getToken(Int32Type.instance.decompose(k)).getLongValue(); + return (left < t || left == t && leftInclusive) && (t < right || t == right && rightInclusive); + }).collect(Collectors.toSet()); + } + + @Test + public void selectFloatVectorFunctions() + { + createTable(KEYSPACE, "CREATE TABLE %s (pk int primary key, value vector<float, 2>)"); + + // basic functionality + Vector<Float> q = vector(1f, 2f); + execute("INSERT INTO %s (pk, value) VALUES (0, ?)", vector(1f, 2f)); + execute("SELECT similarity_cosine(value, value) FROM %s WHERE pk=0"); + + // type inference checks + var result = execute("SELECT similarity_cosine(value, ?) FROM %s WHERe pk=0", q); + assertRows(result, row(1f)); + result = execute("SELECT similarity_euclidean(value, ?) FROM %s WHERe pk=0", q); + assertRows(result, row(1f)); + execute("SELECT similarity_cosine(?, value) FROM %s WHERE pk=0", q); + assertThatThrownBy(() -> execute("SELECT similarity_cosine(?, ?) FROM %s WHERE pk=0", q, q)) + .hasMessageContaining("Cannot infer type of argument ?"); + + // with explicit typing + execute("SELECT similarity_cosine((vector<float, 2>) ?, ?) FROM %s WHERE pk=0", q, q); + execute("SELECT similarity_cosine(?, (vector<float, 2>) ?) FROM %s WHERE pk=0", q, q); + execute("SELECT similarity_cosine((vector<float, 2>) ?, (vector<float, 2>) ?) FROM %s WHERE pk=0", q, q); + } + + @Test + public void selectSimilarityWithAnn() + { + createTable("CREATE TABLE %s (pk int, str_val text, val vector<float, 3>, PRIMARY KEY(pk))"); + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex'"); + + execute("INSERT INTO %s (pk, str_val, val) VALUES (0, 'A', [1.0, 2.0, 3.0])"); + execute("INSERT INTO %s (pk, str_val, val) VALUES (1, 'B', [2.0, 3.0, 4.0])"); + execute("INSERT INTO %s (pk, str_val, val) VALUES (2, 'C', [3.0, 4.0, 5.0])"); + execute("INSERT INTO %s (pk, str_val, val) VALUES (3, 'D', [4.0, 5.0, 6.0])"); + + Vector<Float> q = vector(1.5f, 2.5f, 3.5f); + var result = execute("SELECT str_val, similarity_cosine(val, ?) FROM %s ORDER BY val ANN OF ? LIMIT 2", + q, q); + + assertRowsIgnoringOrder(result, + row("A", 0.9987074f), + row("B", 0.9993764f)); + } + + @Test + public void castedTerminalFloatVectorFunctions() + { + createTable(KEYSPACE, "CREATE TABLE %s (pk int primary key, value vector<float, 2>)"); + + execute("INSERT INTO %s (pk, value) VALUES (0, ?)", vector(1f, 2f)); + execute("SELECT similarity_cosine(value, (vector<float, 2>) [1.0, 1.0]) FROM %s WHERE pk=0"); + execute("SELECT similarity_cosine((vector<float, 2>) [1.0, 1.0], value) FROM %s WHERE pk=0"); + execute("SELECT similarity_cosine((vector<float, 2>) [1.0, 1.0], (vector<float, 2>) [1.0, 1.0]) FROM %s WHERE pk=0"); + } + + @Test + public void inferredTerminalFloatVectorFunctions() throws Throwable Review Comment: It seems there is some overlapping between these two tests and `VectorFctsTest` ########## test/unit/org/apache/cassandra/inject/injections.md: ########## @@ -0,0 +1,356 @@ +<!--- + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +--> + +## Injecting hooks dynamically in testing + +The testing infrastructure is equipped with Byteman, which allows to +inject hooks into the running code. In other words, at any point in +the test, you can tell the particular nodes, to do something at exact +location in the code - like you would set a conditional breakpoint +and when the debugger breaks, you can check something or perform some +additional actions. In short, it is particularly useful to: + +- synchronizing the test code and the nodes +- synchronizing the nodes between each other +- running actions step by step +- counting invocations +- tracing invocations +- force throwing exceptions +- and combinations of the above... + +Here you can find a short introduction and some examples how can it be +used. + +### Basics + +Byteman works on rules which are the units defining what, when, and +where should be invoked. That is, a single rule includes: the invoke +point, bindings, run condition and actions. + +**Invoke point** is a location in the code where the actions should +be hooked. Usually it is a class and method, but it can be specified +more precisely, like - entering method, exiting method, invoking some +method inside, particular line of code and so on. + +**Bindings** are just some constant definitions which can be used then +in run condition and actions. + +**Run condition** is a logical expression which determines whether +the rule should be invoked. + +**Actions** are the statements to be invoked. + +In DSE, we have `Rule` class which reflects a single Byteman rule. Review Comment: This file contains multiple references to DSE. ########## src/java/org/apache/cassandra/io/util/MmappedRegions.java: ########## @@ -268,6 +270,18 @@ public ByteBuffer buffer() return buffer.duplicate(); } + public FloatBuffer floatBuffer() + { + // this does an implicit duplicate(), so we need to expose it directly to avoid doing it twice unnecessarily + return buffer.asFloatBuffer(); + } + + public IntBuffer intBuffer() Review Comment: Nit: Add `@Override` ########## src/java/org/apache/cassandra/index/sai/plan/Expression.java: ########## @@ -172,6 +178,9 @@ public Expression add(Operator op, ByteBuffer value) */ public boolean isSatisfiedBy(ByteBuffer columnValue) { + if (validator.isVector()) Review Comment: Nit: A brief comment on why we do this could be useful for readers without context. ########## src/java/org/apache/cassandra/index/sai/VectorQueryContext.java: ########## @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashSet; +import java.util.NavigableSet; +import java.util.Set; +import java.util.TreeSet; + +import io.github.jbellis.jvector.util.Bits; +import org.apache.cassandra.db.ReadCommand; +import org.apache.cassandra.index.sai.disk.PrimaryKeyMap; +import org.apache.cassandra.index.sai.disk.v1.segment.SegmentMetadata; +import org.apache.cassandra.index.sai.disk.v1.vector.CassandraDiskAnn; +import org.apache.cassandra.index.sai.disk.v1.vector.CassandraOnHeapGraph; +import org.apache.cassandra.index.sai.utils.PrimaryKey; + + +/** + * This represents the state of a vector query. It is repsonsible for maintaining a list of any {@link PrimaryKey}s + * that have been updated or deleted during a search of the indexes. + * <p> + * The number of {@link #shadowedPrimaryKeys} is compared before and after a search is performed. If it changes, it + * means that a {@link PrimaryKey} was found to have been changed. In this case the whole search is repeated until the + * counts match. + * <p> + * When this process has completed, a {@link Bits} array is generated. This is used by the vector graph search to + * identify which nodes in the graph to include in the results. + */ +public class VectorQueryContext +{ + private TreeSet<PrimaryKey> shadowedPrimaryKeys; // allocate when needed Review Comment: It might me helpful to add a brief comment saying what's a shadowed primary key. ########## src/java/org/apache/cassandra/io/util/MmappedRegions.java: ########## @@ -268,6 +270,18 @@ public ByteBuffer buffer() return buffer.duplicate(); } + public FloatBuffer floatBuffer() Review Comment: Nit: Add `@Override` ########## src/java/org/apache/cassandra/index/sai/disk/v1/vector/OnDiskOrdinalsMap.java: ########## @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.disk.v1.vector; + +import java.io.IOException; +import java.util.HashSet; +import java.util.Set; + +import com.google.common.base.Preconditions; + +import io.github.jbellis.jvector.util.Bits; +import org.apache.cassandra.io.util.FileHandle; +import org.apache.cassandra.io.util.RandomAccessReader; + +public class OnDiskOrdinalsMap +{ + private final FileHandle fh; + private final long ordToRowOffset; + private final long segmentEnd; + private final int size; + // the offset where we switch from recording ordinal -> rows, to row -> ordinal + private final long rowOrdinalOffset; + private final Set<Integer> deletedOrdinals; + + public OnDiskOrdinalsMap(FileHandle fh, long segmentOffset, long segmentLength) + { + deletedOrdinals = new HashSet<>(); + + this.segmentEnd = segmentOffset + segmentLength; + this.fh = fh; + try (var reader = fh.createReader()) + { + reader.seek(segmentOffset); + int deletedCount = reader.readInt(); + for (var i = 0; i < deletedCount; i++) + { + deletedOrdinals.add(reader.readInt()); + } + + this.ordToRowOffset = reader.getFilePointer(); + this.size = reader.readInt(); + reader.seek(segmentEnd - 8); + this.rowOrdinalOffset = reader.readLong(); + assert rowOrdinalOffset < segmentEnd : "rowOrdinalOffset " + rowOrdinalOffset + " is not less than segmentEnd " + segmentEnd; + } + catch (Exception e) + { + throw new RuntimeException("Error initializing OnDiskOrdinalsMap at segment " + segmentOffset, e); + } + } + + public RowIdsView getRowIdsView() + { + return new RowIdsView(); + } + + public Bits ignoringDeleted(Bits acceptBits) + { + return BitsUtil.bitsIgnoringDeleted(acceptBits, deletedOrdinals); + } + + public class RowIdsView implements AutoCloseable + { + final RandomAccessReader reader = fh.createReader(); + + public int[] getSegmentRowIdsMatching(int vectorOrdinal) throws IOException + { + Preconditions.checkArgument(vectorOrdinal < size, "vectorOrdinal %s is out of bounds %s", vectorOrdinal, size); + + // read index entry + try + { + reader.seek(ordToRowOffset + 4L + vectorOrdinal * 8L); + } + catch (Exception e) + { + throw new RuntimeException(String.format("Error seeking to index offset for ordinal %d with ordToRowOffset %d", + vectorOrdinal, ordToRowOffset), e); + } + var offset = reader.readLong(); + // seek to and read rowIds + try + { + reader.seek(offset); + } + catch (Exception e) + { + throw new RuntimeException(String.format("Error seeking to rowIds offset for ordinal %d with ordToRowOffset %d", + vectorOrdinal, ordToRowOffset), e); + } + var postingsSize = reader.readInt(); + var rowIds = new int[postingsSize]; + for (var i = 0; i < rowIds.length; i++) + { + rowIds[i] = reader.readInt(); + } + return rowIds; + } + + @Override + public void close() + { + reader.close(); + } + } + + public OrdinalsView getOrdinalsView() + { + return new OrdinalsView(); + } + + public class OrdinalsView implements AutoCloseable + { + final RandomAccessReader reader = fh.createReader(); + private final long high = (segmentEnd - 8 - rowOrdinalOffset) / 8; + + /** + * @return order if given row id is found; otherwise return -1 + */ + public int getOrdinalForRowId(int rowId) throws IOException + { + // Compute the offset of the start of the rowId to vectorOrdinal mapping + long index = DiskBinarySearch.searchInt(0, Math.toIntExact(high), rowId, i -> { + try + { + long offset = rowOrdinalOffset + i * 8; + reader.seek(offset); + return reader.readInt(); + } + catch (IOException e) + { + throw new RuntimeException(e); + } + }); + + // not found + if (index < 0) + return -1; + + return reader.readInt(); + } + + @Override + public void close() + { + reader.close(); + } + } + + public void close() Review Comment: Nit: Maybe this class could implement `Closeable`, so this is annotated with override and more visible for callers? ########## test/distributed/org/apache/cassandra/distributed/test/sai/VectorDistributedTest.java: ########## @@ -0,0 +1,476 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.distributed.test.sai; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.Multimap; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; + +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import org.apache.cassandra.cql3.statements.SelectStatement; +import org.apache.cassandra.db.Keyspace; +import org.apache.cassandra.db.marshal.Int32Type; +import org.apache.cassandra.dht.Murmur3Partitioner; +import org.apache.cassandra.distributed.Cluster; +import org.apache.cassandra.distributed.api.ConsistencyLevel; +import org.apache.cassandra.distributed.test.TestBaseImpl; +import org.apache.cassandra.index.sai.SAITester; +import org.apache.cassandra.index.sai.disk.v1.IndexWriterConfig; + +import static org.apache.cassandra.distributed.api.Feature.GOSSIP; +import static org.apache.cassandra.distributed.api.Feature.NETWORK; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class VectorDistributedTest extends TestBaseImpl +{ + + @Rule + public SAITester.FailureWatcher failureRule = new SAITester.FailureWatcher(); + + private static final String CREATE_KEYSPACE = "CREATE KEYSPACE %%s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': %d}"; + private static final String CREATE_TABLE = "CREATE TABLE %%s (pk int primary key, val vector<float, %d>)"; + private static final String CREATE_INDEX = "CREATE CUSTOM INDEX ON %%s(%s) USING 'StorageAttachedIndex'"; + + private static final VectorSimilarityFunction function = IndexWriterConfig.DEFAULT_SIMILARITY_FUNCTION; + + private static final String INVALID_LIMIT_MESSAGE = "Use of ANN OF in an ORDER BY clause requires a LIMIT that is not greater than 1000"; Review Comment: Nit: unused constant ########## src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java: ########## @@ -105,7 +104,25 @@ public PartitionIterator filterReplicaFilteringProtection(PartitionIterator full @Override public UnfilteredPartitionIterator search(ReadExecutionController executionController) throws RequestTimeoutException { - return new ResultRetriever(queryController, executionController, queryContext, keyFactory); + if (!command.isTopK()) + return new ResultRetriever(queryController, executionController, queryContext, false); + else + { + Supplier<ResultRetriever> resultSupplier = () -> new ResultRetriever(queryController, executionController, queryContext, true); + + // VSTODO performance: if there is shadowed primary keys, we have to at least query twice. + // First time to find out there are shawdow keys, second time to find out there are no more shadow keys. Review Comment: ```suggestion // First time to find out there are shadow keys, second time to find out there are no more shadow keys. ``` ########## src/java/org/apache/cassandra/index/sai/memory/MemtableOrdering.java: ########## @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.memory; + +import org.apache.cassandra.index.sai.QueryContext; +import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; +import org.apache.cassandra.index.sai.plan.Expression; + +/*** + * Analogue of SegmentOrdering, but for memtables. Review Comment: ```suggestion * Analogue of {@link org.apache.cassandra.index.sai.disk.v1.segment.SegmentOrdering}, but for memtables. ``` ########## test/unit/org/apache/cassandra/index/sai/cql/VectorSegmentationTest.java: ########## @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.cql; + +import java.util.ArrayList; +import java.util.List; + +import org.junit.Test; + +import org.apache.cassandra.cql3.UntypedResultSet; +import org.apache.cassandra.db.marshal.FloatType; +import org.apache.cassandra.db.marshal.VectorType; +import org.apache.cassandra.index.sai.disk.v1.segment.SegmentBuilder; + +import static org.assertj.core.api.Assertions.assertThat; + +public class VectorSegmentationTest extends VectorTester +{ + private static final int dimension = 100; + + @Test + public void testMultipleSegmentsForCreatingIndex() throws Throwable + { + createTable("CREATE TABLE %s (pk int, val vector<float, " + dimension + ">, PRIMARY KEY(pk))"); + + int vectorCount = 100; + List<float[]> vectors = new ArrayList<>(); + for (int row = 0; row < vectorCount; row++) + { + float[] vector = nextVector(); + vectors.add(vector); + execute("INSERT INTO %s (pk, val) VALUES (?, ?)", row, vector(vector)); + } + + flush(); + + SegmentBuilder.updateLastValidSegmentRowId(17); // 17 rows per segment + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex'"); + + int limit = 35; + float[] queryVector = nextVector(); + UntypedResultSet resultSet = execute("SELECT * FROM %s ORDER BY val ANN OF ? LIMIT " + limit, vector(queryVector)); + assertThat(resultSet.size()).isEqualTo(limit); + + List<float[]> resultVectors = getVectorsFromResult(resultSet); + double recall = rawIndexedRecall(vectors, queryVector, resultVectors, limit); + assertThat(recall).isGreaterThanOrEqualTo(0.99); + } + + @Test + public void testMultipleSegmentsForCompaction() throws Throwable + { + createTable("CREATE TABLE %s (pk int, val vector<float, " + dimension + ">, PRIMARY KEY(pk))"); + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex'"); + + List<float[]> vectors = new ArrayList<>(); + int rowsPerSSTable = 10; + int sstables = 5; + int pk = 0; + for (int i = 0; i < sstables; i++) + { + for (int row = 0; row < rowsPerSSTable; row++) + { + float[] vector = nextVector(); + execute("INSERT INTO %s (pk, val) VALUES (?, ?)", pk++, vector(vector)); + vectors.add(vector); + } + + flush(); + } + + int limit = 30; + float[] queryVector = nextVector(); + UntypedResultSet resultSet = execute("SELECT * FROM %s ORDER BY val ANN OF ? LIMIT " + limit, vector(queryVector)); + assertThat(resultSet.size()).isEqualTo(limit); + + List<float[]> resultVectors = getVectorsFromResult(resultSet); + double recall = rawIndexedRecall(vectors, queryVector, resultVectors, limit); + assertThat(recall).isGreaterThanOrEqualTo(0.99); + + + SegmentBuilder.updateLastValidSegmentRowId(11); // 11 rows per segment + compact(); + + queryVector = nextVector(); + resultSet = execute("SELECT * FROM %s ORDER BY val ANN OF ? LIMIT " + limit, vector(queryVector)); + assertThat(resultSet.size()).isEqualTo(limit); + + resultVectors = getVectorsFromResult(resultSet); + recall = rawIndexedRecall(vectors, queryVector, resultVectors, limit); + assertThat(recall).isGreaterThanOrEqualTo(0.99); + } + + protected Vector<Float> vector(float[] values) + { + Float[] floats = new Float[values.length]; + for (int i = 0; i < values.length; i++) + floats[i] = values[i]; + + return new Vector<>(floats); + } + + private float[] nextVector() + { + float[] rawVector = new float[dimension]; + for (int i = 0; i < dimension; i++) + { + rawVector[i] = getRandom().nextFloat(); + } + return rawVector; + } + + private List<float[]> getVectorsFromResult(UntypedResultSet result) Review Comment: Nit: can be `static` ########## test/unit/org/apache/cassandra/index/sai/utils/Glove.java: ########## @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.utils; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class Glove Review Comment: This could have some class JavaDoc -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]

