This is an automated email from the ASF dual-hosted git repository.

twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new 0f5f18e1 feat(search): implement vector query for sql/redisearch 
parser & transformer (#2450)
0f5f18e1 is described below

commit 0f5f18e106bf874efc8257b6653b2c9717c24c50
Author: Rebecca Zhou <[email protected]>
AuthorDate: Thu Aug 1 18:22:30 2024 -0700

    feat(search): implement vector query for sql/redisearch parser & 
transformer (#2450)
    
    Co-authored-by: Twice <[email protected]>
---
 src/search/common_transformer.h          | 20 +++++++++
 src/search/ir.h                          | 71 +++++++++++++++++++++++++++++-
 src/search/ir_sema_checker.h             | 63 +++++++++++++++++++++++++++
 src/search/redis_query_parser.h          | 16 +++++--
 src/search/redis_query_transformer.h     | 64 +++++++++++++++++++--------
 src/search/search_encoding.h             |  2 +
 src/search/sql_parser.h                  | 11 ++++-
 src/search/sql_transformer.h             | 74 +++++++++++++++++++++-----------
 tests/cppunit/ir_sema_checker_test.cc    | 35 +++++++++++++++
 tests/cppunit/redis_query_parser_test.cc | 33 ++++++++++++++
 tests/cppunit/sql_parser_test.cc         | 22 ++++++++++
 11 files changed, 362 insertions(+), 49 deletions(-)

diff --git a/src/search/common_transformer.h b/src/search/common_transformer.h
index 5ebbcff6..18b2626d 100644
--- a/src/search/common_transformer.h
+++ b/src/search/common_transformer.h
@@ -105,6 +105,26 @@ struct TreeTransformer {
 
     return result;
   }
+
+  template <typename T = double>
+  static StatusOr<std::vector<T>> Binary2Vector(std::string_view str) {
+    if (str.size() % sizeof(T) != 0) {
+      return {Status::NotOK, "data size is not a multiple of the target type 
size"};
+    }
+
+    std::vector<T> values;
+    const size_t type_size = sizeof(T);
+    values.reserve(str.size() / type_size);
+
+    while (!str.empty()) {
+      T value;
+      memcpy(&value, str.data(), type_size);
+      values.push_back(value);
+      str.remove_prefix(type_size);
+    }
+
+    return values;
+  }
 };
 
 }  // namespace kqir
diff --git a/src/search/ir.h b/src/search/ir.h
index 116fe9a7..3ba980da 100644
--- a/src/search/ir.h
+++ b/src/search/ir.h
@@ -229,6 +229,63 @@ struct NumericCompareExpr : BoolAtomExpr {
   }
 };
 
