This is an automated email from the ASF dual-hosted git repository. spmallette pushed a commit to branch TINKERPOP-3158 in repository https://gitbox.apache.org/repos/asf/tinkerpop.git
commit 8c9ce0e5f0a7eab5ddbf6c9493c02b9fef98f071 Author: Stephen Mallette <[email protected]> AuthorDate: Tue Jun 2 19:18:48 2026 -0400 Add filter support to TinkerGraph vector topK search services Both topK.byElement and topK.byEmbedding now accept an optional 'filter' parameter (Map<String, Object>) that restricts which candidates are considered during ANN computation via JVector's Bits mask. Filtering happens inside the graph traversal rather than as a post-filter, so the full topK results are returned from the matching subset. The filter supports equality predicates on element properties and the special key "~label" for label matching. For byElement, self-exclusion of the source element is also applied via the Bits mask, replacing the previous +1 k hack with post-stream filtering. (tinkerpop-djo) Assisted-by: Claude Code:claude-sonnet-4-6 --- .../reference/implementations-tinkergraph.asciidoc | 32 ++++++++++ .../TinkerVectorSearchByElementFactory.java | 21 ++++--- .../TinkerVectorSearchByEmbeddingFactory.java | 14 ++++- .../tinkergraph/structure/AbstractTinkerGraph.java | 36 +++++++++++ .../structure/AbstractTinkerVectorIndex.java | 21 ++++++- .../structure/TinkerTransactionVectorIndex.java | 37 ++++++++++-- .../tinkergraph/structure/TinkerVectorIndex.java | 37 ++++++++++-- .../structure/TinkerGraphServiceTest.java | 70 ++++++++++++++++++++++ 8 files changed, 246 insertions(+), 22 deletions(-) diff --git a/docs/src/reference/implementations-tinkergraph.asciidoc b/docs/src/reference/implementations-tinkergraph.asciidoc index 84a63e35dd..c395d7e410 100644 --- a/docs/src/reference/implementations-tinkergraph.asciidoc +++ b/docs/src/reference/implementations-tinkergraph.asciidoc @@ -327,6 +327,38 @@ The `call()` step returns a list of maps, each containing: TIP: Vector indices can also be created for edges. +==== Filtering Vector Search + +Both `topK.byElement` and `topK.byEmbedding` accept an optional `filter` parameter that restricts +which candidates are considered during the ANN computation. Filtering happens inside the index +traversal rather than as a post-filter, so the full `topK` results are returned from the matching +subset. + +The filter is a map of property key to required value. The special key `"~label"` matches element +label: + +[gremlin-groovy] +---- +graph.getServiceRegistry().registerService(new TinkerVectorSearchByElementFactory(graph)) +graph.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, [dimension: 3]) +g.addV("person").property("name", "Alice").property("embedding", new float[]{1.0f, 0.0f, 0.0f}).iterate() +g.addV("person").property("name", "Bob").property("embedding", new float[]{0.9f, 0.1f, 0.0f}).iterate() +g.addV("robot").property("name", "R2D2").property("embedding", new float[]{0.95f, 0.05f, 0.0f}).iterate() + +// only robots, even though Bob (a person) is also nearby +g.V().has("name","Alice").call("tinker.search.vector.topK.byElement", + [key: "embedding", topK: 2, filter: ["~label": "robot"]]) +---- + +Property value equality can also be used: + +[gremlin-groovy] +---- +// only active vertices +g.V().has("name","Alice").call("tinker.search.vector.topK.byElement", + [key: "embedding", topK: 5, filter: [active: true]]) +---- + TinkerGraph supports the following distance functions for vector similarity search: * `COSINE`: Measures the cosine of the angle between two vectors (default) diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchByElementFactory.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchByElementFactory.java index 9dffdeb6ac..6967b4b330 100644 --- a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchByElementFactory.java +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchByElementFactory.java @@ -31,6 +31,7 @@ import java.util.Collections; import java.util.Map; import java.util.Set; +import static org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorSearchByElementFactory.Params.FILTER; import static org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorSearchByElementFactory.Params.KEY; import static org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorSearchByElementFactory.Params.TOP_K; import static org.apache.tinkerpop.gremlin.util.CollectionUtil.asMap; @@ -52,10 +53,16 @@ public class TinkerVectorSearchByElementFactory extends TinkerServiceRegistry.Ti * Number of results to return */ String TOP_K = "topK"; + /** + * Optional map of property key to required value restricting which candidates are searched. + * Use the special key {@code "~label"} to filter by element label. + */ + String FILTER = "filter"; Map DESCRIBE = asMap( KEY, "Specify they key storing the embedding for the vector search", - TOP_K, "Number of results to return (optional, defaults to 10)" + TOP_K, "Number of results to return (optional, defaults to 10)", + FILTER, "Map of property key to required value to restrict candidates (optional); use \"~label\" to filter by label" ); } @@ -94,23 +101,19 @@ public class TinkerVectorSearchByElementFactory extends TinkerServiceRegistry.Ti @Override public CloseableIterator<Map<String,Object>> execute(final ServiceCallContext ctx, final Traverser.Admin<Element> in, final Map params) { final String key = (String) params.get(KEY); - - // add 1 because we always filter 1 out of the index because it will return a match on itself - final int k = (int) params.getOrDefault(TOP_K, 10) + 1; + final int k = (int) params.getOrDefault(TOP_K, 10); + final Map<String, Object> filter = (Map<String, Object>) params.get(FILTER); final Element e = in.get(); - // if the current element does not have the specified key, then return no results if (!e.keys().contains(key)) return CloseableIterator.empty(); final float[] embedding = e.value(key); if (e instanceof Vertex) { - return CloseableIterator.of(graph.findNearestVertices(key, embedding, k).stream() - .filter(tie -> !tie.getElement().equals(e)) + return CloseableIterator.of(graph.findNearestVertices(key, embedding, k, filter, e.id()).stream() .map(TinkerIndexElement::toMap).iterator()); } else if (e instanceof Edge) { - return CloseableIterator.of(graph.findNearestEdges(key, embedding, k).stream() - .filter(tie -> !tie.getElement().equals(e)) + return CloseableIterator.of(graph.findNearestEdges(key, embedding, k, filter, e.id()).stream() .map(TinkerIndexElement::toMap).iterator()); } else { return CloseableIterator.empty(); diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchByEmbeddingFactory.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchByEmbeddingFactory.java index 22498b76db..2a5099ea0d 100644 --- a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchByEmbeddingFactory.java +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchByEmbeddingFactory.java @@ -33,6 +33,7 @@ import java.util.Set; import static org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorSearchByElementFactory.Params.KEY; import static org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorSearchByElementFactory.Params.TOP_K; import static org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorSearchByEmbeddingFactory.Params.ELEMENT; +import static org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorSearchByEmbeddingFactory.Params.FILTER; import static org.apache.tinkerpop.gremlin.util.CollectionUtil.asMap; /** @@ -55,11 +56,17 @@ public class TinkerVectorSearchByEmbeddingFactory extends TinkerServiceRegistry. * Specify whether the search should be for a "vertex" or "edge" */ String ELEMENT = "element"; + /** + * Optional map of property key to required value restricting which candidates are searched. + * Use the special key {@code "~label"} to filter by element label. + */ + String FILTER = "filter"; Map DESCRIBE = asMap( KEY, "Specify they key storing the embedding for the vector search", TOP_K, "Number of results to return (optional, defaults to 10)", - ELEMENT, "Specify whether the search should be for a \"vertex\" or \"edge\"" + ELEMENT, "Specify whether the search should be for a \"vertex\" or \"edge\"", + FILTER, "Map of property key to required value to restrict candidates (optional); use \"~label\" to filter by label" ); } @@ -103,6 +110,7 @@ public class TinkerVectorSearchByEmbeddingFactory extends TinkerServiceRegistry. public CloseableIterator<Map<String,Object>> execute(final ServiceCallContext ctx, final Traverser.Admin<Float[]> in, final Map params) { final String key = (String) params.get(KEY); final int k = (int) params.getOrDefault(TOP_K, 10); + final Map<String, Object> filter = (Map<String, Object>) params.get(FILTER); final Object traverserObject = in.get(); final float[] embedding = traverserObject instanceof Float[] ? @@ -110,10 +118,10 @@ public class TinkerVectorSearchByEmbeddingFactory extends TinkerServiceRegistry. final String elementType = (String) params.get(ELEMENT); if ("vertex".equals(elementType)) { - return CloseableIterator.of(graph.findNearestVertices(key, embedding, k).stream() + return CloseableIterator.of(graph.findNearestVertices(key, embedding, k, filter, null).stream() .map(TinkerIndexElement::toMap).iterator()); } else if ("edge".equals(elementType)) { - return CloseableIterator.of(graph.findNearestEdges(key, embedding, k).stream() + return CloseableIterator.of(graph.findNearestEdges(key, embedding, k, filter, null).stream() .map(TinkerIndexElement::toMap).iterator()); } else { return CloseableIterator.empty(); diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerGraph.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerGraph.java index 763e837e41..e3705ce84f 100644 --- a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerGraph.java +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerGraph.java @@ -406,6 +406,42 @@ public abstract class AbstractTinkerGraph implements Graph { return new ArrayList<>(this.edgeVectorIndex.findNearestElements(key, vector)); } + /** + * Find the nearest vertices to the given vector, restricting candidates to those matching the filter. + * + * @param key the property key + * @param vector the query vector + * @param k the number of nearest neighbors to return + * @param filter map of property key to required value; {@code "~label"} matches element label + * @param excludeId element id to exclude from results + * @return a list of vertices sorted by distance + */ + public List<TinkerIndexElement<TinkerVertex>> findNearestVertices(final String key, final float[] vector, + final int k, final Map<String, Object> filter, + final Object excludeId) { + if (null == this.vertexVectorIndex) + throw new IllegalStateException("Vector index not created for vertices on key: '" + key + "'"); + return this.vertexVectorIndex.findNearest(key, vector, k, filter, excludeId); + } + + /** + * Find the nearest edges to the given vector, restricting candidates to those matching the filter. + * + * @param key the property key + * @param vector the query vector + * @param k the number of nearest neighbors to return + * @param filter map of property key to required value; {@code "~label"} matches element label + * @param excludeId element id to exclude from results + * @return a list of edges sorted by distance + */ + public List<TinkerIndexElement<TinkerEdge>> findNearestEdges(final String key, final float[] vector, + final int k, final Map<String, Object> filter, + final Object excludeId) { + if (null == this.edgeVectorIndex) + throw new IllegalStateException("Vector index not created for edges on key: '" + key + "'"); + return this.edgeVectorIndex.findNearest(key, vector, k, filter, excludeId); + } + ///////////// Utility methods /////////////// protected abstract void addOutEdge(final TinkerVertex vertex, final String label, final Edge edge); diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerVectorIndex.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerVectorIndex.java index 603e26f39a..deacc5bbf0 100644 --- a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerVectorIndex.java +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerVectorIndex.java @@ -20,7 +20,9 @@ package org.apache.tinkerpop.gremlin.tinkergraph.structure; import org.apache.tinkerpop.gremlin.structure.Element; +import java.util.Collections; import java.util.List; +import java.util.Map; /** * Base class for representing a vector index for performing nearest neighbor searches. @@ -52,7 +54,9 @@ abstract class AbstractTinkerVectorIndex<T extends Element> extends AbstractTink * @param k the number of nearest neighbors to return * @return a list of elements sorted by distance */ - public abstract List<TinkerIndexElement<T>> findNearest(final String key, final float[] vector, final int k); + public List<TinkerIndexElement<T>> findNearest(final String key, final float[] vector, final int k) { + return findNearest(key, vector, k, Collections.emptyMap(), null); + } /** * Searches for nearest neighbors in the vector index with the default k. @@ -62,9 +66,22 @@ abstract class AbstractTinkerVectorIndex<T extends Element> extends AbstractTink * @return a list of elements sorted by distance */ public List<TinkerIndexElement<T>> findNearest(final String key, final float[] vector) { - return findNearest(key, vector, DEFAULT_K); + return findNearest(key, vector, DEFAULT_K, Collections.emptyMap(), null); } + /** + * Searches for nearest neighbors in the vector index with optional property/label filtering. + * + * @param key the property key + * @param vector the query vector + * @param k the number of nearest neighbors to return + * @param filter map of property key to required value; use {@code "~label"} to filter by element label + * @param excludeId element id to exclude from results (used to exclude the source element in byElement searches) + * @return a list of elements sorted by distance + */ + public abstract List<TinkerIndexElement<T>> findNearest(final String key, final float[] vector, final int k, + final Map<String, Object> filter, final Object excludeId); + /** * Searches for nearest neighbors in the vector index. * diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionVectorIndex.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionVectorIndex.java index 5f984536c6..7bf527264f 100644 --- a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionVectorIndex.java +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionVectorIndex.java @@ -119,11 +119,12 @@ final class TinkerTransactionVectorIndex<T extends TinkerElement> extends Abstra } @Override - public List<TinkerIndexElement<T>> findNearest(final String key, final float[] vector, final int k) { + public List<TinkerIndexElement<T>> findNearest(final String key, final float[] vector, final int k, + final Map<String, Object> filter, final Object excludeId) { final IndexState<T> state = this.vectorIndices.get(key); if (state == null) throw new IllegalArgumentException("The key '" + key + "' is not indexed"); - return state.search(vector, k); + return state.search(vector, k, filter, excludeId); } @Override @@ -243,14 +244,16 @@ final class TinkerTransactionVectorIndex<T extends TinkerElement> extends Abstra builder.markNodeDeleted(ordinal); } - List<TinkerIndexElement<T>> search(final float[] queryVector, final int k) { + List<TinkerIndexElement<T>> search(final float[] queryVector, final int k, + final Map<String, Object> filter, final Object excludeId) { if (vectors.isEmpty()) return Collections.emptyList(); final VectorFloat<?> query = VTS.createFloatVector(queryVector); final ListRandomAccessVectorValues ravv = new ListRandomAccessVectorValues(vectors, dimension); final var ssp = SearchScoreProvider.exact(query, similarityFunction, ravv); + final io.github.jbellis.jvector.util.Bits bits = buildBits(filter, excludeId); try (final GraphSearcher searcher = new GraphSearcher(builder.getGraph())) { - final SearchResult result = searcher.search(ssp, k, io.github.jbellis.jvector.util.Bits.ALL); + final SearchResult result = searcher.search(ssp, k, bits); return Arrays.stream(result.getNodes()) .map(ns -> new TinkerIndexElement<>(elements.get(ns.node), 1.0f - ns.score)) .collect(Collectors.toList()); @@ -259,6 +262,32 @@ final class TinkerTransactionVectorIndex<T extends TinkerElement> extends Abstra } } + private io.github.jbellis.jvector.util.Bits buildBits(final Map<String, Object> filter, final Object excludeId) { + if ((filter == null || filter.isEmpty()) && excludeId == null) + return io.github.jbellis.jvector.util.Bits.ALL; + return new io.github.jbellis.jvector.util.Bits() { + @Override + public boolean get(final int ordinal) { + if (ordinal >= elements.size()) return false; + final T el = elements.get(ordinal); + if (el == null) return false; + if (excludeId != null && el.id().equals(excludeId)) return false; + if (filter != null) { + for (final Map.Entry<String, Object> entry : filter.entrySet()) { + if ("~label".equals(entry.getKey())) { + if (!el.label().equals(entry.getValue())) return false; + } else { + final Property<?> prop = el.property(entry.getKey()); + if (!prop.isPresent() || !prop.value().equals(entry.getValue())) return false; + } + } + } + return true; + } + + }; + } + void close() { try { builder.close(); diff --git a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerVectorIndex.java b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerVectorIndex.java index 0b6382ad7f..f27c630db3 100644 --- a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerVectorIndex.java +++ b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerVectorIndex.java @@ -111,11 +111,12 @@ final class TinkerVectorIndex<T extends Element> extends AbstractTinkerVectorInd } @Override - public List<TinkerIndexElement<T>> findNearest(final String key, final float[] vector, final int k) { + public List<TinkerIndexElement<T>> findNearest(final String key, final float[] vector, final int k, + final Map<String, Object> filter, final Object excludeId) { final IndexState<T> state = this.vectorIndices.get(key); if (state == null) throw new IllegalArgumentException("The key '" + key + "' is not indexed"); - return state.search(vector, k); + return state.search(vector, k, filter, excludeId); } @Override @@ -215,14 +216,16 @@ final class TinkerVectorIndex<T extends Element> extends AbstractTinkerVectorInd builder.markNodeDeleted(ordinal); } - List<TinkerIndexElement<T>> search(final float[] queryVector, final int k) { + List<TinkerIndexElement<T>> search(final float[] queryVector, final int k, + final Map<String, Object> filter, final Object excludeId) { if (vectors.isEmpty()) return Collections.emptyList(); final VectorFloat<?> query = VTS.createFloatVector(queryVector); final ListRandomAccessVectorValues ravv = new ListRandomAccessVectorValues(vectors, dimension); final var ssp = SearchScoreProvider.exact(query, similarityFunction, ravv); + final io.github.jbellis.jvector.util.Bits bits = buildBits(filter, excludeId); try (final GraphSearcher searcher = new GraphSearcher(builder.getGraph())) { - final SearchResult result = searcher.search(ssp, k, io.github.jbellis.jvector.util.Bits.ALL); + final SearchResult result = searcher.search(ssp, k, bits); return java.util.Arrays.stream(result.getNodes()) .map(ns -> new TinkerIndexElement<>(elements.get(ns.node), 1.0f - ns.score)) .collect(Collectors.toList()); @@ -231,6 +234,32 @@ final class TinkerVectorIndex<T extends Element> extends AbstractTinkerVectorInd } } + private io.github.jbellis.jvector.util.Bits buildBits(final Map<String, Object> filter, final Object excludeId) { + if ((filter == null || filter.isEmpty()) && excludeId == null) + return io.github.jbellis.jvector.util.Bits.ALL; + return new io.github.jbellis.jvector.util.Bits() { + @Override + public boolean get(final int ordinal) { + if (ordinal >= elements.size()) return false; + final T el = elements.get(ordinal); + if (el == null) return false; + if (excludeId != null && el.id().equals(excludeId)) return false; + if (filter != null) { + for (final Map.Entry<String, Object> entry : filter.entrySet()) { + if ("~label".equals(entry.getKey())) { + if (!el.label().equals(entry.getValue())) return false; + } else { + final Property<?> prop = el.property(entry.getKey()); + if (!prop.isPresent() || !prop.value().equals(entry.getValue())) return false; + } + } + } + return true; + } + + }; + } + void close() { try { builder.close(); diff --git a/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraphServiceTest.java b/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraphServiceTest.java index 1843cb0990..44e7a8427c 100644 --- a/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraphServiceTest.java +++ b/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraphServiceTest.java @@ -881,6 +881,76 @@ public class TinkerGraphServiceTest { gv.V(vAlice).outE().inV().path().call(TinkerVectorDistanceFactory.NAME, params).next(); } + @Test + public void g_V_callXvector_topK_byElement_filterByLabelX() { + final TinkerGraph graf = TinkerGraph.open(); + graf.getServiceRegistry().registerService(new TinkerVectorSearchByElementFactory(graf)); + graf.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, indexConfig); + final GraphTraversalSource gv = graf.traversal(); + + final Vertex vAlice = gv.addV("person").property("name", "Alice").property("embedding", new float[]{1.0f, 0.0f, 0.0f}).next(); + gv.addV("person").property("name", "Bob").property("embedding", new float[]{0.9f, 0.1f, 0.0f}).iterate(); + gv.addV("robot").property("name", "R2D2").property("embedding", new float[]{0.95f, 0.05f, 0.0f}).iterate(); + + final Map<String, Object> params = new HashMap<String, Object>() {{ + put("key", "embedding"); + put("topK", 2); + put("filter", new HashMap<String, Object>() {{ put("~label", "robot"); }}); + }}; + final List<Object> list = gv.V(vAlice).call(TinkerVectorSearchByElementFactory.NAME, params).toList(); + + // only the robot should be returned even though Bob is also close + assertEquals(1, list.size()); + assertEquals("R2D2", ((Vertex) ((Map) list.get(0)).get("element")).value("name")); + } + + @Test + public void g_V_callXvector_topK_byElement_filterByPropertyX() { + final TinkerGraph graf = TinkerGraph.open(); + graf.getServiceRegistry().registerService(new TinkerVectorSearchByElementFactory(graf)); + graf.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, indexConfig); + final GraphTraversalSource gv = graf.traversal(); + + final Vertex vAlice = gv.addV("person").property("name", "Alice").property("embedding", new float[]{1.0f, 0.0f, 0.0f}).property("active", true).next(); + gv.addV("person").property("name", "Bob").property("embedding", new float[]{0.9f, 0.1f, 0.0f}).property("active", false).iterate(); + gv.addV("person").property("name", "Charlie").property("embedding", new float[]{0.8f, 0.2f, 0.0f}).property("active", true).iterate(); + + final Map<String, Object> params = new HashMap<String, Object>() {{ + put("key", "embedding"); + put("topK", 2); + put("filter", new HashMap<String, Object>() {{ put("active", true); }}); + }}; + final List<Object> list = gv.V(vAlice).call(TinkerVectorSearchByElementFactory.NAME, params).toList(); + + // only active vertices returned (Bob excluded despite being close), self excluded + assertEquals(1, list.size()); + assertEquals("Charlie", ((Vertex) ((Map) list.get(0)).get("element")).value("name")); + } + + @Test + public void g_inject_callXvector_topK_byEmbedding_filterByLabelX() { + final TinkerGraph graf = TinkerGraph.open(); + graf.getServiceRegistry().registerService(new TinkerVectorSearchByEmbeddingFactory(graf)); + graf.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, indexConfig); + final GraphTraversalSource gv = graf.traversal(); + + gv.addV("person").property("name", "Alice").property("embedding", new float[]{1.0f, 0.0f, 0.0f}).iterate(); + gv.addV("robot").property("name", "R2D2").property("embedding", new float[]{0.95f, 0.05f, 0.0f}).iterate(); + + final Float[] query = new Float[]{1.0f, 0.0f, 0.0f}; + final Map<String, Object> params = new HashMap<String, Object>() {{ + put("key", "embedding"); + put("element", "vertex"); + put("topK", 2); + put("filter", new HashMap<String, Object>() {{ put("~label", "robot"); }}); + }}; + final List<Object> list = gv.inject(0).constant(query) + .call(TinkerVectorSearchByEmbeddingFactory.NAME, params).toList(); + + assertEquals(1, list.size()); + assertEquals("R2D2", ((Vertex) ((Map) list.get(0)).get("element")).value("name")); + } + private String toResultString(final Traversal traversal) { return (String) IteratorUtils.stream(traversal).map(Object::toString).collect(Collectors.joining(",", "[", "]")); }
