This is an automated email from the ASF dual-hosted git repository.

spmallette pushed a commit to branch TINKERPOP-3158
in repository https://gitbox.apache.org/repos/asf/tinkerpop.git

commit 8c9ce0e5f0a7eab5ddbf6c9493c02b9fef98f071
Author: Stephen Mallette <[email protected]>
AuthorDate: Tue Jun 2 19:18:48 2026 -0400

    Add filter support to TinkerGraph vector topK search services
    
    Both topK.byElement and topK.byEmbedding now accept an optional 'filter'
    parameter (Map<String, Object>) that restricts which candidates are
    considered during ANN computation via JVector's Bits mask. Filtering
    happens inside the graph traversal rather than as a post-filter, so the
    full topK results are returned from the matching subset.
    
    The filter supports equality predicates on element properties and the
    special key "~label" for label matching. For byElement, self-exclusion
    of the source element is also applied via the Bits mask, replacing the
    previous +1 k hack with post-stream filtering.
    
    (tinkerpop-djo)
    Assisted-by: Claude Code:claude-sonnet-4-6
---
 .../reference/implementations-tinkergraph.asciidoc | 32 ++++++++++
 .../TinkerVectorSearchByElementFactory.java        | 21 ++++---
 .../TinkerVectorSearchByEmbeddingFactory.java      | 14 ++++-
 .../tinkergraph/structure/AbstractTinkerGraph.java | 36 +++++++++++
 .../structure/AbstractTinkerVectorIndex.java       | 21 ++++++-
 .../structure/TinkerTransactionVectorIndex.java    | 37 ++++++++++--
 .../tinkergraph/structure/TinkerVectorIndex.java   | 37 ++++++++++--
 .../structure/TinkerGraphServiceTest.java          | 70 ++++++++++++++++++++++
 8 files changed, 246 insertions(+), 22 deletions(-)

diff --git a/docs/src/reference/implementations-tinkergraph.asciidoc 
b/docs/src/reference/implementations-tinkergraph.asciidoc
index 84a63e35dd..c395d7e410 100644
--- a/docs/src/reference/implementations-tinkergraph.asciidoc
+++ b/docs/src/reference/implementations-tinkergraph.asciidoc
@@ -327,6 +327,38 @@ The `call()` step returns a list of maps, each containing:
 
 TIP: Vector indices can also be created for edges.
 
+==== Filtering Vector Search
+
+Both `topK.byElement` and `topK.byEmbedding` accept an optional `filter` 
parameter that restricts
+which candidates are considered during the ANN computation. Filtering happens 
inside the index
+traversal rather than as a post-filter, so the full `topK` results are 
returned from the matching
+subset.
+
+The filter is a map of property key to required value. The special key 
`"~label"` matches element
+label:
+
+[gremlin-groovy]
+----
+graph.getServiceRegistry().registerService(new 
TinkerVectorSearchByElementFactory(graph))
+graph.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, 
[dimension: 3])
+g.addV("person").property("name", "Alice").property("embedding", new 
float[]{1.0f, 0.0f, 0.0f}).iterate()
+g.addV("person").property("name", "Bob").property("embedding", new 
float[]{0.9f, 0.1f, 0.0f}).iterate()
+g.addV("robot").property("name", "R2D2").property("embedding", new 
float[]{0.95f, 0.05f, 0.0f}).iterate()
+
+// only robots, even though Bob (a person) is also nearby
+g.V().has("name","Alice").call("tinker.search.vector.topK.byElement",
+  [key: "embedding", topK: 2, filter: ["~label": "robot"]])
+----
+
+Property value equality can also be used:
+
+[gremlin-groovy]
+----
+// only active vertices
+g.V().has("name","Alice").call("tinker.search.vector.topK.byElement",
+  [key: "embedding", topK: 5, filter: [active: true]])
+----
+
 TinkerGraph supports the following distance functions for vector similarity 
