This is an automated email from the ASF dual-hosted git repository.
twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git
The following commit(s) were added to refs/heads/unstable by this push:
new 7e063030 Add support for dumping DOT graphs from KQIR (#2205)
7e063030 is described below
commit 7e063030519b1b5a5367a5921a0a5ee2b389d0b7
Author: Twice <[email protected]>
AuthorDate: Fri Mar 29 18:22:47 2024 +0900
Add support for dumping DOT graphs from KQIR (#2205)
---
src/common/type_util.h | 8 ++
src/search/ir.h | 79 +++++++++++++++++--
src/{common/type_util.h => search/ir_dot_dumper.h} | 41 ++++++----
src/search/ir_iterator.h | 92 ++++++++++++++++++++++
src/search/redis_query_transformer.h | 30 +++----
src/search/sql_transformer.h | 25 +++---
tests/cppunit/ir_dot_dumper_test.cc | 55 +++++++++++++
7 files changed, 284 insertions(+), 46 deletions(-)
diff --git a/src/common/type_util.h b/src/common/type_util.h
index d55019f7..98439243 100644
--- a/src/common/type_util.h
+++ b/src/common/type_util.h
@@ -39,3 +39,11 @@ using RemoveCVRef = typename std::remove_cv_t<typename
std::remove_reference_t<T
// dependent false for static_assert with constexpr if, see CWG2518/P2593R1
template <typename T>
constexpr bool AlwaysFalse = false;
+
+template <typename>
+struct GetClassFromMember;
+
+template <typename C, typename T>
+struct GetClassFromMember<T C::*> {
+ using type = C; // NOLINT
+};
diff --git a/src/search/ir.h b/src/search/ir.h
index 87ed6488..02f3766a 100644
--- a/src/search/ir.h
+++ b/src/search/ir.h
@@ -22,22 +22,32 @@
#include <fmt/format.h>
+#include <initializer_list>
#include <limits>
#include <memory>
#include <optional>
#include <string>
+#include <type_traits>
#include <utility>
#include <variant>
#include <vector>
#include "fmt/core.h"
+#include "ir_iterator.h"
#include "string_util.h"
+#include "type_util.h"
// kqir stands for Kvorcks Query Intermediate Representation
namespace kqir {
struct Node {
virtual std::string Dump() const = 0;
+ virtual std::string_view Name() const = 0;
+ virtual std::string Content() const { return {}; }
+
+ virtual NodeIterator ChildBegin() { return {}; };
+ virtual NodeIterator ChildEnd() { return {}; };
+
virtual ~Node() = default;
template <typename T, typename U = Node, typename... Args>
@@ -45,10 +55,16 @@ struct Node {
return std::unique_ptr<U>(new T(std::forward<Args>(args)...));
}
+ template <typename T>
+ static std::unique_ptr<T> MustAs(std::unique_ptr<Node> &&original) {
+ auto casted = As<T>(std::move(original));
+ CHECK(casted != nullptr);
+ return casted;
+ }
+
template <typename T>
static std::unique_ptr<T> As(std::unique_ptr<Node> &&original) {
auto casted = dynamic_cast<T *>(original.release());
- CHECK(casted);
return std::unique_ptr<T>(casted);
}
};
@@ -58,7 +74,9 @@ struct FieldRef : Node {
explicit FieldRef(std::string name) : name(std::move(name)) {}
+ std::string_view Name() const override { return "FieldRef"; }
std::string Dump() const override { return name; }
+ std::string Content() const override { return Dump(); }
};
struct StringLiteral : Node {
@@ -66,7 +84,9 @@ struct StringLiteral : Node {
explicit StringLiteral(std::string val) : val(std::move(val)) {}
+ std::string_view Name() const override { return "StringLiteral"; }
std::string Dump() const override { return fmt::format("\"{}\"",
util::EscapeString(val)); }
+ std::string Content() const override { return Dump(); }
};
struct QueryExpr : Node {};
@@ -80,7 +100,11 @@ struct TagContainExpr : BoolAtomExpr {
TagContainExpr(std::unique_ptr<FieldRef> &&field,
std::unique_ptr<StringLiteral> &&tag)
: field(std::move(field)), tag(std::move(tag)) {}
+ std::string_view Name() const override { return "TagContainExpr"; }
std::string Dump() const override { return fmt::format("{} hastag {}",
field->Dump(), tag->Dump()); }
+
+ NodeIterator ChildBegin() override { return {field.get(), tag.get()}; };
+ NodeIterator ChildEnd() override { return {}; };
};
struct NumericLiteral : Node {
@@ -88,7 +112,9 @@ struct NumericLiteral : Node {
explicit NumericLiteral(double val) : val(val) {}
+ std::string_view Name() const override { return "NumericLiteral"; }
std::string Dump() const override { return fmt::format("{}", val); }
+ std::string Content() const override { return Dump(); }
};
// NOLINTNEXTLINE
@@ -156,7 +182,12 @@ struct NumericCompareExpr : BoolAtomExpr {
__builtin_unreachable();
}
+ std::string_view Name() const override { return "NumericCompareExpr"; }
std::string Dump() const override { return fmt::format("{} {} {}",
field->Dump(), ToOperator(op), num->Dump()); };
+ std::string Content() const override { return ToOperator(op); }
+
+ NodeIterator ChildBegin() override { return {field.get(), num.get()}; };
+ NodeIterator ChildEnd() override { return {}; };
};
struct BoolLiteral : BoolAtomExpr {
@@ -164,7 +195,9 @@ struct BoolLiteral : BoolAtomExpr {
explicit BoolLiteral(bool val) : val(val) {}
+ std::string_view Name() const override { return "BoolLiteral"; }
std::string Dump() const override { return val ? "true" : "false"; }
+ std::string Content() const override { return Dump(); }
};
struct QueryExpr;
@@ -174,7 +207,11 @@ struct NotExpr : QueryExpr {
explicit NotExpr(std::unique_ptr<QueryExpr> &&inner) :
inner(std::move(inner)) {}
+ std::string_view Name() const override { return "NotExpr"; }
std::string Dump() const override { return fmt::format("not {}",
inner->Dump()); }
+
+ NodeIterator ChildBegin() override { return NodeIterator{inner.get()}; };
+ NodeIterator ChildEnd() override { return {}; };
};
struct AndExpr : QueryExpr {
@@ -182,9 +219,13 @@ struct AndExpr : QueryExpr {
explicit AndExpr(std::vector<std::unique_ptr<QueryExpr>> &&inners) :
inners(std::move(inners)) {}
+ std::string_view Name() const override { return "AndExpr"; }
std::string Dump() const override {
return fmt::format("(and {})", util::StringJoin(inners, [](const auto &v)
{ return v->Dump(); }));
}
+
+ NodeIterator ChildBegin() override { return NodeIterator(inners.begin()); };
+ NodeIterator ChildEnd() override { return NodeIterator(inners.end()); };
};
struct OrExpr : QueryExpr {
@@ -192,9 +233,13 @@ struct OrExpr : QueryExpr {
explicit OrExpr(std::vector<std::unique_ptr<QueryExpr>> &&inners) :
inners(std::move(inners)) {}
+ std::string_view Name() const override { return "OrExpr"; }
std::string Dump() const override {
return fmt::format("(or {})", util::StringJoin(inners, [](const auto &v) {
return v->Dump(); }));
}
+
+ NodeIterator ChildBegin() override { return NodeIterator(inners.begin()); };
+ NodeIterator ChildEnd() override { return NodeIterator(inners.end()); };
};
struct Limit : Node {
@@ -203,7 +248,9 @@ struct Limit : Node {
Limit(size_t offset, size_t count) : offset(offset), count(count) {}
+ std::string_view Name() const override { return "Limit"; }
std::string Dump() const override { return fmt::format("limit {}, {}",
offset, count); }
+ std::string Content() const override { return fmt::format("{}, {}", offset,
count); }
};
struct SortBy : Node {
@@ -213,7 +260,13 @@ struct SortBy : Node {
SortBy(Order order, std::unique_ptr<FieldRef> &&field) : order(order),
field(std::move(field)) {}
static constexpr const char *OrderToString(Order order) { return order ==
ASC ? "asc" : "desc"; }
+
+ std::string_view Name() const override { return "SortBy"; }
std::string Dump() const override { return fmt::format("sortby {}, {}",
field->Dump(), OrderToString(order)); }
+ std::string Content() const override { return OrderToString(order); }
+
+ NodeIterator ChildBegin() override { return NodeIterator(field.get()); };
+ NodeIterator ChildEnd() override { return {}; };
};
struct SelectExpr : Node {
@@ -221,10 +274,14 @@ struct SelectExpr : Node {
explicit SelectExpr(std::vector<std::unique_ptr<FieldRef>> &&fields) :
fields(std::move(fields)) {}
+ std::string_view Name() const override { return "SelectExpr"; }
std::string Dump() const override {
if (fields.empty()) return "select *";
return fmt::format("select {}", util::StringJoin(fields, [](const auto &v)
{ return v->Dump(); }));
}
+
+ NodeIterator ChildBegin() override { return NodeIterator(fields.begin()); };
+ NodeIterator ChildEnd() override { return NodeIterator(fields.end()); };
};
struct IndexRef : Node {
@@ -232,24 +289,27 @@ struct IndexRef : Node {
explicit IndexRef(std::string name) : name(std::move(name)) {}
+ std::string_view Name() const override { return "IndexRef"; }
std::string Dump() const override { return name; }
+ std::string Content() const override { return Dump(); }
};
struct SearchStmt : Node {
+ std::unique_ptr<SelectExpr> select_expr;
std::unique_ptr<IndexRef> index;
std::unique_ptr<QueryExpr> query_expr; // optional
std::unique_ptr<Limit> limit; // optional
std::unique_ptr<SortBy> sort_by; // optional
- std::unique_ptr<SelectExpr> select_expr;
SearchStmt(std::unique_ptr<IndexRef> &&index, std::unique_ptr<QueryExpr>
&&query_expr, std::unique_ptr<Limit> &&limit,
std::unique_ptr<SortBy> &&sort_by, std::unique_ptr<SelectExpr>
&&select_expr)
- : index(std::move(index)),
+ : select_expr(std::move(select_expr)),
+ index(std::move(index)),
query_expr(std::move(query_expr)),
limit(std::move(limit)),
- sort_by(std::move(sort_by)),
- select_expr(std::move(select_expr)) {}
+ sort_by(std::move(sort_by)) {}
+ std::string_view Name() const override { return "SearchStmt"; }
std::string Dump() const override {
std::string opt;
if (query_expr) opt += " where " + query_expr->Dump();
@@ -257,6 +317,15 @@ struct SearchStmt : Node {
if (limit) opt += " " + limit->Dump();
return fmt::format("{} from {}{}", select_expr->Dump(), index->Dump(),
opt);
}
+
+ static inline const std::vector<std::function<Node *(Node *)>> ChildMap = {
+ NodeIterator::MemFn<&SearchStmt::select_expr>,
NodeIterator::MemFn<&SearchStmt::index>,
+ NodeIterator::MemFn<&SearchStmt::query_expr>,
NodeIterator::MemFn<&SearchStmt::limit>,
+ NodeIterator::MemFn<&SearchStmt::sort_by>,
+ };
+
+ NodeIterator ChildBegin() override { return NodeIterator(this,
ChildMap.begin()); };
+ NodeIterator ChildEnd() override { return NodeIterator(this,
ChildMap.end()); };
};
} // namespace kqir
diff --git a/src/common/type_util.h b/src/search/ir_dot_dumper.h
similarity index 53%
copy from src/common/type_util.h
copy to src/search/ir_dot_dumper.h
index d55019f7..5bb6dc7b 100644
--- a/src/common/type_util.h
+++ b/src/search/ir_dot_dumper.h
@@ -20,22 +20,35 @@
#pragma once
-#include <utility>
+#include "ir.h"
+#include "string_util.h"
-template <typename F, F *f>
-struct StaticFunction {
- template <typename... Ts>
- auto operator()(Ts &&...args) const ->
decltype(f(std::forward<Ts>(args)...)) { // NOLINT
- return f(std::forward<Ts>(args)...);
// NOLINT
+namespace kqir {
+
+struct DotDumper {
+ std::ostream &os;
+
+ void Dump(Node *node) {
+ os << "digraph {\n";
+ dump(node);
+ os << "}\n";
}
-};
-template <typename... Ts>
-using FirstElement = typename std::tuple_element_t<0, std::tuple<Ts...>>;
+ private:
+ static std::string nodeId(Node *node) { return fmt::format("x{:x}",
(uint64_t)node); }
-template <typename T>
-using RemoveCVRef = typename std::remove_cv_t<typename
std::remove_reference_t<T>>;
+ void dump(Node *node) {
+ os << " " << nodeId(node) << " [ label = \"" << node->Name();
+ if (auto content = node->Content(); !content.empty()) {
+ os << " (" << util::EscapeString(content) << ")\" ];\n";
+ } else {
+ os << "\" ];\n";
+ }
+ for (auto i = node->ChildBegin(); i != node->ChildEnd(); ++i) {
+ os << " " << nodeId(node) << " -> " << nodeId(*i) << ";\n";
+ dump(*i);
+ }
+ }
+};
-// dependent false for static_assert with constexpr if, see CWG2518/P2593R1
-template <typename T>
-constexpr bool AlwaysFalse = false;
+} // namespace kqir
diff --git a/src/search/ir_iterator.h b/src/search/ir_iterator.h
new file mode 100644
index 00000000..2ead473c
--- /dev/null
+++ b/src/search/ir_iterator.h
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+#pragma once
+
+#include <functional>
+#include <memory>
+#include <variant>
+#include <vector>
+
+#include "type_util.h"
+
+namespace kqir {
+
+struct Node;
+
+struct NodeIterator {
+ std::variant<Node *, std::array<Node *, 2>,
+ std::pair<Node *, std::vector<std::function<Node *(Node
*)>>::const_iterator>,
+ std::vector<std::unique_ptr<Node>>::iterator>
+ val;
+
+ NodeIterator() : val(nullptr) {}
+ explicit NodeIterator(Node *node) : val(node) {}
+ NodeIterator(Node *n1, Node *n2) : val(std::array<Node *, 2>{n1, n2}) {}
+ explicit NodeIterator(Node *parent, std::vector<std::function<Node *(Node
*)>>::const_iterator iter)
+ : val(std::make_pair(parent, iter)) {}
+ template <typename Iterator,
+ std::enable_if_t<std::is_base_of_v<Node, typename
Iterator::value_type::element_type>, int> = 0>
+ explicit NodeIterator(Iterator iter) : val(*CastToNodeIter(&iter)) {}
+
+ template <typename Iterator>
+ static auto CastToNodeIter(Iterator *iter) {
+ auto res __attribute__((__may_alias__)) =
reinterpret_cast<std::vector<std::unique_ptr<Node>>::iterator *>(iter);
+ return res;
+ }
+
+ template <auto F>
+ static Node *MemFn(Node *parent) {
+ return (reinterpret_cast<typename GetClassFromMember<decltype(F)>::type
*>(parent)->*F).get();
+ }
+
+ friend bool operator==(NodeIterator l, NodeIterator r) { return l.val ==
r.val; }
+
+ friend bool operator!=(NodeIterator l, NodeIterator r) { return l.val !=
r.val; }
+
+ Node *operator*() {
+ if (val.index() == 0) {
+ return std::get<0>(val);
+ } else if (val.index() == 1) {
+ return std::get<1>(val)[0];
+ } else if (val.index() == 2) {
+ auto &[parent, iter] = std::get<2>(val);
+ return (*iter)(parent);
+ } else {
+ return std::get<3>(val)->get();
+ }
+ }
+
+ NodeIterator &operator++() {
+ if (val.index() == 0) {
+ val = nullptr;
+ } else if (val.index() == 1) {
+ val = std::get<1>(val)[1];
+ } else if (val.index() == 2) {
+ ++std::get<2>(val).second;
+ } else {
+ ++std::get<3>(val);
+ }
+
+ return *this;
+ }
+};
+
+} // namespace kqir
diff --git a/src/search/redis_query_transformer.h
b/src/search/redis_query_transformer.h
index ebd05143..45734b31 100644
--- a/src/search/redis_query_transformer.h
+++ b/src/search/redis_query_transformer.h
@@ -83,13 +83,13 @@ struct Transformer : ir::TreeTransformer {
const auto& rhs = query->children[1];
if (Is<ExclusiveNumber>(lhs)) {
- exprs.push_back(
- std::make_unique<NumericCompareExpr>(NumericCompareExpr::GT,
std::make_unique<FieldRef>(field),
-
Node::As<NumericLiteral>(GET_OR_RET(Transform(lhs->children[0])))));
+ exprs.push_back(std::make_unique<NumericCompareExpr>(
+ NumericCompareExpr::GT, std::make_unique<FieldRef>(field),
+
Node::MustAs<NumericLiteral>(GET_OR_RET(Transform(lhs->children[0])))));
} else if (Is<Number>(lhs)) {
-
exprs.push_back(std::make_unique<NumericCompareExpr>(NumericCompareExpr::GET,
-
std::make_unique<FieldRef>(field),
-
Node::As<NumericLiteral>(GET_OR_RET(Transform(lhs)))));
+ exprs.push_back(
+ std::make_unique<NumericCompareExpr>(NumericCompareExpr::GET,
std::make_unique<FieldRef>(field),
+
Node::MustAs<NumericLiteral>(GET_OR_RET(Transform(lhs)))));
} else { // Inf
if (lhs->string_view() == "+inf") {
return {Status::NotOK, "it's not allowed to set the lower bound as
positive infinity"};
@@ -97,13 +97,13 @@ struct Transformer : ir::TreeTransformer {
}
if (Is<ExclusiveNumber>(rhs)) {
- exprs.push_back(
- std::make_unique<NumericCompareExpr>(NumericCompareExpr::LT,
std::make_unique<FieldRef>(field),
-
Node::As<NumericLiteral>(GET_OR_RET(Transform(rhs->children[0])))));
+ exprs.push_back(std::make_unique<NumericCompareExpr>(
+ NumericCompareExpr::LT, std::make_unique<FieldRef>(field),
+
Node::MustAs<NumericLiteral>(GET_OR_RET(Transform(rhs->children[0])))));
} else if (Is<Number>(rhs)) {
-
exprs.push_back(std::make_unique<NumericCompareExpr>(NumericCompareExpr::LET,
-
std::make_unique<FieldRef>(field),
-
Node::As<NumericLiteral>(GET_OR_RET(Transform(rhs)))));
+ exprs.push_back(
+ std::make_unique<NumericCompareExpr>(NumericCompareExpr::LET,
std::make_unique<FieldRef>(field),
+
Node::MustAs<NumericLiteral>(GET_OR_RET(Transform(rhs)))));
} else { // Inf
if (rhs->string_view() == "-inf") {
return {Status::NotOK, "it's not allowed to set the upper bound as
negative infinity"};
@@ -121,12 +121,12 @@ struct Transformer : ir::TreeTransformer {
} else if (Is<NotExpr>(node)) {
CHECK(node->children.size() == 1);
- return
Node::Create<ir::NotExpr>(Node::As<ir::QueryExpr>(GET_OR_RET(Transform(node->children[0]))));
+ return
Node::Create<ir::NotExpr>(Node::MustAs<ir::QueryExpr>(GET_OR_RET(Transform(node->children[0]))));
} else if (Is<AndExpr>(node)) {
std::vector<std::unique_ptr<ir::QueryExpr>> exprs;
for (const auto& child : node->children) {
- exprs.push_back(Node::As<ir::QueryExpr>(GET_OR_RET(Transform(child))));
+
exprs.push_back(Node::MustAs<ir::QueryExpr>(GET_OR_RET(Transform(child))));
}
return Node::Create<ir::AndExpr>(std::move(exprs));
@@ -134,7 +134,7 @@ struct Transformer : ir::TreeTransformer {
std::vector<std::unique_ptr<ir::QueryExpr>> exprs;
for (const auto& child : node->children) {
- exprs.push_back(Node::As<ir::QueryExpr>(GET_OR_RET(Transform(child))));
+
exprs.push_back(Node::MustAs<ir::QueryExpr>(GET_OR_RET(Transform(child))));
}
return Node::Create<ir::OrExpr>(std::move(exprs));
diff --git a/src/search/sql_transformer.h b/src/search/sql_transformer.h
index 5c300654..63397e40 100644
--- a/src/search/sql_transformer.h
+++ b/src/search/sql_transformer.h
@@ -63,8 +63,9 @@ struct Transformer : ir::TreeTransformer {
} else if (Is<HasTagExpr>(node)) {
CHECK(node->children.size() == 2);
- return
Node::Create<ir::TagContainExpr>(std::make_unique<ir::FieldRef>(node->children[0]->string()),
-
Node::As<ir::StringLiteral>(GET_OR_RET(Transform(node->children[1]))));
+ return Node::Create<ir::TagContainExpr>(
+ std::make_unique<ir::FieldRef>(node->children[0]->string()),
+
Node::MustAs<ir::StringLiteral>(GET_OR_RET(Transform(node->children[1]))));
} else if (Is<NumericCompareExpr>(node)) {
CHECK(node->children.size() == 3);
@@ -74,23 +75,23 @@ struct Transformer : ir::TreeTransformer {
auto op =
ir::NumericCompareExpr::FromOperator(node->children[1]->string_view()).value();
if (Is<Identifier>(lhs) && Is<Number>(rhs)) {
return Node::Create<ir::NumericCompareExpr>(op,
std::make_unique<ir::FieldRef>(lhs->string()),
-
Node::As<ir::NumericLiteral>(GET_OR_RET(Transform(rhs))));
+
Node::MustAs<ir::NumericLiteral>(GET_OR_RET(Transform(rhs))));
} else if (Is<Number>(lhs) && Is<Identifier>(rhs)) {
return
Node::Create<ir::NumericCompareExpr>(ir::NumericCompareExpr::Flip(op),
std::make_unique<ir::FieldRef>(rhs->string()),
-
Node::As<ir::NumericLiteral>(GET_OR_RET(Transform(lhs))));
+
Node::MustAs<ir::NumericLiteral>(GET_OR_RET(Transform(lhs))));
} else {
return {Status::NotOK, "the left and right side of numeric comparison
should be an identifier and a number"};
}
} else if (Is<NotExpr>(node)) {
CHECK(node->children.size() == 1);
- return
Node::Create<ir::NotExpr>(Node::As<ir::QueryExpr>(GET_OR_RET(Transform(node->children[0]))));
+ return
Node::Create<ir::NotExpr>(Node::MustAs<ir::QueryExpr>(GET_OR_RET(Transform(node->children[0]))));
} else if (Is<AndExpr>(node)) {
std::vector<std::unique_ptr<ir::QueryExpr>> exprs;
for (const auto& child : node->children) {
- exprs.push_back(Node::As<ir::QueryExpr>(GET_OR_RET(Transform(child))));
+
exprs.push_back(Node::MustAs<ir::QueryExpr>(GET_OR_RET(Transform(child))));
}
return Node::Create<ir::AndExpr>(std::move(exprs));
@@ -98,7 +99,7 @@ struct Transformer : ir::TreeTransformer {
std::vector<std::unique_ptr<ir::QueryExpr>> exprs;
for (const auto& child : node->children) {
- exprs.push_back(Node::As<ir::QueryExpr>(GET_OR_RET(Transform(child))));
+
exprs.push_back(Node::MustAs<ir::QueryExpr>(GET_OR_RET(Transform(child))));
}
return Node::Create<ir::OrExpr>(std::move(exprs));
@@ -146,8 +147,8 @@ struct Transformer : ir::TreeTransformer {
} else if (Is<SearchStmt>(node)) { // root node
CHECK(node->children.size() >= 2 && node->children.size() <= 5);
- auto index =
Node::As<ir::IndexRef>(GET_OR_RET(Transform(node->children[1])));
- auto select =
Node::As<ir::SelectExpr>(GET_OR_RET(Transform(node->children[0])));
+ auto index =
Node::MustAs<ir::IndexRef>(GET_OR_RET(Transform(node->children[1])));
+ auto select =
Node::MustAs<ir::SelectExpr>(GET_OR_RET(Transform(node->children[0])));
std::unique_ptr<ir::QueryExpr> query_expr;
std::unique_ptr<ir::Limit> limit;
@@ -155,11 +156,11 @@ struct Transformer : ir::TreeTransformer {
for (size_t i = 2; i < node->children.size(); ++i) {
if (Is<WhereClause>(node->children[i])) {
- query_expr =
Node::As<ir::QueryExpr>(GET_OR_RET(Transform(node->children[i])));
+ query_expr =
Node::MustAs<ir::QueryExpr>(GET_OR_RET(Transform(node->children[i])));
} else if (Is<LimitClause>(node->children[i])) {
- limit =
Node::As<ir::Limit>(GET_OR_RET(Transform(node->children[i])));
+ limit =
Node::MustAs<ir::Limit>(GET_OR_RET(Transform(node->children[i])));
} else if (Is<OrderByClause>(node->children[i])) {
- sort_by =
Node::As<ir::SortBy>(GET_OR_RET(Transform(node->children[i])));
+ sort_by =
Node::MustAs<ir::SortBy>(GET_OR_RET(Transform(node->children[i])));
}
}
diff --git a/tests/cppunit/ir_dot_dumper_test.cc
b/tests/cppunit/ir_dot_dumper_test.cc
new file mode 100644
index 00000000..0310eb21
--- /dev/null
+++ b/tests/cppunit/ir_dot_dumper_test.cc
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+#include "search/ir_dot_dumper.h"
+
+#include <regex>
+#include <sstream>
+
+#include "gtest/gtest.h"
+#include "search/sql_transformer.h"
+
+using namespace kqir;
+
+static auto Parse(const std::string& in) { return
sql::ParseToIR(peg::string_input(in, "test")); }
+
+TEST(DotDumperTest, Simple) {
+ auto ir = *Parse("select a from b where c = 1 or d hastag \"x\" and 2 <= e
order by e asc limit 0, 10");
+
+ std::stringstream ss;
+ DotDumper dumper{ss};
+
+ dumper.Dump(ir.get());
+
+ std::string dot = ss.str();
+ std::smatch matches;
+
+ std::regex_search(dot, matches, std::regex(R"((\w+) \[ label =
"SearchStmt)"));
+ auto search_stmt = matches[1].str();
+
+ std::regex_search(dot, matches, std::regex(R"((\w+) \[ label = "OrExpr)"));
+ auto or_expr = matches[1].str();
+
+ std::regex_search(dot, matches, std::regex(R"((\w+) \[ label = "AndExpr)"));
+ auto and_expr = matches[1].str();
+
+ ASSERT_NE(dot.find(fmt::format("{} -> {}", search_stmt, or_expr)),
std::string::npos);
+ ASSERT_NE(dot.find(fmt::format("{} -> {}", or_expr, and_expr)),
std::string::npos);
+}