This is an automated email from the ASF dual-hosted git repository.

twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new 91509c72 Add plan lowering pass for KQIR (#2247)
91509c72 is described below

commit 91509c724270bbff4742d66fc72d20a0aec0be74
Author: Twice <[email protected]>
AuthorDate: Sun Apr 14 18:59:33 2024 +0900

    Add plan lowering pass for KQIR (#2247)
---
 src/search/index_info.h                            |  2 +
 src/search/ir.h                                    | 17 ++--
 src/search/ir_pass.h                               | 72 ++++++++++++++++-
 src/search/ir_plan.h                               | 91 ++++++++++++++++++----
 src/search/ir_sema_checker.h                       |  5 +-
 .../{index_info.h => passes/lower_to_plan.h}       | 42 ++++------
 src/search/search_encoding.h                       |  4 +-
 src/search/sql_transformer.h                       |  2 +-
 tests/cppunit/ir_pass_test.cc                      | 14 ++++
 9 files changed, 199 insertions(+), 50 deletions(-)

diff --git a/src/search/index_info.h b/src/search/index_info.h
index e6e95b72..df059918 100644
--- a/src/search/index_info.h
+++ b/src/search/index_info.h
@@ -37,6 +37,8 @@ struct FieldInfo {
 
   FieldInfo(std::string name, std::unique_ptr<redis::SearchFieldMetadata> 
&&metadata)
       : name(std::move(name)), metadata(std::move(metadata)) {}
+
+  bool IsSortable() const { return 
dynamic_cast<redis::SearchSortableFieldMetadata *>(metadata.get()) != nullptr; }
 };
 
 struct IndexInfo {
diff --git a/src/search/ir.h b/src/search/ir.h
index 2b85b458..d7da716a 100644
--- a/src/search/ir.h
+++ b/src/search/ir.h
@@ -51,6 +51,11 @@ struct Node {
 
   virtual std::unique_ptr<Node> Clone() const = 0;
 
+  template <typename T>
+  std::unique_ptr<T> CloneAs() const {
+    return Node::MustAs<T>(Clone());
+  }
+
   virtual ~Node() = default;
 
   template <typename T, typename U = Node, typename... Args>
@@ -361,14 +366,14 @@ struct IndexRef : Ref {
   std::unique_ptr<Node> Clone() const override { return 
std::make_unique<IndexRef>(*this); }
 };
 
-struct SearchStmt : Node {
+struct SearchExpr : Node {
   std::unique_ptr<SelectClause> select;
   std::unique_ptr<IndexRef> index;
   std::unique_ptr<QueryExpr> query_expr;
   std::unique_ptr<LimitClause> limit;     // optional
   std::unique_ptr<SortByClause> sort_by;  // optional
 
-  SearchStmt(std::unique_ptr<IndexRef> &&index, std::unique_ptr<QueryExpr> 
&&query_expr,
+  SearchExpr(std::unique_ptr<IndexRef> &&index, std::unique_ptr<QueryExpr> 
&&query_expr,
              std::unique_ptr<LimitClause> &&limit, 
std::unique_ptr<SortByClause> &&sort_by,
              std::unique_ptr<SelectClause> &&select)
       : select(std::move(select)),
@@ -386,16 +391,16 @@ struct SearchStmt : Node {
   }
 
   static inline const std::vector<std::function<Node *(Node *)>> ChildMap = {
-      NodeIterator::MemFn<&SearchStmt::select>,     
NodeIterator::MemFn<&SearchStmt::index>,
-      NodeIterator::MemFn<&SearchStmt::query_expr>, 
NodeIterator::MemFn<&SearchStmt::limit>,
-      NodeIterator::MemFn<&SearchStmt::sort_by>,
+      NodeIterator::MemFn<&SearchExpr::select>,     
NodeIterator::MemFn<&SearchExpr::index>,
+      NodeIterator::MemFn<&SearchExpr::query_expr>, 
NodeIterator::MemFn<&SearchExpr::limit>,
+      NodeIterator::MemFn<&SearchExpr::sort_by>,
   };
 
   NodeIterator ChildBegin() override { return NodeIterator(this, 
ChildMap.begin()); };
   NodeIterator ChildEnd() override { return NodeIterator(this, 
ChildMap.end()); };
 
   std::unique_ptr<Node> Clone() const override {
-    return std::make_unique<SearchStmt>(
+    return std::make_unique<SearchExpr>(
         Node::MustAs<IndexRef>(index->Clone()), 
Node::MustAs<QueryExpr>(query_expr->Clone()),
         Node::MustAs<LimitClause>(limit->Clone()), 
Node::MustAs<SortByClause>(sort_by->Clone()),
         Node::MustAs<SelectClause>(select->Clone()));
diff --git a/src/search/ir_pass.h b/src/search/ir_pass.h
index a252216d..924e8c75 100644
--- a/src/search/ir_pass.h
+++ b/src/search/ir_pass.h
@@ -21,6 +21,7 @@
 #pragma once
 
 #include "ir.h"
+#include "search/ir_plan.h"
 
 namespace kqir {
 
@@ -30,7 +31,7 @@ struct Pass {
 
 struct Visitor : Pass {
   std::unique_ptr<Node> Transform(std::unique_ptr<Node> node) override {
-    if (auto v = Node::As<SearchStmt>(std::move(node))) {
+    if (auto v = Node::As<SearchExpr>(std::move(node))) {
       return Visit(std::move(v));
     } else if (auto v = Node::As<SelectClause>(std::move(node))) {
       return Visit(std::move(v));
@@ -58,6 +59,26 @@ struct Visitor : Pass {
       return Visit(std::move(v));
     } else if (auto v = Node::As<BoolLiteral>(std::move(node))) {
       return Visit(std::move(v));
+    } else if (auto v = Node::As<FullIndexScan>(std::move(node))) {
+      return Visit(std::move(v));
+    } else if (auto v = Node::As<NumericFieldScan>(std::move(node))) {
+      return Visit(std::move(v));
+    } else if (auto v = Node::As<TagFieldScan>(std::move(node))) {
+      return Visit(std::move(v));
+    } else if (auto v = Node::As<Filter>(std::move(node))) {
+      return Visit(std::move(v));
+    } else if (auto v = Node::As<Limit>(std::move(node))) {
+      return Visit(std::move(v));
+    } else if (auto v = Node::As<Merge>(std::move(node))) {
+      return Visit(std::move(v));
+    } else if (auto v = Node::As<Sort>(std::move(node))) {
+      return Visit(std::move(v));
+    } else if (auto v = Node::As<TopNSort>(std::move(node))) {
+      return Visit(std::move(v));
+    } else if (auto v = Node::As<Projection>(std::move(node))) {
+      return Visit(std::move(v));
+    } else if (auto v = Node::As<Noop>(std::move(node))) {
+      return Visit(std::move(v));
     }
 
     __builtin_unreachable();
@@ -73,7 +94,7 @@ struct Visitor : Pass {
     return Node::MustAs<T>(Transform(std::move(n)));
   }
 
-  virtual std::unique_ptr<Node> Visit(std::unique_ptr<SearchStmt> node) {
+  virtual std::unique_ptr<Node> Visit(std::unique_ptr<SearchExpr> node) {
     node->index = VisitAs<IndexRef>(std::move(node->index));
     node->select = VisitAs<SelectClause>(std::move(node->select));
     node->query_expr = TransformAs<QueryExpr>(std::move(node->query_expr));
@@ -139,6 +160,53 @@ struct Visitor : Pass {
     node->field = VisitAs<FieldRef>(std::move(node->field));
     return node;
   }
+
+  virtual std::unique_ptr<Node> Visit(std::unique_ptr<Noop> node) { return 
node; }
+
+  virtual std::unique_ptr<Node> Visit(std::unique_ptr<FullIndexScan> node) { 
return node; }
+
+  virtual std::unique_ptr<Node> Visit(std::unique_ptr<NumericFieldScan> node) 
{ return node; }
+
+  virtual std::unique_ptr<Node> Visit(std::unique_ptr<TagFieldScan> node) { 
return node; }
+
+  virtual std::unique_ptr<Node> Visit(std::unique_ptr<Filter> node) {
+    node->source = TransformAs<PlanOperator>(std::move(node->source));
+    node->filter_expr = TransformAs<QueryExpr>(std::move(node->filter_expr));
+    return node;
+  }
+
+  virtual std::unique_ptr<Node> Visit(std::unique_ptr<Limit> node) {
+    node->op = TransformAs<PlanOperator>(std::move(node->op));
+    node->limit = VisitAs<LimitClause>(std::move(node->limit));
+    return node;
+  }
+
+  virtual std::unique_ptr<Node> Visit(std::unique_ptr<Sort> node) {
+    node->op = TransformAs<PlanOperator>(std::move(node->op));
+    node->order = VisitAs<SortByClause>(std::move(node->order));
+    return node;
+  }
+
+  virtual std::unique_ptr<Node> Visit(std::unique_ptr<TopNSort> node) {
+    node->op = TransformAs<PlanOperator>(std::move(node->op));
+    node->limit = VisitAs<LimitClause>(std::move(node->limit));
+    node->order = VisitAs<SortByClause>(std::move(node->order));
+    return node;
+  }
+
+  virtual std::unique_ptr<Node> Visit(std::unique_ptr<Projection> node) {
+    node->source = TransformAs<PlanOperator>(std::move(node->source));
+    node->select = VisitAs<SelectClause>(std::move(node->select));
+    return node;
+  }
+
+  virtual std::unique_ptr<Node> Visit(std::unique_ptr<Merge> node) {
+    for (auto &n : node->ops) {
+      n = TransformAs<PlanOperator>(std::move(n));
+    }
+
+    return node;
+  }
 };
 
 }  // namespace kqir
diff --git a/src/search/ir_plan.h b/src/search/ir_plan.h
index 9d40d948..da805846 100644
--- a/src/search/ir_plan.h
+++ b/src/search/ir_plan.h
@@ -31,22 +31,30 @@ namespace kqir {
 
 struct PlanOperator : Node {};
 
+struct Noop : PlanOperator {
+  std::string_view Name() const override { return "Noop"; };
+  std::string Dump() const override { return "noop"; }
+
+  std::unique_ptr<Node> Clone() const override { return 
std::make_unique<Noop>(*this); }
+};
+
 struct FullIndexScan : PlanOperator {
-  IndexInfo *index;
+  std::unique_ptr<IndexRef> index;
 
-  explicit FullIndexScan(IndexInfo *index) : index(index) {}
+  explicit FullIndexScan(std::unique_ptr<IndexRef> index) : 
index(std::move(index)) {}
 
   std::string_view Name() const override { return "FullIndexScan"; };
-  std::string Content() const override { return index->name; };
-  std::string Dump() const override { return fmt::format("full-scan {}", 
Content()); }
+  std::string Dump() const override { return fmt::format("full-scan {}", 
index->name); }
 
-  std::unique_ptr<Node> Clone() const override { return 
std::make_unique<FullIndexScan>(*this); }
+  std::unique_ptr<Node> Clone() const override {
+    return 
std::make_unique<FullIndexScan>(Node::MustAs<IndexRef>(index->Clone()));
+  }
 };
 
 struct FieldScan : PlanOperator {
-  FieldInfo *field;
+  std::unique_ptr<FieldRef> field;
 
-  explicit FieldScan(FieldInfo *field) : field(field) {}
+  explicit FieldScan(std::unique_ptr<FieldRef> field) : 
field(std::move(field)) {}
 };
 
 struct Interval {
@@ -60,25 +68,29 @@ struct Interval {
 struct NumericFieldScan : FieldScan {
   Interval range;
 
-  NumericFieldScan(FieldInfo *field, Interval range) : FieldScan(field), 
range(range) {}
+  NumericFieldScan(std::unique_ptr<FieldRef> field, Interval range) : 
FieldScan(std::move(field)), range(range) {}
 
   std::string_view Name() const override { return "NumericFieldScan"; };
   std::string Content() const override { return fmt::format("{}, {}", 
field->name, range.ToString()); };
   std::string Dump() const override { return fmt::format("numeric-scan {}", 
Content()); }
 
-  std::unique_ptr<Node> Clone() const override { return 
std::make_unique<NumericFieldScan>(*this); }
+  std::unique_ptr<Node> Clone() const override {
+    return std::make_unique<NumericFieldScan>(field->CloneAs<FieldRef>(), 
range);
+  }
 };
 
 struct TagFieldScan : FieldScan {
   std::string tag;
 
-  TagFieldScan(FieldInfo *field, std::string tag) : FieldScan(field), 
tag(std::move(tag)) {}
+  TagFieldScan(std::unique_ptr<FieldRef> field, std::string tag) : 
FieldScan(std::move(field)), tag(std::move(tag)) {}
 
   std::string_view Name() const override { return "TagFieldScan"; };
   std::string Content() const override { return fmt::format("{}, {}", 
field->name, tag); };
   std::string Dump() const override { return fmt::format("tag-scan {}", 
Content()); }
 
-  std::unique_ptr<Node> Clone() const override { return 
std::make_unique<TagFieldScan>(*this); }
+  std::unique_ptr<Node> Clone() const override {
+    return std::make_unique<TagFieldScan>(field->CloneAs<FieldRef>(), tag);
+  }
 };
 
 struct Filter : PlanOperator {
@@ -89,7 +101,7 @@ struct Filter : PlanOperator {
       : source(std::move(source)), filter_expr(std::move(filter_expr)) {}
 
   std::string_view Name() const override { return "Filter"; };
-  std::string Dump() const override { return fmt::format("(filter {}, {})", 
source->Dump(), Content()); }
+  std::string Dump() const override { return fmt::format("(filter {}: {})", 
filter_expr->Dump(), source->Dump()); }
 
   NodeIterator ChildBegin() override { return {source.get(), 
filter_expr.get()}; }
   NodeIterator ChildEnd() override { return {}; }
@@ -143,6 +155,55 @@ struct Limit : PlanOperator {
   }
 };
 
+struct Sort : PlanOperator {
+  std::unique_ptr<PlanOperator> op;
+  std::unique_ptr<SortByClause> order;
+
+  Sort(std::unique_ptr<PlanOperator> &&op, std::unique_ptr<SortByClause> 
&&order)
+      : op(std::move(op)), order(std::move(order)) {}
+
+  std::string_view Name() const override { return "Sort"; };
+  std::string Dump() const override {
+    return fmt::format("(sort {}, {}: {})", order->field->Dump(), 
order->OrderToString(order->order), op->Dump());
+  }
+
+  NodeIterator ChildBegin() override { return NodeIterator{op.get(), 
order.get()}; }
+  NodeIterator ChildEnd() override { return {}; }
+
+  std::unique_ptr<Node> Clone() const override {
+    return std::make_unique<Sort>(Node::MustAs<PlanOperator>(op->Clone()), 
Node::MustAs<SortByClause>(order->Clone()));
+  }
+};
+
+// operator fusion: Sort + Limit
+struct TopNSort : PlanOperator {
+  std::unique_ptr<PlanOperator> op;
+  std::unique_ptr<SortByClause> order;
+  std::unique_ptr<LimitClause> limit;
+
+  TopNSort(std::unique_ptr<PlanOperator> &&op, std::unique_ptr<SortByClause> 
&&order,
+           std::unique_ptr<LimitClause> &&limit)
+      : op(std::move(op)), order(std::move(order)), limit(std::move(limit)) {}
+
+  std::string_view Name() const override { return "TopNSort"; };
+  std::string Dump() const override {
+    return fmt::format("(top-n sort {}, {}, {}, {}: {})", 
order->field->Dump(), order->OrderToString(order->order),
+                       limit->offset, limit->count, op->Dump());
+  }
+
+  static inline const std::vector<std::function<Node *(Node *)>> ChildMap = {
+      NodeIterator::MemFn<&TopNSort::op>, 
NodeIterator::MemFn<&TopNSort::order>, NodeIterator::MemFn<&TopNSort::limit>};
+
+  NodeIterator ChildBegin() override { return NodeIterator(this, 
ChildMap.begin()); }
+  NodeIterator ChildEnd() override { return NodeIterator(this, 
ChildMap.end()); }
+
+  std::unique_ptr<Node> Clone() const override {
+    return std::make_unique<TopNSort>(Node::MustAs<PlanOperator>(op->Clone()),
+                                      
Node::MustAs<SortByClause>(order->Clone()),
+                                      
Node::MustAs<LimitClause>(limit->Clone()));
+  }
+};
+
 struct Projection : PlanOperator {
   std::unique_ptr<PlanOperator> source;
   std::unique_ptr<SelectClause> select;
@@ -151,7 +212,11 @@ struct Projection : PlanOperator {
       : source(std::move(source)), select(std::move(select)) {}
 
   std::string_view Name() const override { return "Projection"; };
-  std::string Dump() const override { return fmt::format("(project {}: {})", 
select, source); }
+  std::string Dump() const override {
+    auto select_str =
+        select->fields.empty() ? "*" : util::StringJoin(select->fields, 
[](const auto &v) { return v->Dump(); });
+    return fmt::format("project {}: {}", select_str, source->Dump());
+  }
 
   NodeIterator ChildBegin() override { return {source.get(), select.get()}; }
   NodeIterator ChildEnd() override { return {}; }
diff --git a/src/search/ir_sema_checker.h b/src/search/ir_sema_checker.h
index 471c3a3f..40f316cb 100644
--- a/src/search/ir_sema_checker.h
+++ b/src/search/ir_sema_checker.h
@@ -23,6 +23,7 @@
 #include <map>
 #include <memory>
 
+#include "fmt/core.h"
 #include "index_info.h"
 #include "ir.h"
 #include "search_encoding.h"
@@ -38,7 +39,7 @@ struct SemaChecker {
   explicit SemaChecker(const IndexMap &index_map) : index_map(index_map) {}
 
   Status Check(Node *node) {
-    if (auto v = dynamic_cast<SearchStmt *>(node)) {
+    if (auto v = dynamic_cast<SearchExpr *>(node)) {
       auto index_name = v->index->name;
       if (auto iter = index_map.find(index_name); iter != index_map.end()) {
         current_index = &iter->second;
@@ -56,6 +57,8 @@ struct SemaChecker {
     } else if (auto v = dynamic_cast<SortByClause *>(node)) {
       if (auto iter = current_index->fields.find(v->field->name); iter == 
current_index->fields.end()) {
         return {Status::NotOK, fmt::format("field `{}` not found in index 
`{}`", v->field->name, current_index->name)};
+      } else if (!iter->second.IsSortable()) {
+        return {Status::NotOK, fmt::format("field `{}` is not sortable", 
v->field->name)};
       } else {
         v->field->info = &iter->second;
       }
diff --git a/src/search/index_info.h b/src/search/passes/lower_to_plan.h
similarity index 51%
copy from src/search/index_info.h
copy to src/search/passes/lower_to_plan.h
index e6e95b72..dad1db39 100644
--- a/src/search/index_info.h
+++ b/src/search/passes/lower_to_plan.h
@@ -20,42 +20,32 @@
 
 #pragma once
 
-#include <map>
 #include <memory>
-#include <string>
 
-#include "search_encoding.h"
+#include "search/ir.h"
+#include "search/ir_pass.h"
+#include "search/ir_plan.h"
 
 namespace kqir {
 
-struct IndexInfo;
+struct LowerToPlan : Visitor {
+  std::unique_ptr<Node> Visit(std::unique_ptr<SearchExpr> node) override {
+    auto scan = 
std::make_unique<FullIndexScan>(node->index->CloneAs<IndexRef>());
+    auto filter = std::make_unique<Filter>(std::move(scan), 
std::move(node->query_expr));
 
-struct FieldInfo {
-  std::string name;
-  IndexInfo *index = nullptr;
-  std::unique_ptr<redis::SearchFieldMetadata> metadata;
+    std::unique_ptr<PlanOperator> op = std::move(filter);
 
-  FieldInfo(std::string name, std::unique_ptr<redis::SearchFieldMetadata> 
&&metadata)
-      : name(std::move(name)), metadata(std::move(metadata)) {}
-};
-
-struct IndexInfo {
-  using FieldMap = std::map<std::string, FieldInfo>;
+    // order is important here, since limit(sort(op)) is different from 
sort(limit(op))
+    if (node->sort_by) {
+      op = std::make_unique<Sort>(std::move(op), std::move(node->sort_by));
+    }
 
-  std::string name;
-  SearchMetadata metadata;
-  FieldMap fields;
-  redis::SearchPrefixesMetadata prefixes;
+    if (node->limit) {
+      op = std::make_unique<Limit>(std::move(op), std::move(node->limit));
+    }
 
-  IndexInfo(std::string name, SearchMetadata metadata) : 
name(std::move(name)), metadata(std::move(metadata)) {}
-
-  void Add(FieldInfo &&field) {
-    const auto &name = field.name;
-    field.index = this;
-    fields.emplace(name, std::move(field));
+    return std::make_unique<Projection>(std::move(op), 
std::move(node->select));
   }
 };
 
-using IndexMap = std::map<std::string, IndexInfo>;
-
 }  // namespace kqir
diff --git a/src/search/search_encoding.h b/src/search/search_encoding.h
index 1637a504..14bf2923 100644
--- a/src/search/search_encoding.h
+++ b/src/search/search_encoding.h
@@ -125,7 +125,9 @@ inline std::string 
ConstructNumericFieldMetadataSubkey(std::string_view field_na
   return res;
 }
 
-struct SearchNumericFieldMetadata : SearchFieldMetadata {};
+struct SearchSortableFieldMetadata : SearchFieldMetadata {};
+
+struct SearchNumericFieldMetadata : SearchSortableFieldMetadata {};
 
 inline std::string ConstructTagFieldSubkey(std::string_view field_name, 
std::string_view tag, std::string_view key) {
   std::string res = {(char)SearchSubkeyType::TAG_FIELD};
diff --git a/src/search/sql_transformer.h b/src/search/sql_transformer.h
index 871949fd..8386a66e 100644
--- a/src/search/sql_transformer.h
+++ b/src/search/sql_transformer.h
@@ -168,7 +168,7 @@ struct Transformer : ir::TreeTransformer {
         query_expr = std::make_unique<BoolLiteral>(true);
       }
 
-      return Node::Create<ir::SearchStmt>(std::move(index), 
std::move(query_expr), std::move(limit), std::move(sort_by),
+      return Node::Create<ir::SearchExpr>(std::move(index), 
std::move(query_expr), std::move(limit), std::move(sort_by),
                                           std::move(select));
     } else if (IsRoot(node)) {
       CHECK(node->children.size() == 1);
diff --git a/tests/cppunit/ir_pass_test.cc b/tests/cppunit/ir_pass_test.cc
index 70d39f96..0f0952bf 100644
--- a/tests/cppunit/ir_pass_test.cc
+++ b/tests/cppunit/ir_pass_test.cc
@@ -21,6 +21,7 @@
 #include "search/ir_pass.h"
 
 #include "gtest/gtest.h"
+#include "search/passes/lower_to_plan.h"
 #include "search/passes/manager.h"
 #include "search/passes/push_down_not_expr.h"
 #include "search/passes/simplify_and_or_expr.h"
@@ -103,3 +104,16 @@ TEST(IRPassTest, Manager) {
       PassManager::Default(*Parse("select * from a where not (x > 1 or (y < 2 
or z = 3)) and (true or x = 1)"))->Dump(),
       "select * from a where (and x <= 1, y >= 2, z != 3)");
 }
+
+TEST(IRPassTest, LowerToPlan) {
+  LowerToPlan ltp;
+
+  ASSERT_EQ(ltp.Transform(*Parse("select * from a"))->Dump(), "project *: 
(filter true: full-scan a)");
+  ASSERT_EQ(ltp.Transform(*Parse("select * from a where b > 1"))->Dump(), 
"project *: (filter b > 1: full-scan a)");
+  ASSERT_EQ(ltp.Transform(*Parse("select a from b where c = 1 order by 
d"))->Dump(),
+            "project a: (sort d, asc: (filter c = 1: full-scan b))");
+  ASSERT_EQ(ltp.Transform(*Parse("select a from b where c = 1 limit 
1"))->Dump(),
+            "project a: (limit 0, 1: (filter c = 1: full-scan b))");
+  ASSERT_EQ(ltp.Transform(*Parse("select a from b where c = 1 order by d limit 
1"))->Dump(),
+            "project a: (limit 0, 1: (sort d, asc: (filter c = 1: full-scan 
b)))");
+}

Reply via email to