This is an automated email from the ASF dual-hosted git repository. spmallette pushed a commit to branch TINKERPOP-3158 in repository https://gitbox.apache.org/repos/asf/tinkerpop.git
commit 63173b1300f095b4f1c6c772ed1685a80b50989f Author: Stephen Mallette <[email protected]> AuthorDate: Thu May 8 11:23:33 2025 -0400 distance calc and search by embedded --- CHANGELOG.asciidoc | 5 +- .../reference/implementations-tinkergraph.asciidoc | 29 +- docs/src/upgrade/release-3.8.x.asciidoc | 2 +- .../jsr223/TinkerGraphGremlinPlugin.java | 4 +- ...ctory.java => TinkerVectorDistanceFactory.java} | 67 +++-- ...ava => TinkerVectorSearchByElementFactory.java} | 11 +- ...a => TinkerVectorSearchByEmbeddingFactory.java} | 55 ++-- .../structure/TinkerGraphServiceTest.java | 319 ++++++++++++++++++++- 8 files changed, 401 insertions(+), 91 deletions(-) diff --git a/CHANGELOG.asciidoc b/CHANGELOG.asciidoc index 826c754433..5fe08746be 100644 --- a/CHANGELOG.asciidoc +++ b/CHANGELOG.asciidoc @@ -25,8 +25,9 @@ image::https://raw.githubusercontent.com/apache/tinkerpop/master/docs/static/ima This release also includes changes from <<release-3-7-XXX, 3.7.XXX>>. -* Added vector search to TinkerGraph. -* Renamed the regex based seach service in TinkerGraph to `tinker.search.text`. +* Added vector indexing to TinkerGraph with search services for `tinker.search.vector.topKByElement` and `tinker.search.vector.topKByEmbedding`. +* Added vector distance calculation functions for TinkerGraph. +* Renamed the regex based search service in TinkerGraph to `tinker.search.text`. * Modified `TraversalStrategy` construction in Javascript where configurations are now supplied as a `Map` of options. * Fixed bug in GraphSON v2 and v3 where full round trip of `TraversalStrategy` implementations was failing. * Added missing strategies to the `TraversalStrategies` global cache as well as `CoreImports` in `gremlin-groovy`. diff --git a/docs/src/reference/implementations-tinkergraph.asciidoc b/docs/src/reference/implementations-tinkergraph.asciidoc index 35564d7842..b116c56dc9 100644 --- a/docs/src/reference/implementations-tinkergraph.asciidoc +++ b/docs/src/reference/implementations-tinkergraph.asciidoc @@ -280,20 +280,28 @@ Here's a complete example: [gremlin-groovy] ---- -graph.getServiceRegistry().registerService(new TinkerVectorSearchFactory(graph)) <1> -indexConfig = [dimension: 3] <2> +graph.getServiceRegistry().registerService(new TinkerVectorSearchByElementFactory(graph)) <1> +graph.getServiceRegistry().registerService(new TinkerVectorSearchByEmbeddingFactory(graph)) <2> +indexConfig = [dimension: 3] <3> graph.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, indexConfig) g.addV("person").property("name", "Alice").property("embedding", new float[]{1.0f, 0.0f, 0.0f}).iterate() g.addV("person").property("name", "Bob").property("embedding", new float[]{0.0f, 1.0f, 0.0f}).iterate() g.addV("person").property("name", "Charlie").property("embedding", new float[]{0.0f, 0.0f, 1.0f}).iterate() g.addV("person").property("name", "Dave").property("embedding", new float[]{0.9f, 0.1f, 0.0f}).iterate() -params = [key: "embedding", topK: 2] <3> -g.V().has("name", "Alice").call("tinker.search.vector.topKByElement", params).toList() +byElementParams = [key: "embedding", topK: 2] <4> +g.V().has("name", "Alice").call("tinker.search.vector.topKByElement", byElementParams).toList() <5> +byElementParams = [key: "embedding", topK: 2, element: "vertex"] <6> +embedding = new float[]{1.0f, 0.0f, 0.0f} +g.inject([embedding]).unfold().call("tinker.search.vector.topKByEmbedding", params).toList() <7> ---- -<1> Register the vector search service. -<2> Configuration for the vector index that defines the embedding dimension of size 3. -<3> Specify the property key containing the embedding and number of results to return. +<1> Register the vector search service for "topKByElement". +<2> Register the vector search service for "topKByEmbedding". +<3> Configuration for the vector index that defines the embedding dimension of size 3. +<4> Specify the property key containing the embedding and number of results to return. +<5> Search the vector index for vertices like "Alice". +<6> Specify the property key containing the embedding, number of results to return, and the index type to search (i.e. either "vertex" or "edge"). +<7> Given an embedding, search the "vertex" index. The `call()` step returns a list of maps, each containing: @@ -304,7 +312,7 @@ Vector indices can also be created for edges: [gremlin-groovy] ---- -graph.getServiceRegistry().registerService(new TinkerVectorSearchFactory(graph)) +graph.getServiceRegistry().registerService(new TinkerVectorSearchByElementFactory(graph)) graph.createIndex(TinkerIndexType.VECTOR, "embedding", Edge.class, indexConfig) alice = g.V().has("name", "Alice").next() bob = g.V().has("name", "Bob").next() @@ -329,7 +337,7 @@ You can specify the distance function when creating the vector index: [gremlin-groovy] ---- -graph.getServiceRegistry().registerService(new TinkerVectorSearchFactory(graph)) +graph.getServiceRegistry().registerService(new TinkerVectorSearchByElementFactory(graph)) indexConfig = [dimension: 3, distanceType: TinkerIndexType.Vector.EUCLIDEAN] graph.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, indexConfig) ---- @@ -369,6 +377,9 @@ graph.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, indexConfig TIP: Constants for all the configuration values can be found in `TinkerVectorIndex`. They are prefixed with "CONFIG_". For example, "dimension" can be referenced as `TinkerVectorIndex.CONFIG_DIMENSION`. +Note that the distance functions can be used directly with the `TinkerVectorDistanceFactory` service. It allows +calculation of the distance between the elements + [[tinkergraph-gremlin-tx]] === Transactions diff --git a/docs/src/upgrade/release-3.8.x.asciidoc b/docs/src/upgrade/release-3.8.x.asciidoc index 0604fb8fa9..49c381cba7 100644 --- a/docs/src/upgrade/release-3.8.x.asciidoc +++ b/docs/src/upgrade/release-3.8.x.asciidoc @@ -219,7 +219,7 @@ Here's a basic example using Groovy: // Create a graph and register the vector search service graph = TinkerGraph.open() g = traversal().with(graph) -graph.getServiceRegistry().registerService(new TinkerVectorSearchFactory(graph)) +graph.getServiceRegistry().registerService(new TinkerVectorSearchByElementFactory(graph)) // Create a vector index with dimension 3 indexConfig = [dimension: 3] diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/jsr223/TinkerGraphGremlinPlugin.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/jsr223/TinkerGraphGremlinPlugin.java index b218d1940f..daba0498a5 100644 --- a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/jsr223/TinkerGraphGremlinPlugin.java +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/jsr223/TinkerGraphGremlinPlugin.java @@ -31,7 +31,7 @@ import org.apache.tinkerpop.gremlin.tinkergraph.process.computer.TinkerWorkerPoo import org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerDegreeCentralityFactory; import org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerServiceRegistry; import org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerTextSearchFactory; -import org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorSearchFactory; +import org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorSearchByElementFactory; import org.apache.tinkerpop.gremlin.tinkergraph.structure.TinkerEdge; import org.apache.tinkerpop.gremlin.tinkergraph.structure.TinkerElement; import org.apache.tinkerpop.gremlin.tinkergraph.structure.TinkerFactory; @@ -61,7 +61,7 @@ public final class TinkerGraphGremlinPlugin extends AbstractGremlinPlugin { TinkerIoRegistryV1.class, TinkerIoRegistryV2.class, TinkerIoRegistryV3.class, - TinkerVectorSearchFactory.class, + TinkerVectorSearchByElementFactory.class, TinkerTextSearchFactory.class, TinkerDegreeCentralityFactory.class, TinkerServiceRegistry.TinkerServiceFactory.class, diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchFactory.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorDistanceFactory.java similarity index 53% copy from tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchFactory.java copy to tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorDistanceFactory.java index b0a793a814..5c00cbbdfb 100644 --- a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchFactory.java +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorDistanceFactory.java @@ -18,29 +18,29 @@ */ package org.apache.tinkerpop.gremlin.tinkergraph.services; +import com.github.jelmerk.hnswlib.core.DistanceFunction; +import org.apache.tinkerpop.gremlin.process.traversal.Path; import org.apache.tinkerpop.gremlin.process.traversal.Traverser; -import org.apache.tinkerpop.gremlin.structure.Edge; import org.apache.tinkerpop.gremlin.structure.Element; -import org.apache.tinkerpop.gremlin.structure.Vertex; import org.apache.tinkerpop.gremlin.structure.service.Service; import org.apache.tinkerpop.gremlin.structure.util.CloseableIterator; import org.apache.tinkerpop.gremlin.tinkergraph.structure.AbstractTinkerGraph; -import org.apache.tinkerpop.gremlin.tinkergraph.structure.TinkerIndexElement; +import org.apache.tinkerpop.gremlin.tinkergraph.structure.TinkerIndexType; import java.util.Collections; import java.util.Map; import java.util.Set; -import static org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorSearchFactory.Params.KEY; -import static org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorSearchFactory.Params.TOP_K; +import static org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorDistanceFactory.Params.KEY; +import static org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorDistanceFactory.Params.DISTANCE_FUNCTION; import static org.apache.tinkerpop.gremlin.util.CollectionUtil.asMap; /** - * Service to utilize a {@code TinkerVertexIndex} to do a vector search. + * Service to calculate distance between elements in a {@link Path} */ -public class TinkerVectorSearchFactory extends TinkerServiceRegistry.TinkerServiceFactory<Element, Map<String, Object>> implements Service<Element, Map<String, Object>> { +public class TinkerVectorDistanceFactory extends TinkerServiceRegistry.TinkerServiceFactory<Path, Float> implements Service<Path, Float> { - public static final String NAME = "tinker.search.vector.topKByElement"; + public static final String NAME = "tinker.vector.distance"; public interface Params { /** @@ -48,17 +48,17 @@ public class TinkerVectorSearchFactory extends TinkerServiceRegistry.TinkerServi */ String KEY = "key"; /** - * Number of results to return + * The distance function to use. */ - String TOP_K = "topK"; + String DISTANCE_FUNCTION = "distanceFunction"; Map DESCRIBE = asMap( KEY, "Specify they key storing the embedding for the vector search", - TOP_K, "Number of results to return (optional, defaults to 10)" + DISTANCE_FUNCTION, "The distance function to use in the calculation as specified by the TinkerIndexType.Vector name (optional, defaults to \"COSINE\")" ); } - public TinkerVectorSearchFactory(final AbstractTinkerGraph graph) { + public TinkerVectorDistanceFactory(final AbstractTinkerGraph graph) { super(graph, NAME); } @@ -78,7 +78,7 @@ public class TinkerVectorSearchFactory extends TinkerServiceRegistry.TinkerServi } @Override - public Service<Element, Map<String, Object>> createService(final boolean isStart, final Map params) { + public Service<Path, Float> createService(final boolean isStart, final Map params) { if (isStart) { throw new UnsupportedOperationException(Exceptions.cannotStartTraversal); } @@ -87,33 +87,38 @@ public class TinkerVectorSearchFactory extends TinkerServiceRegistry.TinkerServi throw new IllegalArgumentException("The parameter map must contain the key where the embedding is: " + KEY); } + if (params.containsKey(DISTANCE_FUNCTION)) { + try { + TinkerIndexType.Vector.valueOf(params.get(DISTANCE_FUNCTION).toString()); + } catch (IllegalArgumentException ex) { + throw new IllegalArgumentException("Invalid value for " + DISTANCE_FUNCTION + ". Must be a valid TinkerIndexType.Vector.", ex); + } + } + return this; } @Override - public CloseableIterator<Map<String,Object>> execute(final ServiceCallContext ctx, final Traverser.Admin<Element> in, final Map params) { + public CloseableIterator<Float> execute(final ServiceCallContext ctx, final Traverser.Admin<Path> in, final Map params) { final String key = (String) params.get(KEY); - // add 1 because we always filter 1 out of the index because it will return a match on itself - final int k = (int) params.getOrDefault(TOP_K, 10) + 1; - final Element e = in.get(); + final TinkerIndexType.Vector vector = TinkerIndexType.Vector.valueOf( + params.getOrDefault(Params.DISTANCE_FUNCTION, TinkerIndexType.Vector.COSINE).toString()); + final DistanceFunction<float[], Float> distanceFunction = vector.getDistanceFunction(); - // if the current element does not have the specified key, then return no results - if (!e.keys().contains(key)) - return CloseableIterator.empty(); + final Path path = in.get(); + final int pathLength = path.size(); + final Element start = path.get(0); + final Element end = path.get(pathLength - 1); - final float[] embedding = e.value(key); - if (e instanceof Vertex) { - return CloseableIterator.of(graph.findNearestVertices(key, embedding, k).stream() - .filter(tie -> !tie.getElement().equals(e)) - .map(TinkerIndexElement::toMap).iterator()); - } else if (e instanceof Edge) { - return CloseableIterator.of(graph.findNearestEdges(key, embedding, k).stream() - .filter(tie -> !tie.getElement().equals(e)) - .map(TinkerIndexElement::toMap).iterator()); - } else { + // if the elements do not have the specified key, then return no results because there's nothing we can + // calculate distance on + if (!start.keys().contains(key) || !end.keys().contains(key)) return CloseableIterator.empty(); - } + + final float[] startEmbedding = start.value(key); + final float[] endEmbedding = end.value(key); + return CloseableIterator.of(Collections.singleton(distanceFunction.distance(startEmbedding, endEmbedding)).iterator()); } @Override diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchFactory.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchByElementFactory.java similarity index 91% copy from tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchFactory.java copy to tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchByElementFactory.java index b0a793a814..2542b0980e 100644 --- a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchFactory.java +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchByElementFactory.java @@ -31,14 +31,15 @@ import java.util.Collections; import java.util.Map; import java.util.Set; -import static org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorSearchFactory.Params.KEY; -import static org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorSearchFactory.Params.TOP_K; +import static org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorSearchByElementFactory.Params.KEY; +import static org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorSearchByElementFactory.Params.TOP_K; import static org.apache.tinkerpop.gremlin.util.CollectionUtil.asMap; /** - * Service to utilize a {@code TinkerVertexIndex} to do a vector search. + * Service to utilize a {@code TinkerVertexIndex} to do a vector search using an embedding from a supplied vertex + * or edge. */ -public class TinkerVectorSearchFactory extends TinkerServiceRegistry.TinkerServiceFactory<Element, Map<String, Object>> implements Service<Element, Map<String, Object>> { +public class TinkerVectorSearchByElementFactory extends TinkerServiceRegistry.TinkerServiceFactory<Element, Map<String, Object>> implements Service<Element, Map<String, Object>> { public static final String NAME = "tinker.search.vector.topKByElement"; @@ -58,7 +59,7 @@ public class TinkerVectorSearchFactory extends TinkerServiceRegistry.TinkerServi ); } - public TinkerVectorSearchFactory(final AbstractTinkerGraph graph) { + public TinkerVectorSearchByElementFactory(final AbstractTinkerGraph graph) { super(graph, NAME); } diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchFactory.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchByEmbeddingFactory.java similarity index 65% rename from tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchFactory.java rename to tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchByEmbeddingFactory.java index b0a793a814..45d3f3025a 100644 --- a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchFactory.java +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchByEmbeddingFactory.java @@ -18,29 +18,29 @@ */ package org.apache.tinkerpop.gremlin.tinkergraph.services; +import org.apache.commons.lang3.ArrayUtils; import org.apache.tinkerpop.gremlin.process.traversal.Traverser; -import org.apache.tinkerpop.gremlin.structure.Edge; -import org.apache.tinkerpop.gremlin.structure.Element; -import org.apache.tinkerpop.gremlin.structure.Vertex; import org.apache.tinkerpop.gremlin.structure.service.Service; import org.apache.tinkerpop.gremlin.structure.util.CloseableIterator; import org.apache.tinkerpop.gremlin.tinkergraph.structure.AbstractTinkerGraph; import org.apache.tinkerpop.gremlin.tinkergraph.structure.TinkerIndexElement; +import java.util.Arrays; import java.util.Collections; import java.util.Map; import java.util.Set; -import static org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorSearchFactory.Params.KEY; -import static org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorSearchFactory.Params.TOP_K; +import static org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorSearchByElementFactory.Params.KEY; +import static org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorSearchByElementFactory.Params.TOP_K; +import static org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorSearchByEmbeddingFactory.Params.ELEMENT; import static org.apache.tinkerpop.gremlin.util.CollectionUtil.asMap; /** - * Service to utilize a {@code TinkerVertexIndex} to do a vector search. + * Service to utilize a {@code TinkerVertexIndex} to do a vector search using a specified embedding. */ -public class TinkerVectorSearchFactory extends TinkerServiceRegistry.TinkerServiceFactory<Element, Map<String, Object>> implements Service<Element, Map<String, Object>> { +public class TinkerVectorSearchByEmbeddingFactory extends TinkerServiceRegistry.TinkerServiceFactory<Float[], Map<String, Object>> implements Service<Float[], Map<String, Object>> { - public static final String NAME = "tinker.search.vector.topKByElement"; + public static final String NAME = "tinker.search.vector.topKByEmbedding"; public interface Params { /** @@ -51,14 +51,19 @@ public class TinkerVectorSearchFactory extends TinkerServiceRegistry.TinkerServi * Number of results to return */ String TOP_K = "topK"; + /** + * Specify whether the search should be for a "vertex" or "edge" + */ + String ELEMENT = "element"; Map DESCRIBE = asMap( KEY, "Specify they key storing the embedding for the vector search", - TOP_K, "Number of results to return (optional, defaults to 10)" + TOP_K, "Number of results to return (optional, defaults to 10)", + ELEMENT, "Specify whether the search should be for a \"vertex\" or \"edge\"" ); } - public TinkerVectorSearchFactory(final AbstractTinkerGraph graph) { + public TinkerVectorSearchByEmbeddingFactory(final AbstractTinkerGraph graph) { super(graph, NAME); } @@ -78,38 +83,37 @@ public class TinkerVectorSearchFactory extends TinkerServiceRegistry.TinkerServi } @Override - public Service<Element, Map<String, Object>> createService(final boolean isStart, final Map params) { + public Service<Float[], Map<String, Object>> createService(final boolean isStart, final Map params) { if (isStart) { throw new UnsupportedOperationException(Exceptions.cannotStartTraversal); } if (!params.containsKey(KEY)) { - throw new IllegalArgumentException("The parameter map must contain the key where the embedding is: " + KEY); + throw new IllegalArgumentException("The parameter map must contain the key specifying where the embedding is: " + KEY); + } + + if (!params.containsKey(ELEMENT) || !Arrays.asList("vertex", "edge").contains(params.get(ELEMENT))) { + throw new IllegalArgumentException("The parameter map must contain a key: " + ELEMENT + " with a value of \"vertex\" or \"edge\""); } return this; } @Override - public CloseableIterator<Map<String,Object>> execute(final ServiceCallContext ctx, final Traverser.Admin<Element> in, final Map params) { + public CloseableIterator<Map<String,Object>> execute(final ServiceCallContext ctx, final Traverser.Admin<Float[]> in, final Map params) { final String key = (String) params.get(KEY); + final int k = (int) params.getOrDefault(TOP_K, 10); + final Object traverserObject = in.get(); - // add 1 because we always filter 1 out of the index because it will return a match on itself - final int k = (int) params.getOrDefault(TOP_K, 10) + 1; - final Element e = in.get(); + final float[] embedding = traverserObject instanceof Float[] ? + ArrayUtils.toPrimitive((Float[]) traverserObject) : (float[]) traverserObject; - // if the current element does not have the specified key, then return no results - if (!e.keys().contains(key)) - return CloseableIterator.empty(); - - final float[] embedding = e.value(key); - if (e instanceof Vertex) { + final String elementType = (String) params.get(ELEMENT); + if ("vertex".equals(elementType)) { return CloseableIterator.of(graph.findNearestVertices(key, embedding, k).stream() - .filter(tie -> !tie.getElement().equals(e)) .map(TinkerIndexElement::toMap).iterator()); - } else if (e instanceof Edge) { + } else if ("edge".equals(elementType)) { return CloseableIterator.of(graph.findNearestEdges(key, embedding, k).stream() - .filter(tie -> !tie.getElement().equals(e)) .map(TinkerIndexElement::toMap).iterator()); } else { return CloseableIterator.empty(); @@ -120,4 +124,3 @@ public class TinkerVectorSearchFactory extends TinkerServiceRegistry.TinkerServi public void close() {} } - diff --git a/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraphServiceTest.java b/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraphServiceTest.java index b7ba27d88e..79633bc958 100644 --- a/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraphServiceTest.java +++ b/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraphServiceTest.java @@ -30,7 +30,9 @@ import org.apache.tinkerpop.gremlin.structure.Vertex; import org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerDegreeCentralityFactory; import org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerServiceRegistry; import org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerTextSearchFactory; -import org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorSearchFactory; +import org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorDistanceFactory; +import org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorSearchByElementFactory; +import org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorSearchByEmbeddingFactory; import org.apache.tinkerpop.gremlin.util.function.TriFunction; import org.apache.tinkerpop.gremlin.util.iterator.IteratorUtils; import org.apache.tinkerpop.shaded.jackson.databind.ObjectMapper; @@ -61,7 +63,6 @@ import static org.apache.tinkerpop.gremlin.util.CollectionUtil.asMap; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; - /** * Demonstration of Service API. * @@ -455,7 +456,7 @@ public class TinkerGraphServiceTest { @Test public void g_V_callXvector_topKByVertex_key_embeddingX() { final TinkerGraph graf = TinkerGraph.open(); - graf.getServiceRegistry().registerService(new TinkerVectorSearchFactory(graf)); + graf.getServiceRegistry().registerService(new TinkerVectorSearchByElementFactory(graf)); graf.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, indexConfig); final GraphTraversalSource gv = graf.traversal(); @@ -467,7 +468,7 @@ public class TinkerGraphServiceTest { final Map<String,Object> m = new HashMap<String,Object>() {{ put("key", "embedding"); }}; - final List<Object> list = gv.V(vAlice).call(TinkerVectorSearchFactory.NAME, m).toList(); + final List<Object> list = gv.V(vAlice).call(TinkerVectorSearchByElementFactory.NAME, m).toList(); final List<Map<String,Object>> expected = new ArrayList<>(); expected.add(asMap("distance", 0.006116271f, "element", vDave)); @@ -485,7 +486,7 @@ public class TinkerGraphServiceTest { @Test public void g_E_callXvector_topKByEdge_key_embeddingX() { final TinkerGraph graf = TinkerGraph.open(); - graf.getServiceRegistry().registerService(new TinkerVectorSearchFactory(graf)); + graf.getServiceRegistry().registerService(new TinkerVectorSearchByElementFactory(graf)); graf.createIndex(TinkerIndexType.VECTOR, "embedding", Edge.class, indexConfig); final GraphTraversalSource gv = graf.traversal(); @@ -500,7 +501,7 @@ public class TinkerGraphServiceTest { final Map<String, Object> m = new HashMap<String, Object>() {{ put("key", "embedding"); }}; - final List<Object> list = gv.E(e1).call(TinkerVectorSearchFactory.NAME, m).toList(); + final List<Object> list = gv.E(e1).call(TinkerVectorSearchByElementFactory.NAME, m).toList(); final List<Map<String, Object>> expected = new ArrayList<>(); expected.add(asMap("distance", 0.29289323f, "element", e3)); @@ -517,7 +518,7 @@ public class TinkerGraphServiceTest { @Test public void g_V_callXvector_topKByVertex_key_embedding_topK_1X() { final TinkerGraph graf = TinkerGraph.open(); - graf.getServiceRegistry().registerService(new TinkerVectorSearchFactory(graf)); + graf.getServiceRegistry().registerService(new TinkerVectorSearchByElementFactory(graf)); graf.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, indexConfig); final GraphTraversalSource gv = graf.traversal(); @@ -530,7 +531,7 @@ public class TinkerGraphServiceTest { put("key", "embedding"); put("topK", 1); }}; - final List<Object> list = gv.V(vAlice).call(TinkerVectorSearchFactory.NAME, m).toList(); + final List<Object> list = gv.V(vAlice).call(TinkerVectorSearchByElementFactory.NAME, m).toList(); final List<Map<String,Object>> expected = new ArrayList<>(); expected.add(asMap("distance", 0.006116271f, "element", vDave)); @@ -546,7 +547,7 @@ public class TinkerGraphServiceTest { @Test(expected = IllegalArgumentException.class) public void g_V_callXvector_missing_keyX() { final TinkerGraph graf = TinkerGraph.open(); - graf.getServiceRegistry().registerService(new TinkerVectorSearchFactory(graf)); + graf.getServiceRegistry().registerService(new TinkerVectorSearchByElementFactory(graf)); graf.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, indexConfig); final GraphTraversalSource gv = graf.traversal(); @@ -554,13 +555,13 @@ public class TinkerGraphServiceTest { // Missing key parameter should cause a IllegalArgumentException when trying to access it final Map<String,Object> emptyParams = new HashMap<>(); - gv.V(vAlice).call(TinkerVectorSearchFactory.NAME, emptyParams).toList(); + gv.V(vAlice).call(TinkerVectorSearchByElementFactory.NAME, emptyParams).toList(); } @Test public void g_V_callXvector_nonexistent_propertyX() { final TinkerGraph graf = TinkerGraph.open(); - graf.getServiceRegistry().registerService(new TinkerVectorSearchFactory(graf)); + graf.getServiceRegistry().registerService(new TinkerVectorSearchByElementFactory(graf)); graf.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, indexConfig); final GraphTraversalSource gv = graf.traversal(); @@ -571,13 +572,13 @@ public class TinkerGraphServiceTest { put("key", "embedding"); }}; - assertEquals(0, gv.V(vAlice).call(TinkerVectorSearchFactory.NAME, params).count().next().intValue()); + assertEquals(0, gv.V(vAlice).call(TinkerVectorSearchByElementFactory.NAME, params).count().next().intValue()); } @Test public void g_V_properties_callXvector_nonexistent_propertyX() { final TinkerGraph graf = TinkerGraph.open(); - graf.getServiceRegistry().registerService(new TinkerVectorSearchFactory(graf)); + graf.getServiceRegistry().registerService(new TinkerVectorSearchByElementFactory(graf)); graf.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, indexConfig); final GraphTraversalSource gv = graf.traversal(); @@ -588,7 +589,296 @@ public class TinkerGraphServiceTest { }}; // Referencing things other than vertex or edge should return no results - assertEquals(0, gv.V(vAlice).properties().call(TinkerVectorSearchFactory.NAME, params).count().next().intValue()); + assertEquals(0, gv.V(vAlice).properties().call(TinkerVectorSearchByElementFactory.NAME, params).count().next().intValue()); + } + + @Test + public void g_V_callXvector_topKByEmbedding_vertexX() { + final TinkerGraph graf = TinkerGraph.open(); + graf.getServiceRegistry().registerService(new TinkerVectorSearchByEmbeddingFactory(graf)); + graf.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, indexConfig); + final GraphTraversalSource gv = graf.traversal(); + + final Vertex vAlice = gv.addV("person").property("name", "Alice").property("embedding", new float[]{1.0f, 0.0f, 0.0f}).next(); + final Vertex vBob = gv.addV("person").property("name", "Bob").property("embedding", new float[]{0.05f, 1.0f, 0.0f}).next(); + final Vertex vCharlie = gv.addV("person").property("name", "Charlie").property("embedding", new float[]{0.0f, 0.0f, 1.0f}).next(); + final Vertex vDave = gv.addV("person").property("name", "Dave").property("embedding", new float[]{0.9f, 0.1f, 0.0f}).next(); + + final Map<String,Object> m = new HashMap<String,Object>() {{ + put("key", "embedding"); + put("element", "vertex"); + }}; + + // Create an embedding to search with + final Float[] searchEmbedding = new Float[]{1.0f, 0.0f, 0.0f}; + + final List<Object> list = gv.inject(0).constant(searchEmbedding).call(TinkerVectorSearchByEmbeddingFactory.NAME, m).toList(); + + final List<Map<String,Object>> expected = new ArrayList<>(); + expected.add(asMap("distance", 0.0f, "element", vAlice)); + expected.add(asMap("distance", 0.006116271f, "element", vDave)); + expected.add(asMap("distance", 0.9500624f, "element", vBob)); + expected.add(asMap("distance", 1.0f, "element", vCharlie)); + + // Use a custom comparison to ensure the lists are equal + assertEquals(expected.size(), list.size()); + for (int i = 0; i < expected.size(); i++) { + assertEquals(expected.get(i).get("distance"), ((Map) list.get(i)).get("distance")); + assertEquals(expected.get(i).get("element"), ((Map) list.get(i)).get("element")); + } + } + + @Test + public void g_E_callXvector_topKByEmbedding_edgeX() { + final TinkerGraph graf = TinkerGraph.open(); + graf.getServiceRegistry().registerService(new TinkerVectorSearchByEmbeddingFactory(graf)); + graf.createIndex(TinkerIndexType.VECTOR, "embedding", Edge.class, indexConfig); + final GraphTraversalSource gv = graf.traversal(); + + final Vertex vAlice = gv.addV("person").property("name", "Alice").next(); + final Vertex vBob = gv.addV("person").property("name", "Bob").next(); + final Vertex vCharlie = gv.addV("person").property("name", "Charlie").next(); + + final Edge e1 = gv.addE("knows").from(vAlice).to(vBob).property("embedding", new float[]{1.0f, 0.0f, 0.0f}).next(); + final Edge e2 = gv.addE("knows").from(vAlice).to(vCharlie).property("embedding", new float[]{0.0f, 1.0f, 0.0f}).next(); + final Edge e3 = gv.addE("knows").from(vBob).to(vCharlie).property("embedding", new float[]{0.5f, 0.5f, 0.0f}).next(); + + final Map<String, Object> m = new HashMap<String, Object>() {{ + put("key", "embedding"); + put("element", "edge"); + }}; + + // Create an embedding to search with + final Float[] searchEmbedding = new Float[]{1.0f, 0.0f, 0.0f}; + + final List<Object> list = gv.inject(0).constant(searchEmbedding).call(TinkerVectorSearchByEmbeddingFactory.NAME, m).toList(); + + final List<Map<String, Object>> expected = new ArrayList<>(); + expected.add(asMap("distance", 0.0f, "element", e1)); + expected.add(asMap("distance", 0.29289323f, "element", e3)); + expected.add(asMap("distance", 1.0f, "element", e2)); + + // Use a custom comparison to ensure the lists are equal + assertEquals(expected.size(), list.size()); + for (int i = 0; i < expected.size(); i++) { + assertEquals(expected.get(i).get("distance"), ((Map) list.get(i)).get("distance")); + assertEquals(expected.get(i).get("element"), ((Map) list.get(i)).get("element")); + } + } + + @Test + public void g_V_callXvector_topKByEmbedding_vertex_topK_1X() { + final TinkerGraph graf = TinkerGraph.open(); + graf.getServiceRegistry().registerService(new TinkerVectorSearchByEmbeddingFactory(graf)); + graf.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, indexConfig); + final GraphTraversalSource gv = graf.traversal(); + + final Vertex vAlice = gv.addV("person").property("name", "Alice").property("embedding", new float[]{1.0f, 0.0f, 0.0f}).next(); + final Vertex vBob = gv.addV("person").property("name", "Bob").property("embedding", new float[]{0.05f, 1.0f, 0.0f}).next(); + final Vertex vCharlie = gv.addV("person").property("name", "Charlie").property("embedding", new float[]{0.0f, 0.0f, 1.0f}).next(); + final Vertex vDave = gv.addV("person").property("name", "Dave").property("embedding", new float[]{0.9f, 0.1f, 0.0f}).next(); + + final Map<String,Object> m = new HashMap<String,Object>() {{ + put("key", "embedding"); + put("element", "vertex"); + put("topK", 1); + }}; + + // Create an embedding to search with + final Float[] searchEmbedding = new Float[]{1.0f, 0.0f, 0.0f}; + + final List<Object> list = gv.inject(0).constant(searchEmbedding).call(TinkerVectorSearchByEmbeddingFactory.NAME, m).toList(); + + final List<Map<String,Object>> expected = new ArrayList<>(); + expected.add(asMap("distance", 0.0f, "element", vAlice)); + + // Use a custom comparison to ensure the lists are equal + assertEquals(expected.size(), list.size()); + for (int i = 0; i < expected.size(); i++) { + assertEquals(expected.get(i).get("distance"), ((Map) list.get(i)).get("distance")); + assertEquals(expected.get(i).get("element"), ((Map) list.get(i)).get("element")); + } + } + + @Test(expected = IllegalArgumentException.class) + public void g_V_callXvector_topKByEmbedding_missing_keyX() { + final TinkerGraph graf = TinkerGraph.open(); + graf.getServiceRegistry().registerService(new TinkerVectorSearchByEmbeddingFactory(graf)); + graf.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, indexConfig); + final GraphTraversalSource gv = graf.traversal(); + + // Create an embedding to search with + final Float[] searchEmbedding = new Float[]{1.0f, 0.0f, 0.0f}; + + // Missing key parameter should cause a IllegalArgumentException when trying to access it + final Map<String,Object> params = new HashMap<String,Object>() {{ + put("element", "vertex"); + }}; + gv.inject(0).constant(searchEmbedding).call(TinkerVectorSearchByEmbeddingFactory.NAME, params).toList(); + } + + @Test(expected = IllegalArgumentException.class) + public void g_V_callXvector_topKByEmbedding_missing_elementX() { + final TinkerGraph graf = TinkerGraph.open(); + graf.getServiceRegistry().registerService(new TinkerVectorSearchByEmbeddingFactory(graf)); + graf.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, indexConfig); + final GraphTraversalSource gv = graf.traversal(); + + // Create an embedding to search with + final Float[] searchEmbedding = new Float[]{1.0f, 0.0f, 0.0f}; + + // Missing element parameter should cause a IllegalArgumentException + final Map<String,Object> params = new HashMap<String,Object>() {{ + put("key", "embedding"); + }}; + gv.inject(0).constant(searchEmbedding).call(TinkerVectorSearchByEmbeddingFactory.NAME, params).toList(); + } + + @Test(expected = IllegalArgumentException.class) + public void g_V_callXvector_topKByEmbedding_invalid_elementX() { + final TinkerGraph graf = TinkerGraph.open(); + graf.getServiceRegistry().registerService(new TinkerVectorSearchByEmbeddingFactory(graf)); + graf.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, indexConfig); + final GraphTraversalSource gv = graf.traversal(); + + // Create an embedding to search with + final Float[] searchEmbedding = new Float[]{1.0f, 0.0f, 0.0f}; + + // Invalid element parameter should cause a IllegalArgumentException + final Map<String,Object> params = new HashMap<String,Object>() {{ + put("key", "embedding"); + put("element", "invalid"); + }}; + gv.inject(0).constant(searchEmbedding).call(TinkerVectorSearchByEmbeddingFactory.NAME, params).toList(); + } + + @Test + public void g_path_callXvector_distanceX() { + final TinkerGraph graf = TinkerGraph.open(); + graf.getServiceRegistry().registerService(new TinkerVectorDistanceFactory(graf)); + final GraphTraversalSource gv = graf.traversal(); + + // Create vertices with embeddings + final Vertex vAlice = gv.addV("person").property("name", "Alice") + .property("embedding", new float[]{1.0f, 0.0f, 0.0f}).next(); + final Vertex vBob = gv.addV("person").property("name", "Bob") + .property("embedding", new float[]{0.0f, 1.0f, 0.0f}).next(); + final Vertex vCharlie = gv.addV("person").property("name", "Charlie") + .property("embedding", new float[]{0.0f, 0.0f, 1.0f}).next(); + + // Create edges to connect vertices + gv.addE("knows").from(vAlice).to(vBob).next(); + gv.addE("knows").from(vBob).to(vCharlie).next(); + + // Test path with Alice -> Bob + final Map<String, Object> params = new HashMap<String, Object>() {{ + put("key", "embedding"); + }}; + + // Path from Alice to Bob - cosine distance should be 1.0 (orthogonal vectors) + final Float aliceToBobDistance = (Float) gv.V(vAlice).outE().inV().path().call(TinkerVectorDistanceFactory.NAME, params).next(); + assertEquals(1.0f, aliceToBobDistance, 0.0001f); + + // Path from Alice to Charlie - cosine distance should be 1.0 (orthogonal vectors) + final Float aliceToCharlieDistance = (Float) gv.V(vAlice).out().out().path().call(TinkerVectorDistanceFactory.NAME, params).next(); + assertEquals(1.0f, aliceToCharlieDistance, 0.0001f); + + // Path from Bob to Charlie - cosine distance should be 1.0 (orthogonal vectors) + final Float bobToCharlieDistance = (Float) gv.V(vBob).out().path().call(TinkerVectorDistanceFactory.NAME, params).next(); + assertEquals(1.0f, bobToCharlieDistance, 0.0001f); + } + + @Test + public void g_path_callXvector_distance_with_different_distance_functionX() { + final TinkerGraph graf = TinkerGraph.open(); + graf.getServiceRegistry().registerService(new TinkerVectorDistanceFactory(graf)); + final GraphTraversalSource gv = graf.traversal(); + + // Create vertices with embeddings + final Vertex vAlice = gv.addV("person").property("name", "Alice") + .property("embedding", new float[]{1.0f, 0.0f, 0.0f}).next(); + final Vertex vBob = gv.addV("person").property("name", "Bob") + .property("embedding", new float[]{0.0f, 1.0f, 0.0f}).next(); + + // Create edge to connect vertices + gv.addE("knows").from(vAlice).to(vBob).next(); + + // Test with EUCLIDEAN distance function + final Map<String, Object> params = new HashMap<String, Object>() {{ + put("key", "embedding"); + put("distanceFunction", "EUCLIDEAN"); + }}; + + // Path from Alice to Bob - Euclidean distance should be sqrt(2) = 1.414... + final Float aliceToBobDistance = (Float) gv.V(vAlice).outE().inV().path().call(TinkerVectorDistanceFactory.NAME, params).next(); + assertEquals(1.4142f, aliceToBobDistance, 0.0001f); + } + + @Test + public void g_path_callXvector_distance_missing_embeddingX() { + final TinkerGraph graf = TinkerGraph.open(); + graf.getServiceRegistry().registerService(new TinkerVectorDistanceFactory(graf)); + final GraphTraversalSource gv = graf.traversal(); + + // Create vertices, one without embedding + final Vertex vAlice = gv.addV("person").property("name", "Alice") + .property("embedding", new float[]{1.0f, 0.0f, 0.0f}).next(); + final Vertex vBob = gv.addV("person").property("name", "Bob").next(); // No embedding + + // Create edge to connect vertices + gv.addE("knows").from(vAlice).to(vBob).next(); + + // Test path with Alice -> Bob where Bob has no embedding + final Map<String, Object> params = new HashMap<String, Object>() {{ + put("key", "embedding"); + }}; + + // Should return empty result since Bob doesn't have the embedding + final List<Object> results = gv.V(vAlice).outE().inV().path().call(TinkerVectorDistanceFactory.NAME, params).toList(); + assertEquals(0, results.size()); + } + + @Test(expected = IllegalArgumentException.class) + public void g_path_callXvector_distance_missing_keyX() { + final TinkerGraph graf = TinkerGraph.open(); + graf.getServiceRegistry().registerService(new TinkerVectorDistanceFactory(graf)); + final GraphTraversalSource gv = graf.traversal(); + + // Create vertices with embeddings + final Vertex vAlice = gv.addV("person").property("name", "Alice") + .property("embedding", new float[]{1.0f, 0.0f, 0.0f}).next(); + final Vertex vBob = gv.addV("person").property("name", "Bob") + .property("embedding", new float[]{0.0f, 1.0f, 0.0f}).next(); + + // Create edge to connect vertices + gv.addE("knows").from(vAlice).to(vBob).next(); + + // Missing key parameter should cause an IllegalArgumentException + final Map<String, Object> emptyParams = new HashMap<>(); + gv.V(vAlice).outE().inV().path().call(TinkerVectorDistanceFactory.NAME, emptyParams).next(); + } + + @Test(expected = IllegalArgumentException.class) + public void g_path_callXvector_distance_invalid_distance_functionX() { + final TinkerGraph graf = TinkerGraph.open(); + graf.getServiceRegistry().registerService(new TinkerVectorDistanceFactory(graf)); + final GraphTraversalSource gv = graf.traversal(); + + // Create vertices with embeddings + final Vertex vAlice = gv.addV("person").property("name", "Alice") + .property("embedding", new float[]{1.0f, 0.0f, 0.0f}).next(); + final Vertex vBob = gv.addV("person").property("name", "Bob") + .property("embedding", new float[]{0.0f, 1.0f, 0.0f}).next(); + + // Create edge to connect vertices + gv.addE("knows").from(vAlice).to(vBob).next(); + + // Invalid distance function should cause an IllegalArgumentException + final Map<String, Object> params = new HashMap<String, Object>() {{ + put("key", "embedding"); + put("distanceFunction", "INVALID_FUNCTION"); + }}; + gv.V(vAlice).outE().inV().path().call(TinkerVectorDistanceFactory.NAME, params).next(); } private String toResultString(final Traversal traversal) { @@ -604,5 +894,4 @@ public class TinkerGraphServiceTest { assertEquals("Did not produce exactly one result", 1, result.size()); assertEquals(expected, result.get(0)); } - }
