This is an automated email from the ASF dual-hosted git repository. spmallette pushed a commit to branch TINKERPOP-3158 in repository https://gitbox.apache.org/repos/asf/tinkerpop.git
commit 36e90a12611ef2d08c9213c3c6d54663e22f7a12 Author: Stephen Mallette <[email protected]> AuthorDate: Wed Apr 30 14:08:02 2025 -0400 Intro vector search via HNSW --- tinkergraph-gremlin/pom.xml | 5 + .../tinkergraph/structure/AbstractTinkerGraph.java | 21 +- .../tinkergraph/structure/AbstractTinkerIndex.java | 16 +- .../gremlin/tinkergraph/structure/TinkerGraph.java | 109 +++++- .../gremlin/tinkergraph/structure/TinkerIndex.java | 4 +- .../tinkergraph/structure/TinkerIndexHelper.java | 63 +++- .../tinkergraph/structure/TinkerIndexType.java | 66 ++++ .../structure/TinkerTransactionGraph.java | 8 +- .../structure/TinkerTransactionalIndex.java | 4 +- .../tinkergraph/structure/TinkerVectorIndex.java | 383 +++++++++++++++++++++ .../structure/TinkerGraphVectorIndexTest.java | 249 ++++++++++++++ 11 files changed, 902 insertions(+), 26 deletions(-) diff --git a/tinkergraph-gremlin/pom.xml b/tinkergraph-gremlin/pom.xml index 47c5c1c1f3..024fce3a03 100644 --- a/tinkergraph-gremlin/pom.xml +++ b/tinkergraph-gremlin/pom.xml @@ -35,6 +35,11 @@ limitations under the License. <groupId>org.apache.commons</groupId> <artifactId>commons-lang3</artifactId> </dependency> + <dependency> + <groupId>com.github.jelmerk</groupId> + <artifactId>hnswlib-core</artifactId> + <version>1.2.1</version> + </dependency> <dependency> <groupId>com.google.inject</groupId> <artifactId>guice</artifactId> diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerGraph.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerGraph.java index 34ad59357a..ef2f65af01 100644 --- a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerGraph.java +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerGraph.java @@ -38,6 +38,7 @@ import org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerServiceRegistry; import java.io.File; import java.lang.reflect.InvocationTargetException; import java.util.Collections; +import java.util.HashSet; import java.util.Iterator; import java.util.Map; import java.util.Set; @@ -71,6 +72,8 @@ public abstract class AbstractTinkerGraph implements Graph { protected TinkerGraphComputerView graphComputerView = null; protected AbstractTinkerIndex<TinkerVertex> vertexIndex = null; protected AbstractTinkerIndex<TinkerEdge> edgeIndex = null; + protected TinkerVectorIndex<TinkerVertex> vertexVectorIndex = null; + protected TinkerVectorIndex<TinkerEdge> edgeVectorIndex = null; protected IdManager<Vertex> vertexIdManager; protected IdManager<Edge> edgeIdManager; @@ -266,6 +269,8 @@ public abstract class AbstractTinkerGraph implements Graph { this.currentId.set(-1L); this.vertexIndex = null; this.edgeIndex = null; + this.vertexVectorIndex = null; + this.edgeVectorIndex = null; this.graphComputerView = null; this.vertexProperties.clear(); } @@ -396,9 +401,19 @@ public abstract class AbstractTinkerGraph implements Graph { */ public <E extends Element> Set<String> getIndexedKeys(final Class<E> elementClass) { if (Vertex.class.isAssignableFrom(elementClass)) { - return null == this.vertexIndex ? Collections.emptySet() : this.vertexIndex.getIndexedKeys(); + Set<String> keys = new HashSet<>(); + if (this.vertexIndex != null) + keys.addAll(this.vertexIndex.getIndexedKeys()); + if (this.vertexVectorIndex != null) + keys.addAll(this.vertexVectorIndex.getIndexedKeys()); + return keys; } else if (Edge.class.isAssignableFrom(elementClass)) { - return null == this.edgeIndex ? Collections.emptySet() : this.edgeIndex.getIndexedKeys(); + Set<String> keys = new HashSet<>(); + if (this.edgeIndex != null) + keys.addAll(this.edgeIndex.getIndexedKeys()); + if (this.edgeVectorIndex != null) + keys.addAll(this.edgeVectorIndex.getIndexedKeys()); + return keys; } else { throw new IllegalArgumentException("Class is not indexable: " + elementClass); } @@ -575,7 +590,7 @@ public abstract class AbstractTinkerGraph implements Graph { return true; } }, - + /** * Manages identifiers of type {@code String}. */ diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerIndex.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerIndex.java index 877e926c51..80c50266db 100644 --- a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerIndex.java +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerIndex.java @@ -21,8 +21,10 @@ package org.apache.tinkerpop.gremlin.tinkergraph.structure; import org.apache.tinkerpop.gremlin.structure.Element; +import java.util.Collections; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Set; /** @@ -85,16 +87,24 @@ public abstract class AbstractTinkerIndex<T extends Element> { public abstract void autoUpdate(final String key, final Object newValue, final Object oldValue, final T element); /** - * Create new index + * Create new index with no configuration. * @param key property key */ - public abstract void createKeyIndex(final String key); + public void createIndex(final String key) { + createIndex(key, Collections.emptyMap()); + } + + /** + * Create new index with a specified configuration. + * @param key property key + */ + public abstract void createIndex(final String key, final Map<String, Object> configuration); /** * Drop index * @param key property key */ - public abstract void dropKeyIndex(final String key); + public abstract void dropIndex(final String key); /** * Get all index keys for Graph diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraph.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraph.java index 28f3120e3b..a54df6a858 100644 --- a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraph.java +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraph.java @@ -35,6 +35,7 @@ import org.apache.tinkerpop.gremlin.tinkergraph.process.traversal.strategy.optim import org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerServiceRegistry; import org.apache.tinkerpop.gremlin.util.iterator.IteratorUtils; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; @@ -371,11 +372,10 @@ public class TinkerGraph extends AbstractTinkerGraph { } - ///////////// GRAPH SPECIFIC INDEXING METHODS /////////////// /** - * Create an index for said element class ({@link Vertex} or {@link Edge}) and said property key. + * Create a default index for said element class ({@link Vertex} or {@link Edge}) and said property key. * Whenever an element has the specified key mutated, the index is updated. * When the index is created, all existing elements are indexed to ensure that they are captured by the index. * @@ -384,14 +384,43 @@ public class TinkerGraph extends AbstractTinkerGraph { * @param <E> The type of the element class */ public <E extends Element> void createIndex(final String key, final Class<E> elementClass) { - if (Vertex.class.isAssignableFrom(elementClass)) { - if (null == this.vertexIndex) this.vertexIndex = new TinkerIndex<>(this, TinkerVertex.class); - this.vertexIndex.createKeyIndex(key); - } else if (Edge.class.isAssignableFrom(elementClass)) { - if (null == this.edgeIndex) this.edgeIndex = new TinkerIndex<>(this, TinkerEdge.class); - this.edgeIndex.createKeyIndex(key); + createIndex(TinkerIndexType.DEFAULT, key, elementClass, Collections.emptyMap()); + } + + /** + * Create an index for said element class ({@link Vertex} or {@link Edge}) and said property key with the given + * configuration options. Whenever an element has the specified key mutated, the index is updated. When the index + * is created, all existing elements are indexed to ensure that they are captured by the index. + * + * @param indexType the type of the index + * @param key the property key to index + * @param elementClass the element class to index + * @param configuration the configuration options + * @param <E> The type of the element class + */ + public <E extends Element> void createIndex(final TinkerIndexType indexType, final String key, + final Class<E> elementClass, final Map<String, Object> configuration) { + if (TinkerIndexType.VECTOR == indexType) { + if (Vertex.class.isAssignableFrom(elementClass)) { + if (null == this.vertexVectorIndex) this.vertexVectorIndex = new TinkerVectorIndex<>(this, TinkerVertex.class); + this.vertexVectorIndex.createIndex(key, configuration); + } else if (Edge.class.isAssignableFrom(elementClass)) { + if (null == this.edgeVectorIndex) this.edgeVectorIndex = new TinkerVectorIndex<>(this, TinkerEdge.class); + this.edgeVectorIndex.createIndex(key, configuration); + } else { + throw new IllegalArgumentException("Class is not indexable: " + elementClass); + } } else { - throw new IllegalArgumentException("Class is not indexable: " + elementClass); + // Create a standard index + if (Vertex.class.isAssignableFrom(elementClass)) { + if (null == this.vertexIndex) this.vertexIndex = new TinkerIndex<>(this, TinkerVertex.class); + this.vertexIndex.createIndex(key); + } else if (Edge.class.isAssignableFrom(elementClass)) { + if (null == this.edgeIndex) this.edgeIndex = new TinkerIndex<>(this, TinkerEdge.class); + this.edgeIndex.createIndex(key); + } else { + throw new IllegalArgumentException("Class is not indexable: " + elementClass); + } } } @@ -404,11 +433,69 @@ public class TinkerGraph extends AbstractTinkerGraph { */ public <E extends Element> void dropIndex(final String key, final Class<E> elementClass) { if (Vertex.class.isAssignableFrom(elementClass)) { - if (null != this.vertexIndex) this.vertexIndex.dropKeyIndex(key); + if (null != this.vertexIndex) this.vertexIndex.dropIndex(key); + if (null != this.vertexVectorIndex) this.vertexVectorIndex.dropIndex(key); } else if (Edge.class.isAssignableFrom(elementClass)) { - if (null != this.edgeIndex) this.edgeIndex.dropKeyIndex(key); + if (null != this.edgeIndex) this.edgeIndex.dropIndex(key); + if (null != this.edgeVectorIndex) this.edgeVectorIndex.dropIndex(key); } else { throw new IllegalArgumentException("Class is not indexable: " + elementClass); } } + + /** + * Find the nearest vertices to the given vector in the vector index for the specified property key. + * + * @param key the property key + * @param vector the query vector + * @param k the number of nearest neighbors to return + * @return a list of vertices sorted by distance + */ + public List<Vertex> findNearestVertices(final String key, final float[] vector, final int k) { + if (null == this.vertexVectorIndex) + return Collections.emptyList(); + return new ArrayList<>(this.vertexVectorIndex.findNearest(key, vector, k)); + } + + /** + * Find the nearest vertices to the given vector in the vector index for the specified property key. + * Uses the default number of nearest neighbors. + * + * @param key the property key + * @param vector the query vector + * @return a list of vertices sorted by distance + */ + public List<Vertex> findNearestVertices(final String key, final float[] vector) { + if (null == this.vertexVectorIndex) + return Collections.emptyList(); + return new ArrayList<>(this.vertexVectorIndex.findNearest(key, vector)); + } + + /** + * Find the nearest edges to the given vector in the vector index for the specified property key. + * + * @param key the property key + * @param vector the query vector + * @param k the number of nearest neighbors to return + * @return a list of edges sorted by distance + */ + public List<Edge> findNearestEdges(final String key, final float[] vector, final int k) { + if (null == this.edgeVectorIndex) + return Collections.emptyList(); + return new ArrayList<>(this.edgeVectorIndex.findNearest(key, vector, k)); + } + + /** + * Find the nearest edges to the given vector in the vector index for the specified property key. + * Uses the default number of nearest neighbors. + * + * @param key the property key + * @param vector the query vector + * @return a list of edges sorted by distance + */ + public List<Edge> findNearestEdges(final String key, final float[] vector) { + if (null == this.edgeVectorIndex) + return Collections.emptyList(); + return new ArrayList<>(this.edgeVectorIndex.findNearest(key, vector)); + } } diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerIndex.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerIndex.java index 366bc71b4e..8b4ef39fb1 100644 --- a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerIndex.java +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerIndex.java @@ -118,7 +118,7 @@ final class TinkerIndex<T extends Element> extends AbstractTinkerIndex<T> { } @Override - public void createKeyIndex(final String key) { + public void createIndex(final String key, final Map<String,Object> configuration) { if (null == key) throw Graph.Exceptions.argumentCanNotBeNull("key"); if (key.isEmpty()) @@ -138,7 +138,7 @@ final class TinkerIndex<T extends Element> extends AbstractTinkerIndex<T> { } @Override - public void dropKeyIndex(final String key) { + public void dropIndex(final String key) { if (this.index.containsKey(key)) this.index.remove(key).clear(); diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerIndexHelper.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerIndexHelper.java index b4a9e7e79c..7c37b59757 100644 --- a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerIndexHelper.java +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerIndexHelper.java @@ -23,7 +23,56 @@ import java.util.List; public final class TinkerIndexHelper { - private TinkerIndexHelper() { + private TinkerIndexHelper() {} + + /** + * Find nearest neighbors in the vertex vector index. + * + * @param graph the graph + * @param key the property key + * @param vector the query vector + * @param k the number of nearest neighbors to return + * @return a list of vertices sorted by distance + */ + public static List<TinkerVertex> findNearestVertices(final AbstractTinkerGraph graph, final String key, final float[] vector, final int k) { + return null == graph.vertexVectorIndex ? Collections.emptyList() : graph.vertexVectorIndex.findNearest(key, vector, k); + } + + /** + * Find nearest neighbors in the vertex vector index with default k. + * + * @param graph the graph + * @param key the property key + * @param vector the query vector + * @return a list of vertices sorted by distance + */ + public static List<TinkerVertex> findNearestVertices(final AbstractTinkerGraph graph, final String key, final float[] vector) { + return null == graph.vertexVectorIndex ? Collections.emptyList() : graph.vertexVectorIndex.findNearest(key, vector); + } + + /** + * Find nearest neighbors in the edge vector index. + * + * @param graph the graph + * @param key the property key + * @param vector the query vector + * @param k the number of nearest neighbors to return + * @return a list of edges sorted by distance + */ + public static List<TinkerEdge> findNearestEdges(final AbstractTinkerGraph graph, final String key, final float[] vector, final int k) { + return null == graph.edgeVectorIndex ? Collections.emptyList() : graph.edgeVectorIndex.findNearest(key, vector, k); + } + + /** + * Find nearest neighbors in the edge vector index with default k. + * + * @param graph the graph + * @param key the property key + * @param vector the query vector + * @return a list of edges sorted by distance + */ + public static List<TinkerEdge> findNearestEdges(final AbstractTinkerGraph graph, final String key, final float[] vector) { + return null == graph.edgeVectorIndex ? Collections.emptyList() : graph.edgeVectorIndex.findNearest(key, vector); } public static List<TinkerVertex> queryVertexIndex(final AbstractTinkerGraph graph, final String key, final Object value) { @@ -38,35 +87,47 @@ public final class TinkerIndexHelper { final AbstractTinkerGraph graph = (AbstractTinkerGraph) edge.graph(); if (graph.edgeIndex != null) graph.edgeIndex.autoUpdate(key, newValue, oldValue, edge); + if (graph.edgeVectorIndex != null && newValue instanceof float[]) + graph.edgeVectorIndex.autoUpdate(key, newValue, oldValue, edge); } public static void autoUpdateIndex(final TinkerVertex vertex, final String key, final Object newValue, final Object oldValue) { final AbstractTinkerGraph graph = (AbstractTinkerGraph) vertex.graph(); if (graph.vertexIndex != null) graph.vertexIndex.autoUpdate(key, newValue, oldValue, vertex); + if (graph.vertexVectorIndex != null && newValue instanceof float[]) + graph.vertexVectorIndex.autoUpdate(key, newValue, oldValue, vertex); } public static void removeElementIndex(final TinkerVertex vertex) { final AbstractTinkerGraph graph = (AbstractTinkerGraph) vertex.graph(); if (graph.vertexIndex != null) graph.vertexIndex.removeElement(vertex); + if (graph.vertexVectorIndex != null) + graph.vertexVectorIndex.removeElement(vertex); } public static void removeElementIndex(final TinkerEdge edge) { final AbstractTinkerGraph graph = (AbstractTinkerGraph) edge.graph(); if (graph.edgeIndex != null) graph.edgeIndex.removeElement(edge); + if (graph.edgeVectorIndex != null) + graph.edgeVectorIndex.removeElement(edge); } public static void removeIndex(final TinkerVertex vertex, final String key, final Object value) { final AbstractTinkerGraph graph = (AbstractTinkerGraph) vertex.graph(); if (graph.vertexIndex != null) graph.vertexIndex.remove(key, value, vertex); + if (graph.vertexVectorIndex != null && value instanceof float[]) + graph.vertexVectorIndex.remove(key, value, vertex); } public static void removeIndex(final TinkerEdge edge, final String key, final Object value) { final AbstractTinkerGraph graph = (AbstractTinkerGraph) edge.graph(); if (graph.edgeIndex != null) graph.edgeIndex.remove(key, value, edge); + if (graph.edgeVectorIndex != null && value instanceof float[]) + graph.edgeVectorIndex.remove(key, value, edge); } } diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerIndexType.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerIndexType.java new file mode 100644 index 0000000000..84004345e2 --- /dev/null +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerIndexType.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.tinkerpop.gremlin.tinkergraph.structure; + +import com.github.jelmerk.hnswlib.core.DistanceFunction; +import com.github.jelmerk.hnswlib.core.DistanceFunctions; + +/** + * Enum for the different types of indices supported by TinkerGraph + */ +public enum TinkerIndexType { + /** + * Standard key-based index + */ + DEFAULT, + + /** + * Vector index for similarity search + */ + VECTOR; + + /** + * Distance functions for vector index. + */ + public enum Vector implements VectorDistance<float[], Float> { + + COSINE(DistanceFunctions.FLOAT_COSINE_DISTANCE), + EUCLIDEAN(DistanceFunctions.FLOAT_EUCLIDEAN_DISTANCE), + MANHATTAN(DistanceFunctions.FLOAT_MANHATTAN_DISTANCE), + INNER_PRODUCT(DistanceFunctions.FLOAT_INNER_PRODUCT), + BRAY_CURTIS(DistanceFunctions.FLOAT_BRAY_CURTIS_DISTANCE), + CANBERRA(DistanceFunctions.FLOAT_CANBERRA_DISTANCE), + CORRELATION(DistanceFunctions.FLOAT_CORRELATION_DISTANCE); + + private final DistanceFunction<float[], Float> distanceFunction; + + Vector(final DistanceFunction<float[], Float> distanceFunction) { + this.distanceFunction = distanceFunction; + } + + @Override + public DistanceFunction<float[], Float> getDistanceFunction() { + return distanceFunction; + } + } + + interface VectorDistance<V, T> { + DistanceFunction<V, T> getDistanceFunction(); + } +} diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionGraph.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionGraph.java index 3e6931a1d8..f4cd761a2c 100644 --- a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionGraph.java +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionGraph.java @@ -476,10 +476,10 @@ public final class TinkerTransactionGraph extends AbstractTinkerGraph { public <E extends Element> void createIndex(final String key, final Class<E> elementClass) { if (Vertex.class.isAssignableFrom(elementClass)) { if (null == this.vertexIndex) this.vertexIndex = new TinkerTransactionalIndex<>(this, TinkerVertex.class); - this.vertexIndex.createKeyIndex(key); + this.vertexIndex.createIndex(key); } else if (Edge.class.isAssignableFrom(elementClass)) { if (null == this.edgeIndex) this.edgeIndex = new TinkerTransactionalIndex<>(this, TinkerEdge.class); - this.edgeIndex.createKeyIndex(key); + this.edgeIndex.createIndex(key); } else { throw new IllegalArgumentException("Class is not indexable: " + elementClass); } @@ -494,9 +494,9 @@ public final class TinkerTransactionGraph extends AbstractTinkerGraph { */ public <E extends Element> void dropIndex(final String key, final Class<E> elementClass) { if (Vertex.class.isAssignableFrom(elementClass)) { - if (null != this.vertexIndex) this.vertexIndex.dropKeyIndex(key); + if (null != this.vertexIndex) this.vertexIndex.dropIndex(key); } else if (Edge.class.isAssignableFrom(elementClass)) { - if (null != this.edgeIndex) this.edgeIndex.dropKeyIndex(key); + if (null != this.edgeIndex) this.edgeIndex.dropIndex(key); } else { throw new IllegalArgumentException("Class is not indexable: " + elementClass); } diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionalIndex.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionalIndex.java index d23dcb591a..0edfdaaaa3 100644 --- a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionalIndex.java +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionalIndex.java @@ -171,7 +171,7 @@ final class TinkerTransactionalIndex<T extends TinkerElement> extends AbstractTi } @Override - public void createKeyIndex(final String key) { + public void createIndex(final String key, final Map<String,Object> configuration) { if (null == key) throw Graph.Exceptions.argumentCanNotBeNull("key"); if (key.isEmpty()) @@ -192,7 +192,7 @@ final class TinkerTransactionalIndex<T extends TinkerElement> extends AbstractTi } @Override - public void dropKeyIndex(final String key) { + public void dropIndex(final String key) { if (index.containsKey(key)) index.remove(key).clear(); diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerVectorIndex.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerVectorIndex.java new file mode 100644 index 0000000000..a31c4148cd --- /dev/null +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerVectorIndex.java @@ -0,0 +1,383 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.tinkerpop.gremlin.tinkergraph.structure; + +import com.github.jelmerk.hnswlib.core.DistanceFunction; +import com.github.jelmerk.hnswlib.core.DistanceFunctions; +import com.github.jelmerk.hnswlib.core.Item; +import com.github.jelmerk.hnswlib.core.SearchResult; +import com.github.jelmerk.hnswlib.core.hnsw.HnswIndex; +import com.github.jelmerk.hnswlib.core.Index; +import org.apache.tinkerpop.gremlin.structure.Element; +import org.apache.tinkerpop.gremlin.structure.Graph; +import org.apache.tinkerpop.gremlin.structure.Property; +import org.apache.tinkerpop.gremlin.structure.Vertex; + +import java.io.Serializable; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Collectors; + +/** + * A vector index implementation for TinkerGraph using hnswlib. + * + * @param <T> the element type (Vertex or Edge) + */ +final class TinkerVectorIndex<T extends Element> extends AbstractTinkerIndex<T> { + + /** + * Map of property key to vector index + */ + protected Map<String, Index<Object, float[], ElementItem, Float>> vectorIndices = new ConcurrentHashMap<>(); + + /** + * Default number of nearest neighbors to return + */ + private static final int DEFAULT_K = 10; + + /** + * Default M parameter for HNSW index + */ + private static final int DEFAULT_M = 16; + + /** + * Default ef construction parameter for HNSW index + */ + private static final int DEFAULT_EF_CONSTRUCTION = 200; + + /** + * Default ef parameter for HNSW index + */ + private static final int DEFAULT_EF = 10; + + /** + * Default maximum number of items in the index + */ + private static final int DEFAULT_MAX_ITEMS = 100; + + /** + * Configuration key for the dimension of the vector + */ + public static final String CONFIG_DIMENSION = "dimension"; + + /** + * Configuration key for the M parameter of the HNSW index + */ + public static final String CONFIG_M = "m"; + + /** + * Configuration key for the ef construction parameter of the HNSW index + */ + public static final String CONFIG_EF_CONSTRUCTION = "efConstruction"; + + /** + * Configuration key for the ef parameter of the HNSW index + */ + public static final String CONFIG_EF = "ef"; + + /** + * Configuration key for the maximum number of items in the index + */ + public static final String CONFIG_MAX_ITEMS = "maxItems"; + + /** + * Configuration key for the distance function of the HNSW index + */ + public static final String CONFIG_DISTANCE_FUNCTION = "distanceFunction"; + + /** + * Configuration key for the default number of nearest neighbors to return + */ + public static final String CONFIG_DEFAULT_K = "defaultK"; + + /** + * Creates a new vector index for the specified graph and element class. + * + * @param graph the graph + * @param indexClass the element class + */ + public TinkerVectorIndex(final TinkerGraph graph, final Class<T> indexClass) { + super(graph, indexClass); + } + + /** + * Creates a vector index for the specified property key with the given configuration options. + * + * @param key the property key + * @param configuration the configuration options + */ + @Override + public void createIndex(final String key, final Map<String, Object> configuration) { + if (null == key) + throw Graph.Exceptions.argumentCanNotBeNull("key"); + if (key.isEmpty()) + throw new IllegalArgumentException("The key for the index cannot be an empty string"); + + // Get dimension from configuration or throw exception if not provided + if (!configuration.containsKey(CONFIG_DIMENSION)) + throw new IllegalArgumentException("The dimension must be provided in the configuration"); + + final int dimension; + final Object dimObj = configuration.get(CONFIG_DIMENSION); + if (dimObj instanceof Number) { + dimension = ((Number) dimObj).intValue(); + } else { + throw new IllegalArgumentException("The dimension must be a number"); + } + + if (dimension <= 0) + throw new IllegalArgumentException("The dimension must be greater than 0"); + + if (this.indexedKeys.contains(key)) + return; + this.indexedKeys.add(key); + + int m = DEFAULT_M; + if (configuration.containsKey(CONFIG_M)) { + final Object mObj = configuration.get(CONFIG_M); + if (mObj instanceof Number) { + m = ((Number) mObj).intValue(); + } + } + + int efConstruction = DEFAULT_EF_CONSTRUCTION; + if (configuration.containsKey(CONFIG_EF_CONSTRUCTION)) { + final Object efObj = configuration.get(CONFIG_EF_CONSTRUCTION); + if (efObj instanceof Number) { + efConstruction = ((Number) efObj).intValue(); + } + } + + int ef = DEFAULT_EF; + if (configuration.containsKey(CONFIG_EF)) { + final Object efObj = configuration.get(CONFIG_EF); + if (efObj instanceof Number) { + ef = ((Number) efObj).intValue(); + } + } + + int maxItems = DEFAULT_MAX_ITEMS; + if (configuration.containsKey(CONFIG_MAX_ITEMS)) { + final Object maxObj = configuration.get(CONFIG_MAX_ITEMS); + if (maxObj instanceof Number) { + maxItems = ((Number) maxObj).intValue(); + } + } + + TinkerIndexType.Vector vector = TinkerIndexType.Vector.COSINE; + if (configuration.containsKey(CONFIG_DISTANCE_FUNCTION)) { + final Object vec = configuration.get(CONFIG_DISTANCE_FUNCTION); + if (vec instanceof TinkerIndexType.Vector) { + vector = ((TinkerIndexType.Vector) vec); + } + } + + // Create a new HNSW index for this property key + final Index<Object, float[], ElementItem, Float> index = HnswIndex + .newBuilder(dimension, vector.getDistanceFunction(), Float::compare, maxItems) + .withM(m) + .withEfConstruction(efConstruction) + .withEf(ef) + .withRemoveEnabled() + .build(); + + this.vectorIndices.put(key, index); + + // Index existing elements + (Vertex.class.isAssignableFrom(this.indexClass) ? + ((TinkerGraph) this.graph).vertices.values().parallelStream() : + ((TinkerGraph) this.graph).edges.values().parallelStream()) + .map(e -> new Object[]{((T) e).property(key), e}) + .filter(a -> ((Property) a[0]).isPresent()) + .forEach(a -> { + // values for the key that don't match the dimensions of the index won't be added + final Object value = ((Property) a[0]).value(); + if (value instanceof float[] && ((float[]) value).length == dimension) { + this.addToIndex(key, (float[]) value, (T) a[1]); + } + }); + } + + /** + * Adds an element with a vector to the index. + * + * @param key the property key + * @param vector the vector + * @param element the element + */ + public void addToIndex(final String key, final float[] vector, final T element) { + if (!this.indexedKeys.contains(key) || !this.vectorIndices.containsKey(key)) + return; + + final Index<Object, float[], ElementItem, Float> index = this.vectorIndices.get(key); + final ElementItem item = new ElementItem(element.id(), vector, element); + index.add(item); + } + + /** + * Searches for nearest neighbors in the vector index. + * + * @param key the property key + * @param vector the query vector + * @param k the number of nearest neighbors to return + * @return a list of elements sorted by distance + */ + public List<T> findNearest(final String key, final float[] vector, final int k) { + if (!this.indexedKeys.contains(key) || !this.vectorIndices.containsKey(key)) + return Collections.emptyList(); + + final Index<Object, float[], ElementItem, Float> index = this.vectorIndices.get(key); + final List<SearchResult<ElementItem, Float>> nearest = index.findNearest(vector, k); + return nearest.stream().map(sr -> sr.item().element).collect(Collectors.toList()); + } + + /** + * Searches for nearest neighbors in the vector index with the default k. + * + * @param key the property key + * @param vector the query vector + * @return a list of elements sorted by distance + */ + public List<T> findNearest(final String key, final float[] vector) { + return findNearest(key, vector, DEFAULT_K); + } + + /** + * Removes an element from the vector index. + * + * @param key the property key + * @param element the element + */ + public void removeFromIndex(final String key, final T element) { + if (!this.indexedKeys.contains(key) || !this.vectorIndices.containsKey(key)) + return; + + final Index<Object, float[], ElementItem, Float> index = this.vectorIndices.get(key); + try { + index.remove(element.id(), 0); + } catch (Exception e) { + // If the element is not in the index, just ignore the exception + } + } + + /** + * Updates the vector index when an element's property changes. + * + * @param key the property key + * @param newValue the new vector value + * @param element the element + */ + public void updateIndex(final String key, final float[] newValue, final T element) { + if (!this.indexedKeys.contains(key) || !this.vectorIndices.containsKey(key)) + return; + + final Index<Object, float[], ElementItem, Float> index = this.vectorIndices.get(key); + try { + index.remove(element.id(), 0); + } catch (Exception e) { + // If the element is not in the index, just ignore the exception + } + final ElementItem item = new ElementItem(element.id(), newValue, element); + index.add(item); + } + + /** + * Drops the vector index for the specified property key. + * + * @param key the property key + */ + @Override + public void dropIndex(final String key) { + if (this.vectorIndices.containsKey(key)) { + this.vectorIndices.remove(key); + } + + this.indexedKeys.remove(key); + } + + /** + * A class that wraps an element with its vector for use in the HNSW index. + */ + private class ElementItem implements Item<Object, float[]>, Serializable { + private final Object id; + private final float[] vector; + private final T element; + + public ElementItem(final Object id, final float[] vector, final T element) { + this.id = id; + this.vector = vector; + this.element = element; + } + + @Override + public Object id() { + return id; + } + + @Override + public float[] vector() { + return vector; + } + + @Override + public int dimensions() { + return vector.length; + } + } + + // AbstractTinkerIndex implementation methods + + @Override + public List<T> get(final String key, final Object value) { + // This method is for regular indices, not vector indices + return Collections.emptyList(); + } + + @Override + public long count(final String key, final Object value) { + // This method is for regular indices, not vector indices + return 0; + } + + @Override + public void remove(final String key, final Object value, final T element) { + // For vector indices, we use removeFromIndex + if (value instanceof float[]) { + removeFromIndex(key, element); + } + } + + @Override + public void removeElement(final T element) { + if (this.indexClass.isAssignableFrom(element.getClass())) { + for (String key : this.indexedKeys) { + removeFromIndex(key, element); + } + } + } + + @Override + public void autoUpdate(final String key, final Object newValue, final Object oldValue, final T element) { + if (this.indexedKeys.contains(key) && newValue instanceof float[]) { + updateIndex(key, (float[]) newValue, element); + } + } +} diff --git a/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraphVectorIndexTest.java b/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraphVectorIndexTest.java new file mode 100644 index 0000000000..6b8267be55 --- /dev/null +++ b/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraphVectorIndexTest.java @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.tinkerpop.gremlin.tinkergraph.structure; + +import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.GraphTraversalSource; +import org.apache.tinkerpop.gremlin.structure.Edge; +import org.apache.tinkerpop.gremlin.structure.Vertex; +import org.junit.Test; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.tinkerpop.gremlin.process.traversal.AnonymousTraversalSource.traversal; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.hamcrest.CoreMatchers.hasItem; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.not; +import static org.hamcrest.MatcherAssert.assertThat; + +/** + * Tests for TinkerGraph vector index functionality. + */ +public class TinkerGraphVectorIndexTest { + + private static final Map<String,Object> indexConfig = new HashMap<String,Object>() {{ + put(TinkerVectorIndex.CONFIG_DIMENSION, 3); + }}; + + @Test + public void shouldCreateEdgeVectorIndex() { + final TinkerGraph graph = TinkerGraph.open(); + final GraphTraversalSource g = traversal().withEmbedded(graph); + graph.createIndex(TinkerIndexType.VECTOR,"embedding", Edge.class, indexConfig); + assertThat(graph.getIndexedKeys(Edge.class), hasItem("embedding")); + } + + @Test + public void shouldCreateVectorIndex() { + final TinkerGraph graph = TinkerGraph.open(); + final GraphTraversalSource g = traversal().withEmbedded(graph); + graph.createIndex(TinkerIndexType.VECTOR,"embedding", Vertex.class, indexConfig); + assertThat(graph.getIndexedKeys(Vertex.class), hasItem("embedding")); + } + + @Test + public void shouldFindNearestVertices() { + final TinkerGraph graph = TinkerGraph.open(); + final GraphTraversalSource g = traversal().withEmbedded(graph); + g.addV("person").property("name", "Alice").property("embedding", new float[]{1.0f, 0.0f, 0.0f}).iterate(); + g.addV("person").property("name", "Bob").property("embedding", new float[]{0.0f, 1.0f, 0.0f}).iterate(); + g.addV("person").property("name", "Charlie").property("embedding", new float[]{0.0f, 0.0f, 1.0f}).iterate(); + g.addV("person").property("name", "Dave").property("embedding", new float[]{0.9f, 0.1f, 0.0f}).iterate(); + + graph.createIndex(TinkerIndexType.VECTOR,"embedding", Vertex.class, indexConfig); + + final List<Vertex> nearest = graph.findNearestVertices("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); + assertNotNull(nearest); + assertEquals(2, nearest.size()); + assertEquals("Alice", nearest.get(0).value("name")); + assertEquals("Dave", nearest.get(1).value("name")); + } + + @Test + public void shouldUpdateVectorIndex() { + final TinkerGraph graph = TinkerGraph.open(); + final GraphTraversalSource g = traversal().withEmbedded(graph); + g.addV("person").property("name", "Alice").property("embedding", new float[]{1.0f, 0.0f, 0.0f}).iterate(); + g.addV("person").property("name", "Bob").property("embedding", new float[]{0.0f, 1.0f, 0.0f}).iterate(); + + graph.createIndex(TinkerIndexType.VECTOR,"embedding", Vertex.class, indexConfig); + + // Update a vertex property + g.V().has("name", "Bob").property("embedding", new float[]{0.9f, 0.1f, 0.0f}).iterate(); + + final List<Vertex> nearest = graph.findNearestVertices("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); + assertNotNull(nearest); + assertEquals(2, nearest.size()); + assertEquals("Alice", nearest.get(0).value("name")); + assertEquals("Bob", nearest.get(1).value("name")); + } + + @Test + public void shouldRemoveFromVectorIndex() { + final TinkerGraph graph = TinkerGraph.open(); + final GraphTraversalSource g = traversal().withEmbedded(graph); + g.addV("person").property("name", "Alice").property("embedding", new float[]{1.0f, 0.0f, 0.0f}).iterate(); + g.addV("person").property("name", "Bob").property("embedding", new float[]{0.0f, 1.0f, 0.0f}).iterate(); + g.addV("person").property("name", "Charlie").property("embedding", new float[]{0.0f, 0.0f, 1.0f}).iterate(); + + graph.createIndex(TinkerIndexType.VECTOR,"embedding", Vertex.class, indexConfig); + + // Remove a vertex + g.V().has("name", "Bob").drop().iterate(); + + final List<Vertex> nearest = graph.findNearestVertices("embedding", new float[]{0.0f, 1.0f, 0.0f}, 2); + assertNotNull(nearest); + assertEquals(2, nearest.size()); + assertThat(nearest.stream().noneMatch(v -> v.value("name").equals("Bob")), is(true)); + } + + @Test + public void shouldDropVectorIndex() { + final TinkerGraph graph = TinkerGraph.open(); + final GraphTraversalSource g = traversal().withEmbedded(graph); + g.addV("person").property("name", "Alice").property("embedding", new float[]{1.0f, 0.0f, 0.0f}).iterate(); + g.addV("person").property("name", "Bob").property("embedding", new float[]{0.0f, 1.0f, 0.0f}).iterate(); + + graph.createIndex(TinkerIndexType.VECTOR,"embedding", Vertex.class, indexConfig); + assertThat(graph.getIndexedKeys(Vertex.class), hasItem("embedding")); + + // Drop index + graph.dropIndex("embedding", Vertex.class); + assertThat(graph.getIndexedKeys(Vertex.class), not(hasItem("embedding"))); + + // Search for nearest neighbors should return empty list + final List<Vertex> nearest = graph.findNearestVertices("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); + assertNotNull(nearest); + assertEquals(0, nearest.size()); + } + + @Test + public void shouldFindNearestEdges() { + final TinkerGraph graph = TinkerGraph.open(); + final GraphTraversalSource g = traversal().withEmbedded(graph); + final Vertex alice = g.addV("person").property("name", "Alice").next(); + final Vertex bob = g.addV("person").property("name", "Bob").next(); + final Vertex charlie = g.addV("person").property("name", "Charlie").next(); + final Vertex dave = g.addV("person").property("name", "Dave").next(); + g.addE("knows").from(alice).to(bob).property("embedding", new float[]{1.0f, 0.0f, 0.0f}).property("strength", 0.8f).iterate(); + g.addE("knows").from(bob).to(charlie).property("embedding", new float[]{0.0f, 1.0f, 0.0f}).property("strength", 0.6f).iterate(); + g.addE("knows").from(charlie).to(dave).property("embedding", new float[]{0.0f, 0.0f, 1.0f}).property("strength", 0.7f).iterate(); + g.addE("knows").from(alice).to(dave).property("embedding", new float[]{0.9f, 0.1f, 0.0f}).property("strength", 0.9f).iterate(); + + graph.createIndex(TinkerIndexType.VECTOR,"embedding", Edge.class, indexConfig); + + final List<Edge> nearest = graph.findNearestEdges("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); + assertNotNull(nearest); + assertEquals(2, nearest.size()); + assertEquals(0.8f, (float) nearest.get(0).value("strength"), 0.0001f); + assertEquals(0.9f, (float) nearest.get(1).value("strength"), 0.0001f); + } + + @Test + public void shouldUpdateEdgeVectorIndex() { + final TinkerGraph graph = TinkerGraph.open(); + final GraphTraversalSource g = traversal().withEmbedded(graph); + final Vertex alice = g.addV("person").property("name", "Alice").next(); + final Vertex bob = g.addV("person").property("name", "Bob").next(); + g.addE("knows").from(alice).to(bob).property("embedding", new float[]{1.0f, 0.0f, 0.0f}).property("strength", 0.8f).iterate(); + final Edge edge = g.addE("knows").from(bob).to(alice).property("embedding", new float[]{0.0f, 1.0f, 0.0f}).property("strength", 0.6f).next(); + + graph.createIndex(TinkerIndexType.VECTOR,"embedding", Edge.class, indexConfig); + + // Update an edge property + g.E(edge.id()).property("embedding", new float[]{0.9f, 0.1f, 0.0f}).iterate(); + + final List<Edge> nearest = graph.findNearestEdges("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); + assertNotNull(nearest); + assertEquals(2, nearest.size()); + assertEquals(0.8f, (float) nearest.get(0).value("strength"), 0.0001f); + assertEquals(0.6f, (float) nearest.get(1).value("strength"), 0.0001f); + } + + @Test + public void shouldRemoveEdgeFromVectorIndex() { + final TinkerGraph graph = TinkerGraph.open(); + final GraphTraversalSource g = traversal().withEmbedded(graph); + final Vertex alice = g.addV("person").property("name", "Alice").next(); + final Vertex bob = g.addV("person").property("name", "Bob").next(); + final Vertex charlie = g.addV("person").property("name", "Charlie").next(); + g.addE("knows").from(alice).to(bob).property("embedding", new float[]{1.0f, 0.0f, 0.0f}).property("strength", 0.8f).iterate(); + final Edge edge = g.addE("knows").from(bob).to(charlie).property("embedding", new float[]{0.0f, 1.0f, 0.0f}).property("strength", 0.6f).next(); + g.addE("knows").from(charlie).to(alice).property("embedding", new float[]{0.0f, 0.0f, 1.0f}).property("strength", 0.7f).iterate(); + + graph.createIndex(TinkerIndexType.VECTOR, "embedding", Edge.class, indexConfig); + + // Remove an edge + g.E(edge.id()).drop().iterate(); + + final List<Edge> nearest = graph.findNearestEdges("embedding", new float[]{0.0f, 1.0f, 0.0f}, 2); + assertNotNull(nearest); + assertEquals(2, nearest.size()); + assertThat(nearest.stream().noneMatch(e -> e.value("strength").equals(0.6f)), is(true)); + } + + @Test + public void shouldDropEdgeVectorIndex() { + final TinkerGraph graph = TinkerGraph.open(); + final GraphTraversalSource g = traversal().withEmbedded(graph); + final Vertex alice = g.addV("person").property("name", "Alice").next(); + final Vertex bob = g.addV("person").property("name", "Bob").next(); + g.addE("knows").from(alice).to(bob).property("embedding", new float[]{1.0f, 0.0f, 0.0f}).property("strength", 0.8f).iterate(); + g.addE("knows").from(bob).to(alice).property("embedding", new float[]{0.0f, 1.0f, 0.0f}).property("strength", 0.6f).iterate(); + + graph.createIndex(TinkerIndexType.VECTOR,"embedding", Edge.class, indexConfig); + assertThat(graph.getIndexedKeys(Edge.class), hasItem("embedding")); + + // Drop index + graph.dropIndex("embedding", Edge.class); + assertThat(graph.getIndexedKeys(Edge.class), not(hasItem("embedding"))); + + // Search for nearest neighbors should return empty list + final List<Edge> nearest = graph.findNearestEdges("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); + assertNotNull(nearest); + assertEquals(0, nearest.size()); + } + + @Test(expected = IllegalArgumentException.class) + public void shouldThrowExceptionWhenVectorDimensionExceedsConfigured() { + final TinkerGraph graph = TinkerGraph.open(); + final GraphTraversalSource g = traversal().withEmbedded(graph); + + // Create a vector index with dimension 3 + graph.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, indexConfig); + + // Try to add a vertex with a vector of dimension 4 (exceeds configured dimension 3) + g.addV("person").property("name", "Alice").property("embedding", new float[]{1.0f, 0.0f, 0.0f, 0.0f}).iterate(); + } + + @Test(expected = IllegalArgumentException.class) + public void shouldThrowExceptionWhenVectorDimensionIsSmallerThanConfigured() { + final TinkerGraph graph = TinkerGraph.open(); + final GraphTraversalSource g = traversal().withEmbedded(graph); + + // Create a vector index with dimension 3 + graph.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, indexConfig); + + // Try to add a vertex with a vector of dimension 2 (smaller than configured dimension 3) + g.addV("person").property("name", "Alice").property("embedding", new float[]{1.0f, 0.0f}).iterate(); + } +}