search:
 
 * `COSINE`: Measures the cosine of the angle between two vectors (default)
diff --git 
a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchByElementFactory.java
 
b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchByElementFactory.java
index 9dffdeb6ac..6967b4b330 100644
--- 
a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchByElementFactory.java
+++ 
b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchByElementFactory.java
@@ -31,6 +31,7 @@ import java.util.Collections;
 import java.util.Map;
 import java.util.Set;
 
+import static 
org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorSearchByElementFactory.Params.FILTER;
 import static 
org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorSearchByElementFactory.Params.KEY;
 import static 
org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorSearchByElementFactory.Params.TOP_K;
 import static org.apache.tinkerpop.gremlin.util.CollectionUtil.asMap;
@@ -52,10 +53,16 @@ public class TinkerVectorSearchByElementFactory extends 
TinkerServiceRegistry.Ti
          * Number of results to return
          */
         String TOP_K = "topK";
+        /**
+         * Optional map of property key to required value restricting which 
candidates are searched.
+         * Use the special key {@code "~label"} to filter by element label.
+         */
+        String FILTER = "filter";
 
         Map DESCRIBE = asMap(
                 KEY, "Specify they key storing the embedding for the vector 
search",
-                TOP_K, "Number of results to return (optional, defaults to 10)"
+                TOP_K, "Number of results to return (optional, defaults to 
10)",
+                FILTER, "Map of property key to required value to restrict 
candidates (optional); use \"~label\" to filter by label"
         );
     }
 
