This is an automated email from the ASF dual-hosted git repository. spmallette pushed a commit to branch TINKERPOP-3158 in repository https://gitbox.apache.org/repos/asf/tinkerpop.git
commit 91fac89656576dbb1a69a43c4dc58e25ad753dbb Author: Stephen Mallette <[email protected]> AuthorDate: Fri May 2 09:50:11 2025 -0400 vector search in transaction graph --- .../tinkergraph/structure/AbstractTinkerGraph.java | 140 ++++++++++++++- .../structure/AbstractTinkerVectorIndex.java | 54 ++++++ .../gremlin/tinkergraph/structure/TinkerGraph.java | 127 +------------- .../tinkergraph/structure/TinkerTransaction.java | 27 ++- .../structure/TinkerTransactionGraph.java | 44 +---- ...ionalIndex.java => TinkerTransactionIndex.java} | 4 +- ...ndex.java => TinkerTransactionVectorIndex.java} | 81 ++++++--- .../tinkergraph/structure/TinkerVectorIndex.java | 11 +- .../structure/TinkerGraphVectorIndexTest.java | 188 +++++++++++++++------ .../structure/TinkerTransactionGraphTest.java | 18 +- 10 files changed, 429 insertions(+), 265 deletions(-) diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerGraph.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerGraph.java index ef2f65af01..5ace0fd93a 100644 --- a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerGraph.java +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerGraph.java @@ -37,9 +37,11 @@ import org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerServiceRegistry; import java.io.File; import java.lang.reflect.InvocationTargetException; +import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; import java.util.Iterator; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.UUID; @@ -72,8 +74,8 @@ public abstract class AbstractTinkerGraph implements Graph { protected TinkerGraphComputerView graphComputerView = null; protected AbstractTinkerIndex<TinkerVertex> vertexIndex = null; protected AbstractTinkerIndex<TinkerEdge> edgeIndex = null; - protected TinkerVectorIndex<TinkerVertex> vertexVectorIndex = null; - protected TinkerVectorIndex<TinkerEdge> edgeVectorIndex = null; + protected AbstractTinkerVectorIndex<TinkerVertex> vertexVectorIndex = null; + protected AbstractTinkerVectorIndex<TinkerEdge> edgeVectorIndex = null; protected IdManager<Vertex> vertexIdManager; protected IdManager<Edge> edgeIdManager; @@ -292,6 +294,64 @@ public abstract class AbstractTinkerGraph implements Graph { return configuration; } + ///////////// Vector Search ///////////////// + + /** + * Find the nearest vertices to the given vector in the vector index for the specified property key. + * + * @param key the property key + * @param vector the query vector + * @param k the number of nearest neighbors to return + * @return a list of vertices sorted by distance + */ + public List<Vertex> findNearestVertices(final String key, final float[] vector, final int k) { + if (null == this.vertexVectorIndex) + throw new IllegalStateException("Vector index not created for vertices on key: '" + key + "'"); + return new ArrayList<>(this.vertexVectorIndex.findNearest(key, vector, k)); + } + + /** + * Find the nearest vertices to the given vector in the vector index for the specified property key. + * Uses the default number of nearest neighbors. + * + * @param key the property key + * @param vector the query vector + * @return a list of vertices sorted by distance + */ + public List<Vertex> findNearestVertices(final String key, final float[] vector) { + if (null == this.vertexVectorIndex) + throw new IllegalStateException("Vector index not created for vertices on key: '" + key + "'"); + return new ArrayList<>(this.vertexVectorIndex.findNearest(key, vector)); + } + + /** + * Find the nearest edges to the given vector in the vector index for the specified property key. + * + * @param key the property key + * @param vector the query vector + * @param k the number of nearest neighbors to return + * @return a list of edges sorted by distance + */ + public List<Edge> findNearestEdges(final String key, final float[] vector, final int k) { + if (null == this.edgeVectorIndex) + throw new IllegalStateException("Vector index not created for edges on key: '" + key + "'"); + return new ArrayList<>(this.edgeVectorIndex.findNearest(key, vector, k)); + } + + /** + * Find the nearest edges to the given vector in the vector index for the specified property key. + * Uses the default number of nearest neighbors. + * + * @param key the property key + * @param vector the query vector + * @return a list of edges sorted by distance + */ + public List<Edge> findNearestEdges(final String key, final float[] vector) { + if (null == this.edgeVectorIndex) + throw new IllegalStateException("Vector index not created for edges on key: '" + key + "'"); + return new ArrayList<>(this.edgeVectorIndex.findNearest(key, vector)); + } + ///////////// Utility methods /////////////// protected abstract void addOutEdge(final TinkerVertex vertex, final String label, final Edge edge); @@ -392,6 +452,82 @@ public abstract class AbstractTinkerGraph implements Graph { ///////////// GRAPH SPECIFIC INDEXING METHODS /////////////// + /** + * Provides a mechanism to internally construct an {@link AbstractTinkerIndex} of the appropriate type for a + * particular implementation of this class. + */ + protected abstract <E extends Element> AbstractTinkerIndex<E> constructTinkerIndex(final TinkerIndexType indexType, + final Class<E> elementClass); + + /** + * Create an index for said element class ({@link Vertex} or {@link Edge}) and said property key. + * Whenever an element has the specified key mutated, the index is updated. + * When the index is created, all existing elements are indexed to ensure that they are captured by the index. + * + * @param key the property key to index + * @param elementClass the element class to index + * @param <E> The type of the element class + */ + public <E extends Element> void createIndex(final String key, final Class<E> elementClass) { + createIndex(TinkerIndexType.DEFAULT, key, elementClass, Collections.emptyMap()); + } + + /** + * Create an index for said element class ({@link Vertex} or {@link Edge}) and said property key with the specified index type. + * Whenever an element has the specified key mutated, the index is updated. + * When the index is created, all existing elements are indexed to ensure that they are captured by the index. + * + * @param indexType the type of the index + * @param key the property key to index + * @param elementClass the element class to index + * @param configuration the configuration options + * @param <E> The type of the element class + */ + public <E extends Element> void createIndex(final TinkerIndexType indexType, final String key, + final Class<E> elementClass, final Map<String, Object> configuration) { + if (TinkerIndexType.VECTOR == indexType) { + if (Vertex.class.isAssignableFrom(elementClass)) { + if (null == this.vertexVectorIndex) this.vertexVectorIndex = (AbstractTinkerVectorIndex<TinkerVertex>) constructTinkerIndex(TinkerIndexType.VECTOR, TinkerVertex.class); + this.vertexVectorIndex.createIndex(key, configuration); + } else if (Edge.class.isAssignableFrom(elementClass)) { + if (null == this.edgeVectorIndex) this.edgeVectorIndex = (AbstractTinkerVectorIndex<TinkerEdge>) constructTinkerIndex(TinkerIndexType.VECTOR, TinkerEdge.class); + this.edgeVectorIndex.createIndex(key, configuration); + } else { + throw new IllegalArgumentException("Class is not indexable: " + elementClass); + } + } else { + // Create a standard index + if (Vertex.class.isAssignableFrom(elementClass)) { + if (null == this.vertexIndex) this.vertexIndex = constructTinkerIndex(TinkerIndexType.DEFAULT, TinkerVertex.class); + this.vertexIndex.createIndex(key); + } else if (Edge.class.isAssignableFrom(elementClass)) { + if (null == this.edgeIndex) this.edgeIndex = constructTinkerIndex(TinkerIndexType.DEFAULT, TinkerEdge.class); + this.edgeIndex.createIndex(key); + } else { + throw new IllegalArgumentException("Class is not indexable: " + elementClass); + } + } + } + + /** + * Drop the index for the specified element class ({@link Vertex} or {@link Edge}) and key. + * + * @param key the property key to stop indexing + * @param elementClass the element class of the index to drop + * @param <E> The type of the element class + */ + public <E extends Element> void dropIndex(final String key, final Class<E> elementClass) { + if (Vertex.class.isAssignableFrom(elementClass)) { + if (null != this.vertexIndex) this.vertexIndex.dropIndex(key); + if (null != this.vertexVectorIndex) this.vertexVectorIndex.dropIndex(key); + } else if (Edge.class.isAssignableFrom(elementClass)) { + if (null != this.edgeIndex) this.edgeIndex.dropIndex(key); + if (null != this.edgeVectorIndex) this.edgeVectorIndex.dropIndex(key); + } else { + throw new IllegalArgumentException("Class is not indexable: " + elementClass); + } + } + /** * Return all the keys currently being index for said element class ({@link Vertex} or {@link Edge}). * diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerVectorIndex.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerVectorIndex.java new file mode 100644 index 0000000000..3a4e67cb25 --- /dev/null +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerVectorIndex.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.tinkerpop.gremlin.tinkergraph.structure; + +import org.apache.tinkerpop.gremlin.structure.Element; + +import java.util.List; +/** + * Base class for representing a vector index for performing nearest neighbor searches. + * + * @param <T> the type of elements stored in the vector index + */ +public abstract class AbstractTinkerVectorIndex<T extends Element> extends AbstractTinkerIndex<T> { + + protected AbstractTinkerVectorIndex(final AbstractTinkerGraph graph, final Class<T> indexClass) { + super(graph, indexClass); + } + + /** + * Searches for nearest neighbors in the vector index. + * + * @param key the property key + * @param vector the query vector + * @param k the number of nearest neighbors to return + * @return a list of elements sorted by distance + */ + public abstract List<T> findNearest(final String key, final float[] vector, final int k); + + /** + * Searches for nearest neighbors in the vector index with the default k. + * + * @param key the property key + * @param vector the query vector + * @return a list of elements sorted by distance + */ + public abstract List<T> findNearest(final String key, final float[] vector); + +} diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraph.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraph.java index a54df6a858..177378b620 100644 --- a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraph.java +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraph.java @@ -372,130 +372,9 @@ public class TinkerGraph extends AbstractTinkerGraph { } - ///////////// GRAPH SPECIFIC INDEXING METHODS /////////////// - - /** - * Create a default index for said element class ({@link Vertex} or {@link Edge}) and said property key. - * Whenever an element has the specified key mutated, the index is updated. - * When the index is created, all existing elements are indexed to ensure that they are captured by the index. - * - * @param key the property key to index - * @param elementClass the element class to index - * @param <E> The type of the element class - */ - public <E extends Element> void createIndex(final String key, final Class<E> elementClass) { - createIndex(TinkerIndexType.DEFAULT, key, elementClass, Collections.emptyMap()); - } - - /** - * Create an index for said element class ({@link Vertex} or {@link Edge}) and said property key with the given - * configuration options. Whenever an element has the specified key mutated, the index is updated. When the index - * is created, all existing elements are indexed to ensure that they are captured by the index. - * - * @param indexType the type of the index - * @param key the property key to index - * @param elementClass the element class to index - * @param configuration the configuration options - * @param <E> The type of the element class - */ - public <E extends Element> void createIndex(final TinkerIndexType indexType, final String key, - final Class<E> elementClass, final Map<String, Object> configuration) { - if (TinkerIndexType.VECTOR == indexType) { - if (Vertex.class.isAssignableFrom(elementClass)) { - if (null == this.vertexVectorIndex) this.vertexVectorIndex = new TinkerVectorIndex<>(this, TinkerVertex.class); - this.vertexVectorIndex.createIndex(key, configuration); - } else if (Edge.class.isAssignableFrom(elementClass)) { - if (null == this.edgeVectorIndex) this.edgeVectorIndex = new TinkerVectorIndex<>(this, TinkerEdge.class); - this.edgeVectorIndex.createIndex(key, configuration); - } else { - throw new IllegalArgumentException("Class is not indexable: " + elementClass); - } - } else { - // Create a standard index - if (Vertex.class.isAssignableFrom(elementClass)) { - if (null == this.vertexIndex) this.vertexIndex = new TinkerIndex<>(this, TinkerVertex.class); - this.vertexIndex.createIndex(key); - } else if (Edge.class.isAssignableFrom(elementClass)) { - if (null == this.edgeIndex) this.edgeIndex = new TinkerIndex<>(this, TinkerEdge.class); - this.edgeIndex.createIndex(key); - } else { - throw new IllegalArgumentException("Class is not indexable: " + elementClass); - } - } - } - - /** - * Drop the index for the specified element class ({@link Vertex} or {@link Edge}) and key. - * - * @param key the property key to stop indexing - * @param elementClass the element class of the index to drop - * @param <E> The type of the element class - */ - public <E extends Element> void dropIndex(final String key, final Class<E> elementClass) { - if (Vertex.class.isAssignableFrom(elementClass)) { - if (null != this.vertexIndex) this.vertexIndex.dropIndex(key); - if (null != this.vertexVectorIndex) this.vertexVectorIndex.dropIndex(key); - } else if (Edge.class.isAssignableFrom(elementClass)) { - if (null != this.edgeIndex) this.edgeIndex.dropIndex(key); - if (null != this.edgeVectorIndex) this.edgeVectorIndex.dropIndex(key); - } else { - throw new IllegalArgumentException("Class is not indexable: " + elementClass); - } - } - - /** - * Find the nearest vertices to the given vector in the vector index for the specified property key. - * - * @param key the property key - * @param vector the query vector - * @param k the number of nearest neighbors to return - * @return a list of vertices sorted by distance - */ - public List<Vertex> findNearestVertices(final String key, final float[] vector, final int k) { - if (null == this.vertexVectorIndex) - return Collections.emptyList(); - return new ArrayList<>(this.vertexVectorIndex.findNearest(key, vector, k)); - } - - /** - * Find the nearest vertices to the given vector in the vector index for the specified property key. - * Uses the default number of nearest neighbors. - * - * @param key the property key - * @param vector the query vector - * @return a list of vertices sorted by distance - */ - public List<Vertex> findNearestVertices(final String key, final float[] vector) { - if (null == this.vertexVectorIndex) - return Collections.emptyList(); - return new ArrayList<>(this.vertexVectorIndex.findNearest(key, vector)); - } - - /** - * Find the nearest edges to the given vector in the vector index for the specified property key. - * - * @param key the property key - * @param vector the query vector - * @param k the number of nearest neighbors to return - * @return a list of edges sorted by distance - */ - public List<Edge> findNearestEdges(final String key, final float[] vector, final int k) { - if (null == this.edgeVectorIndex) - return Collections.emptyList(); - return new ArrayList<>(this.edgeVectorIndex.findNearest(key, vector, k)); + @Override + protected <E extends Element> AbstractTinkerIndex<E> constructTinkerIndex(final TinkerIndexType indexType, final Class<E> elementClass) { + return indexType == TinkerIndexType.VECTOR ? new TinkerVectorIndex<E>(this, elementClass) : new TinkerIndex<E>(this, elementClass); } - /** - * Find the nearest edges to the given vector in the vector index for the specified property key. - * Uses the default number of nearest neighbors. - * - * @param key the property key - * @param vector the query vector - * @return a list of edges sorted by distance - */ - public List<Edge> findNearestEdges(final String key, final float[] vector) { - if (null == this.edgeVectorIndex) - return Collections.emptyList(); - return new ArrayList<>(this.edgeVectorIndex.findNearest(key, vector)); - } } diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransaction.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransaction.java index 4cf8b50fca..b214849249 100644 --- a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransaction.java +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransaction.java @@ -19,7 +19,6 @@ package org.apache.tinkerpop.gremlin.tinkergraph.structure; import org.apache.tinkerpop.gremlin.process.traversal.TraversalSource; -import org.apache.tinkerpop.gremlin.structure.Transaction; import org.apache.tinkerpop.gremlin.structure.util.AbstractThreadLocalTransaction; import org.apache.tinkerpop.gremlin.structure.util.TransactionException; @@ -176,11 +175,15 @@ final class TinkerTransaction extends AbstractThreadLocalTransaction { throw new TransactionException(TX_CONFLICT); // update indices - final TinkerTransactionalIndex vertexIndex = (TinkerTransactionalIndex) graph.vertexIndex; + final TinkerTransactionIndex vertexIndex = (TinkerTransactionIndex) graph.vertexIndex; if (vertexIndex != null) vertexIndex.commit(changedVertices); - final TinkerTransactionalIndex edgeIndex = (TinkerTransactionalIndex) graph.edgeIndex; + final TinkerTransactionIndex edgeIndex = (TinkerTransactionIndex) graph.edgeIndex; if (edgeIndex != null) edgeIndex.commit(changedEdges); + // update vector indices + if (graph.vertexVectorIndex != null) ((TinkerTransactionVectorIndex) graph.vertexVectorIndex).commit(changedVertices); + if (graph.edgeVectorIndex != null) ((TinkerTransactionVectorIndex) graph.edgeVectorIndex).commit(changedEdges); + // commit all changes changedVertices.forEach(v -> v.commit(txVersion)); changedEdges.forEach(e -> e.commit(txVersion)); @@ -190,11 +193,15 @@ final class TinkerTransaction extends AbstractThreadLocalTransaction { changedEdges.forEach(e -> e.rollback()); // also revert indices update - final TinkerTransactionalIndex vertexIndex = (TinkerTransactionalIndex) graph.vertexIndex; + final TinkerTransactionIndex vertexIndex = (TinkerTransactionIndex) graph.vertexIndex; if (vertexIndex != null) vertexIndex.rollback(); - final TinkerTransactionalIndex edgeIndex = (TinkerTransactionalIndex) graph.edgeIndex; + final TinkerTransactionIndex edgeIndex = (TinkerTransactionIndex) graph.edgeIndex; if (edgeIndex != null) edgeIndex.rollback(); + // revert vector indices update + if (graph.vertexVectorIndex != null) ((TinkerTransactionVectorIndex) graph.vertexVectorIndex).rollback(); + if (graph.edgeVectorIndex != null) ((TinkerTransactionVectorIndex) graph.edgeVectorIndex).rollback(); + throw ex; } finally { // remove elements from graph if not used in other tx's @@ -234,10 +241,14 @@ final class TinkerTransaction extends AbstractThreadLocalTransaction { if (null != changedEdges) changedEdges.forEach(e -> e.rollback()); // rollback indices - final TinkerTransactionalIndex vertexIndex = (TinkerTransactionalIndex) graph.vertexIndex; + final TinkerTransactionIndex vertexIndex = (TinkerTransactionIndex) graph.vertexIndex; if (vertexIndex != null) vertexIndex.rollback(); - final TinkerTransactionalIndex edgeIndex = (TinkerTransactionalIndex) graph.edgeIndex; - if (vertexIndex != null) edgeIndex.rollback(); + final TinkerTransactionIndex edgeIndex = (TinkerTransactionIndex) graph.edgeIndex; + if (edgeIndex != null) edgeIndex.rollback(); + + // rollback vector indices + if (graph.vertexVectorIndex != null) ((TinkerTransactionVectorIndex) graph.vertexVectorIndex).rollback(); + if (graph.edgeVectorIndex != null) ((TinkerTransactionVectorIndex) graph.edgeVectorIndex).rollback(); // cleanup unused containers if (null != changedVertices) diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionGraph.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionGraph.java index f4cd761a2c..6b53a43f96 100644 --- a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionGraph.java +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionGraph.java @@ -35,8 +35,10 @@ import org.apache.tinkerpop.gremlin.tinkergraph.process.traversal.strategy.optim import org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerServiceRegistry; import org.apache.tinkerpop.gremlin.util.iterator.IteratorUtils; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -461,44 +463,8 @@ public final class TinkerTransactionGraph extends AbstractTinkerGraph { } - - /////////////// GRAPH SPECIFIC INDEXING METHODS /////////////// - - /** - * Create an index for said element class ({@link Vertex} or {@link Edge}) and said property key. - * Whenever an element has the specified key mutated, the index is updated. - * When the index is created, all existing elements are indexed to ensure that they are captured by the index. - * - * @param key the property key to index - * @param elementClass the element class to index - * @param <E> The type of the element class - */ - public <E extends Element> void createIndex(final String key, final Class<E> elementClass) { - if (Vertex.class.isAssignableFrom(elementClass)) { - if (null == this.vertexIndex) this.vertexIndex = new TinkerTransactionalIndex<>(this, TinkerVertex.class); - this.vertexIndex.createIndex(key); - } else if (Edge.class.isAssignableFrom(elementClass)) { - if (null == this.edgeIndex) this.edgeIndex = new TinkerTransactionalIndex<>(this, TinkerEdge.class); - this.edgeIndex.createIndex(key); - } else { - throw new IllegalArgumentException("Class is not indexable: " + elementClass); - } - } - - /** - * Drop the index for the specified element class ({@link Vertex} or {@link Edge}) and key. - * - * @param key the property key to stop indexing - * @param elementClass the element class of the index to drop - * @param <E> The type of the element class - */ - public <E extends Element> void dropIndex(final String key, final Class<E> elementClass) { - if (Vertex.class.isAssignableFrom(elementClass)) { - if (null != this.vertexIndex) this.vertexIndex.dropIndex(key); - } else if (Edge.class.isAssignableFrom(elementClass)) { - if (null != this.edgeIndex) this.edgeIndex.dropIndex(key); - } else { - throw new IllegalArgumentException("Class is not indexable: " + elementClass); - } + @Override + protected <E extends Element> AbstractTinkerIndex<E> constructTinkerIndex(final TinkerIndexType indexType, final Class<E> elementClass) { + return indexType == TinkerIndexType.VECTOR ? new TinkerTransactionVectorIndex(this, elementClass) : new TinkerTransactionIndex(this, elementClass); } } diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionalIndex.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionIndex.java similarity index 97% rename from tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionalIndex.java rename to tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionIndex.java index 0edfdaaaa3..cc1bc7521d 100644 --- a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionalIndex.java +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionIndex.java @@ -32,13 +32,13 @@ import java.util.stream.Collectors; /** * @author Valentyn Kahamlyk */ -final class TinkerTransactionalIndex<T extends TinkerElement> extends AbstractTinkerIndex<T> { +final class TinkerTransactionIndex<T extends TinkerElement> extends AbstractTinkerIndex<T> { protected Map<String, Map<Object, Set<TinkerElementContainer<T>>>> index = new ConcurrentHashMap<>(); protected ThreadLocal<Map<String, Map<Object, Set<T>>>> txIndex = ThreadLocal.withInitial(() -> new ConcurrentHashMap<>()); - public TinkerTransactionalIndex(final TinkerTransactionGraph graph, final Class<T> indexClass) { + public TinkerTransactionIndex(final TinkerTransactionGraph graph, final Class<T> indexClass) { super(graph, indexClass); } diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerVectorIndex.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionVectorIndex.java similarity index 81% copy from tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerVectorIndex.java copy to tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionVectorIndex.java index a31c4148cd..111d4fcb7b 100644 --- a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerVectorIndex.java +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionVectorIndex.java @@ -18,13 +18,10 @@ */ package org.apache.tinkerpop.gremlin.tinkergraph.structure; -import com.github.jelmerk.hnswlib.core.DistanceFunction; -import com.github.jelmerk.hnswlib.core.DistanceFunctions; import com.github.jelmerk.hnswlib.core.Item; import com.github.jelmerk.hnswlib.core.SearchResult; import com.github.jelmerk.hnswlib.core.hnsw.HnswIndex; import com.github.jelmerk.hnswlib.core.Index; -import org.apache.tinkerpop.gremlin.structure.Element; import org.apache.tinkerpop.gremlin.structure.Graph; import org.apache.tinkerpop.gremlin.structure.Property; import org.apache.tinkerpop.gremlin.structure.Vertex; @@ -33,15 +30,16 @@ import java.io.Serializable; import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; /** - * A vector index implementation for TinkerGraph using hnswlib. + * A vector index implementation for TinkerTransactionGraph using hnswlib. * * @param <T> the element type (Vertex or Edge) */ -final class TinkerVectorIndex<T extends Element> extends AbstractTinkerIndex<T> { +final class TinkerTransactionVectorIndex<T extends TinkerElement> extends AbstractTinkerVectorIndex<T> { /** * Map of property key to vector index @@ -103,18 +101,13 @@ final class TinkerVectorIndex<T extends Element> extends AbstractTinkerIndex<T> */ public static final String CONFIG_DISTANCE_FUNCTION = "distanceFunction"; - /** - * Configuration key for the default number of nearest neighbors to return - */ - public static final String CONFIG_DEFAULT_K = "defaultK"; - /** * Creates a new vector index for the specified graph and element class. * * @param graph the graph * @param indexClass the element class */ - public TinkerVectorIndex(final TinkerGraph graph, final Class<T> indexClass) { + public TinkerTransactionVectorIndex(final TinkerTransactionGraph graph, final Class<T> indexClass) { super(graph, indexClass); } @@ -202,18 +195,25 @@ final class TinkerVectorIndex<T extends Element> extends AbstractTinkerIndex<T> this.vectorIndices.put(key, index); // Index existing elements - (Vertex.class.isAssignableFrom(this.indexClass) ? - ((TinkerGraph) this.graph).vertices.values().parallelStream() : - ((TinkerGraph) this.graph).edges.values().parallelStream()) - .map(e -> new Object[]{((T) e).property(key), e}) - .filter(a -> ((Property) a[0]).isPresent()) - .forEach(a -> { - // values for the key that don't match the dimensions of the index won't be added - final Object value = ((Property) a[0]).value(); + final Map elements = + Vertex.class.isAssignableFrom(indexClass) ? + ((TinkerTransactionGraph) graph).getVertices() : + ((TinkerTransactionGraph) graph).getEdges(); + + for (Object element : elements.values()) { + TinkerElementContainer container = (TinkerElementContainer) element; + Object e = container.get(); + if (e != null && indexClass.isInstance(e)) { + T tinkerElement = (T) e; + Property property = tinkerElement.property(key); + if (property.isPresent()) { + Object value = property.value(); if (value instanceof float[] && ((float[]) value).length == dimension) { - this.addToIndex(key, (float[]) value, (T) a[1]); + this.addToIndex(key, (float[]) value, tinkerElement); } - }); + } + } + } } /** @@ -225,7 +225,7 @@ final class TinkerVectorIndex<T extends Element> extends AbstractTinkerIndex<T> */ public void addToIndex(final String key, final float[] vector, final T element) { if (!this.indexedKeys.contains(key) || !this.vectorIndices.containsKey(key)) - return; + throw new IllegalArgumentException("The key '" + key + "' is not indexed"); final Index<Object, float[], ElementItem, Float> index = this.vectorIndices.get(key); final ElementItem item = new ElementItem(element.id(), vector, element); @@ -242,7 +242,7 @@ final class TinkerVectorIndex<T extends Element> extends AbstractTinkerIndex<T> */ public List<T> findNearest(final String key, final float[] vector, final int k) { if (!this.indexedKeys.contains(key) || !this.vectorIndices.containsKey(key)) - return Collections.emptyList(); + throw new IllegalArgumentException("The key '" + key + "' is not indexed"); final Index<Object, float[], ElementItem, Float> index = this.vectorIndices.get(key); final List<SearchResult<ElementItem, Float>> nearest = index.findNearest(vector, k); @@ -380,4 +380,39 @@ final class TinkerVectorIndex<T extends Element> extends AbstractTinkerIndex<T> updateIndex(key, (float[]) newValue, element); } } + + /** + * Commit changes to the index. + * + * @param updatedElements the set of updated elements + */ + public void commit(final Set<TinkerElementContainer> updatedElements) { + for (final TinkerElementContainer container : updatedElements) { + Object element = container.get(); + if (element != null && !container.isDeleted() && indexClass.isInstance(element)) { + T tinkerElement = (T) element; + for (String key : this.indexedKeys) { + Property property = tinkerElement.property(key); + if (property.isPresent() && property.value() instanceof float[]) { + updateIndex(key, (float[]) property.value(), tinkerElement); + } + } + } else if (container.isDeleted()) { + Object oldElement = container.getUnmodified(); + if (oldElement != null && indexClass.isInstance(oldElement)) { + T tinkerOldElement = (T) oldElement; + for (String key : this.indexedKeys) { + removeFromIndex(key, tinkerOldElement); + } + } + } + } + } + + /** + * Rollback changes to the index. + */ + public void rollback() { + // No specific action needed for rollback in the current implementation + } } diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerVectorIndex.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerVectorIndex.java index a31c4148cd..f280d0a3fb 100644 --- a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerVectorIndex.java +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerVectorIndex.java @@ -18,8 +18,6 @@ */ package org.apache.tinkerpop.gremlin.tinkergraph.structure; -import com.github.jelmerk.hnswlib.core.DistanceFunction; -import com.github.jelmerk.hnswlib.core.DistanceFunctions; import com.github.jelmerk.hnswlib.core.Item; import com.github.jelmerk.hnswlib.core.SearchResult; import com.github.jelmerk.hnswlib.core.hnsw.HnswIndex; @@ -41,7 +39,7 @@ import java.util.stream.Collectors; * * @param <T> the element type (Vertex or Edge) */ -final class TinkerVectorIndex<T extends Element> extends AbstractTinkerIndex<T> { +final class TinkerVectorIndex<T extends Element> extends AbstractTinkerVectorIndex<T> { /** * Map of property key to vector index @@ -103,11 +101,6 @@ final class TinkerVectorIndex<T extends Element> extends AbstractTinkerIndex<T> */ public static final String CONFIG_DISTANCE_FUNCTION = "distanceFunction"; - /** - * Configuration key for the default number of nearest neighbors to return - */ - public static final String CONFIG_DEFAULT_K = "defaultK"; - /** * Creates a new vector index for the specified graph and element class. * @@ -242,7 +235,7 @@ final class TinkerVectorIndex<T extends Element> extends AbstractTinkerIndex<T> */ public List<T> findNearest(final String key, final float[] vector, final int k) { if (!this.indexedKeys.contains(key) || !this.vectorIndices.containsKey(key)) - return Collections.emptyList(); + throw new IllegalArgumentException("The key '" + key + "' is not indexed"); final Index<Object, float[], ElementItem, Float> index = this.vectorIndices.get(key); final List<SearchResult<ElementItem, Float>> nearest = index.findNearest(vector, k); diff --git a/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraphVectorIndexTest.java b/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraphVectorIndexTest.java index 6b8267be55..0b1d63af86 100644 --- a/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraphVectorIndexTest.java +++ b/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraphVectorIndexTest.java @@ -20,55 +20,75 @@ package org.apache.tinkerpop.gremlin.tinkergraph.structure; import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.GraphTraversalSource; import org.apache.tinkerpop.gremlin.structure.Edge; +import org.apache.tinkerpop.gremlin.structure.Graph; import org.apache.tinkerpop.gremlin.structure.Vertex; +import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import java.util.Arrays; +import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; import static org.apache.tinkerpop.gremlin.process.traversal.AnonymousTraversalSource.traversal; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.hamcrest.CoreMatchers.hasItem; import static org.hamcrest.CoreMatchers.is; -import static org.hamcrest.CoreMatchers.not; import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.fail; /** * Tests for TinkerGraph vector index functionality. */ +@RunWith(Parameterized.class) public class TinkerGraphVectorIndexTest { - private static final Map<String,Object> indexConfig = new HashMap<String,Object>() {{ + protected static final Map<String,Object> indexConfig = new HashMap<String,Object>() {{ put(TinkerVectorIndex.CONFIG_DIMENSION, 3); }}; + @Parameterized.Parameter + public AbstractTinkerGraph graph; + + @Parameterized.Parameters + public static Collection<Object[]> data() { + return Arrays.asList(new Object[][]{ + {TinkerGraph.open()}, + {TinkerTransactionGraph.open()} + }); + } + + @Before + public void setUp() throws Exception { + graph.clear(); + tryCommitChanges(graph); + } + @Test public void shouldCreateEdgeVectorIndex() { - final TinkerGraph graph = TinkerGraph.open(); - final GraphTraversalSource g = traversal().withEmbedded(graph); - graph.createIndex(TinkerIndexType.VECTOR,"embedding", Edge.class, indexConfig); - assertThat(graph.getIndexedKeys(Edge.class), hasItem("embedding")); + graph.createIndex(TinkerIndexType.VECTOR, "embedding", Edge.class, indexConfig); + assertThat(graph.getIndexedKeys(Edge.class).contains("embedding"), is(true)); } @Test public void shouldCreateVectorIndex() { - final TinkerGraph graph = TinkerGraph.open(); - final GraphTraversalSource g = traversal().withEmbedded(graph); - graph.createIndex(TinkerIndexType.VECTOR,"embedding", Vertex.class, indexConfig); - assertThat(graph.getIndexedKeys(Vertex.class), hasItem("embedding")); + graph.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, indexConfig); + assertThat(graph.getIndexedKeys(Vertex.class).contains("embedding"), is(true)); } @Test public void shouldFindNearestVertices() { - final TinkerGraph graph = TinkerGraph.open(); - final GraphTraversalSource g = traversal().withEmbedded(graph); + final GraphTraversalSource g = traversal().with(graph); g.addV("person").property("name", "Alice").property("embedding", new float[]{1.0f, 0.0f, 0.0f}).iterate(); g.addV("person").property("name", "Bob").property("embedding", new float[]{0.0f, 1.0f, 0.0f}).iterate(); g.addV("person").property("name", "Charlie").property("embedding", new float[]{0.0f, 0.0f, 1.0f}).iterate(); g.addV("person").property("name", "Dave").property("embedding", new float[]{0.9f, 0.1f, 0.0f}).iterate(); + tryCommitChanges(graph); + graph.createIndex(TinkerIndexType.VECTOR,"embedding", Vertex.class, indexConfig); final List<Vertex> nearest = graph.findNearestVertices("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); @@ -80,15 +100,16 @@ public class TinkerGraphVectorIndexTest { @Test public void shouldUpdateVectorIndex() { - final TinkerGraph graph = TinkerGraph.open(); - final GraphTraversalSource g = traversal().withEmbedded(graph); + final GraphTraversalSource g = traversal().with(graph); g.addV("person").property("name", "Alice").property("embedding", new float[]{1.0f, 0.0f, 0.0f}).iterate(); g.addV("person").property("name", "Bob").property("embedding", new float[]{0.0f, 1.0f, 0.0f}).iterate(); + tryCommitChanges(graph); graph.createIndex(TinkerIndexType.VECTOR,"embedding", Vertex.class, indexConfig); // Update a vertex property g.V().has("name", "Bob").property("embedding", new float[]{0.9f, 0.1f, 0.0f}).iterate(); + tryCommitChanges(graph); final List<Vertex> nearest = graph.findNearestVertices("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); assertNotNull(nearest); @@ -99,16 +120,17 @@ public class TinkerGraphVectorIndexTest { @Test public void shouldRemoveFromVectorIndex() { - final TinkerGraph graph = TinkerGraph.open(); - final GraphTraversalSource g = traversal().withEmbedded(graph); + final GraphTraversalSource g = traversal().with(graph); g.addV("person").property("name", "Alice").property("embedding", new float[]{1.0f, 0.0f, 0.0f}).iterate(); g.addV("person").property("name", "Bob").property("embedding", new float[]{0.0f, 1.0f, 0.0f}).iterate(); g.addV("person").property("name", "Charlie").property("embedding", new float[]{0.0f, 0.0f, 1.0f}).iterate(); + tryCommitChanges(graph); graph.createIndex(TinkerIndexType.VECTOR,"embedding", Vertex.class, indexConfig); // Remove a vertex g.V().has("name", "Bob").drop().iterate(); + tryCommitChanges(graph); final List<Vertex> nearest = graph.findNearestVertices("embedding", new float[]{0.0f, 1.0f, 0.0f}, 2); assertNotNull(nearest); @@ -118,28 +140,27 @@ public class TinkerGraphVectorIndexTest { @Test public void shouldDropVectorIndex() { - final TinkerGraph graph = TinkerGraph.open(); - final GraphTraversalSource g = traversal().withEmbedded(graph); + final GraphTraversalSource g = traversal().with(graph); g.addV("person").property("name", "Alice").property("embedding", new float[]{1.0f, 0.0f, 0.0f}).iterate(); g.addV("person").property("name", "Bob").property("embedding", new float[]{0.0f, 1.0f, 0.0f}).iterate(); + tryCommitChanges(graph); graph.createIndex(TinkerIndexType.VECTOR,"embedding", Vertex.class, indexConfig); - assertThat(graph.getIndexedKeys(Vertex.class), hasItem("embedding")); + assertThat(graph.getIndexedKeys(Vertex.class).contains("embedding"), is(true)); // Drop index graph.dropIndex("embedding", Vertex.class); - assertThat(graph.getIndexedKeys(Vertex.class), not(hasItem("embedding"))); + assertThat(graph.getIndexedKeys(Vertex.class).contains("embedding"), is(false)); - // Search for nearest neighbors should return empty list - final List<Vertex> nearest = graph.findNearestVertices("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); - assertNotNull(nearest); - assertEquals(0, nearest.size()); + try { + graph.findNearestVertices("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); + fail("Should have thrown exception since the index was removed"); + } catch (IllegalArgumentException ex) { } } @Test public void shouldFindNearestEdges() { - final TinkerGraph graph = TinkerGraph.open(); - final GraphTraversalSource g = traversal().withEmbedded(graph); + final GraphTraversalSource g = traversal().with(graph); final Vertex alice = g.addV("person").property("name", "Alice").next(); final Vertex bob = g.addV("person").property("name", "Bob").next(); final Vertex charlie = g.addV("person").property("name", "Charlie").next(); @@ -148,6 +169,7 @@ public class TinkerGraphVectorIndexTest { g.addE("knows").from(bob).to(charlie).property("embedding", new float[]{0.0f, 1.0f, 0.0f}).property("strength", 0.6f).iterate(); g.addE("knows").from(charlie).to(dave).property("embedding", new float[]{0.0f, 0.0f, 1.0f}).property("strength", 0.7f).iterate(); g.addE("knows").from(alice).to(dave).property("embedding", new float[]{0.9f, 0.1f, 0.0f}).property("strength", 0.9f).iterate(); + tryCommitChanges(graph); graph.createIndex(TinkerIndexType.VECTOR,"embedding", Edge.class, indexConfig); @@ -160,17 +182,18 @@ public class TinkerGraphVectorIndexTest { @Test public void shouldUpdateEdgeVectorIndex() { - final TinkerGraph graph = TinkerGraph.open(); - final GraphTraversalSource g = traversal().withEmbedded(graph); + final GraphTraversalSource g = traversal().with(graph); final Vertex alice = g.addV("person").property("name", "Alice").next(); final Vertex bob = g.addV("person").property("name", "Bob").next(); g.addE("knows").from(alice).to(bob).property("embedding", new float[]{1.0f, 0.0f, 0.0f}).property("strength", 0.8f).iterate(); final Edge edge = g.addE("knows").from(bob).to(alice).property("embedding", new float[]{0.0f, 1.0f, 0.0f}).property("strength", 0.6f).next(); + tryCommitChanges(graph); graph.createIndex(TinkerIndexType.VECTOR,"embedding", Edge.class, indexConfig); // Update an edge property g.E(edge.id()).property("embedding", new float[]{0.9f, 0.1f, 0.0f}).iterate(); + tryCommitChanges(graph); final List<Edge> nearest = graph.findNearestEdges("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); assertNotNull(nearest); @@ -181,19 +204,20 @@ public class TinkerGraphVectorIndexTest { @Test public void shouldRemoveEdgeFromVectorIndex() { - final TinkerGraph graph = TinkerGraph.open(); - final GraphTraversalSource g = traversal().withEmbedded(graph); + final GraphTraversalSource g = traversal().with(graph); final Vertex alice = g.addV("person").property("name", "Alice").next(); final Vertex bob = g.addV("person").property("name", "Bob").next(); final Vertex charlie = g.addV("person").property("name", "Charlie").next(); g.addE("knows").from(alice).to(bob).property("embedding", new float[]{1.0f, 0.0f, 0.0f}).property("strength", 0.8f).iterate(); final Edge edge = g.addE("knows").from(bob).to(charlie).property("embedding", new float[]{0.0f, 1.0f, 0.0f}).property("strength", 0.6f).next(); g.addE("knows").from(charlie).to(alice).property("embedding", new float[]{0.0f, 0.0f, 1.0f}).property("strength", 0.7f).iterate(); + tryCommitChanges(graph); - graph.createIndex(TinkerIndexType.VECTOR, "embedding", Edge.class, indexConfig); + graph.createIndex(TinkerIndexType.VECTOR,"embedding", Edge.class, indexConfig); // Remove an edge g.E(edge.id()).drop().iterate(); + tryCommitChanges(graph); final List<Edge> nearest = graph.findNearestEdges("embedding", new float[]{0.0f, 1.0f, 0.0f}, 2); assertNotNull(nearest); @@ -203,47 +227,113 @@ public class TinkerGraphVectorIndexTest { @Test public void shouldDropEdgeVectorIndex() { - final TinkerGraph graph = TinkerGraph.open(); - final GraphTraversalSource g = traversal().withEmbedded(graph); + final GraphTraversalSource g = traversal().with(graph); final Vertex alice = g.addV("person").property("name", "Alice").next(); final Vertex bob = g.addV("person").property("name", "Bob").next(); g.addE("knows").from(alice).to(bob).property("embedding", new float[]{1.0f, 0.0f, 0.0f}).property("strength", 0.8f).iterate(); g.addE("knows").from(bob).to(alice).property("embedding", new float[]{0.0f, 1.0f, 0.0f}).property("strength", 0.6f).iterate(); + tryCommitChanges(graph); graph.createIndex(TinkerIndexType.VECTOR,"embedding", Edge.class, indexConfig); - assertThat(graph.getIndexedKeys(Edge.class), hasItem("embedding")); + assertThat(graph.getIndexedKeys(Edge.class).contains("embedding"), is(true)); // Drop index graph.dropIndex("embedding", Edge.class); - assertThat(graph.getIndexedKeys(Edge.class), not(hasItem("embedding"))); + assertThat(graph.getIndexedKeys(Edge.class).contains("embedding"), is(false)); - // Search for nearest neighbors should return empty list - final List<Edge> nearest = graph.findNearestEdges("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); - assertNotNull(nearest); - assertEquals(0, nearest.size()); + try { + graph.findNearestEdges("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); + fail("Should have thrown exception since the index was removed"); + } catch (IllegalArgumentException ex) { } } @Test(expected = IllegalArgumentException.class) public void shouldThrowExceptionWhenVectorDimensionExceedsConfigured() { - final TinkerGraph graph = TinkerGraph.open(); - final GraphTraversalSource g = traversal().withEmbedded(graph); + final GraphTraversalSource g = traversal().with(graph); // Create a vector index with dimension 3 - graph.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, indexConfig); + graph.createIndex(TinkerIndexType.VECTOR,"embedding", Vertex.class, indexConfig); // Try to add a vertex with a vector of dimension 4 (exceeds configured dimension 3) g.addV("person").property("name", "Alice").property("embedding", new float[]{1.0f, 0.0f, 0.0f, 0.0f}).iterate(); + tryCommitChanges(graph); } @Test(expected = IllegalArgumentException.class) public void shouldThrowExceptionWhenVectorDimensionIsSmallerThanConfigured() { - final TinkerGraph graph = TinkerGraph.open(); - final GraphTraversalSource g = traversal().withEmbedded(graph); + final GraphTraversalSource g = traversal().with(graph); // Create a vector index with dimension 3 - graph.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, indexConfig); + graph.createIndex(TinkerIndexType.VECTOR,"embedding", Vertex.class, indexConfig); // Try to add a vertex with a vector of dimension 2 (smaller than configured dimension 3) g.addV("person").property("name", "Alice").property("embedding", new float[]{1.0f, 0.0f}).iterate(); + tryCommitChanges(graph); + } + + @Test + public void shouldRollbackVectorIndexChanges() { + final GraphTraversalSource g = traversal().with(graph); + g.addV("person").property("name", "Alice").property("embedding", new float[]{1.0f, 0.0f, 0.0f}).iterate(); + g.addV("person").property("name", "Bob").property("embedding", new float[]{0.0f, 1.0f, 0.0f}).iterate(); + tryCommitChanges(graph); + + graph.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, indexConfig); + + // Update a vertex property but rollback + g.V().has("name", "Bob").property("embedding", new float[]{0.9f, 0.1f, 0.0f}).iterate(); + tryRollbackChanges(graph); + + // Bob's embedding should still be [0.0f, 1.0f, 0.0f] + final List<Vertex> nearest = graph.findNearestVertices("embedding", new float[]{0.0f, 1.0f, 0.0f}, 1); + assertNotNull(nearest); + assertEquals(1, nearest.size()); + assertEquals("Bob", nearest.get(0).value("name")); + } + + @Test + public void shouldHandleEmptyGraphForNearestVertices() { + graph.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, indexConfig); + final List<Vertex> nearest = graph.findNearestVertices("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); + assertNotNull(nearest); + assertEquals(0, nearest.size()); + } + + @Test + public void shouldHandleEmptyGraphForNearestEdges() { + graph.createIndex(TinkerIndexType.VECTOR, "embedding", Edge.class, indexConfig); + final List<Edge> nearest = graph.findNearestEdges("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); + assertNotNull(nearest); + assertEquals(0, nearest.size()); + } + + @Test(expected = IllegalStateException.class) + public void shouldThrowExceptionWhenIndexNotCreatedForNearestVertices() { + graph.findNearestVertices("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); + } + + @Test(expected = IllegalStateException.class) + public void shouldThrowExceptionWhenIndexNotCreatedForNearestEdges() { + graph.findNearestEdges("embedding", new float[]{1.0f, 0.0f, 0.0f}, 2); + } + + @Test(expected = IllegalStateException.class) + public void shouldThrowExceptionWhenIndexNotCreatedForNearestVerticesNoDefaultCount() { + graph.findNearestVertices("embedding", new float[]{1.0f, 0.0f, 0.0f}); + } + + @Test(expected = IllegalStateException.class) + public void shouldThrowExceptionWhenIndexNotCreatedForNearestEdgesNoDefaultCount() { + graph.findNearestEdges("embedding", new float[]{1.0f, 0.0f, 0.0f}); + } + + private void tryCommitChanges(final Graph graph) { + if (graph.features().graph().supportsTransactions()) + graph.tx().commit(); + } + + private void tryRollbackChanges(final Graph graph) { + if (graph.features().graph().supportsTransactions()) + graph.tx().rollback(); } -} +} \ No newline at end of file diff --git a/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionGraphTest.java b/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionGraphTest.java index 5df8de254f..338693eaae 100644 --- a/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionGraphTest.java +++ b/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionGraphTest.java @@ -399,7 +399,7 @@ public class TinkerTransactionGraphTest { countElementsInNewThreadTx(g, 2, 0); final Map<Object, Set<TinkerElementContainer<?>>> index = - (Map<Object, Set<TinkerElementContainer<?>>>) ((TinkerTransactionalIndex) g.vertexIndex).index.get("test-property"); + (Map<Object, Set<TinkerElementContainer<?>>>) ((TinkerTransactionIndex) g.vertexIndex).index.get("test-property"); assertNotNull(index); // should be only vertex vid in set assertEquals(1, index.get(1).size()); @@ -441,7 +441,7 @@ public class TinkerTransactionGraphTest { countElementsInNewThreadTx(g, 2, 0); final Map<Object, Set<TinkerElementContainer<?>>> index = - (Map<Object, Set<TinkerElementContainer<?>>>) ((TinkerTransactionalIndex) g.vertexIndex).index.get("test-property"); + (Map<Object, Set<TinkerElementContainer<?>>>) ((TinkerTransactionIndex) g.vertexIndex).index.get("test-property"); assertNotNull(index); // should be only vertex vid in set assertEquals(1, index.get(AbstractTinkerIndex.IndexedNull.instance()).size()); @@ -481,7 +481,7 @@ public class TinkerTransactionGraphTest { countElementsInNewThreadTx(g, 0, 0); final Map<Object, Set<TinkerElementContainer<?>>> index = - (Map<Object, Set<TinkerElementContainer<?>>>) ((TinkerTransactionalIndex) g.vertexIndex).index.get("test-property"); + (Map<Object, Set<TinkerElementContainer<?>>>) ((TinkerTransactionIndex) g.vertexIndex).index.get("test-property"); assertNotNull(index); // should be only vertex vid in set assertEquals(0, index.size()); @@ -520,7 +520,7 @@ public class TinkerTransactionGraphTest { countElementsInNewThreadTx(g, 1, 0); final Map<Object, Set<TinkerElementContainer<?>>> index = - (Map<Object, Set<TinkerElementContainer<?>>>) ((TinkerTransactionalIndex) g.vertexIndex).index.get("test-property"); + (Map<Object, Set<TinkerElementContainer<?>>>) ((TinkerTransactionIndex) g.vertexIndex).index.get("test-property"); assertNotNull(index); // should be only vertex vid in set assertEquals(0, index.size()); @@ -560,7 +560,7 @@ public class TinkerTransactionGraphTest { countElementsInNewThreadTx(g, 2, 1); final Map<Object, Set<TinkerElementContainer<?>>> index = - (Map<Object, Set<TinkerElementContainer<?>>>) ((TinkerTransactionalIndex) g.edgeIndex).index.get("test-property"); + (Map<Object, Set<TinkerElementContainer<?>>>) ((TinkerTransactionIndex) g.edgeIndex).index.get("test-property"); assertNotNull(index); // should be only vertex vid in set assertEquals(1, index.size()); @@ -606,7 +606,7 @@ public class TinkerTransactionGraphTest { countElementsInNewThreadTx(g, 2, 1); final Map<Object, Set<TinkerElementContainer<?>>> index = - (Map<Object, Set<TinkerElementContainer<?>>>) ((TinkerTransactionalIndex) g.edgeIndex).index.get("test-property"); + (Map<Object, Set<TinkerElementContainer<?>>>) ((TinkerTransactionIndex) g.edgeIndex).index.get("test-property"); assertNotNull(index); // should be only vertex vid in set assertEquals(1, index.size()); @@ -649,7 +649,7 @@ public class TinkerTransactionGraphTest { countElementsInNewThreadTx(g, 2, 0); final Map<Object, Set<TinkerElementContainer<?>>> index = - (Map<Object, Set<TinkerElementContainer<?>>>) ((TinkerTransactionalIndex) g.edgeIndex).index.get("test-property"); + (Map<Object, Set<TinkerElementContainer<?>>>) ((TinkerTransactionIndex) g.edgeIndex).index.get("test-property"); assertNotNull(index); // should be only vertex vid in set assertEquals(0, index.size()); @@ -690,7 +690,7 @@ public class TinkerTransactionGraphTest { countElementsInNewThreadTx(g, 2, 1); final Map<Object, Set<TinkerElementContainer<?>>> index = - (Map<Object, Set<TinkerElementContainer<?>>>) ((TinkerTransactionalIndex) g.edgeIndex).index.get("test-property"); + (Map<Object, Set<TinkerElementContainer<?>>>) ((TinkerTransactionIndex) g.edgeIndex).index.get("test-property"); assertNotNull(index); // should be only vertex vid in set assertEquals(1, index.size()); @@ -738,7 +738,7 @@ public class TinkerTransactionGraphTest { countElementsInNewThreadTx(g, 2, 1); final Map<Object, Set<TinkerElementContainer<?>>> index = - (Map<Object, Set<TinkerElementContainer<?>>>) ((TinkerTransactionalIndex) g.edgeIndex).index.get("test-property"); + (Map<Object, Set<TinkerElementContainer<?>>>) ((TinkerTransactionIndex) g.edgeIndex).index.get("test-property"); assertNotNull(index); // should be only vertex vid in set assertEquals(0, index.size());