+struct VectorLiteral : Literal {
+  std::vector<double> values;
+
+  explicit VectorLiteral(std::vector<double> &&values) : 
values(std::move(values)){};
+
+  std::string_view Name() const override { return "VectorLiteral"; }
+  std::string Dump() const override {
+    return fmt::format("[{}]", util::StringJoin(values, [](auto v) { return 
std::to_string(v); }));
+  }
+  std::string Content() const override { return Dump(); }
+
+  std::unique_ptr<Node> Clone() const override { return 
std::make_unique<VectorLiteral>(*this); }
+};
+
+struct VectorRangeExpr : BoolAtomExpr {
+  std::unique_ptr<FieldRef> field;
+  std::unique_ptr<NumericLiteral> range;
+  std::unique_ptr<VectorLiteral> vector;
+
+  VectorRangeExpr(std::unique_ptr<FieldRef> &&field, 
std::unique_ptr<NumericLiteral> &&range,
+                  std::unique_ptr<VectorLiteral> &&vector)
+      : field(std::move(field)), range(std::move(range)), 
vector(std::move(vector)) {}
+
+  std::string_view Name() const override { return "VectorRangeExpr"; }
+  std::string Dump() const override {
+    return fmt::format("{} <-> {} < {}", field->Dump(), vector->Dump(), 
range->Dump());
+  }
+
+  std::unique_ptr<Node> Clone() const override {
+    return 
std::make_unique<VectorRangeExpr>(Node::MustAs<FieldRef>(field->Clone()),
+                                             
Node::MustAs<NumericLiteral>(range->Clone()),
+                                             
Node::MustAs<VectorLiteral>(vector->Clone()));
+  }
+};
+
+struct VectorKnnExpr : BoolAtomExpr {
+  // TODO: Support pre-filter for hybrid query
+  std::unique_ptr<FieldRef> field;
+  std::unique_ptr<NumericLiteral> k;
+  std::unique_ptr<VectorLiteral> vector;
+
+  VectorKnnExpr(std::unique_ptr<FieldRef> &&field, 
std::unique_ptr<NumericLiteral> &&k,
+                std::unique_ptr<VectorLiteral> &&vector)
+      : field(std::move(field)), k(std::move(k)), vector(std::move(vector)) {}
+
+  std::string_view Name() const override { return "VectorKnnExpr"; }
+  std::string Dump() const override {
+    return fmt::format("KNN k={}, {} <-> {}", k->Dump(), field->Dump(), 
vector->Dump());
+  }
+
+  std::unique_ptr<Node> Clone() const override {
+    return 
std::make_unique<VectorKnnExpr>(Node::MustAs<FieldRef>(field->Clone()),
+                                           
Node::MustAs<NumericLiteral>(k->Clone()),
+                                           
Node::MustAs<VectorLiteral>(vector->Clone()));
+  }
+};
+
 struct BoolLiteral : BoolAtomExpr, Literal {
   bool val;
 
@@ -336,18 +393,30 @@ struct LimitClause : Node {
   std::string Content() const override { return fmt::format("{}, {}", offset, 
count); }
 
   std::unique_ptr<Node> Clone() const override { return 
std::make_unique<LimitClause>(*this); }
+  size_t Offset() const { return offset; }
+
+  size_t Count() const { return count; }
 };
 
 struct SortByClause : Node {
   enum Order { ASC, DESC } order = ASC;
   std::unique_ptr<FieldRef> field;
+  std::unique_ptr<VectorLiteral> vector = nullptr;
 
   SortByClause(Order order, std::unique_ptr<FieldRef> &&field) : order(order), 
field(std::move(field)) {}
+  SortByClause(std::unique_ptr<FieldRef> &&field, 
std::unique_ptr<VectorLiteral> &&vector)
+      : field(std::move(field)), vector(std::move(vector)) {}
 
   static constexpr const char *OrderToString(Order order) { return order == 
ASC ? "asc" : "desc"; }
+  bool IsVectorField() const { return vector != nullptr; }
 
   std::string_view Name() const override { return "SortByClause"; }
-  std::string Dump() const override { return fmt::format("sortby {}, {}", 
field->Dump(), OrderToString(order)); }
+  std::string Dump() const override {
+    if (!IsVectorField()) {
+      return fmt::format("sortby {}, {}", field->Dump(), OrderToString(order));
+    }
+    return fmt::format("sortby {} <-> {}", field->Dump(), vector->Dump());
+  }
   std::string Content() const override { return OrderToString(order); }
 
   NodeIterator ChildBegin() override { return NodeIterator(field.get()); };
diff --git a/src/search/ir_sema_checker.h b/src/search/ir_sema_checker.h
index a7a76181..43d722b4 100644
--- a/src/search/ir_sema_checker.h
+++ b/src/search/ir_sema_checker.h
@@ -50,6 +50,9 @@ struct SemaChecker {
         GET_OR_RET(Check(v->query_expr.get()));
         if (v->limit) GET_OR_RET(Check(v->limit.get()));
         if (v->sort_by) GET_OR_RET(Check(v->sort_by.get()));
+        if (v->sort_by && v->sort_by->IsVectorField() && !v->limit) {
+          return {Status::NotOK, "expect a LIMIT clause for vector field to 
construct a KNN search"};
+        }
       } else {
         return {Status::NotOK, fmt::format("index `{}` not found", 
index_name)};
       }
@@ -60,8 +63,25 @@ struct SemaChecker {
         return {Status::NotOK, fmt::format("field `{}` not found in index 
`{}`", v->field->name, current_index->name)};
       } else if (!iter->second.IsSortable()) {
         return {Status::NotOK, fmt::format("field `{}` is not sortable", 
v->field->name)};
+      } else if (auto is_vector = 
iter->second.MetadataAs<redis::HnswVectorFieldMetadata>() != nullptr;
+                 is_vector != v->IsVectorField()) {
+        std::string not_str = is_vector ? "" : "not ";
+        return {Status::NotOK,
+                fmt::format("field `{}` is {}a vector field according to 
metadata and does {}expect a vector parameter",
+                            v->field->name, not_str, not_str)};
       } else {
         v->field->info = &iter->second;
+        if (v->IsVectorField()) {
+          auto meta = 
v->field->info->MetadataAs<redis::HnswVectorFieldMetadata>();
+          if (!v->field->info->HasIndex()) {
+            return {Status::NotOK,
+                    fmt::format("field `{}` is marked as NOINDEX and cannot be 
used for KNN search", v->field->name)};
+          }
+          if (v->vector->values.size() != meta->dim) {
+            return {Status::NotOK,
+                    fmt::format("vector should be of size `{}` for field 
`{}`", meta->dim, v->field->name)};
+          }
+        }
       }
     } else if (auto v = dynamic_cast<AndExpr *>(node)) {
       for (const auto &n : v->inners) {
@@ -97,6 +117,49 @@ struct SemaChecker {
       } else {
         v->field->info = &iter->second;
       }
+    } else if (auto v = dynamic_cast<VectorKnnExpr *>(node)) {
+      if (auto iter = current_index->fields.find(v->field->name); iter == 
current_index->fields.end()) {
+        return {Status::NotOK, fmt::format("field `{}` not found in index 
`{}`", v->field->name, current_index->name)};
+      } else if (!iter->second.MetadataAs<redis::HnswVectorFieldMetadata>()) {
+        return {Status::NotOK, fmt::format("field `{}` is not a vector field", 
v->field->name)};
+      } else {
+        v->field->info = &iter->second;
+
+        if (!v->field->info->HasIndex()) {
+          return {Status::NotOK,
+                  fmt::format("field `{}` is marked as NOINDEX and cannot be 
used for KNN search", v->field->name)};
+        }
+        if (v->k->val <= 0) {
+          return {Status::NotOK, fmt::format("KNN search parameter `k` must be 
greater than 0")};
+        }
+        auto meta = 
v->field->info->MetadataAs<redis::HnswVectorFieldMetadata>();
+        if (v->vector->values.size() != meta->dim) {
+          return {Status::NotOK,
+                  fmt::format("vector should be of size `{}` for field `{}`", 
meta->dim, v->field->name)};
+        }
+      }
+    } else if (auto v = dynamic_cast<VectorRangeExpr *>(node)) {
+      if (auto iter = current_index->fields.find(v->field->name); iter == 
current_index->fields.end()) {
+        return {Status::NotOK, fmt::format("field `{}` not found in index 
`{}`", v->field->name, current_index->name)};
+      } else if (!iter->second.MetadataAs<redis::HnswVectorFieldMetadata>()) {
+        return {Status::NotOK, fmt::format("field `{}` is not a vector field", 
v->field->name)};
+      } else {
+        v->field->info = &iter->second;
+
+        auto meta = 
v->field->info->MetadataAs<redis::HnswVectorFieldMetadata>();
+        if (meta->distance_metric == redis::DistanceMetric::L2 && 
v->range->val < 0) {
+          return {Status::NotOK, "range cannot be a negative number for l2 
distance metric"};
+        }
+
+        if (meta->distance_metric == redis::DistanceMetric::COSINE && 
(v->range->val < 0 || v->range->val > 2)) {
+          return {Status::NotOK, "range has to be between 0 and 2 for cosine 
distance metric"};
+        }
+
+        if (v->vector->values.size() != meta->dim) {
+          return {Status::NotOK,
+                  fmt::format("vector should be of size `{}` for field `{}`", 
meta->dim, v->field->name)};
+        }
+      }
     } else if (auto v = dynamic_cast<SelectClause *>(node)) {
       for (const auto &n : v->fields) {
         if (auto iter = current_index->fields.find(n->name); iter == 
current_index->fields.end()) {
diff --git a/src/search/redis_query_parser.h b/src/search/redis_query_parser.h
index 5fe03046..5b0f172c 100644
--- a/src/search/redis_query_parser.h
+++ b/src/search/redis_query_parser.h
@@ -30,6 +30,11 @@ namespace redis_query {
 
 using namespace peg;
 
+struct VectorRangeToken : string<'V', 'E', 'C', 'T', 'O', 'R', '_', 'R', 'A', 
'N', 'G', 'E'> {};
+struct KnnToken : string<'K', 'N', 'N'> {};
+struct ArrowOp : string<'=', '>'> {};
+struct Wildcard : one<'*'> {};
+
 struct Field : seq<one<'@'>, Identifier> {};
 
 struct Param : seq<one<'$'>, Identifier> {};
@@ -44,9 +49,10 @@ struct ExclusiveNumber : seq<one<'('>, NumberOrParam> {};
 struct NumericRangePart : sor<Inf, ExclusiveNumber, NumberOrParam> {};
 struct NumericRange : seq<one<'['>, WSPad<NumericRangePart>, 
WSPad<NumericRangePart>, one<']'>> {};
 
-struct FieldQuery : seq<WSPad<Field>, one<':'>, WSPad<sor<TagList, 
NumericRange>>> {};
+struct KnnSearch : seq<one<'['>, WSPad<KnnToken>, WSPad<NumberOrParam>, 
WSPad<Field>, WSPad<Param>, one<']'>> {};
+struct VectorRange : seq<one<'['>, WSPad<VectorRangeToken>, 
WSPad<NumberOrParam>, WSPad<Param>, one<']'>> {};
 
-struct Wildcard : one<'*'> {};
+struct FieldQuery : seq<WSPad<Field>, one<':'>, WSPad<sor<VectorRange, 
TagList, NumericRange>>> {};
 
 struct QueryExpr;
 
@@ -64,7 +70,11 @@ struct AndExprP : sor<AndExpr, BooleanExpr> {};
 struct OrExpr : seq<AndExprP, plus<seq<one<'|'>, AndExprP>>> {};
 struct OrExprP : sor<OrExpr, AndExprP> {};
 
-struct QueryExpr : seq<OrExprP> {};
+struct PrefilterExpr : seq<WSPad<BooleanExpr>, ArrowOp, WSPad<KnnSearch>> {};
+
+struct QueryP : sor<PrefilterExpr, OrExprP> {};
+
+struct QueryExpr : seq<QueryP> {};
 
 }  // namespace redis_query
 
diff --git a/src/search/redis_query_transformer.h 
b/src/search/redis_query_transformer.h
index 6ff1581b..c81230e4 100644
--- a/src/search/redis_query_transformer.h
+++ b/src/search/redis_query_transformer.h
@@ -35,10 +35,10 @@ namespace redis_query {
 namespace ir = kqir;
 
 template <typename Rule>
-using TreeSelector =
-    parse_tree::selector<Rule, parse_tree::store_content::on<Number, StringL, 
Param, Identifier, Inf>,
-                         parse_tree::remove_content::on<TagList, NumericRange, 
ExclusiveNumber, FieldQuery, NotExpr,
-                                                        AndExpr, OrExpr, 
Wildcard>>;
+using TreeSelector = parse_tree::selector<
+    Rule, parse_tree::store_content::on<Number, StringL, Param, Identifier, 
Inf>,
+    parse_tree::remove_content::on<TagList, NumericRange, VectorRange, 
ExclusiveNumber, FieldQuery, NotExpr, AndExpr,
+                                   OrExpr, PrefilterExpr, KnnSearch, Wildcard, 
VectorRangeToken, KnnToken, ArrowOp>>;
 
 template <typename Input>
 StatusOr<std::unique_ptr<parse_tree::node>> ParseToTree(Input&& in) {
@@ -53,7 +53,31 @@ StatusOr<std::unique_ptr<parse_tree::node>> 
ParseToTree(Input&& in) {
 struct Transformer : ir::TreeTransformer {
   explicit Transformer(const ParamMap& param_map) : TreeTransformer(param_map) 
{}
 
+  StatusOr<std::unique_ptr<VectorLiteral>> Transform2Vector(const TreeNode& 
node) {
+    std::string vector_str = GET_OR_RET(GetParam(node));
+
+    std::vector<double> values = GET_OR_RET(Binary2Vector<double>(vector_str));
+    if (values.empty()) {
+      return {Status::NotOK, "empty vector is invalid"};
+    }
+    return std::make_unique<ir::VectorLiteral>(std::move(values));
+  };
+
   auto Transform(const TreeNode& node) -> StatusOr<std::unique_ptr<Node>> {
+    auto number_or_param = [this](const TreeNode& node) -> 
StatusOr<std::unique_ptr<NumericLiteral>> {
+      if (Is<Number>(node)) {
+        return Node::MustAs<ir::NumericLiteral>(GET_OR_RET(Transform(node)));
+      } else if (Is<Param>(node)) {
+        auto val = GET_OR_RET(ParseFloat(GET_OR_RET(GetParam(node)))
+                                  .Prefixed(fmt::format("parameter {} is not a 
number", node->string_view())));
+
+        return std::make_unique<ir::NumericLiteral>(val);
+      } else {
+        return {Status::NotOK,
+                fmt::format("expected a number or a parameter in numeric 
comparison but got {}", node->type)};
+      }
+    };
+
     if (Is<Number>(node)) {
       return Node::Create<ir::NumericLiteral>(*ParseFloat(node->string()));
     } else if (Is<Wildcard>(node)) {
@@ -88,26 +112,12 @@ struct Transformer : ir::TreeTransformer {
         } else {
           return std::make_unique<ir::OrExpr>(std::move(exprs));
         }
-      } else {  // NumericRange
+      } else if (Is<NumericRange>(query)) {
         std::vector<std::unique_ptr<ir::QueryExpr>> exprs;
 
         const auto& lhs = query->children[0];
         const auto& rhs = query->children[1];
 
-        auto number_or_param = [this](const TreeNode& node) -> 
StatusOr<std::unique_ptr<NumericLiteral>> {
-          if (Is<Number>(node)) {
-            return 
Node::MustAs<ir::NumericLiteral>(GET_OR_RET(Transform(node)));
-          } else if (Is<Param>(node)) {
-            auto val = GET_OR_RET(ParseFloat(GET_OR_RET(GetParam(node)))
-                                      .Prefixed(fmt::format("parameter {} is 
not a number", node->string_view())));
-
-            return std::make_unique<ir::NumericLiteral>(val);
-          } else {
-            return {Status::NotOK,
-                    fmt::format("expected a number or a parameter in numeric 
comparison but got {}", node->type)};
-          }
-        };
-
         if (Is<ExclusiveNumber>(lhs)) {
           
exprs.push_back(std::make_unique<NumericCompareExpr>(NumericCompareExpr::GT,
                                                                
std::make_unique<FieldRef>(field),
@@ -141,11 +151,27 @@ struct Transformer : ir::TreeTransformer {
         } else {
           return std::make_unique<ir::AndExpr>(std::move(exprs));
         }
+      } else if (Is<VectorRange>(query)) {
+        return 
std::make_unique<VectorRangeExpr>(std::make_unique<FieldRef>(field),
+                                                 
GET_OR_RET(number_or_param(query->children[1])),
+                                                 
GET_OR_RET(Transform2Vector(query->children[2])));
       }
     } else if (Is<NotExpr>(node)) {
       CHECK(node->children.size() == 1);
 
       return 
Node::Create<ir::NotExpr>(Node::MustAs<ir::QueryExpr>(GET_OR_RET(Transform(node->children[0]))));
+    } else if (Is<PrefilterExpr>(node)) {
+      CHECK(node->children.size() == 3);
+
+      // TODO(Beihao): Support Hybrid Query
+      // const auto& prefilter = node->children[0];
+      const auto& knn_search = node->children[2];
+      CHECK(knn_search->children.size() == 4);
+
+      return 
std::make_unique<VectorKnnExpr>(std::make_unique<FieldRef>(knn_search->children[2]->string()),
+                                             
GET_OR_RET(number_or_param(knn_search->children[1])),
+                                             
GET_OR_RET(Transform2Vector(knn_search->children[3])));
+
     } else if (Is<AndExpr>(node)) {
       std::vector<std::unique_ptr<ir::QueryExpr>> exprs;
 
diff --git a/src/search/search_encoding.h b/src/search/search_encoding.h
index 2fbbde8c..26b442ca 100644
--- a/src/search/search_encoding.h
+++ b/src/search/search_encoding.h
@@ -373,6 +373,8 @@ struct HnswVectorFieldMetadata : IndexFieldMetadata {
 
   HnswVectorFieldMetadata() : IndexFieldMetadata(IndexFieldType::VECTOR) {}
 
+  bool IsSortable() const override { return true; }
+
   void Encode(std::string *dst) const override {
     IndexFieldMetadata::Encode(dst);
     PutFixed8(dst, uint8_t(vector_type));
diff --git a/src/search/sql_parser.h b/src/search/sql_parser.h
index 22b985fd..751b0b47 100644
--- a/src/search/sql_parser.h
+++ b/src/search/sql_parser.h
@@ -41,7 +41,12 @@ struct NumericAtomExpr : WSPad<sor<NumberOrParam, 
Identifier>> {};
 struct NumericCompareOp : sor<string<'!', '='>, string<'<', '='>, string<'>', 
'='>, one<'=', '<', '>'>> {};
 struct NumericCompareExpr : seq<NumericAtomExpr, NumericCompareOp, 
NumericAtomExpr> {};
 
-struct BooleanAtomExpr : sor<HasTagExpr, NumericCompareExpr, WSPad<Boolean>> 
{};
+struct VectorCompareOp : string<'<', '-', '>'> {};
+struct VectorLiteral : seq<WSPad<one<'['>>, Number, star<seq<WSPad<one<','>>>, 
Number>, WSPad<one<']'>>> {};
+struct VectorCompareExpr : seq<WSPad<Identifier>, VectorCompareOp, 
WSPad<VectorLiteral>> {};
+struct VectorRangeExpr : seq<VectorCompareExpr, one<'<'>, 
WSPad<NumberOrParam>> {};
+
+struct BooleanAtomExpr : sor<HasTagExpr, NumericCompareExpr, VectorRangeExpr, 
WSPad<Boolean>> {};
 
 struct QueryExpr;
 
@@ -84,7 +89,9 @@ struct Limit : string<'l', 'i', 'm', 'i', 't'> {};
 
 struct WhereClause : seq<Where, QueryExpr> {};
 struct AscOrDesc : sor<Asc, Desc> {};
-struct OrderByClause : seq<OrderBy, WSPad<Identifier>, opt<WSPad<AscOrDesc>>> 
{};
+struct SortableFieldExpr : seq<WSPad<Identifier>, opt<AscOrDesc>> {};
+struct OrderByExpr : sor<WSPad<VectorCompareExpr>, WSPad<SortableFieldExpr>> 
{};
+struct OrderByClause : seq<OrderBy, OrderByExpr> {};
 struct LimitClause : seq<Limit, opt<seq<WSPad<UnsignedInteger>, one<','>>>, 
WSPad<UnsignedInteger>> {};
 
 struct SearchStmt
diff --git a/src/search/sql_transformer.h b/src/search/sql_transformer.h
index d2ed8c21..01705107 100644
--- a/src/search/sql_transformer.h
+++ b/src/search/sql_transformer.h
@@ -41,8 +41,9 @@ using TreeSelector = parse_tree::selector<
     Rule,
     parse_tree::store_content::on<Boolean, Number, StringL, Param, Identifier, 
NumericCompareOp, AscOrDesc,
                                   UnsignedInteger>,
-    parse_tree::remove_content::on<HasTagExpr, NumericCompareExpr, NotExpr, 
AndExpr, OrExpr, Wildcard, SelectExpr,
-                                   FromExpr, WhereClause, OrderByClause, 
LimitClause, SearchStmt>>;
+    parse_tree::remove_content::on<HasTagExpr, NumericCompareExpr, 
VectorCompareOp, VectorLiteral, VectorCompareExpr,
+                                   VectorRangeExpr, NotExpr, AndExpr, OrExpr, 
Wildcard, SelectExpr, FromExpr,
+                                   WhereClause, OrderByClause, OrderByExpr, 
LimitClause, SearchStmt>>;
 
 template <typename Input>
 StatusOr<std::unique_ptr<parse_tree::node>> ParseToTree(Input&& in) {
@@ -58,12 +59,32 @@ struct Transformer : ir::TreeTransformer {
   explicit Transformer(const ParamMap& param_map) : TreeTransformer(param_map) 
{}
 
   auto Transform(const TreeNode& node) -> StatusOr<std::unique_ptr<Node>> {
+    auto number_or_param = [this](const TreeNode& node) -> 
StatusOr<std::unique_ptr<NumericLiteral>> {
+      if (Is<Number>(node)) {
+        return Node::MustAs<ir::NumericLiteral>(GET_OR_RET(Transform(node)));
+      } else if (Is<Param>(node)) {
+        auto val = GET_OR_RET(ParseFloat(GET_OR_RET(GetParam(node)))
+                                  .Prefixed(fmt::format("parameter {} is not a 
number", node->string_view())));
+
+        return std::make_unique<ir::NumericLiteral>(val);
+      } else {
+        return {Status::NotOK,
+                fmt::format("expected a number or a parameter in numeric 
comparison but got {}", node->type)};
+      }
+    };
+
     if (Is<Boolean>(node)) {
       return Node::Create<ir::BoolLiteral>(node->string_view() == "true");
     } else if (Is<Number>(node)) {
       return Node::Create<ir::NumericLiteral>(*ParseFloat(node->string()));
     } else if (Is<StringL>(node)) {
       return 
Node::Create<ir::StringLiteral>(GET_OR_RET(UnescapeString(node->string_view())));
+    } else if (Is<VectorLiteral>(node)) {
+      std::vector<double> values;
+      for (const auto& child : node->children) {
+        values.push_back(*ParseFloat(child->string()));
+      }
+      return Node::Create<ir::VectorLiteral>(std::move(values));
     } else if (Is<HasTagExpr>(node)) {
       CHECK(node->children.size() == 2);
 
@@ -85,20 +106,6 @@ struct Transformer : ir::TreeTransformer {
       const auto& lhs = node->children[0];
       const auto& rhs = node->children[2];
 
-      auto number_or_param = [this](const TreeNode& node) -> 
StatusOr<std::unique_ptr<NumericLiteral>> {
-        if (Is<Number>(node)) {
-          return Node::MustAs<ir::NumericLiteral>(GET_OR_RET(Transform(node)));
-        } else if (Is<Param>(node)) {
-          auto val = GET_OR_RET(ParseFloat(GET_OR_RET(GetParam(node)))
-                                    .Prefixed(fmt::format("parameter {} is not 
a number", node->string_view())));
-
-          return std::make_unique<ir::NumericLiteral>(val);
-        } else {
-          return {Status::NotOK,
-                  fmt::format("expected a number or a parameter in numeric 
comparison but got {}", node->type)};
-        }
-      };
-
       auto op = 
ir::NumericCompareExpr::FromOperator(node->children[1]->string_view()).value();
       if (Is<Identifier>(lhs) && (Is<Number>(rhs) || Is<Param>(rhs))) {
         return Node::Create<ir::NumericCompareExpr>(op, 
std::make_unique<ir::FieldRef>(lhs->string()),
@@ -110,6 +117,16 @@ struct Transformer : ir::TreeTransformer {
       } else {
         return {Status::NotOK, "the left and right side of numeric comparison 
should be an identifier and a number"};
       }
+    } else if (Is<VectorRangeExpr>(node)) {
+      // TODO(Beihao): Handle distance metrics for operator
+      CHECK(node->children.size() == 2);
+      const auto& vector_comp_expr = node->children[0];
+      CHECK(vector_comp_expr->children.size() == 3);
+
+      return Node::Create<ir::VectorRangeExpr>(
+          
std::make_unique<ir::FieldRef>(vector_comp_expr->children[0]->string()),
+          GET_OR_RET(number_or_param(node->children[1])),
+          
Node::MustAs<ir::VectorLiteral>(GET_OR_RET(Transform(vector_comp_expr->children[2]))));
     } else if (Is<NotExpr>(node)) {
       CHECK(node->children.size() == 1);
 
@@ -161,15 +178,24 @@ struct Transformer : ir::TreeTransformer {
 
       return Node::Create<ir::LimitClause>(offset, count);
     } else if (Is<OrderByClause>(node)) {
-      CHECK(node->children.size() == 1 || node->children.size() == 2);
-
-      auto field = std::make_unique<FieldRef>(node->children[0]->string());
-      auto order = SortByClause::Order::ASC;
-      if (node->children.size() == 2 && node->children[1]->string_view() == 
"desc") {
-        order = SortByClause::Order::DESC;
+      CHECK(node->children.size() == 1);
+      const auto& order_by_expr = node->children[0];
+      CHECK(order_by_expr->children.size() == 1 || 
order_by_expr->children.size() == 2);
+
+      if (Is<VectorCompareExpr>(order_by_expr->children[0])) {
+        const auto& vector_compare_expr = order_by_expr->children[0];
+        CHECK(vector_compare_expr->children.size() == 3);
+        auto field = 
std::make_unique<FieldRef>(vector_compare_expr->children[0]->string());
+        return Node::Create<SortByClause>(
+            std::move(field), 
Node::MustAs<ir::VectorLiteral>(GET_OR_RET(Transform(vector_compare_expr->children[2]))));
+      } else {
+        auto field = 
std::make_unique<FieldRef>(order_by_expr->children[0]->string());
+        auto order = SortByClause::Order::ASC;
+        if (order_by_expr->children.size() == 2 && 
order_by_expr->children[1]->string_view() == "desc") {
+          order = SortByClause::Order::DESC;
+        }
+        return Node::Create<SortByClause>(order, std::move(field));
       }
-
-      return Node::Create<SortByClause>(order, std::move(field));
     } else if (Is<SearchStmt>(node)) {  // root node
       CHECK(node->children.size() >= 2 && node->children.size() <= 5);
 
diff --git a/tests/cppunit/ir_sema_checker_test.cc 
b/tests/cppunit/ir_sema_checker_test.cc
index 3a15dde7..df8076ce 100644
--- a/tests/cppunit/ir_sema_checker_test.cc
+++ b/tests/cppunit/ir_sema_checker_test.cc
@@ -38,10 +38,26 @@ static IndexMap MakeIndexMap() {
   auto f1 = FieldInfo("f1", std::make_unique<redis::TagFieldMetadata>());
   auto f2 = FieldInfo("f2", std::make_unique<redis::NumericFieldMetadata>());
   auto f3 = FieldInfo("f3", std::make_unique<redis::NumericFieldMetadata>());
+
+  auto hnsw_field_meta = std::make_unique<redis::HnswVectorFieldMetadata>();
+  hnsw_field_meta->vector_type = redis::VectorType::FLOAT64;
+  hnsw_field_meta->dim = 3;
+  hnsw_field_meta->distance_metric = redis::DistanceMetric::L2;
+  auto f4 = FieldInfo("f4", std::move(hnsw_field_meta));
+
+  hnsw_field_meta = std::make_unique<redis::HnswVectorFieldMetadata>();
+  hnsw_field_meta->vector_type = redis::VectorType::FLOAT64;
+  hnsw_field_meta->dim = 3;
+  hnsw_field_meta->distance_metric = redis::DistanceMetric::COSINE;
+  auto f5 = FieldInfo("f5", std::move(hnsw_field_meta));
+  f5.metadata->noindex = true;
+
   auto ia = std::make_unique<IndexInfo>("ia", redis::IndexMetadata(), "");
   ia->Add(std::move(f1));
   ia->Add(std::move(f2));
   ia->Add(std::move(f3));
+  ia->Add(std::move(f4));
+  ia->Add(std::move(f5));
 
   IndexMap res;
   res.Insert(std::move(ia));
@@ -68,6 +84,25 @@ TEST(SemaCheckerTest, Simple) {
     ASSERT_EQ(checker.Check(Parse("select f1 from ia where f1 hastag 
\",\"")->get()).Msg(),
               "tag cannot contain the separator `,`");
     ASSERT_EQ(checker.Check(Parse("select f1 from ia order by 
a")->get()).Msg(), "field `a` not found in index `ia`");
+    ASSERT_EQ(checker.Check(Parse("select f4 from ia order by f4 <-> [3.6,4.7] 
limit 5")->get()).Msg(),
+              "vector should be of size `3` for field `f4`");
+    ASSERT_EQ(checker.Check(Parse("select f4 from ia where f4 <-> [3.6,4.7] < 
5")->get()).Msg(),
+              "vector should be of size `3` for field `f4`");
+    ASSERT_EQ(checker.Check(Parse("select f4 from ia where f4 <-> 
[3.6,4.7,5.6] < -5")->get()).Msg(),
+              "range cannot be a negative number for l2 distance metric");
+    ASSERT_EQ(checker.Check(Parse("select f4 from ia order by f4 limit 
5")->get()).Msg(),
+              "field `f4` is a vector field according to metadata and does 
expect a vector parameter");
+    ASSERT_EQ(checker.Check(Parse("select f4 from ia order by f1 <-> 
[3.6,4.7,5.6] limit 5")->get()).Msg(),
+              "field `f1` is not sortable");
+    ASSERT_EQ(checker.Check(Parse("select f4 from ia order by f2 <-> 
[3.6,4.7,5.6] limit 5")->get()).Msg(),
+              "field `f2` is not a vector field according to metadata and does 
not expect a vector parameter");
+    ASSERT_EQ(checker.Check(Parse("select f4 from ia order by f4 <-> 
[3.6,4.7,5.6]")->get()).Msg(),
+              "expect a LIMIT clause for vector field to construct a KNN 
search");
+    ASSERT_EQ(checker.Check(Parse("select f5 from ia order by f5 <-> 
[3.6,4.7,5.6] limit 5")->get()).Msg(),
+              "field `f5` is marked as NOINDEX and cannot be used for KNN 
search");
+    ASSERT_EQ(checker.Check(Parse("select f5 from ia where f5 <-> 
[3.6,4.7,5.6] < 5")->get()).Msg(),
+              "range has to be between 0 and 2 for cosine distance metric");
+    ASSERT_EQ(checker.Check(Parse("select f5 from ia where f5 <-> 
[3.6,4.7,5.6] < 0.5")->get()).Msg(), "ok");
   }
 
   {
diff --git a/tests/cppunit/redis_query_parser_test.cc 
b/tests/cppunit/redis_query_parser_test.cc
index bd66d41a..4fc25e49 100644
--- a/tests/cppunit/redis_query_parser_test.cc
+++ b/tests/cppunit/redis_query_parser_test.cc
@@ -101,3 +101,36 @@ TEST(RedisQueryParserTest, Params) {
   AssertIR(Parse("@c:{$y} @d:[$zzz inf]", {{"y", "hello"}, {"zzz", "3"}}), 
"(and c hastag \"hello\", d >= 3)");
   ASSERT_EQ(Parse("@c:{$y}", {{"z", "hello"}}).Msg(), "parameter with name `y` 
not found");
 }
+
+TEST(RedisQueryParserTest, Vector) {
+  std::vector<double> vec = {1, 2, 3};
+  std::string vec_str(reinterpret_cast<const char*>(vec.data()), vec.size() * 
sizeof(double));
+
+  AssertSyntaxError(Parse("@field:[RANGE 10 $vector]", {{"vector", vec_str}}));
+  AssertSyntaxError(Parse("@field:[VECTOR_RANGE 10 not_param"));
+  AssertSyntaxError(Parse("@field:[VECTOR_RANGE $vector]", {{"vector", 
vec_str}}));
+  AssertSyntaxError(Parse("@field:[VECTOR_RANGE $vector 10]", {{"vector", 
vec_str}}));
+  AssertSyntaxError(Parse("* =>[knn 5 @field $BLOB]", {{"BLOB", vec_str}}));
+  AssertSyntaxError(Parse("* =>[KNN 5 @field not_param]"));
+  AssertSyntaxError(Parse("KNN 5 @vector $BLOB", {{"BLOB", vec_str}}));
+  AssertSyntaxError(Parse("[KNN 5 @vector $BLOB]", {{"BLOB", vec_str}}));
+  AssertSyntaxError(Parse("KNN 5 @vector $BLOB", {{"BLOB", vec_str}}));
+  AssertSyntaxError(Parse("*=>[KNN 5 $vector_blob_param]", 
{{"vector_blob_param", vec_str}}));
+
+  AssertIR(Parse("@field:[VECTOR_RANGE 10 $vector]", {{"vector", vec_str}}),
+           "field <-> [1.000000, 2.000000, 3.000000] < 10");
+  AssertIR(Parse("*=>[KNN 10 @doc_embedding $BLOB]", {{"BLOB", vec_str}}),
+           "KNN k=10, doc_embedding <-> [1.000000, 2.000000, 3.000000]");
+  AssertIR(Parse("(*) => [KNN 10 @doc_embedding $BLOB]", {{"BLOB", vec_str}}),
+           "KNN k=10, doc_embedding <-> [1.000000, 2.000000, 3.000000]");
+  AssertIR(Parse("(@a:[1 2]) => [KNN 8 @vec_embedding $blob]", {{"blob", 
vec_str}}),
+           "KNN k=8, vec_embedding <-> [1.000000, 2.000000, 3.000000]");
+  AssertIR(Parse("* =>[KNN 5 @vector $BLOB]", {{"BLOB", vec_str}}),
+           "KNN k=5, vector <-> [1.000000, 2.000000, 3.000000]");
+
+  vec_str = vec_str.substr(0, 3);
+  ASSERT_EQ(Parse("@field:[VECTOR_RANGE 10 $vector]", {{"vector", 
vec_str}}).Msg(),
+            "data size is not a multiple of the target type size");
+  vec_str = "";
+  ASSERT_EQ(Parse("@field:[VECTOR_RANGE 10 $vector]", {{"vector", 
vec_str}}).Msg(), "empty vector is invalid");
+}
diff --git a/tests/cppunit/sql_parser_test.cc b/tests/cppunit/sql_parser_test.cc
index e85368ce..9173d4c3 100644
--- a/tests/cppunit/sql_parser_test.cc
+++ b/tests/cppunit/sql_parser_test.cc
@@ -146,3 +146,25 @@ TEST(SQLParserTest, Params) {
            "select a from b where (and c hastag \"hello\", d = 3)");
   ASSERT_EQ(Parse("select a from b where c hastag @y", {{"z", 
"hello"}}).Msg(), "parameter with name `y` not found");
 }
+
+TEST(SQLParserTest, Vector) {
+  AssertSyntaxError(Parse("select a from b where embedding <-> [3,1,2]"));
+  AssertSyntaxError(Parse("select a from b where embedding <-> [3,1,2] <"));
+  AssertSyntaxError(Parse("select a from b where embedding [3,1,2] < 3"));
+  AssertSyntaxError(Parse("select a from b where embedding <> [3,1,2] < 4"));
+  AssertSyntaxError(Parse("select a from b where embedding <- [3,1,2] < 3"));
+  AssertSyntaxError(Parse("select a from b order by embedding <-> [1,2,3] < 
3"));
+  AssertSyntaxError(Parse("select a from b where embedding <-> [1,2,3] limit 
5"));
+  AssertSyntaxError(Parse("select a from b where [3,1,2] <-> embedding < 5"));
+  AssertSyntaxError(Parse("select a from b where embedding <-> [] < 5"));
+  AssertSyntaxError(Parse("select a from b order by embedding <-> @vec limit 
5", {{"vec", "[3.6,7.8]"}}));
+  AssertSyntaxError(Parse("select a from b where embedding <#> [3,1,2] < 5"));
+  AssertSyntaxError(Parse("select a from b order by embedding <-> [3,1,2] desc 
limit 5"));
+
+  AssertIR(Parse("select a from b where embedding <-> [3,1,2] < 5"),
+           "select a from b where embedding <-> [3.000000, 1.000000, 2.000000] 
< 5");
+  AssertIR(Parse("select a from b where embedding <-> [0.5,0.5] < 10 and c > 
100"),
+           "select a from b where (and embedding <-> [0.500000, 0.500000] < 
10, c > 100)");
+  AssertIR(Parse("select a from b order by embedding <-> [3.6] limit 5"),
+           "select a from b where true sortby embedding <-> [3.600000] limit 
0, 5");
+}

Reply via email to