This is an automated email from the ASF dual-hosted git repository.

twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new edcb7067 feat(search): Hnsw Vector Search Optimizaton Pass (#2466)
edcb7067 is described below

commit edcb7067453324f0bff1a4cca2d186730b31898e
Author: Rebecca Zhou <[email protected]>
AuthorDate: Wed Aug 7 06:41:11 2024 -0700

    feat(search): Hnsw Vector Search Optimizaton Pass (#2466)
    
    Co-authored-by: Twice <[email protected]>
---
 src/search/executors/filter_executor.h   | 22 ++++++++++++++
 src/search/ir.h                          | 19 ++++++------
 src/search/ir_pass.h                     | 29 ++++++++++++++++++
 src/search/ir_plan.h                     |  2 +-
 src/search/ir_sema_checker.h             | 13 +++++----
 src/search/passes/cost_model.h           | 10 +++++++
 src/search/passes/index_selection.h      | 23 +++++++++++++++
 src/search/passes/manager.h              |  4 ++-
 src/search/passes/sort_limit_to_knn.h    | 50 ++++++++++++++++++++++++++++++++
 src/search/redis_query_parser.h          |  5 ++--
 src/search/redis_query_transformer.h     | 16 ++++++----
 src/search/sql_transformer.h             |  1 -
 tests/cppunit/ir_pass_test.cc            | 47 +++++++++++++++++++++++++++++-
 tests/cppunit/ir_sema_checker_test.cc    |  7 +++++
 tests/cppunit/plan_executor_test.cc      | 36 +++++++++++++++++++++++
 tests/cppunit/redis_query_parser_test.cc | 14 ++++++---
 16 files changed, 267 insertions(+), 31 deletions(-)

diff --git a/src/search/executors/filter_executor.h 
b/src/search/executors/filter_executor.h
index df14b29b..1b1febe8 100644
--- a/src/search/executors/filter_executor.h
+++ b/src/search/executors/filter_executor.h
@@ -23,6 +23,7 @@
 #include <variant>
 
 #include "parse_util.h"
+#include "search/hnsw_indexer.h"
 #include "search/ir.h"
 #include "search/plan_executor.h"
 #include "search/search_encoding.h"
@@ -44,6 +45,9 @@ struct QueryExprEvaluator {
     if (auto v = dynamic_cast<NotExpr *>(e)) {
       return Visit(v);
     }
+    if (auto v = dynamic_cast<VectorRangeExpr *>(e)) {
+      return Visit(v);
+    }
     if (auto v = dynamic_cast<NumericCompareExpr *>(e)) {
       return Visit(v);
     }
@@ -112,6 +116,24 @@ struct QueryExprEvaluator {
         __builtin_unreachable();
     }
   }
+
+  StatusOr<bool> Visit(VectorRangeExpr *v) const {
+    auto val = GET_OR_RET(ctx->Retrieve(row, v->field->info));
+
+    CHECK(val.Is<kqir::NumericArray>());
+    auto l_values = val.Get<kqir::NumericArray>();
+    auto r_values = v->vector->values;
+    auto meta = v->field->info->MetadataAs<redis::HnswVectorFieldMetadata>();
+
+    redis::VectorItem left, right;
+    GET_OR_RET(redis::VectorItem::Create({}, l_values, meta, &left));
+    GET_OR_RET(redis::VectorItem::Create({}, r_values, meta, &right));
+
+    auto dist = GET_OR_RET(redis::ComputeSimilarity(left, right));
+    auto effective_range = v->range->val * (1 + meta->epsilon);
+
+    return (dist >= -abs(effective_range) && dist <= abs(effective_range));
+  }
 };
 
 struct FilterExecutor : ExecutorNode {
diff --git a/src/search/ir.h b/src/search/ir.h
index 3ba980da..c7aec26b 100644
--- a/src/search/ir.h
+++ b/src/search/ir.h
@@ -265,24 +265,19 @@ struct VectorRangeExpr : BoolAtomExpr {
 };
 
 struct VectorKnnExpr : BoolAtomExpr {
-  // TODO: Support pre-filter for hybrid query
   std::unique_ptr<FieldRef> field;
-  std::unique_ptr<NumericLiteral> k;
   std::unique_ptr<VectorLiteral> vector;
+  size_t k;
 
-  VectorKnnExpr(std::unique_ptr<FieldRef> &&field, 
std::unique_ptr<NumericLiteral> &&k,
-                std::unique_ptr<VectorLiteral> &&vector)
-      : field(std::move(field)), k(std::move(k)), vector(std::move(vector)) {}
+  VectorKnnExpr(std::unique_ptr<FieldRef> &&field, 
std::unique_ptr<VectorLiteral> &&vector, size_t k)
+      : field(std::move(field)), vector(std::move(vector)), k(k) {}
 
   std::string_view Name() const override { return "VectorKnnExpr"; }
-  std::string Dump() const override {
-    return fmt::format("KNN k={}, {} <-> {}", k->Dump(), field->Dump(), 
vector->Dump());
-  }
+  std::string Dump() const override { return fmt::format("KNN k={}, {} <-> 
{}", k, field->Dump(), vector->Dump()); }
 
   std::unique_ptr<Node> Clone() const override {
     return 
std::make_unique<VectorKnnExpr>(Node::MustAs<FieldRef>(field->Clone()),
-                                           
Node::MustAs<NumericLiteral>(k->Clone()),
-                                           
Node::MustAs<VectorLiteral>(vector->Clone()));
+                                           
Node::MustAs<VectorLiteral>(vector->Clone()), k);
   }
 };
 
@@ -425,6 +420,10 @@ struct SortByClause : Node {
   std::unique_ptr<Node> Clone() const override {
     return std::make_unique<SortByClause>(order, 
Node::MustAs<FieldRef>(field->Clone()));
   }
+
+  std::unique_ptr<FieldRef> TakeFieldRef() { return std::move(field); }
+
+  std::unique_ptr<VectorLiteral> TakeVectorLiteral() { return 
std::move(vector); }
 };
 
 struct SelectClause : Node {
diff --git a/src/search/ir_pass.h b/src/search/ir_pass.h
index 2068a45a..e783ca8f 100644
--- a/src/search/ir_pass.h
+++ b/src/search/ir_pass.h
@@ -59,6 +59,12 @@ struct Visitor : Pass {
       return Visit(std::move(v));
     } else if (auto v = Node::As<TagContainExpr>(std::move(node))) {
       return Visit(std::move(v));
+    } else if (auto v = Node::As<VectorLiteral>(std::move(node))) {
+      return Visit(std::move(v));
+    } else if (auto v = Node::As<VectorKnnExpr>(std::move(node))) {
+      return Visit(std::move(v));
+    } else if (auto v = Node::As<VectorRangeExpr>(std::move(node))) {
+      return Visit(std::move(v));
     } else if (auto v = Node::As<StringLiteral>(std::move(node))) {
       return Visit(std::move(v));
     } else if (auto v = Node::As<BoolLiteral>(std::move(node))) {
@@ -69,6 +75,10 @@ struct Visitor : Pass {
       return Visit(std::move(v));
     } else if (auto v = Node::As<TagFieldScan>(std::move(node))) {
       return Visit(std::move(v));
+    } else if (auto v = Node::As<HnswVectorFieldRangeScan>(std::move(node))) {
+      return Visit(std::move(v));
+    } else if (auto v = Node::As<HnswVectorFieldKnnScan>(std::move(node))) {
+      return Visit(std::move(v));
     } else if (auto v = Node::As<Filter>(std::move(node))) {
       return Visit(std::move(v));
     } else if (auto v = Node::As<Limit>(std::move(node))) {
@@ -125,6 +135,8 @@ struct Visitor : Pass {
 
   virtual std::unique_ptr<Node> Visit(std::unique_ptr<NumericLiteral> node) { 
return node; }
 
+  virtual std::unique_ptr<Node> Visit(std::unique_ptr<VectorLiteral> node) { 
return node; }
+
   virtual std::unique_ptr<Node> Visit(std::unique_ptr<NumericCompareExpr> 
node) {
     node->field = VisitAs<FieldRef>(std::move(node->field));
     node->num = VisitAs<NumericLiteral>(std::move(node->num));
@@ -137,6 +149,19 @@ struct Visitor : Pass {
     return node;
   }
 
+  virtual std::unique_ptr<Node> Visit(std::unique_ptr<VectorKnnExpr> node) {
+    node->field = VisitAs<FieldRef>(std::move(node->field));
+    node->vector = VisitAs<VectorLiteral>(std::move(node->vector));
+    return node;
+  }
+
+  virtual std::unique_ptr<Node> Visit(std::unique_ptr<VectorRangeExpr> node) {
+    node->field = VisitAs<FieldRef>(std::move(node->field));
+    node->range = VisitAs<NumericLiteral>(std::move(node->range));
+    node->vector = VisitAs<VectorLiteral>(std::move(node->vector));
+    return node;
+  }
+
   virtual std::unique_ptr<Node> Visit(std::unique_ptr<AndExpr> node) {
     for (auto &n : node->inners) {
       n = TransformAs<QueryExpr>(std::move(n));
@@ -173,6 +198,10 @@ struct Visitor : Pass {
 
   virtual std::unique_ptr<Node> Visit(std::unique_ptr<TagFieldScan> node) { 
return node; }
 
+  virtual std::unique_ptr<Node> 
Visit(std::unique_ptr<HnswVectorFieldRangeScan> node) { return node; }
+
+  virtual std::unique_ptr<Node> Visit(std::unique_ptr<HnswVectorFieldKnnScan> 
node) { return node; }
+
   virtual std::unique_ptr<Node> Visit(std::unique_ptr<Filter> node) {
     node->source = TransformAs<PlanOperator>(std::move(node->source));
     node->filter_expr = TransformAs<QueryExpr>(std::move(node->filter_expr));
diff --git a/src/search/ir_plan.h b/src/search/ir_plan.h
index 94e8b589..8743a827 100644
--- a/src/search/ir_plan.h
+++ b/src/search/ir_plan.h
@@ -99,7 +99,7 @@ struct TagFieldScan : FieldScan {
 
 struct HnswVectorFieldKnnScan : FieldScan {
   kqir::NumericArray vector;
-  uint16_t k;
+  uint32_t k;
 
   HnswVectorFieldKnnScan(std::unique_ptr<FieldRef> field, kqir::NumericArray 
vector, uint16_t k)
       : FieldScan(std::move(field)), vector(std::move(vector)), k(k) {}
diff --git a/src/search/ir_sema_checker.h b/src/search/ir_sema_checker.h
index 43d722b4..8d18cd84 100644
--- a/src/search/ir_sema_checker.h
+++ b/src/search/ir_sema_checker.h
@@ -50,8 +50,14 @@ struct SemaChecker {
         GET_OR_RET(Check(v->query_expr.get()));
         if (v->limit) GET_OR_RET(Check(v->limit.get()));
         if (v->sort_by) GET_OR_RET(Check(v->sort_by.get()));
-        if (v->sort_by && v->sort_by->IsVectorField() && !v->limit) {
-          return {Status::NotOK, "expect a LIMIT clause for vector field to 
construct a KNN search"};
+        if (v->sort_by && v->sort_by->IsVectorField()) {
+          if (!v->limit) {
+            return {Status::NotOK, "expect a LIMIT clause for vector field to 
construct a KNN search"};
+          }
+          // TODO: allow hybrid query
+          if (auto b = dynamic_cast<BoolLiteral *>(v->query_expr.get()); b == 
nullptr) {
+            return {Status::NotOK, "KNN search cannot be combined with other 
query expressions"};
+          }
         }
       } else {
         return {Status::NotOK, fmt::format("index `{}` not found", 
index_name)};
@@ -129,9 +135,6 @@ struct SemaChecker {
           return {Status::NotOK,
                   fmt::format("field `{}` is marked as NOINDEX and cannot be 
used for KNN search", v->field->name)};
         }
-        if (v->k->val <= 0) {
-          return {Status::NotOK, fmt::format("KNN search parameter `k` must be 
greater than 0")};
-        }
         auto meta = 
v->field->info->MetadataAs<redis::HnswVectorFieldMetadata>();
         if (v->vector->values.size() != meta->dim) {
           return {Status::NotOK,
diff --git a/src/search/passes/cost_model.h b/src/search/passes/cost_model.h
index 86e0e3a5..960708d7 100644
--- a/src/search/passes/cost_model.h
+++ b/src/search/passes/cost_model.h
@@ -36,6 +36,12 @@ struct CostModel {
     if (auto v = dynamic_cast<const FullIndexScan *>(node)) {
       return Visit(v);
     }
+    if (auto v = dynamic_cast<const HnswVectorFieldKnnScan *>(node)) {
+      return Visit(v);
+    }
+    if (auto v = dynamic_cast<const HnswVectorFieldRangeScan *>(node)) {
+      return Visit(v);
+    }
     if (auto v = dynamic_cast<const NumericFieldScan *>(node)) {
       return Visit(v);
     }
@@ -74,6 +80,10 @@ struct CostModel {
 
   static size_t Visit(const TagFieldScan *node) { return 10; }
 
+  static size_t Visit(const HnswVectorFieldKnnScan *node) { return 3; }
+
+  static size_t Visit(const HnswVectorFieldRangeScan *node) { return 4; }
+
   static size_t Visit(const Filter *node) { return 
Transform(node->source.get()) + 1; }
 
   static size_t Visit(const Merge *node) {
diff --git a/src/search/passes/index_selection.h 
b/src/search/passes/index_selection.h
index e60287d4..09e1bcb3 100644
--- a/src/search/passes/index_selection.h
+++ b/src/search/passes/index_selection.h
@@ -112,6 +112,12 @@ struct IndexSelection : Visitor {
     if (auto v = dynamic_cast<OrExpr *>(node)) {
       return VisitExpr(v);
     }
+    if (auto v = dynamic_cast<VectorKnnExpr *>(node)) {
+      return VisitExpr(v);
+    }
+    if (auto v = dynamic_cast<VectorRangeExpr *>(node)) {
+      return VisitExpr(v);
+    }
     if (auto v = dynamic_cast<NumericCompareExpr *>(node)) {
       return VisitExpr(v);
     }
@@ -153,6 +159,23 @@ struct IndexSelection : Visitor {
     return MakeFullIndexFilter(node);
   }
 
+  std::unique_ptr<PlanOperator> VisitExpr(VectorRangeExpr *node) const {
+    if (node->field->info->HasIndex()) {
+      return 
std::make_unique<HnswVectorFieldRangeScan>(node->field->CloneAs<FieldRef>(), 
node->vector->values,
+                                                        node->range->val);
+    }
+
+    return MakeFullIndexFilter(node);
+  }
+
+  std::unique_ptr<PlanOperator> VisitExpr(VectorKnnExpr *node) const {
+    if (node->field->info->HasIndex()) {
+      return 
std::make_unique<HnswVectorFieldKnnScan>(node->field->CloneAs<FieldRef>(), 
node->vector->values, node->k);
+    }
+
+    return MakeFullIndexFilter(node);
+  }
+
   template <typename Expr>
   std::unique_ptr<PlanOperator> VisitExprImpl(Expr *node) {
     struct AggregatedNodes {
diff --git a/src/search/passes/manager.h b/src/search/passes/manager.h
index 57f317d2..ce2d6a3d 100644
--- a/src/search/passes/manager.h
+++ b/src/search/passes/manager.h
@@ -35,6 +35,7 @@
 #include "search/passes/simplify_and_or_expr.h"
 #include "search/passes/simplify_boolean.h"
 #include "search/passes/sort_limit_fuse.h"
+#include "search/passes/sort_limit_to_knn.h"
 #include "type_util.h"
 
 namespace kqir {
@@ -86,7 +87,8 @@ struct PassManager {
   }
 
   static PassSequence ExprPasses() {
-    return Create(SimplifyAndOrExpr{}, PushDownNotExpr{}, SimplifyBoolean{}, 
SimplifyAndOrExpr{});
+    return Create(SimplifyAndOrExpr{}, PushDownNotExpr{}, SimplifyBoolean{}, 
SimplifyAndOrExpr{},
+                  SortByWithLimitToKnnExpr{}, SimplifyAndOrExpr{});
   }
   static PassSequence NumericPasses() { return Create(IntervalAnalysis{true}, 
SimplifyAndOrExpr{}, SimplifyBoolean{}); }
   static PassSequence PlanPasses() { return Create(LowerToPlan{}, 
IndexSelection{}, SortLimitFuse{}); }
diff --git a/src/search/passes/sort_limit_to_knn.h 
b/src/search/passes/sort_limit_to_knn.h
new file mode 100644
index 00000000..e0f7b958
--- /dev/null
+++ b/src/search/passes/sort_limit_to_knn.h
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+#pragma once
+
+#include <memory>
+
+#include "search/ir.h"
+#include "search/ir_pass.h"
+#include "search/ir_plan.h"
+
+namespace kqir {
+
+struct SortByWithLimitToKnnExpr : Visitor {
+  std::unique_ptr<Node> Visit(std::unique_ptr<SearchExpr> node) override {
+    node = Node::MustAs<SearchExpr>(Visitor::Visit(std::move(node)));
+
+    // TODO: allow hybrid query
+    if (node->sort_by && node->sort_by->IsVectorField() && node->limit) {
+      if (auto b = dynamic_cast<BoolLiteral*>(node->query_expr.get()); b && 
b->val) {
+        node->query_expr =
+            
std::make_unique<VectorKnnExpr>(Node::MustAs<FieldRef>(node->sort_by->TakeFieldRef()),
+                                            
Node::MustAs<VectorLiteral>(node->sort_by->TakeVectorLiteral()),
+                                            node->limit->Offset() + 
node->limit->Count());
+        node->sort_by.reset();
+      }
+    }
+
+    return node;
+  }
+};
+
+}  // namespace kqir
diff --git a/src/search/redis_query_parser.h b/src/search/redis_query_parser.h
index 5b0f172c..627910a3 100644
--- a/src/search/redis_query_parser.h
+++ b/src/search/redis_query_parser.h
@@ -43,13 +43,14 @@ struct Tag : sor<Identifier, StringL, Param> {};
 struct TagList : seq<one<'{'>, WSPad<Tag>, star<seq<one<'|'>, WSPad<Tag>>>, 
one<'}'>> {};
 
 struct NumberOrParam : sor<Number, Param> {};
+struct UintOrParam : sor<UnsignedInteger, Param> {};
 
 struct Inf : seq<opt<one<'+', '-'>>, string<'i', 'n', 'f'>> {};
 struct ExclusiveNumber : seq<one<'('>, NumberOrParam> {};
 struct NumericRangePart : sor<Inf, ExclusiveNumber, NumberOrParam> {};
 struct NumericRange : seq<one<'['>, WSPad<NumericRangePart>, 
WSPad<NumericRangePart>, one<']'>> {};
 
-struct KnnSearch : seq<one<'['>, WSPad<KnnToken>, WSPad<NumberOrParam>, 
WSPad<Field>, WSPad<Param>, one<']'>> {};
+struct KnnSearch : seq<one<'['>, WSPad<KnnToken>, WSPad<UintOrParam>, 
WSPad<Field>, WSPad<Param>, one<']'>> {};
 struct VectorRange : seq<one<'['>, WSPad<VectorRangeToken>, 
WSPad<NumberOrParam>, WSPad<Param>, one<']'>> {};
 
 struct FieldQuery : seq<WSPad<Field>, one<':'>, WSPad<sor<VectorRange, 
TagList, NumericRange>>> {};
@@ -70,7 +71,7 @@ struct AndExprP : sor<AndExpr, BooleanExpr> {};
 struct OrExpr : seq<AndExprP, plus<seq<one<'|'>, AndExprP>>> {};
 struct OrExprP : sor<OrExpr, AndExprP> {};
 
-struct PrefilterExpr : seq<WSPad<BooleanExpr>, ArrowOp, WSPad<KnnSearch>> {};
+struct PrefilterExpr : seq<WSPad<Wildcard>, ArrowOp, WSPad<KnnSearch>> {};
 
 struct QueryP : sor<PrefilterExpr, OrExprP> {};
 
diff --git a/src/search/redis_query_transformer.h 
b/src/search/redis_query_transformer.h
index c81230e4..ed7c8fc6 100644
--- a/src/search/redis_query_transformer.h
+++ b/src/search/redis_query_transformer.h
@@ -36,7 +36,7 @@ namespace ir = kqir;
 
 template <typename Rule>
 using TreeSelector = parse_tree::selector<
-    Rule, parse_tree::store_content::on<Number, StringL, Param, Identifier, 
Inf>,
+    Rule, parse_tree::store_content::on<Number, UnsignedInteger, StringL, 
Param, Identifier, Inf>,
     parse_tree::remove_content::on<TagList, NumericRange, VectorRange, 
ExclusiveNumber, FieldQuery, NotExpr, AndExpr,
                                    OrExpr, PrefilterExpr, KnnSearch, Wildcard, 
VectorRangeToken, KnnToken, ArrowOp>>;
 
@@ -161,17 +161,21 @@ struct Transformer : ir::TreeTransformer {
 
       return 
Node::Create<ir::NotExpr>(Node::MustAs<ir::QueryExpr>(GET_OR_RET(Transform(node->children[0]))));
     } else if (Is<PrefilterExpr>(node)) {
+      // TODO: allow hybrid query
       CHECK(node->children.size() == 3);
 
-      // TODO(Beihao): Support Hybrid Query
-      // const auto& prefilter = node->children[0];
       const auto& knn_search = node->children[2];
       CHECK(knn_search->children.size() == 4);
 
-      return 
std::make_unique<VectorKnnExpr>(std::make_unique<FieldRef>(knn_search->children[2]->string()),
-                                             
GET_OR_RET(number_or_param(knn_search->children[1])),
-                                             
GET_OR_RET(Transform2Vector(knn_search->children[3])));
+      size_t k = 0;
+      if (Is<UnsignedInteger>(knn_search->children[1])) {
+        k = *ParseInt(knn_search->children[1]->string());
+      } else {
+        k = *ParseInt(GET_OR_RET(GetParam(node)));
+      }
 
+      return 
std::make_unique<VectorKnnExpr>(std::make_unique<FieldRef>(knn_search->children[2]->string()),
+                                             
GET_OR_RET(Transform2Vector(knn_search->children[3])), k);
     } else if (Is<AndExpr>(node)) {
       std::vector<std::unique_ptr<ir::QueryExpr>> exprs;
 
diff --git a/src/search/sql_transformer.h b/src/search/sql_transformer.h
index 01705107..49d04307 100644
--- a/src/search/sql_transformer.h
+++ b/src/search/sql_transformer.h
@@ -118,7 +118,6 @@ struct Transformer : ir::TreeTransformer {
         return {Status::NotOK, "the left and right side of numeric comparison 
should be an identifier and a number"};
       }
     } else if (Is<VectorRangeExpr>(node)) {
-      // TODO(Beihao): Handle distance metrics for operator
       CHECK(node->children.size() == 2);
       const auto& vector_comp_expr = node->children[0];
       CHECK(vector_comp_expr->children.size() == 3);
diff --git a/tests/cppunit/ir_pass_test.cc b/tests/cppunit/ir_pass_test.cc
index 81ed49e8..9d576678 100644
--- a/tests/cppunit/ir_pass_test.cc
+++ b/tests/cppunit/ir_pass_test.cc
@@ -111,6 +111,19 @@ TEST(IRPassTest, Manager) {
             "select * from a where (and x <= 1, y >= 2, z != 3)");
 }
 
+TEST(IRPassTest, SortByWithLimitToKnnExpr) {
+  SortByWithLimitToKnnExpr tsbtke;
+
+  ASSERT_EQ(tsbtke.Transform(*Parse("select a from b order by embedding <-> 
[3.6] limit 5"))->Dump(),
+            "select a from b where KNN k=5, embedding <-> [3.600000] limit 0, 
5");
+  ASSERT_EQ(tsbtke.Transform(*Parse("select a from b where false order by 
embedding <-> [3,1,2] limit 5"))->Dump(),
+            "select a from b where false sortby embedding <-> [3.000000, 
1.000000, 2.000000] limit 0, 5");
+  ASSERT_EQ(tsbtke.Transform(*Parse("select a from b where true order by 
embedding <-> [3,1,2] limit 5"))->Dump(),
+            "select a from b where KNN k=5, embedding <-> [3.000000, 1.000000, 
2.000000] limit 0, 5");
+  ASSERT_EQ(tsbtke.Transform(*Parse("select a from b where true order by 
embedding <-> [3,1,2] limit 3, 5"))->Dump(),
+            "select a from b where KNN k=8, embedding <-> [3.000000, 1.000000, 
2.000000] limit 3, 5");
+}
+
 TEST(IRPassTest, LowerToPlan) {
   LowerToPlan ltp;
 
@@ -118,11 +131,15 @@ TEST(IRPassTest, LowerToPlan) {
   ASSERT_EQ(ltp.Transform(*Parse("select * from a limit 1"))->Dump(), "project 
*: (limit 0, 1: full-scan a)");
   ASSERT_EQ(ltp.Transform(*Parse("select * from a where false"))->Dump(), 
"project *: noop");
   ASSERT_EQ(ltp.Transform(*Parse("select * from a where false limit 
1"))->Dump(), "project *: noop");
+  ASSERT_EQ(ltp.Transform(*Parse("select * from a where false order by 
embedding <-> [3,1,2] limit 5"))->Dump(),
+            "project *: noop");
   ASSERT_EQ(ltp.Transform(*Parse("select * from a where b > 1"))->Dump(), 
"project *: (filter b > 1: full-scan a)");
   ASSERT_EQ(ltp.Transform(*Parse("select a from b where c = 1 order by 
d"))->Dump(),
             "project a: (sort d, asc: (filter c = 1: full-scan b))");
   ASSERT_EQ(ltp.Transform(*Parse("select a from b where c = 1 limit 
1"))->Dump(),
             "project a: (limit 0, 1: (filter c = 1: full-scan b))");
+  ASSERT_EQ(ltp.Transform(*Parse("select a from b where c = 1 and d = 2 order 
by e limit 1"))->Dump(),
+            "project a: (limit 0, 1: (sort e, asc: (filter (and c = 1, d = 2): 
full-scan b)))");
   ASSERT_EQ(ltp.Transform(*Parse("select a from b where c = 1 order by d limit 
1"))->Dump(),
             "project a: (limit 0, 1: (sort d, asc: (filter c = 1: full-scan 
b)))");
 }
@@ -176,12 +193,28 @@ static IndexMap MakeIndexMap() {
   auto f4 = FieldInfo("n2", std::make_unique<redis::NumericFieldMetadata>());
   auto f5 = FieldInfo("n3", std::make_unique<redis::NumericFieldMetadata>());
   f5.metadata->noindex = true;
+
+  auto hnsw_field_meta = std::make_unique<redis::HnswVectorFieldMetadata>();
+  hnsw_field_meta->vector_type = redis::VectorType::FLOAT64;
+  hnsw_field_meta->dim = 3;
+  hnsw_field_meta->distance_metric = redis::DistanceMetric::L2;
+  auto f6 = FieldInfo("v1", std::move(hnsw_field_meta));
+
+  hnsw_field_meta = std::make_unique<redis::HnswVectorFieldMetadata>();
+  hnsw_field_meta->vector_type = redis::VectorType::FLOAT64;
+  hnsw_field_meta->dim = 3;
+  hnsw_field_meta->distance_metric = redis::DistanceMetric::L2;
+  auto f7 = FieldInfo("v2", std::move(hnsw_field_meta));
+  f7.metadata->noindex = true;
+
   auto ia = std::make_unique<IndexInfo>("ia", redis::IndexMetadata(), "");
   ia->Add(std::move(f1));
   ia->Add(std::move(f2));
   ia->Add(std::move(f3));
   ia->Add(std::move(f4));
   ia->Add(std::move(f5));
+  ia->Add(std::move(f6));
+  ia->Add(std::move(f7));
 
   IndexMap res;
   res.Insert(std::move(ia));
@@ -238,7 +271,19 @@ TEST(IRPassTest, IndexSelection) {
       "project *: (filter t2 hastag \"a\": tag-scan t1, a)");
   ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where t2 
hastag \"a\""))->Dump(),
             "project *: (filter t2 hastag \"a\": full-scan ia)");
-
+  ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where v1 
<-> [3,1,2] < 5"))->Dump(),
+            "project *: hnsw-vector-range-scan v1, [3.000000, 1.000000, 
2.000000], 5");
+  ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia order by 
v1 <-> [3,1,2] limit 5"))->Dump(),
+            "project *: (limit 0, 5: hnsw-vector-knn-scan v1, [3.000000, 
1.000000, 2.000000], 5)");
+  ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia order by 
v1 <-> [3,1,2] limit 2, 7"))->Dump(),
+            "project *: (limit 2, 7: hnsw-vector-knn-scan v1, [3.000000, 
1.000000, 2.000000], 9)");
+  ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where v2 
<-> [3,1,2] < 5"))->Dump(),
+            "project *: (filter v2 <-> [3.000000, 1.000000, 2.000000] < 5: 
full-scan ia)");
+  ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where n1 
>= 1 and v1 <-> [3,1,2] < 5"))->Dump(),
+            "project *: (filter n1 >= 1: hnsw-vector-range-scan v1, [3.000000, 
1.000000, 2.000000], 5)");
+  ASSERT_EQ(
+      PassManager::Execute(passes, ParseS(sc, "select * from ia where v1 <-> 
[3,1,2] < 5 and t1 hastag \"a\""))->Dump(),
+      "project *: (filter t1 hastag \"a\": hnsw-vector-range-scan v1, 
[3.000000, 1.000000, 2.000000], 5)");
   ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where n1 
>= 2 or n1 < 1"))->Dump(),
             "project *: (merge numeric-scan n1, [-inf, 1), asc, numeric-scan 
n1, [2, inf), asc)");
   ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where n1 
>= 1 or n2 >= 2"))->Dump(),
diff --git a/tests/cppunit/ir_sema_checker_test.cc 
b/tests/cppunit/ir_sema_checker_test.cc
index df8076ce..8926d026 100644
--- a/tests/cppunit/ir_sema_checker_test.cc
+++ b/tests/cppunit/ir_sema_checker_test.cc
@@ -100,6 +100,13 @@ TEST(SemaCheckerTest, Simple) {
               "expect a LIMIT clause for vector field to construct a KNN 
search");
     ASSERT_EQ(checker.Check(Parse("select f5 from ia order by f5 <-> 
[3.6,4.7,5.6] limit 5")->get()).Msg(),
               "field `f5` is marked as NOINDEX and cannot be used for KNN 
search");
+    ASSERT_EQ(checker.Check(Parse("select f5 from ia where f2 = 1 order by f4 
<-> [3.6,4.7,5.6] limit 5")->get()).Msg(),
+              "KNN search cannot be combined with other query expressions");
+    ASSERT_EQ(checker.Check(Parse("select f5 from ia where true order by f4 
<-> [3.6,4.7,5.6] limit 5")->get()).Msg(),
+              "ok");
+    ASSERT_EQ(checker.Check(Parse("select f5 from ia where false order by f4 
<-> [3.6,4.7,5.6] limit 5")->get()).Msg(),
+              "ok");
+    ASSERT_EQ(checker.Check(Parse("select f5 from ia where f2 = 1 and f5 <-> 
[3.6,4.7,5.6] < 1")->get()).Msg(), "ok");
     ASSERT_EQ(checker.Check(Parse("select f5 from ia where f5 <-> 
[3.6,4.7,5.6] < 5")->get()).Msg(),
               "range has to be between 0 and 2 for cosine distance metric");
     ASSERT_EQ(checker.Check(Parse("select f5 from ia where f5 <-> 
[3.6,4.7,5.6] < 0.5")->get()).Msg(), "ok");
diff --git a/tests/cppunit/plan_executor_test.cc 
b/tests/cppunit/plan_executor_test.cc
index 1b80329d..91c8cd2d 100644
--- a/tests/cppunit/plan_executor_test.cc
+++ b/tests/cppunit/plan_executor_test.cc
@@ -93,6 +93,7 @@ static auto FieldI(const std::string& f) -> const FieldInfo* 
{ return &IndexI()-
 
 static auto N(double n) { return MakeValue<Numeric>(n); }
 static auto T(const std::string& v) { return 
MakeValue<StringArray>(util::Split(v, ",")); }
+static auto V(const std::vector<double>& vals) { return 
MakeValue<NumericArray>(vals); }
 
 TEST(PlanExecutorTest, TopNSort) {
   std::vector<ExecutorNode::RowType> data{
@@ -201,6 +202,41 @@ TEST(PlanExecutorTest, Filter) {
     ASSERT_EQ(NextRow(ctx).key, "f");
     ASSERT_EQ(ctx.Next().GetValue(), exe_end);
   }
+
+  data = {{"a", {{FieldI("f4"), V({1, 2, 3})}}, IndexI()}, {"b", 
{{FieldI("f4"), V({9, 10, 11})}}, IndexI()},
+          {"c", {{FieldI("f4"), V({4, 5, 6})}}, IndexI()}, {"d", 
{{FieldI("f4"), V({1, 2, 3})}}, IndexI()},
+          {"e", {{FieldI("f4"), V({2, 3, 4})}}, IndexI()}, {"f", 
{{FieldI("f4"), V({12, 13, 14})}}, IndexI()},
+          {"g", {{FieldI("f4"), V({1, 2, 3})}}, IndexI()}};
+  {
+    auto field = std::make_unique<FieldRef>("f4", FieldI("f4"));
+    std::vector<double> vector = {11, 12, 13};
+    auto op = std::make_unique<Filter>(
+        std::make_unique<Mock>(data),
+        std::make_unique<VectorRangeExpr>(field->CloneAs<FieldRef>(), 
std::make_unique<NumericLiteral>(4),
+                                          
std::make_unique<VectorLiteral>(std::move(vector))));
+
+    auto ctx = ExecutorContext(op.get());
+    ASSERT_EQ(NextRow(ctx).key, "b");
+    ASSERT_EQ(NextRow(ctx).key, "f");
+    ASSERT_EQ(ctx.Next().GetValue(), exe_end);
+  }
+
+  {
+    auto field = std::make_unique<FieldRef>("f4", FieldI("f4"));
+    std::vector<double> vector = {2, 3, 4};
+    auto op = std::make_unique<Filter>(
+        std::make_unique<Mock>(data),
+        std::make_unique<VectorRangeExpr>(field->CloneAs<FieldRef>(), 
std::make_unique<NumericLiteral>(5),
+                                          
std::make_unique<VectorLiteral>(std::move(vector))));
+
+    auto ctx = ExecutorContext(op.get());
+    ASSERT_EQ(NextRow(ctx).key, "a");
+    ASSERT_EQ(NextRow(ctx).key, "c");
+    ASSERT_EQ(NextRow(ctx).key, "d");
+    ASSERT_EQ(NextRow(ctx).key, "e");
+    ASSERT_EQ(NextRow(ctx).key, "g");
+    ASSERT_EQ(ctx.Next().GetValue(), exe_end);
+  }
 }
 
 TEST(PlanExecutorTest, Limit) {
diff --git a/tests/cppunit/redis_query_parser_test.cc 
b/tests/cppunit/redis_query_parser_test.cc
index 4fc25e49..96f6e26a 100644
--- a/tests/cppunit/redis_query_parser_test.cc
+++ b/tests/cppunit/redis_query_parser_test.cc
@@ -115,16 +115,22 @@ TEST(RedisQueryParserTest, Vector) {
   AssertSyntaxError(Parse("KNN 5 @vector $BLOB", {{"BLOB", vec_str}}));
   AssertSyntaxError(Parse("[KNN 5 @vector $BLOB]", {{"BLOB", vec_str}}));
   AssertSyntaxError(Parse("KNN 5 @vector $BLOB", {{"BLOB", vec_str}}));
+  AssertSyntaxError(Parse("* =>[KNN -1 @vector $BLOB]", {{"BLOB", vec_str}}));
   AssertSyntaxError(Parse("*=>[KNN 5 $vector_blob_param]", 
{{"vector_blob_param", vec_str}}));
+  AssertSyntaxError(Parse("(*) => [KNN 10 @doc_embedding $BLOB]", {{"BLOB", 
vec_str}}));
+  AssertSyntaxError(Parse("(@a:[1 2]) => [KNN 8 @vec_embedding $blob]", 
{{"blob", vec_str}}));
+  AssertSyntaxError(Parse("(@a:{x|y}) => [KNN 8 @vec_embedding $blob]", 
{{"blob", vec_str}}));
+  AssertSyntaxError(Parse("(@a:{x|y}) => [KNN 8 @vec_embedding $blob]", 
{{"blob", vec_str}}));
+  AssertSyntaxError(Parse("(@a:{x}|@b:[1 inf] | @c:{y}) => [KNN 8 
@vec_embedding $blob]", {{"blob", vec_str}}));
+  AssertSyntaxError(Parse("(@a:{x}|@b:[1 inf] | @field:[VECTOR_RANGE 10 
$vector]) => [KNN 8 @vec_embedding $blob]",
+                          {{"blob", vec_str}, {"vector", vec_str}}));
 
   AssertIR(Parse("@field:[VECTOR_RANGE 10 $vector]", {{"vector", vec_str}}),
            "field <-> [1.000000, 2.000000, 3.000000] < 10");
+  AssertIR(Parse("@field:[VECTOR_RANGE 10 $vector]| @b:[1 inf]", {{"vector", 
vec_str}}),
+           "(or field <-> [1.000000, 2.000000, 3.000000] < 10, b >= 1)");
   AssertIR(Parse("*=>[KNN 10 @doc_embedding $BLOB]", {{"BLOB", vec_str}}),
            "KNN k=10, doc_embedding <-> [1.000000, 2.000000, 3.000000]");
-  AssertIR(Parse("(*) => [KNN 10 @doc_embedding $BLOB]", {{"BLOB", vec_str}}),
-           "KNN k=10, doc_embedding <-> [1.000000, 2.000000, 3.000000]");
-  AssertIR(Parse("(@a:[1 2]) => [KNN 8 @vec_embedding $blob]", {{"blob", 
vec_str}}),
-           "KNN k=8, vec_embedding <-> [1.000000, 2.000000, 3.000000]");
   AssertIR(Parse("* =>[KNN 5 @vector $BLOB]", {{"BLOB", vec_str}}),
            "KNN k=5, vector <-> [1.000000, 2.000000, 3.000000]");
 

Reply via email to