This is an automated email from the ASF dual-hosted git repository.
twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git
The following commit(s) were added to refs/heads/unstable by this push:
new edcb7067 feat(search): Hnsw Vector Search Optimizaton Pass (#2466)
edcb7067 is described below
commit edcb7067453324f0bff1a4cca2d186730b31898e
Author: Rebecca Zhou <[email protected]>
AuthorDate: Wed Aug 7 06:41:11 2024 -0700
feat(search): Hnsw Vector Search Optimizaton Pass (#2466)
Co-authored-by: Twice <[email protected]>
---
src/search/executors/filter_executor.h | 22 ++++++++++++++
src/search/ir.h | 19 ++++++------
src/search/ir_pass.h | 29 ++++++++++++++++++
src/search/ir_plan.h | 2 +-
src/search/ir_sema_checker.h | 13 +++++----
src/search/passes/cost_model.h | 10 +++++++
src/search/passes/index_selection.h | 23 +++++++++++++++
src/search/passes/manager.h | 4 ++-
src/search/passes/sort_limit_to_knn.h | 50 ++++++++++++++++++++++++++++++++
src/search/redis_query_parser.h | 5 ++--
src/search/redis_query_transformer.h | 16 ++++++----
src/search/sql_transformer.h | 1 -
tests/cppunit/ir_pass_test.cc | 47 +++++++++++++++++++++++++++++-
tests/cppunit/ir_sema_checker_test.cc | 7 +++++
tests/cppunit/plan_executor_test.cc | 36 +++++++++++++++++++++++
tests/cppunit/redis_query_parser_test.cc | 14 ++++++---
16 files changed, 267 insertions(+), 31 deletions(-)
diff --git a/src/search/executors/filter_executor.h
b/src/search/executors/filter_executor.h
index df14b29b..1b1febe8 100644
--- a/src/search/executors/filter_executor.h
+++ b/src/search/executors/filter_executor.h
@@ -23,6 +23,7 @@
#include <variant>
#include "parse_util.h"
+#include "search/hnsw_indexer.h"
#include "search/ir.h"
#include "search/plan_executor.h"
#include "search/search_encoding.h"
@@ -44,6 +45,9 @@ struct QueryExprEvaluator {
if (auto v = dynamic_cast<NotExpr *>(e)) {
return Visit(v);
}
+ if (auto v = dynamic_cast<VectorRangeExpr *>(e)) {
+ return Visit(v);
+ }
if (auto v = dynamic_cast<NumericCompareExpr *>(e)) {
return Visit(v);
}
@@ -112,6 +116,24 @@ struct QueryExprEvaluator {
__builtin_unreachable();
}
}
+
+ StatusOr<bool> Visit(VectorRangeExpr *v) const {
+ auto val = GET_OR_RET(ctx->Retrieve(row, v->field->info));
+
+ CHECK(val.Is<kqir::NumericArray>());
+ auto l_values = val.Get<kqir::NumericArray>();
+ auto r_values = v->vector->values;
+ auto meta = v->field->info->MetadataAs<redis::HnswVectorFieldMetadata>();
+
+ redis::VectorItem left, right;
+ GET_OR_RET(redis::VectorItem::Create({}, l_values, meta, &left));
+ GET_OR_RET(redis::VectorItem::Create({}, r_values, meta, &right));
+
+ auto dist = GET_OR_RET(redis::ComputeSimilarity(left, right));
+ auto effective_range = v->range->val * (1 + meta->epsilon);
+
+ return (dist >= -abs(effective_range) && dist <= abs(effective_range));
+ }
};
struct FilterExecutor : ExecutorNode {
diff --git a/src/search/ir.h b/src/search/ir.h
index 3ba980da..c7aec26b 100644
--- a/src/search/ir.h
+++ b/src/search/ir.h
@@ -265,24 +265,19 @@ struct VectorRangeExpr : BoolAtomExpr {
};
struct VectorKnnExpr : BoolAtomExpr {
- // TODO: Support pre-filter for hybrid query
std::unique_ptr<FieldRef> field;
- std::unique_ptr<NumericLiteral> k;
std::unique_ptr<VectorLiteral> vector;
+ size_t k;
- VectorKnnExpr(std::unique_ptr<FieldRef> &&field,
std::unique_ptr<NumericLiteral> &&k,
- std::unique_ptr<VectorLiteral> &&vector)
- : field(std::move(field)), k(std::move(k)), vector(std::move(vector)) {}
+ VectorKnnExpr(std::unique_ptr<FieldRef> &&field,
std::unique_ptr<VectorLiteral> &&vector, size_t k)
+ : field(std::move(field)), vector(std::move(vector)), k(k) {}
std::string_view Name() const override { return "VectorKnnExpr"; }
- std::string Dump() const override {
- return fmt::format("KNN k={}, {} <-> {}", k->Dump(), field->Dump(),
vector->Dump());
- }
+ std::string Dump() const override { return fmt::format("KNN k={}, {} <->
{}", k, field->Dump(), vector->Dump()); }
std::unique_ptr<Node> Clone() const override {
return
std::make_unique<VectorKnnExpr>(Node::MustAs<FieldRef>(field->Clone()),
-
Node::MustAs<NumericLiteral>(k->Clone()),
-
Node::MustAs<VectorLiteral>(vector->Clone()));
+
Node::MustAs<VectorLiteral>(vector->Clone()), k);
}
};
@@ -425,6 +420,10 @@ struct SortByClause : Node {
std::unique_ptr<Node> Clone() const override {
return std::make_unique<SortByClause>(order,
Node::MustAs<FieldRef>(field->Clone()));
}
+
+ std::unique_ptr<FieldRef> TakeFieldRef() { return std::move(field); }
+
+ std::unique_ptr<VectorLiteral> TakeVectorLiteral() { return
std::move(vector); }
};
struct SelectClause : Node {
diff --git a/src/search/ir_pass.h b/src/search/ir_pass.h
index 2068a45a..e783ca8f 100644
--- a/src/search/ir_pass.h
+++ b/src/search/ir_pass.h
@@ -59,6 +59,12 @@ struct Visitor : Pass {
return Visit(std::move(v));
} else if (auto v = Node::As<TagContainExpr>(std::move(node))) {
return Visit(std::move(v));
+ } else if (auto v = Node::As<VectorLiteral>(std::move(node))) {
+ return Visit(std::move(v));
+ } else if (auto v = Node::As<VectorKnnExpr>(std::move(node))) {
+ return Visit(std::move(v));
+ } else if (auto v = Node::As<VectorRangeExpr>(std::move(node))) {
+ return Visit(std::move(v));
} else if (auto v = Node::As<StringLiteral>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<BoolLiteral>(std::move(node))) {
@@ -69,6 +75,10 @@ struct Visitor : Pass {
return Visit(std::move(v));
} else if (auto v = Node::As<TagFieldScan>(std::move(node))) {
return Visit(std::move(v));
+ } else if (auto v = Node::As<HnswVectorFieldRangeScan>(std::move(node))) {
+ return Visit(std::move(v));
+ } else if (auto v = Node::As<HnswVectorFieldKnnScan>(std::move(node))) {
+ return Visit(std::move(v));
} else if (auto v = Node::As<Filter>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<Limit>(std::move(node))) {
@@ -125,6 +135,8 @@ struct Visitor : Pass {
virtual std::unique_ptr<Node> Visit(std::unique_ptr<NumericLiteral> node) {
return node; }
+ virtual std::unique_ptr<Node> Visit(std::unique_ptr<VectorLiteral> node) {
return node; }
+
virtual std::unique_ptr<Node> Visit(std::unique_ptr<NumericCompareExpr>
node) {
node->field = VisitAs<FieldRef>(std::move(node->field));
node->num = VisitAs<NumericLiteral>(std::move(node->num));
@@ -137,6 +149,19 @@ struct Visitor : Pass {
return node;
}
+ virtual std::unique_ptr<Node> Visit(std::unique_ptr<VectorKnnExpr> node) {
+ node->field = VisitAs<FieldRef>(std::move(node->field));
+ node->vector = VisitAs<VectorLiteral>(std::move(node->vector));
+ return node;
+ }
+
+ virtual std::unique_ptr<Node> Visit(std::unique_ptr<VectorRangeExpr> node) {
+ node->field = VisitAs<FieldRef>(std::move(node->field));
+ node->range = VisitAs<NumericLiteral>(std::move(node->range));
+ node->vector = VisitAs<VectorLiteral>(std::move(node->vector));
+ return node;
+ }
+
virtual std::unique_ptr<Node> Visit(std::unique_ptr<AndExpr> node) {
for (auto &n : node->inners) {
n = TransformAs<QueryExpr>(std::move(n));
@@ -173,6 +198,10 @@ struct Visitor : Pass {
virtual std::unique_ptr<Node> Visit(std::unique_ptr<TagFieldScan> node) {
return node; }
+ virtual std::unique_ptr<Node>
Visit(std::unique_ptr<HnswVectorFieldRangeScan> node) { return node; }
+
+ virtual std::unique_ptr<Node> Visit(std::unique_ptr<HnswVectorFieldKnnScan>
node) { return node; }
+
virtual std::unique_ptr<Node> Visit(std::unique_ptr<Filter> node) {
node->source = TransformAs<PlanOperator>(std::move(node->source));
node->filter_expr = TransformAs<QueryExpr>(std::move(node->filter_expr));
diff --git a/src/search/ir_plan.h b/src/search/ir_plan.h
index 94e8b589..8743a827 100644
--- a/src/search/ir_plan.h
+++ b/src/search/ir_plan.h
@@ -99,7 +99,7 @@ struct TagFieldScan : FieldScan {
struct HnswVectorFieldKnnScan : FieldScan {
kqir::NumericArray vector;
- uint16_t k;
+ uint32_t k;
HnswVectorFieldKnnScan(std::unique_ptr<FieldRef> field, kqir::NumericArray
vector, uint16_t k)
: FieldScan(std::move(field)), vector(std::move(vector)), k(k) {}
diff --git a/src/search/ir_sema_checker.h b/src/search/ir_sema_checker.h
index 43d722b4..8d18cd84 100644
--- a/src/search/ir_sema_checker.h
+++ b/src/search/ir_sema_checker.h
@@ -50,8 +50,14 @@ struct SemaChecker {
GET_OR_RET(Check(v->query_expr.get()));
if (v->limit) GET_OR_RET(Check(v->limit.get()));
if (v->sort_by) GET_OR_RET(Check(v->sort_by.get()));
- if (v->sort_by && v->sort_by->IsVectorField() && !v->limit) {
- return {Status::NotOK, "expect a LIMIT clause for vector field to
construct a KNN search"};
+ if (v->sort_by && v->sort_by->IsVectorField()) {
+ if (!v->limit) {
+ return {Status::NotOK, "expect a LIMIT clause for vector field to
construct a KNN search"};
+ }
+ // TODO: allow hybrid query
+ if (auto b = dynamic_cast<BoolLiteral *>(v->query_expr.get()); b ==
nullptr) {
+ return {Status::NotOK, "KNN search cannot be combined with other
query expressions"};
+ }
}
} else {
return {Status::NotOK, fmt::format("index `{}` not found",
index_name)};
@@ -129,9 +135,6 @@ struct SemaChecker {
return {Status::NotOK,
fmt::format("field `{}` is marked as NOINDEX and cannot be
used for KNN search", v->field->name)};
}
- if (v->k->val <= 0) {
- return {Status::NotOK, fmt::format("KNN search parameter `k` must be
greater than 0")};
- }
auto meta =
v->field->info->MetadataAs<redis::HnswVectorFieldMetadata>();
if (v->vector->values.size() != meta->dim) {
return {Status::NotOK,
diff --git a/src/search/passes/cost_model.h b/src/search/passes/cost_model.h
index 86e0e3a5..960708d7 100644
--- a/src/search/passes/cost_model.h
+++ b/src/search/passes/cost_model.h
@@ -36,6 +36,12 @@ struct CostModel {
if (auto v = dynamic_cast<const FullIndexScan *>(node)) {
return Visit(v);
}
+ if (auto v = dynamic_cast<const HnswVectorFieldKnnScan *>(node)) {
+ return Visit(v);
+ }
+ if (auto v = dynamic_cast<const HnswVectorFieldRangeScan *>(node)) {
+ return Visit(v);
+ }
if (auto v = dynamic_cast<const NumericFieldScan *>(node)) {
return Visit(v);
}
@@ -74,6 +80,10 @@ struct CostModel {
static size_t Visit(const TagFieldScan *node) { return 10; }
+ static size_t Visit(const HnswVectorFieldKnnScan *node) { return 3; }
+
+ static size_t Visit(const HnswVectorFieldRangeScan *node) { return 4; }
+
static size_t Visit(const Filter *node) { return
Transform(node->source.get()) + 1; }
static size_t Visit(const Merge *node) {
diff --git a/src/search/passes/index_selection.h
b/src/search/passes/index_selection.h
index e60287d4..09e1bcb3 100644
--- a/src/search/passes/index_selection.h
+++ b/src/search/passes/index_selection.h
@@ -112,6 +112,12 @@ struct IndexSelection : Visitor {
if (auto v = dynamic_cast<OrExpr *>(node)) {
return VisitExpr(v);
}
+ if (auto v = dynamic_cast<VectorKnnExpr *>(node)) {
+ return VisitExpr(v);
+ }
+ if (auto v = dynamic_cast<VectorRangeExpr *>(node)) {
+ return VisitExpr(v);
+ }
if (auto v = dynamic_cast<NumericCompareExpr *>(node)) {
return VisitExpr(v);
}
@@ -153,6 +159,23 @@ struct IndexSelection : Visitor {
return MakeFullIndexFilter(node);
}
+ std::unique_ptr<PlanOperator> VisitExpr(VectorRangeExpr *node) const {
+ if (node->field->info->HasIndex()) {
+ return
std::make_unique<HnswVectorFieldRangeScan>(node->field->CloneAs<FieldRef>(),
node->vector->values,
+ node->range->val);
+ }
+
+ return MakeFullIndexFilter(node);
+ }
+
+ std::unique_ptr<PlanOperator> VisitExpr(VectorKnnExpr *node) const {
+ if (node->field->info->HasIndex()) {
+ return
std::make_unique<HnswVectorFieldKnnScan>(node->field->CloneAs<FieldRef>(),
node->vector->values, node->k);
+ }
+
+ return MakeFullIndexFilter(node);
+ }
+
template <typename Expr>
std::unique_ptr<PlanOperator> VisitExprImpl(Expr *node) {
struct AggregatedNodes {
diff --git a/src/search/passes/manager.h b/src/search/passes/manager.h
index 57f317d2..ce2d6a3d 100644
--- a/src/search/passes/manager.h
+++ b/src/search/passes/manager.h
@@ -35,6 +35,7 @@
#include "search/passes/simplify_and_or_expr.h"
#include "search/passes/simplify_boolean.h"
#include "search/passes/sort_limit_fuse.h"
+#include "search/passes/sort_limit_to_knn.h"
#include "type_util.h"
namespace kqir {
@@ -86,7 +87,8 @@ struct PassManager {
}
static PassSequence ExprPasses() {
- return Create(SimplifyAndOrExpr{}, PushDownNotExpr{}, SimplifyBoolean{},
SimplifyAndOrExpr{});
+ return Create(SimplifyAndOrExpr{}, PushDownNotExpr{}, SimplifyBoolean{},
SimplifyAndOrExpr{},
+ SortByWithLimitToKnnExpr{}, SimplifyAndOrExpr{});
}
static PassSequence NumericPasses() { return Create(IntervalAnalysis{true},
SimplifyAndOrExpr{}, SimplifyBoolean{}); }
static PassSequence PlanPasses() { return Create(LowerToPlan{},
IndexSelection{}, SortLimitFuse{}); }
diff --git a/src/search/passes/sort_limit_to_knn.h
b/src/search/passes/sort_limit_to_knn.h
new file mode 100644
index 00000000..e0f7b958
--- /dev/null
+++ b/src/search/passes/sort_limit_to_knn.h
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+#pragma once
+
+#include <memory>
+
+#include "search/ir.h"
+#include "search/ir_pass.h"
+#include "search/ir_plan.h"
+
+namespace kqir {
+
+struct SortByWithLimitToKnnExpr : Visitor {
+ std::unique_ptr<Node> Visit(std::unique_ptr<SearchExpr> node) override {
+ node = Node::MustAs<SearchExpr>(Visitor::Visit(std::move(node)));
+
+ // TODO: allow hybrid query
+ if (node->sort_by && node->sort_by->IsVectorField() && node->limit) {
+ if (auto b = dynamic_cast<BoolLiteral*>(node->query_expr.get()); b &&
b->val) {
+ node->query_expr =
+
std::make_unique<VectorKnnExpr>(Node::MustAs<FieldRef>(node->sort_by->TakeFieldRef()),
+
Node::MustAs<VectorLiteral>(node->sort_by->TakeVectorLiteral()),
+ node->limit->Offset() +
node->limit->Count());
+ node->sort_by.reset();
+ }
+ }
+
+ return node;
+ }
+};
+
+} // namespace kqir
diff --git a/src/search/redis_query_parser.h b/src/search/redis_query_parser.h
index 5b0f172c..627910a3 100644
--- a/src/search/redis_query_parser.h
+++ b/src/search/redis_query_parser.h
@@ -43,13 +43,14 @@ struct Tag : sor<Identifier, StringL, Param> {};
struct TagList : seq<one<'{'>, WSPad<Tag>, star<seq<one<'|'>, WSPad<Tag>>>,
one<'}'>> {};
struct NumberOrParam : sor<Number, Param> {};
+struct UintOrParam : sor<UnsignedInteger, Param> {};
struct Inf : seq<opt<one<'+', '-'>>, string<'i', 'n', 'f'>> {};
struct ExclusiveNumber : seq<one<'('>, NumberOrParam> {};
struct NumericRangePart : sor<Inf, ExclusiveNumber, NumberOrParam> {};
struct NumericRange : seq<one<'['>, WSPad<NumericRangePart>,
WSPad<NumericRangePart>, one<']'>> {};
-struct KnnSearch : seq<one<'['>, WSPad<KnnToken>, WSPad<NumberOrParam>,
WSPad<Field>, WSPad<Param>, one<']'>> {};
+struct KnnSearch : seq<one<'['>, WSPad<KnnToken>, WSPad<UintOrParam>,
WSPad<Field>, WSPad<Param>, one<']'>> {};
struct VectorRange : seq<one<'['>, WSPad<VectorRangeToken>,
WSPad<NumberOrParam>, WSPad<Param>, one<']'>> {};
struct FieldQuery : seq<WSPad<Field>, one<':'>, WSPad<sor<VectorRange,
TagList, NumericRange>>> {};
@@ -70,7 +71,7 @@ struct AndExprP : sor<AndExpr, BooleanExpr> {};
struct OrExpr : seq<AndExprP, plus<seq<one<'|'>, AndExprP>>> {};
struct OrExprP : sor<OrExpr, AndExprP> {};
-struct PrefilterExpr : seq<WSPad<BooleanExpr>, ArrowOp, WSPad<KnnSearch>> {};
+struct PrefilterExpr : seq<WSPad<Wildcard>, ArrowOp, WSPad<KnnSearch>> {};
struct QueryP : sor<PrefilterExpr, OrExprP> {};
diff --git a/src/search/redis_query_transformer.h
b/src/search/redis_query_transformer.h
index c81230e4..ed7c8fc6 100644
--- a/src/search/redis_query_transformer.h
+++ b/src/search/redis_query_transformer.h
@@ -36,7 +36,7 @@ namespace ir = kqir;
template <typename Rule>
using TreeSelector = parse_tree::selector<
- Rule, parse_tree::store_content::on<Number, StringL, Param, Identifier,
Inf>,
+ Rule, parse_tree::store_content::on<Number, UnsignedInteger, StringL,
Param, Identifier, Inf>,
parse_tree::remove_content::on<TagList, NumericRange, VectorRange,
ExclusiveNumber, FieldQuery, NotExpr, AndExpr,
OrExpr, PrefilterExpr, KnnSearch, Wildcard,
VectorRangeToken, KnnToken, ArrowOp>>;
@@ -161,17 +161,21 @@ struct Transformer : ir::TreeTransformer {
return
Node::Create<ir::NotExpr>(Node::MustAs<ir::QueryExpr>(GET_OR_RET(Transform(node->children[0]))));
} else if (Is<PrefilterExpr>(node)) {
+ // TODO: allow hybrid query
CHECK(node->children.size() == 3);
- // TODO(Beihao): Support Hybrid Query
- // const auto& prefilter = node->children[0];
const auto& knn_search = node->children[2];
CHECK(knn_search->children.size() == 4);
- return
std::make_unique<VectorKnnExpr>(std::make_unique<FieldRef>(knn_search->children[2]->string()),
-
GET_OR_RET(number_or_param(knn_search->children[1])),
-
GET_OR_RET(Transform2Vector(knn_search->children[3])));
+ size_t k = 0;
+ if (Is<UnsignedInteger>(knn_search->children[1])) {
+ k = *ParseInt(knn_search->children[1]->string());
+ } else {
+ k = *ParseInt(GET_OR_RET(GetParam(node)));
+ }
+ return
std::make_unique<VectorKnnExpr>(std::make_unique<FieldRef>(knn_search->children[2]->string()),
+
GET_OR_RET(Transform2Vector(knn_search->children[3])), k);
} else if (Is<AndExpr>(node)) {
std::vector<std::unique_ptr<ir::QueryExpr>> exprs;
diff --git a/src/search/sql_transformer.h b/src/search/sql_transformer.h
index 01705107..49d04307 100644
--- a/src/search/sql_transformer.h
+++ b/src/search/sql_transformer.h
@@ -118,7 +118,6 @@ struct Transformer : ir::TreeTransformer {
return {Status::NotOK, "the left and right side of numeric comparison
should be an identifier and a number"};
}
} else if (Is<VectorRangeExpr>(node)) {
- // TODO(Beihao): Handle distance metrics for operator
CHECK(node->children.size() == 2);
const auto& vector_comp_expr = node->children[0];
CHECK(vector_comp_expr->children.size() == 3);
diff --git a/tests/cppunit/ir_pass_test.cc b/tests/cppunit/ir_pass_test.cc
index 81ed49e8..9d576678 100644
--- a/tests/cppunit/ir_pass_test.cc
+++ b/tests/cppunit/ir_pass_test.cc
@@ -111,6 +111,19 @@ TEST(IRPassTest, Manager) {
"select * from a where (and x <= 1, y >= 2, z != 3)");
}
+TEST(IRPassTest, SortByWithLimitToKnnExpr) {
+ SortByWithLimitToKnnExpr tsbtke;
+
+ ASSERT_EQ(tsbtke.Transform(*Parse("select a from b order by embedding <->
[3.6] limit 5"))->Dump(),
+ "select a from b where KNN k=5, embedding <-> [3.600000] limit 0,
5");
+ ASSERT_EQ(tsbtke.Transform(*Parse("select a from b where false order by
embedding <-> [3,1,2] limit 5"))->Dump(),
+ "select a from b where false sortby embedding <-> [3.000000,
1.000000, 2.000000] limit 0, 5");
+ ASSERT_EQ(tsbtke.Transform(*Parse("select a from b where true order by
embedding <-> [3,1,2] limit 5"))->Dump(),
+ "select a from b where KNN k=5, embedding <-> [3.000000, 1.000000,
2.000000] limit 0, 5");
+ ASSERT_EQ(tsbtke.Transform(*Parse("select a from b where true order by
embedding <-> [3,1,2] limit 3, 5"))->Dump(),
+ "select a from b where KNN k=8, embedding <-> [3.000000, 1.000000,
2.000000] limit 3, 5");
+}
+
TEST(IRPassTest, LowerToPlan) {
LowerToPlan ltp;
@@ -118,11 +131,15 @@ TEST(IRPassTest, LowerToPlan) {
ASSERT_EQ(ltp.Transform(*Parse("select * from a limit 1"))->Dump(), "project
*: (limit 0, 1: full-scan a)");
ASSERT_EQ(ltp.Transform(*Parse("select * from a where false"))->Dump(),
"project *: noop");
ASSERT_EQ(ltp.Transform(*Parse("select * from a where false limit
1"))->Dump(), "project *: noop");
+ ASSERT_EQ(ltp.Transform(*Parse("select * from a where false order by
embedding <-> [3,1,2] limit 5"))->Dump(),
+ "project *: noop");
ASSERT_EQ(ltp.Transform(*Parse("select * from a where b > 1"))->Dump(),
"project *: (filter b > 1: full-scan a)");
ASSERT_EQ(ltp.Transform(*Parse("select a from b where c = 1 order by
d"))->Dump(),
"project a: (sort d, asc: (filter c = 1: full-scan b))");
ASSERT_EQ(ltp.Transform(*Parse("select a from b where c = 1 limit
1"))->Dump(),
"project a: (limit 0, 1: (filter c = 1: full-scan b))");
+ ASSERT_EQ(ltp.Transform(*Parse("select a from b where c = 1 and d = 2 order
by e limit 1"))->Dump(),
+ "project a: (limit 0, 1: (sort e, asc: (filter (and c = 1, d = 2):
full-scan b)))");
ASSERT_EQ(ltp.Transform(*Parse("select a from b where c = 1 order by d limit
1"))->Dump(),
"project a: (limit 0, 1: (sort d, asc: (filter c = 1: full-scan
b)))");
}
@@ -176,12 +193,28 @@ static IndexMap MakeIndexMap() {
auto f4 = FieldInfo("n2", std::make_unique<redis::NumericFieldMetadata>());
auto f5 = FieldInfo("n3", std::make_unique<redis::NumericFieldMetadata>());
f5.metadata->noindex = true;
+
+ auto hnsw_field_meta = std::make_unique<redis::HnswVectorFieldMetadata>();
+ hnsw_field_meta->vector_type = redis::VectorType::FLOAT64;
+ hnsw_field_meta->dim = 3;
+ hnsw_field_meta->distance_metric = redis::DistanceMetric::L2;
+ auto f6 = FieldInfo("v1", std::move(hnsw_field_meta));
+
+ hnsw_field_meta = std::make_unique<redis::HnswVectorFieldMetadata>();
+ hnsw_field_meta->vector_type = redis::VectorType::FLOAT64;
+ hnsw_field_meta->dim = 3;
+ hnsw_field_meta->distance_metric = redis::DistanceMetric::L2;
+ auto f7 = FieldInfo("v2", std::move(hnsw_field_meta));
+ f7.metadata->noindex = true;
+
auto ia = std::make_unique<IndexInfo>("ia", redis::IndexMetadata(), "");
ia->Add(std::move(f1));
ia->Add(std::move(f2));
ia->Add(std::move(f3));
ia->Add(std::move(f4));
ia->Add(std::move(f5));
+ ia->Add(std::move(f6));
+ ia->Add(std::move(f7));
IndexMap res;
res.Insert(std::move(ia));
@@ -238,7 +271,19 @@ TEST(IRPassTest, IndexSelection) {
"project *: (filter t2 hastag \"a\": tag-scan t1, a)");
ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where t2
hastag \"a\""))->Dump(),
"project *: (filter t2 hastag \"a\": full-scan ia)");
-
+ ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where v1
<-> [3,1,2] < 5"))->Dump(),
+ "project *: hnsw-vector-range-scan v1, [3.000000, 1.000000,
2.000000], 5");
+ ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia order by
v1 <-> [3,1,2] limit 5"))->Dump(),
+ "project *: (limit 0, 5: hnsw-vector-knn-scan v1, [3.000000,
1.000000, 2.000000], 5)");
+ ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia order by
v1 <-> [3,1,2] limit 2, 7"))->Dump(),
+ "project *: (limit 2, 7: hnsw-vector-knn-scan v1, [3.000000,
1.000000, 2.000000], 9)");
+ ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where v2
<-> [3,1,2] < 5"))->Dump(),
+ "project *: (filter v2 <-> [3.000000, 1.000000, 2.000000] < 5:
full-scan ia)");
+ ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where n1
>= 1 and v1 <-> [3,1,2] < 5"))->Dump(),
+ "project *: (filter n1 >= 1: hnsw-vector-range-scan v1, [3.000000,
1.000000, 2.000000], 5)");
+ ASSERT_EQ(
+ PassManager::Execute(passes, ParseS(sc, "select * from ia where v1 <->
[3,1,2] < 5 and t1 hastag \"a\""))->Dump(),
+ "project *: (filter t1 hastag \"a\": hnsw-vector-range-scan v1,
[3.000000, 1.000000, 2.000000], 5)");
ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where n1
>= 2 or n1 < 1"))->Dump(),
"project *: (merge numeric-scan n1, [-inf, 1), asc, numeric-scan
n1, [2, inf), asc)");
ASSERT_EQ(PassManager::Execute(passes, ParseS(sc, "select * from ia where n1
>= 1 or n2 >= 2"))->Dump(),
diff --git a/tests/cppunit/ir_sema_checker_test.cc
b/tests/cppunit/ir_sema_checker_test.cc
index df8076ce..8926d026 100644
--- a/tests/cppunit/ir_sema_checker_test.cc
+++ b/tests/cppunit/ir_sema_checker_test.cc
@@ -100,6 +100,13 @@ TEST(SemaCheckerTest, Simple) {
"expect a LIMIT clause for vector field to construct a KNN
search");
ASSERT_EQ(checker.Check(Parse("select f5 from ia order by f5 <->
[3.6,4.7,5.6] limit 5")->get()).Msg(),
"field `f5` is marked as NOINDEX and cannot be used for KNN
search");
+ ASSERT_EQ(checker.Check(Parse("select f5 from ia where f2 = 1 order by f4
<-> [3.6,4.7,5.6] limit 5")->get()).Msg(),
+ "KNN search cannot be combined with other query expressions");
+ ASSERT_EQ(checker.Check(Parse("select f5 from ia where true order by f4
<-> [3.6,4.7,5.6] limit 5")->get()).Msg(),
+ "ok");
+ ASSERT_EQ(checker.Check(Parse("select f5 from ia where false order by f4
<-> [3.6,4.7,5.6] limit 5")->get()).Msg(),
+ "ok");
+ ASSERT_EQ(checker.Check(Parse("select f5 from ia where f2 = 1 and f5 <->
[3.6,4.7,5.6] < 1")->get()).Msg(), "ok");
ASSERT_EQ(checker.Check(Parse("select f5 from ia where f5 <->
[3.6,4.7,5.6] < 5")->get()).Msg(),
"range has to be between 0 and 2 for cosine distance metric");
ASSERT_EQ(checker.Check(Parse("select f5 from ia where f5 <->
[3.6,4.7,5.6] < 0.5")->get()).Msg(), "ok");
diff --git a/tests/cppunit/plan_executor_test.cc
b/tests/cppunit/plan_executor_test.cc
index 1b80329d..91c8cd2d 100644
--- a/tests/cppunit/plan_executor_test.cc
+++ b/tests/cppunit/plan_executor_test.cc
@@ -93,6 +93,7 @@ static auto FieldI(const std::string& f) -> const FieldInfo*
{ return &IndexI()-
static auto N(double n) { return MakeValue<Numeric>(n); }
static auto T(const std::string& v) { return
MakeValue<StringArray>(util::Split(v, ",")); }
+static auto V(const std::vector<double>& vals) { return
MakeValue<NumericArray>(vals); }
TEST(PlanExecutorTest, TopNSort) {
std::vector<ExecutorNode::RowType> data{
@@ -201,6 +202,41 @@ TEST(PlanExecutorTest, Filter) {
ASSERT_EQ(NextRow(ctx).key, "f");
ASSERT_EQ(ctx.Next().GetValue(), exe_end);
}
+
+ data = {{"a", {{FieldI("f4"), V({1, 2, 3})}}, IndexI()}, {"b",
{{FieldI("f4"), V({9, 10, 11})}}, IndexI()},
+ {"c", {{FieldI("f4"), V({4, 5, 6})}}, IndexI()}, {"d",
{{FieldI("f4"), V({1, 2, 3})}}, IndexI()},
+ {"e", {{FieldI("f4"), V({2, 3, 4})}}, IndexI()}, {"f",
{{FieldI("f4"), V({12, 13, 14})}}, IndexI()},
+ {"g", {{FieldI("f4"), V({1, 2, 3})}}, IndexI()}};
+ {
+ auto field = std::make_unique<FieldRef>("f4", FieldI("f4"));
+ std::vector<double> vector = {11, 12, 13};
+ auto op = std::make_unique<Filter>(
+ std::make_unique<Mock>(data),
+ std::make_unique<VectorRangeExpr>(field->CloneAs<FieldRef>(),
std::make_unique<NumericLiteral>(4),
+
std::make_unique<VectorLiteral>(std::move(vector))));
+
+ auto ctx = ExecutorContext(op.get());
+ ASSERT_EQ(NextRow(ctx).key, "b");
+ ASSERT_EQ(NextRow(ctx).key, "f");
+ ASSERT_EQ(ctx.Next().GetValue(), exe_end);
+ }
+
+ {
+ auto field = std::make_unique<FieldRef>("f4", FieldI("f4"));
+ std::vector<double> vector = {2, 3, 4};
+ auto op = std::make_unique<Filter>(
+ std::make_unique<Mock>(data),
+ std::make_unique<VectorRangeExpr>(field->CloneAs<FieldRef>(),
std::make_unique<NumericLiteral>(5),
+
std::make_unique<VectorLiteral>(std::move(vector))));
+
+ auto ctx = ExecutorContext(op.get());
+ ASSERT_EQ(NextRow(ctx).key, "a");
+ ASSERT_EQ(NextRow(ctx).key, "c");
+ ASSERT_EQ(NextRow(ctx).key, "d");
+ ASSERT_EQ(NextRow(ctx).key, "e");
+ ASSERT_EQ(NextRow(ctx).key, "g");
+ ASSERT_EQ(ctx.Next().GetValue(), exe_end);
+ }
}
TEST(PlanExecutorTest, Limit) {
diff --git a/tests/cppunit/redis_query_parser_test.cc
b/tests/cppunit/redis_query_parser_test.cc
index 4fc25e49..96f6e26a 100644
--- a/tests/cppunit/redis_query_parser_test.cc
+++ b/tests/cppunit/redis_query_parser_test.cc
@@ -115,16 +115,22 @@ TEST(RedisQueryParserTest, Vector) {
AssertSyntaxError(Parse("KNN 5 @vector $BLOB", {{"BLOB", vec_str}}));
AssertSyntaxError(Parse("[KNN 5 @vector $BLOB]", {{"BLOB", vec_str}}));
AssertSyntaxError(Parse("KNN 5 @vector $BLOB", {{"BLOB", vec_str}}));
+ AssertSyntaxError(Parse("* =>[KNN -1 @vector $BLOB]", {{"BLOB", vec_str}}));
AssertSyntaxError(Parse("*=>[KNN 5 $vector_blob_param]",
{{"vector_blob_param", vec_str}}));
+ AssertSyntaxError(Parse("(*) => [KNN 10 @doc_embedding $BLOB]", {{"BLOB",
vec_str}}));
+ AssertSyntaxError(Parse("(@a:[1 2]) => [KNN 8 @vec_embedding $blob]",
{{"blob", vec_str}}));
+ AssertSyntaxError(Parse("(@a:{x|y}) => [KNN 8 @vec_embedding $blob]",
{{"blob", vec_str}}));
+ AssertSyntaxError(Parse("(@a:{x|y}) => [KNN 8 @vec_embedding $blob]",
{{"blob", vec_str}}));
+ AssertSyntaxError(Parse("(@a:{x}|@b:[1 inf] | @c:{y}) => [KNN 8
@vec_embedding $blob]", {{"blob", vec_str}}));
+ AssertSyntaxError(Parse("(@a:{x}|@b:[1 inf] | @field:[VECTOR_RANGE 10
$vector]) => [KNN 8 @vec_embedding $blob]",
+ {{"blob", vec_str}, {"vector", vec_str}}));
AssertIR(Parse("@field:[VECTOR_RANGE 10 $vector]", {{"vector", vec_str}}),
"field <-> [1.000000, 2.000000, 3.000000] < 10");
+ AssertIR(Parse("@field:[VECTOR_RANGE 10 $vector]| @b:[1 inf]", {{"vector",
vec_str}}),
+ "(or field <-> [1.000000, 2.000000, 3.000000] < 10, b >= 1)");
AssertIR(Parse("*=>[KNN 10 @doc_embedding $BLOB]", {{"BLOB", vec_str}}),
"KNN k=10, doc_embedding <-> [1.000000, 2.000000, 3.000000]");
- AssertIR(Parse("(*) => [KNN 10 @doc_embedding $BLOB]", {{"BLOB", vec_str}}),
- "KNN k=10, doc_embedding <-> [1.000000, 2.000000, 3.000000]");
- AssertIR(Parse("(@a:[1 2]) => [KNN 8 @vec_embedding $blob]", {{"blob",
vec_str}}),
- "KNN k=8, vec_embedding <-> [1.000000, 2.000000, 3.000000]");
AssertIR(Parse("* =>[KNN 5 @vector $BLOB]", {{"BLOB", vec_str}}),
"KNN k=5, vector <-> [1.000000, 2.000000, 3.000000]");