This is an automated email from the ASF dual-hosted git repository.

twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new 7e063030 Add support for dumping DOT graphs from KQIR (#2205)
7e063030 is described below

commit 7e063030519b1b5a5367a5921a0a5ee2b389d0b7
Author: Twice <[email protected]>
AuthorDate: Fri Mar 29 18:22:47 2024 +0900

    Add support for dumping DOT graphs from KQIR (#2205)
---
 src/common/type_util.h                             |  8 ++
 src/search/ir.h                                    | 79 +++++++++++++++++--
 src/{common/type_util.h => search/ir_dot_dumper.h} | 41 ++++++----
 src/search/ir_iterator.h                           | 92 ++++++++++++++++++++++
 src/search/redis_query_transformer.h               | 30 +++----
 src/search/sql_transformer.h                       | 25 +++---
 tests/cppunit/ir_dot_dumper_test.cc                | 55 +++++++++++++
 7 files changed, 284 insertions(+), 46 deletions(-)

diff --git a/src/common/type_util.h b/src/common/type_util.h
index d55019f7..98439243 100644
--- a/src/common/type_util.h
+++ b/src/common/type_util.h
@@ -39,3 +39,11 @@ using RemoveCVRef = typename std::remove_cv_t<typename 
std::remove_reference_t<T
 // dependent false for static_assert with constexpr if, see CWG2518/P2593R1
 template <typename T>
 constexpr bool AlwaysFalse = false;
+
+template <typename>
+struct GetClassFromMember;
+
+template <typename C, typename T>
+struct GetClassFromMember<T C::*> {
+  using type = C;  // NOLINT
+};
diff --git a/src/search/ir.h b/src/search/ir.h
index 87ed6488..02f3766a 100644
--- a/src/search/ir.h
+++ b/src/search/ir.h
@@ -22,22 +22,32 @@
 
 #include <fmt/format.h>
 
+#include <initializer_list>
 #include <limits>
 #include <memory>
 #include <optional>
 #include <string>
+#include <type_traits>
 #include <utility>
 #include <variant>
 #include <vector>
 
 #include "fmt/core.h"
+#include "ir_iterator.h"
 #include "string_util.h"
+#include "type_util.h"
 
 // kqir stands for Kvorcks Query Intermediate Representation
 namespace kqir {
 
 struct Node {
   virtual std::string Dump() const = 0;
+  virtual std::string_view Name() const = 0;
+  virtual std::string Content() const { return {}; }
+
+  virtual NodeIterator ChildBegin() { return {}; };
+  virtual NodeIterator ChildEnd() { return {}; };
+
   virtual ~Node() = default;
 
   template <typename T, typename U = Node, typename... Args>
@@ -45,10 +55,16 @@ struct Node {
     return std::unique_ptr<U>(new T(std::forward<Args>(args)...));
   }
 
+  template <typename T>
+  static std::unique_ptr<T> MustAs(std::unique_ptr<Node> &&original) {
+    auto casted = As<T>(std::move(original));
+    CHECK(casted != nullptr);
+    return casted;
+  }
+
   template <typename T>
   static std::unique_ptr<T> As(std::unique_ptr<Node> &&original) {
     auto casted = dynamic_cast<T *>(original.release());
-    CHECK(casted);
     return std::unique_ptr<T>(casted);
   }
 };
@@ -58,7 +74,9 @@ struct FieldRef : Node {
 
   explicit FieldRef(std::string name) : name(std::move(name)) {}
 
+  std::string_view Name() const override { return "FieldRef"; }
   std::string Dump() const override { return name; }
+  std::string Content() const override { return Dump(); }
 };
 
 struct StringLiteral : Node {
@@ -66,7 +84,9 @@ struct StringLiteral : Node {
 
   explicit StringLiteral(std::string val) : val(std::move(val)) {}
 
+  std::string_view Name() const override { return "StringLiteral"; }
   std::string Dump() const override { return fmt::format("\"{}\"", 
util::EscapeString(val)); }
+  std::string Content() const override { return Dump(); }
 };
 
 struct QueryExpr : Node {};
@@ -80,7 +100,11 @@ struct TagContainExpr : BoolAtomExpr {
   TagContainExpr(std::unique_ptr<FieldRef> &&field, 
std::unique_ptr<StringLiteral> &&tag)
       : field(std::move(field)), tag(std::move(tag)) {}
 
+  std::string_view Name() const override { return "TagContainExpr"; }
   std::string Dump() const override { return fmt::format("{} hastag {}", 
field->Dump(), tag->Dump()); }
+
+  NodeIterator ChildBegin() override { return {field.get(), tag.get()}; };
+  NodeIterator ChildEnd() override { return {}; };
 };
 
 struct NumericLiteral : Node {
@@ -88,7 +112,9 @@ struct NumericLiteral : Node {
 
   explicit NumericLiteral(double val) : val(val) {}
 
+  std::string_view Name() const override { return "NumericLiteral"; }
   std::string Dump() const override { return fmt::format("{}", val); }
+  std::string Content() const override { return Dump(); }
 };
 
 // NOLINTNEXTLINE
@@ -156,7 +182,12 @@ struct NumericCompareExpr : BoolAtomExpr {
     __builtin_unreachable();
   }
 
+  std::string_view Name() const override { return "NumericCompareExpr"; }
   std::string Dump() const override { return fmt::format("{} {} {}", 
field->Dump(), ToOperator(op), num->Dump()); };
+  std::string Content() const override { return ToOperator(op); }
+
+  NodeIterator ChildBegin() override { return {field.get(), num.get()}; };
+  NodeIterator ChildEnd() override { return {}; };
 };
 
 struct BoolLiteral : BoolAtomExpr {
@@ -164,7 +195,9 @@ struct BoolLiteral : BoolAtomExpr {
 
   explicit BoolLiteral(bool val) : val(val) {}
 
+  std::string_view Name() const override { return "BoolLiteral"; }
   std::string Dump() const override { return val ? "true" : "false"; }
+  std::string Content() const override { return Dump(); }
 };
 
 struct QueryExpr;
@@ -174,7 +207,11 @@ struct NotExpr : QueryExpr {
 
   explicit NotExpr(std::unique_ptr<QueryExpr> &&inner) : 
inner(std::move(inner)) {}
 
+  std::string_view Name() const override { return "NotExpr"; }
   std::string Dump() const override { return fmt::format("not {}", 
inner->Dump()); }
+
+  NodeIterator ChildBegin() override { return NodeIterator{inner.get()}; };
+  NodeIterator ChildEnd() override { return {}; };
 };
 
 struct AndExpr : QueryExpr {
@@ -182,9 +219,13 @@ struct AndExpr : QueryExpr {
 
   explicit AndExpr(std::vector<std::unique_ptr<QueryExpr>> &&inners) : 
inners(std::move(inners)) {}
 
+  std::string_view Name() const override { return "AndExpr"; }
   std::string Dump() const override {
     return fmt::format("(and {})", util::StringJoin(inners, [](const auto &v) 
{ return v->Dump(); }));
   }
+
+  NodeIterator ChildBegin() override { return NodeIterator(inners.begin()); };
+  NodeIterator ChildEnd() override { return NodeIterator(inners.end()); };
 };
 
 struct OrExpr : QueryExpr {
@@ -192,9 +233,13 @@ struct OrExpr : QueryExpr {
 
   explicit OrExpr(std::vector<std::unique_ptr<QueryExpr>> &&inners) : 
inners(std::move(inners)) {}
 
+  std::string_view Name() const override { return "OrExpr"; }
   std::string Dump() const override {
     return fmt::format("(or {})", util::StringJoin(inners, [](const auto &v) { 
return v->Dump(); }));
   }
+
+  NodeIterator ChildBegin() override { return NodeIterator(inners.begin()); };
+  NodeIterator ChildEnd() override { return NodeIterator(inners.end()); };
 };
 
 struct Limit : Node {
@@ -203,7 +248,9 @@ struct Limit : Node {
 
   Limit(size_t offset, size_t count) : offset(offset), count(count) {}
 
+  std::string_view Name() const override { return "Limit"; }
   std::string Dump() const override { return fmt::format("limit {}, {}", 
offset, count); }
+  std::string Content() const override { return fmt::format("{}, {}", offset, 
count); }
 };
 
 struct SortBy : Node {
@@ -213,7 +260,13 @@ struct SortBy : Node {
   SortBy(Order order, std::unique_ptr<FieldRef> &&field) : order(order), 
field(std::move(field)) {}
 
   static constexpr const char *OrderToString(Order order) { return order == 
ASC ? "asc" : "desc"; }
+
+  std::string_view Name() const override { return "SortBy"; }
   std::string Dump() const override { return fmt::format("sortby {}, {}", 
field->Dump(), OrderToString(order)); }
+  std::string Content() const override { return OrderToString(order); }
+
+  NodeIterator ChildBegin() override { return NodeIterator(field.get()); };
+  NodeIterator ChildEnd() override { return {}; };
 };
 
 struct SelectExpr : Node {
@@ -221,10 +274,14 @@ struct SelectExpr : Node {
 
   explicit SelectExpr(std::vector<std::unique_ptr<FieldRef>> &&fields) : 
fields(std::move(fields)) {}
 
+  std::string_view Name() const override { return "SelectExpr"; }
   std::string Dump() const override {
     if (fields.empty()) return "select *";
     return fmt::format("select {}", util::StringJoin(fields, [](const auto &v) 
{ return v->Dump(); }));
   }
+
+  NodeIterator ChildBegin() override { return NodeIterator(fields.begin()); };
+  NodeIterator ChildEnd() override { return NodeIterator(fields.end()); };
 };
 
 struct IndexRef : Node {
@@ -232,24 +289,27 @@ struct IndexRef : Node {
 
   explicit IndexRef(std::string name) : name(std::move(name)) {}
 
+  std::string_view Name() const override { return "IndexRef"; }
   std::string Dump() const override { return name; }
+  std::string Content() const override { return Dump(); }
 };
 
 struct SearchStmt : Node {
+  std::unique_ptr<SelectExpr> select_expr;
   std::unique_ptr<IndexRef> index;
   std::unique_ptr<QueryExpr> query_expr;  // optional
   std::unique_ptr<Limit> limit;           // optional
   std::unique_ptr<SortBy> sort_by;        // optional
-  std::unique_ptr<SelectExpr> select_expr;
 
   SearchStmt(std::unique_ptr<IndexRef> &&index, std::unique_ptr<QueryExpr> 
&&query_expr, std::unique_ptr<Limit> &&limit,
              std::unique_ptr<SortBy> &&sort_by, std::unique_ptr<SelectExpr> 
&&select_expr)
-      : index(std::move(index)),
+      : select_expr(std::move(select_expr)),
+        index(std::move(index)),
         query_expr(std::move(query_expr)),
         limit(std::move(limit)),
-        sort_by(std::move(sort_by)),
-        select_expr(std::move(select_expr)) {}
+        sort_by(std::move(sort_by)) {}
 
+  std::string_view Name() const override { return "SearchStmt"; }
   std::string Dump() const override {
     std::string opt;
     if (query_expr) opt += " where " + query_expr->Dump();
@@ -257,6 +317,15 @@ struct SearchStmt : Node {
     if (limit) opt += " " + limit->Dump();
     return fmt::format("{} from {}{}", select_expr->Dump(), index->Dump(), 
opt);
   }
+
+  static inline const std::vector<std::function<Node *(Node *)>> ChildMap = {
+      NodeIterator::MemFn<&SearchStmt::select_expr>, 
NodeIterator::MemFn<&SearchStmt::index>,
+      NodeIterator::MemFn<&SearchStmt::query_expr>,  
NodeIterator::MemFn<&SearchStmt::limit>,
+      NodeIterator::MemFn<&SearchStmt::sort_by>,
+  };
+
+  NodeIterator ChildBegin() override { return NodeIterator(this, 
ChildMap.begin()); };
+  NodeIterator ChildEnd() override { return NodeIterator(this, 
ChildMap.end()); };
 };
 
 }  // namespace kqir
diff --git a/src/common/type_util.h b/src/search/ir_dot_dumper.h
similarity index 53%
copy from src/common/type_util.h
copy to src/search/ir_dot_dumper.h
index d55019f7..5bb6dc7b 100644
--- a/src/common/type_util.h
+++ b/src/search/ir_dot_dumper.h
@@ -20,22 +20,35 @@
 
 #pragma once
 
-#include <utility>
+#include "ir.h"
+#include "string_util.h"
 
-template <typename F, F *f>
-struct StaticFunction {
-  template <typename... Ts>
-  auto operator()(Ts &&...args) const -> 
decltype(f(std::forward<Ts>(args)...)) {  // NOLINT
-    return f(std::forward<Ts>(args)...);                                       
    // NOLINT
+namespace kqir {
+
+struct DotDumper {
+  std::ostream &os;
+
+  void Dump(Node *node) {
+    os << "digraph {\n";
+    dump(node);
+    os << "}\n";
   }
-};
 
-template <typename... Ts>
-using FirstElement = typename std::tuple_element_t<0, std::tuple<Ts...>>;
+ private:
+  static std::string nodeId(Node *node) { return fmt::format("x{:x}", 
(uint64_t)node); }
 
-template <typename T>
-using RemoveCVRef = typename std::remove_cv_t<typename 
std::remove_reference_t<T>>;
+  void dump(Node *node) {
+    os << "  " << nodeId(node) << " [ label = \"" << node->Name();
+    if (auto content = node->Content(); !content.empty()) {
+      os << " (" << util::EscapeString(content) << ")\" ];\n";
+    } else {
+      os << "\" ];\n";
+    }
+    for (auto i = node->ChildBegin(); i != node->ChildEnd(); ++i) {
+      os << "  " << nodeId(node) << " -> " << nodeId(*i) << ";\n";
+      dump(*i);
+    }
+  }
+};
 
-// dependent false for static_assert with constexpr if, see CWG2518/P2593R1
-template <typename T>
-constexpr bool AlwaysFalse = false;
+}  // namespace kqir
diff --git a/src/search/ir_iterator.h b/src/search/ir_iterator.h
new file mode 100644
index 00000000..2ead473c
--- /dev/null
+++ b/src/search/ir_iterator.h
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+#pragma once
+
+#include <functional>
+#include <memory>
+#include <variant>
+#include <vector>
+
+#include "type_util.h"
+
+namespace kqir {
+
+struct Node;
+
+struct NodeIterator {
+  std::variant<Node *, std::array<Node *, 2>,
+               std::pair<Node *, std::vector<std::function<Node *(Node 
*)>>::const_iterator>,
+               std::vector<std::unique_ptr<Node>>::iterator>
+      val;
+
+  NodeIterator() : val(nullptr) {}
+  explicit NodeIterator(Node *node) : val(node) {}
+  NodeIterator(Node *n1, Node *n2) : val(std::array<Node *, 2>{n1, n2}) {}
+  explicit NodeIterator(Node *parent, std::vector<std::function<Node *(Node 
*)>>::const_iterator iter)
+      : val(std::make_pair(parent, iter)) {}
+  template <typename Iterator,
+            std::enable_if_t<std::is_base_of_v<Node, typename 
Iterator::value_type::element_type>, int> = 0>
+  explicit NodeIterator(Iterator iter) : val(*CastToNodeIter(&iter)) {}
+
+  template <typename Iterator>
+  static auto CastToNodeIter(Iterator *iter) {
+    auto res __attribute__((__may_alias__)) = 
reinterpret_cast<std::vector<std::unique_ptr<Node>>::iterator *>(iter);
+    return res;
+  }
+
+  template <auto F>
+  static Node *MemFn(Node *parent) {
+    return (reinterpret_cast<typename GetClassFromMember<decltype(F)>::type 
*>(parent)->*F).get();
+  }
+
+  friend bool operator==(NodeIterator l, NodeIterator r) { return l.val == 
r.val; }
+
+  friend bool operator!=(NodeIterator l, NodeIterator r) { return l.val != 
r.val; }
+
+  Node *operator*() {
+    if (val.index() == 0) {
+      return std::get<0>(val);
+    } else if (val.index() == 1) {
+      return std::get<1>(val)[0];
+    } else if (val.index() == 2) {
+      auto &[parent, iter] = std::get<2>(val);
+      return (*iter)(parent);
+    } else {
+      return std::get<3>(val)->get();
+    }
+  }
+
+  NodeIterator &operator++() {
+    if (val.index() == 0) {
+      val = nullptr;
+    } else if (val.index() == 1) {
+      val = std::get<1>(val)[1];
+    } else if (val.index() == 2) {
+      ++std::get<2>(val).second;
+    } else {
+      ++std::get<3>(val);
+    }
+
+    return *this;
+  }
+};
+
+}  // namespace kqir
diff --git a/src/search/redis_query_transformer.h 
b/src/search/redis_query_transformer.h
index ebd05143..45734b31 100644
--- a/src/search/redis_query_transformer.h
+++ b/src/search/redis_query_transformer.h
@@ -83,13 +83,13 @@ struct Transformer : ir::TreeTransformer {
         const auto& rhs = query->children[1];
 
         if (Is<ExclusiveNumber>(lhs)) {
-          exprs.push_back(
-              std::make_unique<NumericCompareExpr>(NumericCompareExpr::GT, 
std::make_unique<FieldRef>(field),
-                                                   
Node::As<NumericLiteral>(GET_OR_RET(Transform(lhs->children[0])))));
+          exprs.push_back(std::make_unique<NumericCompareExpr>(
+              NumericCompareExpr::GT, std::make_unique<FieldRef>(field),
+              
Node::MustAs<NumericLiteral>(GET_OR_RET(Transform(lhs->children[0])))));
         } else if (Is<Number>(lhs)) {
-          
exprs.push_back(std::make_unique<NumericCompareExpr>(NumericCompareExpr::GET,
-                                                               
std::make_unique<FieldRef>(field),
-                                                               
Node::As<NumericLiteral>(GET_OR_RET(Transform(lhs)))));
+          exprs.push_back(
+              std::make_unique<NumericCompareExpr>(NumericCompareExpr::GET, 
std::make_unique<FieldRef>(field),
+                                                   
Node::MustAs<NumericLiteral>(GET_OR_RET(Transform(lhs)))));
         } else {  // Inf
           if (lhs->string_view() == "+inf") {
             return {Status::NotOK, "it's not allowed to set the lower bound as 
positive infinity"};
@@ -97,13 +97,13 @@ struct Transformer : ir::TreeTransformer {
         }
 
         if (Is<ExclusiveNumber>(rhs)) {
-          exprs.push_back(
-              std::make_unique<NumericCompareExpr>(NumericCompareExpr::LT, 
std::make_unique<FieldRef>(field),
-                                                   
Node::As<NumericLiteral>(GET_OR_RET(Transform(rhs->children[0])))));
+          exprs.push_back(std::make_unique<NumericCompareExpr>(
+              NumericCompareExpr::LT, std::make_unique<FieldRef>(field),
+              
Node::MustAs<NumericLiteral>(GET_OR_RET(Transform(rhs->children[0])))));
         } else if (Is<Number>(rhs)) {
-          
exprs.push_back(std::make_unique<NumericCompareExpr>(NumericCompareExpr::LET,
-                                                               
std::make_unique<FieldRef>(field),
-                                                               
Node::As<NumericLiteral>(GET_OR_RET(Transform(rhs)))));
+          exprs.push_back(
+              std::make_unique<NumericCompareExpr>(NumericCompareExpr::LET, 
std::make_unique<FieldRef>(field),
+                                                   
Node::MustAs<NumericLiteral>(GET_OR_RET(Transform(rhs)))));
         } else {  // Inf
           if (rhs->string_view() == "-inf") {
             return {Status::NotOK, "it's not allowed to set the upper bound as 
negative infinity"};
@@ -121,12 +121,12 @@ struct Transformer : ir::TreeTransformer {
     } else if (Is<NotExpr>(node)) {
       CHECK(node->children.size() == 1);
 
-      return 
Node::Create<ir::NotExpr>(Node::As<ir::QueryExpr>(GET_OR_RET(Transform(node->children[0]))));
+      return 
Node::Create<ir::NotExpr>(Node::MustAs<ir::QueryExpr>(GET_OR_RET(Transform(node->children[0]))));
     } else if (Is<AndExpr>(node)) {
       std::vector<std::unique_ptr<ir::QueryExpr>> exprs;
 
       for (const auto& child : node->children) {
-        exprs.push_back(Node::As<ir::QueryExpr>(GET_OR_RET(Transform(child))));
+        
exprs.push_back(Node::MustAs<ir::QueryExpr>(GET_OR_RET(Transform(child))));
       }
 
       return Node::Create<ir::AndExpr>(std::move(exprs));
@@ -134,7 +134,7 @@ struct Transformer : ir::TreeTransformer {
       std::vector<std::unique_ptr<ir::QueryExpr>> exprs;
 
       for (const auto& child : node->children) {
-        exprs.push_back(Node::As<ir::QueryExpr>(GET_OR_RET(Transform(child))));
+        
exprs.push_back(Node::MustAs<ir::QueryExpr>(GET_OR_RET(Transform(child))));
       }
 
       return Node::Create<ir::OrExpr>(std::move(exprs));
diff --git a/src/search/sql_transformer.h b/src/search/sql_transformer.h
index 5c300654..63397e40 100644
--- a/src/search/sql_transformer.h
+++ b/src/search/sql_transformer.h
@@ -63,8 +63,9 @@ struct Transformer : ir::TreeTransformer {
     } else if (Is<HasTagExpr>(node)) {
       CHECK(node->children.size() == 2);
 
-      return 
Node::Create<ir::TagContainExpr>(std::make_unique<ir::FieldRef>(node->children[0]->string()),
-                                              
Node::As<ir::StringLiteral>(GET_OR_RET(Transform(node->children[1]))));
+      return Node::Create<ir::TagContainExpr>(
+          std::make_unique<ir::FieldRef>(node->children[0]->string()),
+          
Node::MustAs<ir::StringLiteral>(GET_OR_RET(Transform(node->children[1]))));
     } else if (Is<NumericCompareExpr>(node)) {
       CHECK(node->children.size() == 3);
 
@@ -74,23 +75,23 @@ struct Transformer : ir::TreeTransformer {
       auto op = 
ir::NumericCompareExpr::FromOperator(node->children[1]->string_view()).value();
       if (Is<Identifier>(lhs) && Is<Number>(rhs)) {
         return Node::Create<ir::NumericCompareExpr>(op, 
std::make_unique<ir::FieldRef>(lhs->string()),
-                                                    
Node::As<ir::NumericLiteral>(GET_OR_RET(Transform(rhs))));
+                                                    
Node::MustAs<ir::NumericLiteral>(GET_OR_RET(Transform(rhs))));
       } else if (Is<Number>(lhs) && Is<Identifier>(rhs)) {
         return 
Node::Create<ir::NumericCompareExpr>(ir::NumericCompareExpr::Flip(op),
                                                     
std::make_unique<ir::FieldRef>(rhs->string()),
-                                                    
Node::As<ir::NumericLiteral>(GET_OR_RET(Transform(lhs))));
+                                                    
Node::MustAs<ir::NumericLiteral>(GET_OR_RET(Transform(lhs))));
       } else {
         return {Status::NotOK, "the left and right side of numeric comparison 
should be an identifier and a number"};
       }
     } else if (Is<NotExpr>(node)) {
       CHECK(node->children.size() == 1);
 
-      return 
Node::Create<ir::NotExpr>(Node::As<ir::QueryExpr>(GET_OR_RET(Transform(node->children[0]))));
+      return 
Node::Create<ir::NotExpr>(Node::MustAs<ir::QueryExpr>(GET_OR_RET(Transform(node->children[0]))));
     } else if (Is<AndExpr>(node)) {
       std::vector<std::unique_ptr<ir::QueryExpr>> exprs;
 
       for (const auto& child : node->children) {
-        exprs.push_back(Node::As<ir::QueryExpr>(GET_OR_RET(Transform(child))));
+        
exprs.push_back(Node::MustAs<ir::QueryExpr>(GET_OR_RET(Transform(child))));
       }
 
       return Node::Create<ir::AndExpr>(std::move(exprs));
@@ -98,7 +99,7 @@ struct Transformer : ir::TreeTransformer {
       std::vector<std::unique_ptr<ir::QueryExpr>> exprs;
 
       for (const auto& child : node->children) {
-        exprs.push_back(Node::As<ir::QueryExpr>(GET_OR_RET(Transform(child))));
+        
exprs.push_back(Node::MustAs<ir::QueryExpr>(GET_OR_RET(Transform(child))));
       }
 
       return Node::Create<ir::OrExpr>(std::move(exprs));
@@ -146,8 +147,8 @@ struct Transformer : ir::TreeTransformer {
     } else if (Is<SearchStmt>(node)) {  // root node
       CHECK(node->children.size() >= 2 && node->children.size() <= 5);
 
-      auto index = 
Node::As<ir::IndexRef>(GET_OR_RET(Transform(node->children[1])));
-      auto select = 
Node::As<ir::SelectExpr>(GET_OR_RET(Transform(node->children[0])));
+      auto index = 
Node::MustAs<ir::IndexRef>(GET_OR_RET(Transform(node->children[1])));
+      auto select = 
Node::MustAs<ir::SelectExpr>(GET_OR_RET(Transform(node->children[0])));
 
       std::unique_ptr<ir::QueryExpr> query_expr;
       std::unique_ptr<ir::Limit> limit;
@@ -155,11 +156,11 @@ struct Transformer : ir::TreeTransformer {
 
       for (size_t i = 2; i < node->children.size(); ++i) {
         if (Is<WhereClause>(node->children[i])) {
-          query_expr = 
Node::As<ir::QueryExpr>(GET_OR_RET(Transform(node->children[i])));
+          query_expr = 
Node::MustAs<ir::QueryExpr>(GET_OR_RET(Transform(node->children[i])));
         } else if (Is<LimitClause>(node->children[i])) {
-          limit = 
Node::As<ir::Limit>(GET_OR_RET(Transform(node->children[i])));
+          limit = 
Node::MustAs<ir::Limit>(GET_OR_RET(Transform(node->children[i])));
         } else if (Is<OrderByClause>(node->children[i])) {
-          sort_by = 
Node::As<ir::SortBy>(GET_OR_RET(Transform(node->children[i])));
+          sort_by = 
Node::MustAs<ir::SortBy>(GET_OR_RET(Transform(node->children[i])));
         }
       }
 
diff --git a/tests/cppunit/ir_dot_dumper_test.cc 
b/tests/cppunit/ir_dot_dumper_test.cc
new file mode 100644
index 00000000..0310eb21
--- /dev/null
+++ b/tests/cppunit/ir_dot_dumper_test.cc
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+#include "search/ir_dot_dumper.h"
+
+#include <regex>
+#include <sstream>
+
+#include "gtest/gtest.h"
+#include "search/sql_transformer.h"
+
+using namespace kqir;
+
+static auto Parse(const std::string& in) { return 
sql::ParseToIR(peg::string_input(in, "test")); }
+
+TEST(DotDumperTest, Simple) {
+  auto ir = *Parse("select a from b where c = 1 or d hastag \"x\" and 2 <= e 
order by e asc limit 0, 10");
+
+  std::stringstream ss;
+  DotDumper dumper{ss};
+
+  dumper.Dump(ir.get());
+
+  std::string dot = ss.str();
+  std::smatch matches;
+
+  std::regex_search(dot, matches, std::regex(R"((\w+) \[ label = 
"SearchStmt)"));
+  auto search_stmt = matches[1].str();
+
+  std::regex_search(dot, matches, std::regex(R"((\w+) \[ label = "OrExpr)"));
+  auto or_expr = matches[1].str();
+
+  std::regex_search(dot, matches, std::regex(R"((\w+) \[ label = "AndExpr)"));
+  auto and_expr = matches[1].str();
+
+  ASSERT_NE(dot.find(fmt::format("{} -> {}", search_stmt, or_expr)), 
std::string::npos);
+  ASSERT_NE(dot.find(fmt::format("{} -> {}", or_expr, and_expr)), 
std::string::npos);
+}

Reply via email to