@@ -94,23 +101,19 @@ public class TinkerVectorSearchByElementFactory extends 
TinkerServiceRegistry.Ti
     @Override
     public CloseableIterator<Map<String,Object>> execute(final 
ServiceCallContext ctx, final Traverser.Admin<Element> in, final Map params) {
         final String key = (String) params.get(KEY);
-
-        // add 1 because we always filter 1 out of the index because it will 
return a match on itself
-        final int k = (int) params.getOrDefault(TOP_K, 10) + 1;
+        final int k = (int) params.getOrDefault(TOP_K, 10);
+        final Map<String, Object> filter = (Map<String, Object>) 
params.get(FILTER);
         final Element e = in.get();
 
-        // if the current element does not have the specified key, then return 
no results
         if (!e.keys().contains(key))
             return CloseableIterator.empty();
 
         final float[] embedding = e.value(key);
         if (e instanceof Vertex) {
-            return CloseableIterator.of(graph.findNearestVertices(key, 
embedding, k).stream()
-                    .filter(tie -> !tie.getElement().equals(e))
+            return CloseableIterator.of(graph.findNearestVertices(key, 
embedding, k, filter, e.id()).stream()
                     .map(TinkerIndexElement::toMap).iterator());
         } else if (e instanceof Edge) {
-            return CloseableIterator.of(graph.findNearestEdges(key, embedding, 
k).stream()
-                    .filter(tie -> !tie.getElement().equals(e))
+            return CloseableIterator.of(graph.findNearestEdges(key, embedding, 
k, filter, e.id()).stream()
                     .map(TinkerIndexElement::toMap).iterator());
         } else {
             return CloseableIterator.empty();
diff --git 
a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchByEmbeddingFactory.java
 
b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchByEmbeddingFactory.java
index 22498b76db..2a5099ea0d 100644
--- 
a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchByEmbeddingFactory.java
+++ 
b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/services/TinkerVectorSearchByEmbeddingFactory.java
@@ -33,6 +33,7 @@ import java.util.Set;
 import static 
org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorSearchByElementFactory.Params.KEY;
 import static 
org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorSearchByElementFactory.Params.TOP_K;
 import static 
org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorSearchByEmbeddingFactory.Params.ELEMENT;
+import static 
org.apache.tinkerpop.gremlin.tinkergraph.services.TinkerVectorSearchByEmbeddingFactory.Params.FILTER;
 import static org.apache.tinkerpop.gremlin.util.CollectionUtil.asMap;
 
 /**
@@ -55,11 +56,17 @@ public class TinkerVectorSearchByEmbeddingFactory extends 
TinkerServiceRegistry.
          * Specify whether the search should be for a "vertex" or "edge"
          */
         String ELEMENT = "element";
+        /**
+         * Optional map of property key to required value restricting which 
candidates are searched.
+         * Use the special key {@code "~label"} to filter by element label.
+         */
+        String FILTER = "filter";
 
         Map DESCRIBE = asMap(
                 KEY, "Specify they key storing the embedding for the vector 
search",
                 TOP_K, "Number of results to return (optional, defaults to 
10)",
-                ELEMENT, "Specify whether the search should be for a 
\"vertex\" or \"edge\""
+                ELEMENT, "Specify whether the search should be for a 
\"vertex\" or \"edge\"",
+                FILTER, "Map of property key to required value to restrict 
candidates (optional); use \"~label\" to filter by label"
         );
     }
 
@@ -103,6 +110,7 @@ public class TinkerVectorSearchByEmbeddingFactory extends 
TinkerServiceRegistry.
     public CloseableIterator<Map<String,Object>> execute(final 
ServiceCallContext ctx, final Traverser.Admin<Float[]> in, final Map params) {
         final String key = (String) params.get(KEY);
         final int k = (int) params.getOrDefault(TOP_K, 10);
+        final Map<String, Object> filter = (Map<String, Object>) 
params.get(FILTER);
         final Object traverserObject = in.get();
 
         final float[] embedding = traverserObject instanceof Float[] ?
@@ -110,10 +118,10 @@ public class TinkerVectorSearchByEmbeddingFactory extends 
TinkerServiceRegistry.
 
         final String elementType = (String) params.get(ELEMENT);
         if ("vertex".equals(elementType)) {
-            return CloseableIterator.of(graph.findNearestVertices(key, 
embedding, k).stream()
+            return CloseableIterator.of(graph.findNearestVertices(key, 
embedding, k, filter, null).stream()
                     .map(TinkerIndexElement::toMap).iterator());
         } else if ("edge".equals(elementType)) {
-            return CloseableIterator.of(graph.findNearestEdges(key, embedding, 
k).stream()
+            return CloseableIterator.of(graph.findNearestEdges(key, embedding, 
k, filter, null).stream()
                     .map(TinkerIndexElement::toMap).iterator());
         } else {
             return CloseableIterator.empty();
diff --git 
a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerGraph.java
 
b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerGraph.java
index 763e837e41..e3705ce84f 100644
--- 
a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerGraph.java
+++ 
b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerGraph.java
@@ -406,6 +406,42 @@ public abstract class AbstractTinkerGraph implements Graph 
{
         return new ArrayList<>(this.edgeVectorIndex.findNearestElements(key, 
vector));
     }
 
+    /**
+     * Find the nearest vertices to the given vector, restricting candidates 
to those matching the filter.
+     *
+     * @param key       the property key
+     * @param vector    the query vector
+     * @param k         the number of nearest neighbors to return
+     * @param filter    map of property key to required value; {@code 
"~label"} matches element label
+     * @param excludeId element id to exclude from results
+     * @return a list of vertices sorted by distance
+     */
+    public List<TinkerIndexElement<TinkerVertex>> findNearestVertices(final 
String key, final float[] vector,
+                                                                       final 
int k, final Map<String, Object> filter,
+                                                                       final 
Object excludeId) {
+        if (null == this.vertexVectorIndex)
+            throw new IllegalStateException("Vector index not created for 
vertices on key: '" + key + "'");
+        return this.vertexVectorIndex.findNearest(key, vector, k, filter, 
excludeId);
+    }
+
+    /**
+     * Find the nearest edges to the given vector, restricting candidates to 
those matching the filter.
+     *
+     * @param key       the property key
+     * @param vector    the query vector
+     * @param k         the number of nearest neighbors to return
+     * @param filter    map of property key to required value; {@code 
"~label"} matches element label
+     * @param excludeId element id to exclude from results
+     * @return a list of edges sorted by distance
+     */
+    public List<TinkerIndexElement<TinkerEdge>> findNearestEdges(final String 
key, final float[] vector,
+                                                                  final int k, 
final Map<String, Object> filter,
+                                                                  final Object 
excludeId) {
+        if (null == this.edgeVectorIndex)
+            throw new IllegalStateException("Vector index not created for 
edges on key: '" + key + "'");
+        return this.edgeVectorIndex.findNearest(key, vector, k, filter, 
excludeId);
+    }
+
     ///////////// Utility methods ///////////////
     protected abstract void addOutEdge(final TinkerVertex vertex, final String 
label, final Edge edge);
 
diff --git 
a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerVectorIndex.java
 
b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerVectorIndex.java
index 603e26f39a..deacc5bbf0 100644
--- 
a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerVectorIndex.java
+++ 
b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/AbstractTinkerVectorIndex.java
@@ -20,7 +20,9 @@ package org.apache.tinkerpop.gremlin.tinkergraph.structure;
 
 import org.apache.tinkerpop.gremlin.structure.Element;
 
+import java.util.Collections;
 import java.util.List;
+import java.util.Map;
 
 /**
  * Base class for representing a vector index for performing nearest neighbor 
searches.
@@ -52,7 +54,9 @@ abstract class AbstractTinkerVectorIndex<T extends Element> 
extends AbstractTink
      * @param k      the number of nearest neighbors to return
      * @return a list of elements sorted by distance
      */
-    public abstract List<TinkerIndexElement<T>> findNearest(final String key, 
final float[] vector, final int k);
+    public List<TinkerIndexElement<T>> findNearest(final String key, final 
float[] vector, final int k) {
+        return findNearest(key, vector, k, Collections.emptyMap(), null);
+    }
 
     /**
      * Searches for nearest neighbors in the vector index with the default k.
@@ -62,9 +66,22 @@ abstract class AbstractTinkerVectorIndex<T extends Element> 
extends AbstractTink
      * @return a list of elements sorted by distance
      */
     public List<TinkerIndexElement<T>> findNearest(final String key, final 
float[] vector) {
-        return findNearest(key, vector, DEFAULT_K);
+        return findNearest(key, vector, DEFAULT_K, Collections.emptyMap(), 
null);
     }
 
+    /**
+     * Searches for nearest neighbors in the vector index with optional 
property/label filtering.
+     *
+     * @param key       the property key
+     * @param vector    the query vector
+     * @param k         the number of nearest neighbors to return
+     * @param filter    map of property key to required value; use {@code 
"~label"} to filter by element label
+     * @param excludeId element id to exclude from results (used to exclude 
the source element in byElement searches)
+     * @return a list of elements sorted by distance
+     */
+    public abstract List<TinkerIndexElement<T>> findNearest(final String key, 
final float[] vector, final int k,
+                                                            final Map<String, 
Object> filter, final Object excludeId);
+
     /**
      * Searches for nearest neighbors in the vector index.
      *
diff --git 
a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionVectorIndex.java
 
b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionVectorIndex.java
index 5f984536c6..7bf527264f 100644
--- 
a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionVectorIndex.java
+++ 
b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerTransactionVectorIndex.java
@@ -119,11 +119,12 @@ final class TinkerTransactionVectorIndex<T extends 
TinkerElement> extends Abstra
     }
 
     @Override
-    public List<TinkerIndexElement<T>> findNearest(final String key, final 
float[] vector, final int k) {
+    public List<TinkerIndexElement<T>> findNearest(final String key, final 
float[] vector, final int k,
+                                                   final Map<String, Object> 
filter, final Object excludeId) {
         final IndexState<T> state = this.vectorIndices.get(key);
         if (state == null)
             throw new IllegalArgumentException("The key '" + key + "' is not 
indexed");
-        return state.search(vector, k);
+        return state.search(vector, k, filter, excludeId);
     }
 
     @Override
@@ -243,14 +244,16 @@ final class TinkerTransactionVectorIndex<T extends 
TinkerElement> extends Abstra
                 builder.markNodeDeleted(ordinal);
         }
 
-        List<TinkerIndexElement<T>> search(final float[] queryVector, final 
int k) {
+        List<TinkerIndexElement<T>> search(final float[] queryVector, final 
int k,
+                                           final Map<String, Object> filter, 
final Object excludeId) {
             if (vectors.isEmpty())
                 return Collections.emptyList();
             final VectorFloat<?> query = VTS.createFloatVector(queryVector);
             final ListRandomAccessVectorValues ravv = new 
ListRandomAccessVectorValues(vectors, dimension);
             final var ssp = SearchScoreProvider.exact(query, 
similarityFunction, ravv);
+            final io.github.jbellis.jvector.util.Bits bits = buildBits(filter, 
excludeId);
             try (final GraphSearcher searcher = new 
GraphSearcher(builder.getGraph())) {
-                final SearchResult result = searcher.search(ssp, k, 
io.github.jbellis.jvector.util.Bits.ALL);
+                final SearchResult result = searcher.search(ssp, k, bits);
                 return Arrays.stream(result.getNodes())
                         .map(ns -> new 
TinkerIndexElement<>(elements.get(ns.node), 1.0f - ns.score))
                         .collect(Collectors.toList());
@@ -259,6 +262,32 @@ final class TinkerTransactionVectorIndex<T extends 
TinkerElement> extends Abstra
             }
         }
 
+        private io.github.jbellis.jvector.util.Bits buildBits(final 
Map<String, Object> filter, final Object excludeId) {
+            if ((filter == null || filter.isEmpty()) && excludeId == null)
+                return io.github.jbellis.jvector.util.Bits.ALL;
+            return new io.github.jbellis.jvector.util.Bits() {
+                @Override
+                public boolean get(final int ordinal) {
+                    if (ordinal >= elements.size()) return false;
+                    final T el = elements.get(ordinal);
+                    if (el == null) return false;
+                    if (excludeId != null && el.id().equals(excludeId)) return 
false;
+                    if (filter != null) {
+                        for (final Map.Entry<String, Object> entry : 
filter.entrySet()) {
+                            if ("~label".equals(entry.getKey())) {
+                                if (!el.label().equals(entry.getValue())) 
return false;
+                            } else {
+                                final Property<?> prop = 
el.property(entry.getKey());
+                                if (!prop.isPresent() || 
!prop.value().equals(entry.getValue())) return false;
+                            }
+                        }
+                    }
+                    return true;
+                }
+
+            };
+        }
+
         void close() {
             try {
                 builder.close();
diff --git 
a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerVectorIndex.java
 
b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerVectorIndex.java
index 0b6382ad7f..f27c630db3 100644
--- 
a/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerVectorIndex.java
+++ 
b/tinkergraph-gremlin/src/main/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerVectorIndex.java
@@ -111,11 +111,12 @@ final class TinkerVectorIndex<T extends Element> extends 
AbstractTinkerVectorInd
     }
 
     @Override
-    public List<TinkerIndexElement<T>> findNearest(final String key, final 
float[] vector, final int k) {
+    public List<TinkerIndexElement<T>> findNearest(final String key, final 
float[] vector, final int k,
+                                                   final Map<String, Object> 
filter, final Object excludeId) {
         final IndexState<T> state = this.vectorIndices.get(key);
         if (state == null)
             throw new IllegalArgumentException("The key '" + key + "' is not 
indexed");
-        return state.search(vector, k);
+        return state.search(vector, k, filter, excludeId);
     }
 
     @Override
@@ -215,14 +216,16 @@ final class TinkerVectorIndex<T extends Element> extends 
AbstractTinkerVectorInd
                 builder.markNodeDeleted(ordinal);
         }
 
-        List<TinkerIndexElement<T>> search(final float[] queryVector, final 
int k) {
+        List<TinkerIndexElement<T>> search(final float[] queryVector, final 
int k,
+                                           final Map<String, Object> filter, 
final Object excludeId) {
             if (vectors.isEmpty())
                 return Collections.emptyList();
             final VectorFloat<?> query = VTS.createFloatVector(queryVector);
             final ListRandomAccessVectorValues ravv = new 
ListRandomAccessVectorValues(vectors, dimension);
             final var ssp = SearchScoreProvider.exact(query, 
similarityFunction, ravv);
+            final io.github.jbellis.jvector.util.Bits bits = buildBits(filter, 
excludeId);
             try (final GraphSearcher searcher = new 
GraphSearcher(builder.getGraph())) {
-                final SearchResult result = searcher.search(ssp, k, 
io.github.jbellis.jvector.util.Bits.ALL);
+                final SearchResult result = searcher.search(ssp, k, bits);
                 return java.util.Arrays.stream(result.getNodes())
                         .map(ns -> new 
TinkerIndexElement<>(elements.get(ns.node), 1.0f - ns.score))
                         .collect(Collectors.toList());
@@ -231,6 +234,32 @@ final class TinkerVectorIndex<T extends Element> extends 
AbstractTinkerVectorInd
             }
         }
 
+        private io.github.jbellis.jvector.util.Bits buildBits(final 
Map<String, Object> filter, final Object excludeId) {
+            if ((filter == null || filter.isEmpty()) && excludeId == null)
+                return io.github.jbellis.jvector.util.Bits.ALL;
+            return new io.github.jbellis.jvector.util.Bits() {
+                @Override
+                public boolean get(final int ordinal) {
+                    if (ordinal >= elements.size()) return false;
+                    final T el = elements.get(ordinal);
+                    if (el == null) return false;
+                    if (excludeId != null && el.id().equals(excludeId)) return 
false;
+                    if (filter != null) {
+                        for (final Map.Entry<String, Object> entry : 
filter.entrySet()) {
+                            if ("~label".equals(entry.getKey())) {
+                                if (!el.label().equals(entry.getValue())) 
return false;
+                            } else {
+                                final Property<?> prop = 
el.property(entry.getKey());
+                                if (!prop.isPresent() || 
!prop.value().equals(entry.getValue())) return false;
+                            }
+                        }
+                    }
+                    return true;
+                }
+
+            };
+        }
+
         void close() {
             try {
                 builder.close();
diff --git 
a/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraphServiceTest.java
 
b/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraphServiceTest.java
index 1843cb0990..44e7a8427c 100644
--- 
a/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraphServiceTest.java
+++ 
b/tinkergraph-gremlin/src/test/java/org/apache/tinkerpop/gremlin/tinkergraph/structure/TinkerGraphServiceTest.java
@@ -881,6 +881,76 @@ public class TinkerGraphServiceTest {
         
gv.V(vAlice).outE().inV().path().call(TinkerVectorDistanceFactory.NAME, 
params).next();
     }
 
+    @Test
+    public void g_V_callXvector_topK_byElement_filterByLabelX() {
+        final TinkerGraph graf = TinkerGraph.open();
+        graf.getServiceRegistry().registerService(new 
TinkerVectorSearchByElementFactory(graf));
+        graf.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, 
indexConfig);
+        final GraphTraversalSource gv = graf.traversal();
+
+        final Vertex vAlice = gv.addV("person").property("name", 
"Alice").property("embedding", new float[]{1.0f, 0.0f, 0.0f}).next();
+        gv.addV("person").property("name", "Bob").property("embedding", new 
float[]{0.9f, 0.1f, 0.0f}).iterate();
+        gv.addV("robot").property("name", "R2D2").property("embedding", new 
float[]{0.95f, 0.05f, 0.0f}).iterate();
+
+        final Map<String, Object> params = new HashMap<String, Object>() {{
+            put("key", "embedding");
+            put("topK", 2);
+            put("filter", new HashMap<String, Object>() {{ put("~label", 
"robot"); }});
+        }};
+        final List<Object> list = 
gv.V(vAlice).call(TinkerVectorSearchByElementFactory.NAME, params).toList();
+
+        // only the robot should be returned even though Bob is also close
+        assertEquals(1, list.size());
+        assertEquals("R2D2", ((Vertex) ((Map) 
list.get(0)).get("element")).value("name"));
+    }
+
+    @Test
+    public void g_V_callXvector_topK_byElement_filterByPropertyX() {
+        final TinkerGraph graf = TinkerGraph.open();
+        graf.getServiceRegistry().registerService(new 
TinkerVectorSearchByElementFactory(graf));
+        graf.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, 
indexConfig);
+        final GraphTraversalSource gv = graf.traversal();
+
+        final Vertex vAlice = gv.addV("person").property("name", 
"Alice").property("embedding", new float[]{1.0f, 0.0f, 
0.0f}).property("active", true).next();
+        gv.addV("person").property("name", "Bob").property("embedding", new 
float[]{0.9f, 0.1f, 0.0f}).property("active", false).iterate();
+        gv.addV("person").property("name", "Charlie").property("embedding", 
new float[]{0.8f, 0.2f, 0.0f}).property("active", true).iterate();
+
+        final Map<String, Object> params = new HashMap<String, Object>() {{
+            put("key", "embedding");
+            put("topK", 2);
+            put("filter", new HashMap<String, Object>() {{ put("active", 
true); }});
+        }};
+        final List<Object> list = 
gv.V(vAlice).call(TinkerVectorSearchByElementFactory.NAME, params).toList();
+
+        // only active vertices returned (Bob excluded despite being close), 
self excluded
+        assertEquals(1, list.size());
+        assertEquals("Charlie", ((Vertex) ((Map) 
list.get(0)).get("element")).value("name"));
+    }
+
+    @Test
+    public void g_inject_callXvector_topK_byEmbedding_filterByLabelX() {
+        final TinkerGraph graf = TinkerGraph.open();
+        graf.getServiceRegistry().registerService(new 
TinkerVectorSearchByEmbeddingFactory(graf));
+        graf.createIndex(TinkerIndexType.VECTOR, "embedding", Vertex.class, 
indexConfig);
+        final GraphTraversalSource gv = graf.traversal();
+
+        gv.addV("person").property("name", "Alice").property("embedding", new 
float[]{1.0f, 0.0f, 0.0f}).iterate();
+        gv.addV("robot").property("name", "R2D2").property("embedding", new 
float[]{0.95f, 0.05f, 0.0f}).iterate();
+
+        final Float[] query = new Float[]{1.0f, 0.0f, 0.0f};
+        final Map<String, Object> params = new HashMap<String, Object>() {{
+            put("key", "embedding");
+            put("element", "vertex");
+            put("topK", 2);
+            put("filter", new HashMap<String, Object>() {{ put("~label", 
"robot"); }});
+        }};
+        final List<Object> list = gv.inject(0).constant(query)
+                .call(TinkerVectorSearchByEmbeddingFactory.NAME, 
params).toList();
+
+        assertEquals(1, list.size());
+        assertEquals("R2D2", ((Vertex) ((Map) 
list.get(0)).get("element")).value("name"));
+    }
+
     private String toResultString(final Traversal traversal) {
         return (String) 
IteratorUtils.stream(traversal).map(Object::toString).collect(Collectors.joining(",",
 "[", "]"));
     }

Reply via email to