This is an automated email from the ASF dual-hosted git repository.
twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git
The following commit(s) were added to refs/heads/unstable by this push:
new 91509c72 Add plan lowering pass for KQIR (#2247)
91509c72 is described below
commit 91509c724270bbff4742d66fc72d20a0aec0be74
Author: Twice <[email protected]>
AuthorDate: Sun Apr 14 18:59:33 2024 +0900
Add plan lowering pass for KQIR (#2247)
---
src/search/index_info.h | 2 +
src/search/ir.h | 17 ++--
src/search/ir_pass.h | 72 ++++++++++++++++-
src/search/ir_plan.h | 91 ++++++++++++++++++----
src/search/ir_sema_checker.h | 5 +-
.../{index_info.h => passes/lower_to_plan.h} | 42 ++++------
src/search/search_encoding.h | 4 +-
src/search/sql_transformer.h | 2 +-
tests/cppunit/ir_pass_test.cc | 14 ++++
9 files changed, 199 insertions(+), 50 deletions(-)
diff --git a/src/search/index_info.h b/src/search/index_info.h
index e6e95b72..df059918 100644
--- a/src/search/index_info.h
+++ b/src/search/index_info.h
@@ -37,6 +37,8 @@ struct FieldInfo {
FieldInfo(std::string name, std::unique_ptr<redis::SearchFieldMetadata>
&&metadata)
: name(std::move(name)), metadata(std::move(metadata)) {}
+
+ bool IsSortable() const { return
dynamic_cast<redis::SearchSortableFieldMetadata *>(metadata.get()) != nullptr; }
};
struct IndexInfo {
diff --git a/src/search/ir.h b/src/search/ir.h
index 2b85b458..d7da716a 100644
--- a/src/search/ir.h
+++ b/src/search/ir.h
@@ -51,6 +51,11 @@ struct Node {
virtual std::unique_ptr<Node> Clone() const = 0;
+ template <typename T>
+ std::unique_ptr<T> CloneAs() const {
+ return Node::MustAs<T>(Clone());
+ }
+
virtual ~Node() = default;
template <typename T, typename U = Node, typename... Args>
@@ -361,14 +366,14 @@ struct IndexRef : Ref {
std::unique_ptr<Node> Clone() const override { return
std::make_unique<IndexRef>(*this); }
};
-struct SearchStmt : Node {
+struct SearchExpr : Node {
std::unique_ptr<SelectClause> select;
std::unique_ptr<IndexRef> index;
std::unique_ptr<QueryExpr> query_expr;
std::unique_ptr<LimitClause> limit; // optional
std::unique_ptr<SortByClause> sort_by; // optional
- SearchStmt(std::unique_ptr<IndexRef> &&index, std::unique_ptr<QueryExpr>
&&query_expr,
+ SearchExpr(std::unique_ptr<IndexRef> &&index, std::unique_ptr<QueryExpr>
&&query_expr,
std::unique_ptr<LimitClause> &&limit,
std::unique_ptr<SortByClause> &&sort_by,
std::unique_ptr<SelectClause> &&select)
: select(std::move(select)),
@@ -386,16 +391,16 @@ struct SearchStmt : Node {
}
static inline const std::vector<std::function<Node *(Node *)>> ChildMap = {
- NodeIterator::MemFn<&SearchStmt::select>,
NodeIterator::MemFn<&SearchStmt::index>,
- NodeIterator::MemFn<&SearchStmt::query_expr>,
NodeIterator::MemFn<&SearchStmt::limit>,
- NodeIterator::MemFn<&SearchStmt::sort_by>,
+ NodeIterator::MemFn<&SearchExpr::select>,
NodeIterator::MemFn<&SearchExpr::index>,
+ NodeIterator::MemFn<&SearchExpr::query_expr>,
NodeIterator::MemFn<&SearchExpr::limit>,
+ NodeIterator::MemFn<&SearchExpr::sort_by>,
};
NodeIterator ChildBegin() override { return NodeIterator(this,
ChildMap.begin()); };
NodeIterator ChildEnd() override { return NodeIterator(this,
ChildMap.end()); };
std::unique_ptr<Node> Clone() const override {
- return std::make_unique<SearchStmt>(
+ return std::make_unique<SearchExpr>(
Node::MustAs<IndexRef>(index->Clone()),
Node::MustAs<QueryExpr>(query_expr->Clone()),
Node::MustAs<LimitClause>(limit->Clone()),
Node::MustAs<SortByClause>(sort_by->Clone()),
Node::MustAs<SelectClause>(select->Clone()));
diff --git a/src/search/ir_pass.h b/src/search/ir_pass.h
index a252216d..924e8c75 100644
--- a/src/search/ir_pass.h
+++ b/src/search/ir_pass.h
@@ -21,6 +21,7 @@
#pragma once
#include "ir.h"
+#include "search/ir_plan.h"
namespace kqir {
@@ -30,7 +31,7 @@ struct Pass {
struct Visitor : Pass {
std::unique_ptr<Node> Transform(std::unique_ptr<Node> node) override {
- if (auto v = Node::As<SearchStmt>(std::move(node))) {
+ if (auto v = Node::As<SearchExpr>(std::move(node))) {
return Visit(std::move(v));
} else if (auto v = Node::As<SelectClause>(std::move(node))) {
return Visit(std::move(v));
@@ -58,6 +59,26 @@ struct Visitor : Pass {
return Visit(std::move(v));
} else if (auto v = Node::As<BoolLiteral>(std::move(node))) {
return Visit(std::move(v));
+ } else if (auto v = Node::As<FullIndexScan>(std::move(node))) {
+ return Visit(std::move(v));
+ } else if (auto v = Node::As<NumericFieldScan>(std::move(node))) {
+ return Visit(std::move(v));
+ } else if (auto v = Node::As<TagFieldScan>(std::move(node))) {
+ return Visit(std::move(v));
+ } else if (auto v = Node::As<Filter>(std::move(node))) {
+ return Visit(std::move(v));
+ } else if (auto v = Node::As<Limit>(std::move(node))) {
+ return Visit(std::move(v));
+ } else if (auto v = Node::As<Merge>(std::move(node))) {
+ return Visit(std::move(v));
+ } else if (auto v = Node::As<Sort>(std::move(node))) {
+ return Visit(std::move(v));
+ } else if (auto v = Node::As<TopNSort>(std::move(node))) {
+ return Visit(std::move(v));
+ } else if (auto v = Node::As<Projection>(std::move(node))) {
+ return Visit(std::move(v));
+ } else if (auto v = Node::As<Noop>(std::move(node))) {
+ return Visit(std::move(v));
}
__builtin_unreachable();
@@ -73,7 +94,7 @@ struct Visitor : Pass {
return Node::MustAs<T>(Transform(std::move(n)));
}
- virtual std::unique_ptr<Node> Visit(std::unique_ptr<SearchStmt> node) {
+ virtual std::unique_ptr<Node> Visit(std::unique_ptr<SearchExpr> node) {
node->index = VisitAs<IndexRef>(std::move(node->index));
node->select = VisitAs<SelectClause>(std::move(node->select));
node->query_expr = TransformAs<QueryExpr>(std::move(node->query_expr));
@@ -139,6 +160,53 @@ struct Visitor : Pass {
node->field = VisitAs<FieldRef>(std::move(node->field));
return node;
}
+
+ virtual std::unique_ptr<Node> Visit(std::unique_ptr<Noop> node) { return
node; }
+
+ virtual std::unique_ptr<Node> Visit(std::unique_ptr<FullIndexScan> node) {
return node; }
+
+ virtual std::unique_ptr<Node> Visit(std::unique_ptr<NumericFieldScan> node)
{ return node; }
+
+ virtual std::unique_ptr<Node> Visit(std::unique_ptr<TagFieldScan> node) {
return node; }
+
+ virtual std::unique_ptr<Node> Visit(std::unique_ptr<Filter> node) {
+ node->source = TransformAs<PlanOperator>(std::move(node->source));
+ node->filter_expr = TransformAs<QueryExpr>(std::move(node->filter_expr));
+ return node;
+ }
+
+ virtual std::unique_ptr<Node> Visit(std::unique_ptr<Limit> node) {
+ node->op = TransformAs<PlanOperator>(std::move(node->op));
+ node->limit = VisitAs<LimitClause>(std::move(node->limit));
+ return node;
+ }
+
+ virtual std::unique_ptr<Node> Visit(std::unique_ptr<Sort> node) {
+ node->op = TransformAs<PlanOperator>(std::move(node->op));
+ node->order = VisitAs<SortByClause>(std::move(node->order));
+ return node;
+ }
+
+ virtual std::unique_ptr<Node> Visit(std::unique_ptr<TopNSort> node) {
+ node->op = TransformAs<PlanOperator>(std::move(node->op));
+ node->limit = VisitAs<LimitClause>(std::move(node->limit));
+ node->order = VisitAs<SortByClause>(std::move(node->order));
+ return node;
+ }
+
+ virtual std::unique_ptr<Node> Visit(std::unique_ptr<Projection> node) {
+ node->source = TransformAs<PlanOperator>(std::move(node->source));
+ node->select = VisitAs<SelectClause>(std::move(node->select));
+ return node;
+ }
+
+ virtual std::unique_ptr<Node> Visit(std::unique_ptr<Merge> node) {
+ for (auto &n : node->ops) {
+ n = TransformAs<PlanOperator>(std::move(n));
+ }
+
+ return node;
+ }
};
} // namespace kqir
diff --git a/src/search/ir_plan.h b/src/search/ir_plan.h
index 9d40d948..da805846 100644
--- a/src/search/ir_plan.h
+++ b/src/search/ir_plan.h
@@ -31,22 +31,30 @@ namespace kqir {
struct PlanOperator : Node {};
+struct Noop : PlanOperator {
+ std::string_view Name() const override { return "Noop"; };
+ std::string Dump() const override { return "noop"; }
+
+ std::unique_ptr<Node> Clone() const override { return
std::make_unique<Noop>(*this); }
+};
+
struct FullIndexScan : PlanOperator {
- IndexInfo *index;
+ std::unique_ptr<IndexRef> index;
- explicit FullIndexScan(IndexInfo *index) : index(index) {}
+ explicit FullIndexScan(std::unique_ptr<IndexRef> index) :
index(std::move(index)) {}
std::string_view Name() const override { return "FullIndexScan"; };
- std::string Content() const override { return index->name; };
- std::string Dump() const override { return fmt::format("full-scan {}",
Content()); }
+ std::string Dump() const override { return fmt::format("full-scan {}",
index->name); }
- std::unique_ptr<Node> Clone() const override { return
std::make_unique<FullIndexScan>(*this); }
+ std::unique_ptr<Node> Clone() const override {
+ return
std::make_unique<FullIndexScan>(Node::MustAs<IndexRef>(index->Clone()));
+ }
};
struct FieldScan : PlanOperator {
- FieldInfo *field;
+ std::unique_ptr<FieldRef> field;
- explicit FieldScan(FieldInfo *field) : field(field) {}
+ explicit FieldScan(std::unique_ptr<FieldRef> field) :
field(std::move(field)) {}
};
struct Interval {
@@ -60,25 +68,29 @@ struct Interval {
struct NumericFieldScan : FieldScan {
Interval range;
- NumericFieldScan(FieldInfo *field, Interval range) : FieldScan(field),
range(range) {}
+ NumericFieldScan(std::unique_ptr<FieldRef> field, Interval range) :
FieldScan(std::move(field)), range(range) {}
std::string_view Name() const override { return "NumericFieldScan"; };
std::string Content() const override { return fmt::format("{}, {}",
field->name, range.ToString()); };
std::string Dump() const override { return fmt::format("numeric-scan {}",
Content()); }
- std::unique_ptr<Node> Clone() const override { return
std::make_unique<NumericFieldScan>(*this); }
+ std::unique_ptr<Node> Clone() const override {
+ return std::make_unique<NumericFieldScan>(field->CloneAs<FieldRef>(),
range);
+ }
};
struct TagFieldScan : FieldScan {
std::string tag;
- TagFieldScan(FieldInfo *field, std::string tag) : FieldScan(field),
tag(std::move(tag)) {}
+ TagFieldScan(std::unique_ptr<FieldRef> field, std::string tag) :
FieldScan(std::move(field)), tag(std::move(tag)) {}
std::string_view Name() const override { return "TagFieldScan"; };
std::string Content() const override { return fmt::format("{}, {}",
field->name, tag); };
std::string Dump() const override { return fmt::format("tag-scan {}",
Content()); }
- std::unique_ptr<Node> Clone() const override { return
std::make_unique<TagFieldScan>(*this); }
+ std::unique_ptr<Node> Clone() const override {
+ return std::make_unique<TagFieldScan>(field->CloneAs<FieldRef>(), tag);
+ }
};
struct Filter : PlanOperator {
@@ -89,7 +101,7 @@ struct Filter : PlanOperator {
: source(std::move(source)), filter_expr(std::move(filter_expr)) {}
std::string_view Name() const override { return "Filter"; };
- std::string Dump() const override { return fmt::format("(filter {}, {})",
source->Dump(), Content()); }
+ std::string Dump() const override { return fmt::format("(filter {}: {})",
filter_expr->Dump(), source->Dump()); }
NodeIterator ChildBegin() override { return {source.get(),
filter_expr.get()}; }
NodeIterator ChildEnd() override { return {}; }
@@ -143,6 +155,55 @@ struct Limit : PlanOperator {
}
};
+struct Sort : PlanOperator {
+ std::unique_ptr<PlanOperator> op;
+ std::unique_ptr<SortByClause> order;
+
+ Sort(std::unique_ptr<PlanOperator> &&op, std::unique_ptr<SortByClause>
&&order)
+ : op(std::move(op)), order(std::move(order)) {}
+
+ std::string_view Name() const override { return "Sort"; };
+ std::string Dump() const override {
+ return fmt::format("(sort {}, {}: {})", order->field->Dump(),
order->OrderToString(order->order), op->Dump());
+ }
+
+ NodeIterator ChildBegin() override { return NodeIterator{op.get(),
order.get()}; }
+ NodeIterator ChildEnd() override { return {}; }
+
+ std::unique_ptr<Node> Clone() const override {
+ return std::make_unique<Sort>(Node::MustAs<PlanOperator>(op->Clone()),
Node::MustAs<SortByClause>(order->Clone()));
+ }
+};
+
+// operator fusion: Sort + Limit
+struct TopNSort : PlanOperator {
+ std::unique_ptr<PlanOperator> op;
+ std::unique_ptr<SortByClause> order;
+ std::unique_ptr<LimitClause> limit;
+
+ TopNSort(std::unique_ptr<PlanOperator> &&op, std::unique_ptr<SortByClause>
&&order,
+ std::unique_ptr<LimitClause> &&limit)
+ : op(std::move(op)), order(std::move(order)), limit(std::move(limit)) {}
+
+ std::string_view Name() const override { return "TopNSort"; };
+ std::string Dump() const override {
+ return fmt::format("(top-n sort {}, {}, {}, {}: {})",
order->field->Dump(), order->OrderToString(order->order),
+ limit->offset, limit->count, op->Dump());
+ }
+
+ static inline const std::vector<std::function<Node *(Node *)>> ChildMap = {
+ NodeIterator::MemFn<&TopNSort::op>,
NodeIterator::MemFn<&TopNSort::order>, NodeIterator::MemFn<&TopNSort::limit>};
+
+ NodeIterator ChildBegin() override { return NodeIterator(this,
ChildMap.begin()); }
+ NodeIterator ChildEnd() override { return NodeIterator(this,
ChildMap.end()); }
+
+ std::unique_ptr<Node> Clone() const override {
+ return std::make_unique<TopNSort>(Node::MustAs<PlanOperator>(op->Clone()),
+
Node::MustAs<SortByClause>(order->Clone()),
+
Node::MustAs<LimitClause>(limit->Clone()));
+ }
+};
+
struct Projection : PlanOperator {
std::unique_ptr<PlanOperator> source;
std::unique_ptr<SelectClause> select;
@@ -151,7 +212,11 @@ struct Projection : PlanOperator {
: source(std::move(source)), select(std::move(select)) {}
std::string_view Name() const override { return "Projection"; };
- std::string Dump() const override { return fmt::format("(project {}: {})",
select, source); }
+ std::string Dump() const override {
+ auto select_str =
+ select->fields.empty() ? "*" : util::StringJoin(select->fields,
[](const auto &v) { return v->Dump(); });
+ return fmt::format("project {}: {}", select_str, source->Dump());
+ }
NodeIterator ChildBegin() override { return {source.get(), select.get()}; }
NodeIterator ChildEnd() override { return {}; }
diff --git a/src/search/ir_sema_checker.h b/src/search/ir_sema_checker.h
index 471c3a3f..40f316cb 100644
--- a/src/search/ir_sema_checker.h
+++ b/src/search/ir_sema_checker.h
@@ -23,6 +23,7 @@
#include <map>
#include <memory>
+#include "fmt/core.h"
#include "index_info.h"
#include "ir.h"
#include "search_encoding.h"
@@ -38,7 +39,7 @@ struct SemaChecker {
explicit SemaChecker(const IndexMap &index_map) : index_map(index_map) {}
Status Check(Node *node) {
- if (auto v = dynamic_cast<SearchStmt *>(node)) {
+ if (auto v = dynamic_cast<SearchExpr *>(node)) {
auto index_name = v->index->name;
if (auto iter = index_map.find(index_name); iter != index_map.end()) {
current_index = &iter->second;
@@ -56,6 +57,8 @@ struct SemaChecker {
} else if (auto v = dynamic_cast<SortByClause *>(node)) {
if (auto iter = current_index->fields.find(v->field->name); iter ==
current_index->fields.end()) {
return {Status::NotOK, fmt::format("field `{}` not found in index
`{}`", v->field->name, current_index->name)};
+ } else if (!iter->second.IsSortable()) {
+ return {Status::NotOK, fmt::format("field `{}` is not sortable",
v->field->name)};
} else {
v->field->info = &iter->second;
}
diff --git a/src/search/index_info.h b/src/search/passes/lower_to_plan.h
similarity index 51%
copy from src/search/index_info.h
copy to src/search/passes/lower_to_plan.h
index e6e95b72..dad1db39 100644
--- a/src/search/index_info.h
+++ b/src/search/passes/lower_to_plan.h
@@ -20,42 +20,32 @@
#pragma once
-#include <map>
#include <memory>
-#include <string>
-#include "search_encoding.h"
+#include "search/ir.h"
+#include "search/ir_pass.h"
+#include "search/ir_plan.h"
namespace kqir {
-struct IndexInfo;
+struct LowerToPlan : Visitor {
+ std::unique_ptr<Node> Visit(std::unique_ptr<SearchExpr> node) override {
+ auto scan =
std::make_unique<FullIndexScan>(node->index->CloneAs<IndexRef>());
+ auto filter = std::make_unique<Filter>(std::move(scan),
std::move(node->query_expr));
-struct FieldInfo {
- std::string name;
- IndexInfo *index = nullptr;
- std::unique_ptr<redis::SearchFieldMetadata> metadata;
+ std::unique_ptr<PlanOperator> op = std::move(filter);
- FieldInfo(std::string name, std::unique_ptr<redis::SearchFieldMetadata>
&&metadata)
- : name(std::move(name)), metadata(std::move(metadata)) {}
-};
-
-struct IndexInfo {
- using FieldMap = std::map<std::string, FieldInfo>;
+ // order is important here, since limit(sort(op)) is different from
sort(limit(op))
+ if (node->sort_by) {
+ op = std::make_unique<Sort>(std::move(op), std::move(node->sort_by));
+ }
- std::string name;
- SearchMetadata metadata;
- FieldMap fields;
- redis::SearchPrefixesMetadata prefixes;
+ if (node->limit) {
+ op = std::make_unique<Limit>(std::move(op), std::move(node->limit));
+ }
- IndexInfo(std::string name, SearchMetadata metadata) :
name(std::move(name)), metadata(std::move(metadata)) {}
-
- void Add(FieldInfo &&field) {
- const auto &name = field.name;
- field.index = this;
- fields.emplace(name, std::move(field));
+ return std::make_unique<Projection>(std::move(op),
std::move(node->select));
}
};
-using IndexMap = std::map<std::string, IndexInfo>;
-
} // namespace kqir
diff --git a/src/search/search_encoding.h b/src/search/search_encoding.h
index 1637a504..14bf2923 100644
--- a/src/search/search_encoding.h
+++ b/src/search/search_encoding.h
@@ -125,7 +125,9 @@ inline std::string
ConstructNumericFieldMetadataSubkey(std::string_view field_na
return res;
}
-struct SearchNumericFieldMetadata : SearchFieldMetadata {};
+struct SearchSortableFieldMetadata : SearchFieldMetadata {};
+
+struct SearchNumericFieldMetadata : SearchSortableFieldMetadata {};
inline std::string ConstructTagFieldSubkey(std::string_view field_name,
std::string_view tag, std::string_view key) {
std::string res = {(char)SearchSubkeyType::TAG_FIELD};
diff --git a/src/search/sql_transformer.h b/src/search/sql_transformer.h
index 871949fd..8386a66e 100644
--- a/src/search/sql_transformer.h
+++ b/src/search/sql_transformer.h
@@ -168,7 +168,7 @@ struct Transformer : ir::TreeTransformer {
query_expr = std::make_unique<BoolLiteral>(true);
}
- return Node::Create<ir::SearchStmt>(std::move(index),
std::move(query_expr), std::move(limit), std::move(sort_by),
+ return Node::Create<ir::SearchExpr>(std::move(index),
std::move(query_expr), std::move(limit), std::move(sort_by),
std::move(select));
} else if (IsRoot(node)) {
CHECK(node->children.size() == 1);
diff --git a/tests/cppunit/ir_pass_test.cc b/tests/cppunit/ir_pass_test.cc
index 70d39f96..0f0952bf 100644
--- a/tests/cppunit/ir_pass_test.cc
+++ b/tests/cppunit/ir_pass_test.cc
@@ -21,6 +21,7 @@
#include "search/ir_pass.h"
#include "gtest/gtest.h"
+#include "search/passes/lower_to_plan.h"
#include "search/passes/manager.h"
#include "search/passes/push_down_not_expr.h"
#include "search/passes/simplify_and_or_expr.h"
@@ -103,3 +104,16 @@ TEST(IRPassTest, Manager) {
PassManager::Default(*Parse("select * from a where not (x > 1 or (y < 2
or z = 3)) and (true or x = 1)"))->Dump(),
"select * from a where (and x <= 1, y >= 2, z != 3)");
}
+
+TEST(IRPassTest, LowerToPlan) {
+ LowerToPlan ltp;
+
+ ASSERT_EQ(ltp.Transform(*Parse("select * from a"))->Dump(), "project *:
(filter true: full-scan a)");
+ ASSERT_EQ(ltp.Transform(*Parse("select * from a where b > 1"))->Dump(),
"project *: (filter b > 1: full-scan a)");
+ ASSERT_EQ(ltp.Transform(*Parse("select a from b where c = 1 order by
d"))->Dump(),
+ "project a: (sort d, asc: (filter c = 1: full-scan b))");
+ ASSERT_EQ(ltp.Transform(*Parse("select a from b where c = 1 limit
1"))->Dump(),
+ "project a: (limit 0, 1: (filter c = 1: full-scan b))");
+ ASSERT_EQ(ltp.Transform(*Parse("select a from b where c = 1 order by d limit
1"))->Dump(),
+ "project a: (limit 0, 1: (sort d, asc: (filter c = 1: full-scan
b)))");
+}