This is an automated email from the ASF dual-hosted git repository.
twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git
The following commit(s) were added to refs/heads/unstable by this push:
new 35c9e937 Add IR visitor and boolean simplification pass (#2211)
35c9e937 is described below
commit 35c9e937341a82e494be678bc56e518ab71fff2e
Author: Twice <[email protected]>
AuthorDate: Sun Mar 31 18:25:01 2024 +0900
Add IR visitor and boolean simplification pass (#2211)
---
.github/workflows/kvrocks.yaml | 4 +-
src/search/ir.h | 11 +--
src/search/ir_pass.h | 144 +++++++++++++++++++++++++++++++++++
src/search/passes/simplify_boolean.h | 92 ++++++++++++++++++++++
tests/cppunit/ir_pass_test.cc | 64 ++++++++++++++++
5 files changed, 308 insertions(+), 7 deletions(-)
diff --git a/.github/workflows/kvrocks.yaml b/.github/workflows/kvrocks.yaml
index 675a23b6..6c2b51c9 100644
--- a/.github/workflows/kvrocks.yaml
+++ b/.github/workflows/kvrocks.yaml
@@ -175,7 +175,7 @@ jobs:
compiler: gcc
with_openssl: -DENABLE_OPENSSL=ON
- name: Ubuntu Clang with OpenSSL
- os: ubuntu-20.04
+ os: ubuntu-22.04
compiler: clang
with_openssl: -DENABLE_OPENSSL=ON
- name: Ubuntu GCC without luaJIT
@@ -191,7 +191,7 @@ jobs:
compiler: gcc
new_encoding: -DENABLE_NEW_ENCODING=TRUE
- name: Ubuntu Clang with new encoding
- os: ubuntu-20.04
+ os: ubuntu-22.04
compiler: clang
new_encoding: -DENABLE_NEW_ENCODING=TRUE
- name: Ubuntu GCC with speedb enabled
diff --git a/src/search/ir.h b/src/search/ir.h
index 02f3766a..6e3dea28 100644
--- a/src/search/ir.h
+++ b/src/search/ir.h
@@ -55,16 +55,17 @@ struct Node {
return std::unique_ptr<U>(new T(std::forward<Args>(args)...));
}
- template <typename T>
- static std::unique_ptr<T> MustAs(std::unique_ptr<Node> &&original) {
+ template <typename T, typename U>
+ static std::unique_ptr<T> MustAs(std::unique_ptr<U> &&original) {
auto casted = As<T>(std::move(original));
CHECK(casted != nullptr);
return casted;
}
- template <typename T>
- static std::unique_ptr<T> As(std::unique_ptr<Node> &&original) {
- auto casted = dynamic_cast<T *>(original.release());
+ template <typename T, typename U>
+ static std::unique_ptr<T> As(std::unique_ptr<U> &&original) {
+ auto casted = dynamic_cast<T *>(original.get());
+ if (casted) original.release();
return std::unique_ptr<T>(casted);
}
};
diff --git a/src/search/ir_pass.h b/src/search/ir_pass.h
new file mode 100644
index 00000000..9a67530a
--- /dev/null
+++ b/src/search/ir_pass.h
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+#pragma once
+
+#include "ir.h"
+
+namespace kqir {
+
+struct Pass {
+ virtual std::unique_ptr<Node> Transform(std::unique_ptr<Node> node) = 0;
+};
+
+struct Visitor : Pass {
+ std::unique_ptr<Node> Transform(std::unique_ptr<Node> node) override {
+ if (auto v = Node::As<SearchStmt>(std::move(node))) {
+ return Visit(std::move(v));
+ } else if (auto v = Node::As<SelectExpr>(std::move(node))) {
+ return Visit(std::move(v));
+ } else if (auto v = Node::As<IndexRef>(std::move(node))) {
+ return Visit(std::move(v));
+ } else if (auto v = Node::As<Limit>(std::move(node))) {
+ return Visit(std::move(v));
+ } else if (auto v = Node::As<SortBy>(std::move(node))) {
+ return Visit(std::move(v));
+ } else if (auto v = Node::As<AndExpr>(std::move(node))) {
+ return Visit(std::move(v));
+ } else if (auto v = Node::As<OrExpr>(std::move(node))) {
+ return Visit(std::move(v));
+ } else if (auto v = Node::As<NotExpr>(std::move(node))) {
+ return Visit(std::move(v));
+ } else if (auto v = Node::As<NumericCompareExpr>(std::move(node))) {
+ return Visit(std::move(v));
+ } else if (auto v = Node::As<NumericLiteral>(std::move(node))) {
+ return Visit(std::move(v));
+ } else if (auto v = Node::As<FieldRef>(std::move(node))) {
+ return Visit(std::move(v));
+ } else if (auto v = Node::As<TagContainExpr>(std::move(node))) {
+ return Visit(std::move(v));
+ } else if (auto v = Node::As<StringLiteral>(std::move(node))) {
+ return Visit(std::move(v));
+ } else if (auto v = Node::As<BoolLiteral>(std::move(node))) {
+ return Visit(std::move(v));
+ }
+
+ __builtin_unreachable();
+ }
+
+ template <typename T>
+ std::unique_ptr<T> VisitAs(std::unique_ptr<T> n) {
+ return Node::MustAs<T>(Visit(std::move(n)));
+ }
+
+ template <typename T>
+ std::unique_ptr<T> TransformAs(std::unique_ptr<Node> n) {
+ return Node::MustAs<T>(Transform(std::move(n)));
+ }
+
+ virtual std::unique_ptr<Node> Visit(std::unique_ptr<SearchStmt> node) {
+ node->index = VisitAs<IndexRef>(std::move(node->index));
+ node->select_expr = VisitAs<SelectExpr>(std::move(node->select_expr));
+ if (node->query_expr) node->query_expr =
TransformAs<QueryExpr>(std::move(node->query_expr));
+ if (node->sort_by) node->sort_by =
VisitAs<SortBy>(std::move(node->sort_by));
+ if (node->limit) node->limit = VisitAs<Limit>(std::move(node->limit));
+ return node;
+ }
+
+ virtual std::unique_ptr<Node> Visit(std::unique_ptr<SelectExpr> node) {
+ for (auto &n : node->fields) {
+ n = VisitAs<FieldRef>(std::move(n));
+ }
+
+ return node;
+ }
+
+ virtual std::unique_ptr<Node> Visit(std::unique_ptr<IndexRef> node) { return
node; }
+
+ virtual std::unique_ptr<Node> Visit(std::unique_ptr<FieldRef> node) { return
node; }
+
+ virtual std::unique_ptr<Node> Visit(std::unique_ptr<BoolLiteral> node) {
return node; }
+
+ virtual std::unique_ptr<Node> Visit(std::unique_ptr<StringLiteral> node) {
return node; }
+
+ virtual std::unique_ptr<Node> Visit(std::unique_ptr<NumericLiteral> node) {
return node; }
+
+ virtual std::unique_ptr<Node> Visit(std::unique_ptr<NumericCompareExpr>
node) {
+ node->field = VisitAs<FieldRef>(std::move(node->field));
+ node->num = VisitAs<NumericLiteral>(std::move(node->num));
+ return node;
+ }
+
+ virtual std::unique_ptr<Node> Visit(std::unique_ptr<TagContainExpr> node) {
+ node->field = VisitAs<FieldRef>(std::move(node->field));
+ node->tag = VisitAs<StringLiteral>(std::move(node->tag));
+ return node;
+ }
+
+ virtual std::unique_ptr<Node> Visit(std::unique_ptr<AndExpr> node) {
+ for (auto &n : node->inners) {
+ n = TransformAs<QueryExpr>(std::move(n));
+ }
+
+ return node;
+ }
+
+ virtual std::unique_ptr<Node> Visit(std::unique_ptr<OrExpr> node) {
+ for (auto &n : node->inners) {
+ n = TransformAs<QueryExpr>(std::move(n));
+ }
+
+ return node;
+ }
+
+ virtual std::unique_ptr<Node> Visit(std::unique_ptr<NotExpr> node) {
+ node->inner = TransformAs<QueryExpr>(std::move(node->inner));
+ return node;
+ }
+
+ virtual std::unique_ptr<Node> Visit(std::unique_ptr<Limit> node) { return
node; }
+
+ virtual std::unique_ptr<Node> Visit(std::unique_ptr<SortBy> node) {
+ node->field = VisitAs<FieldRef>(std::move(node->field));
+ return node;
+ }
+};
+
+} // namespace kqir
diff --git a/src/search/passes/simplify_boolean.h
b/src/search/passes/simplify_boolean.h
new file mode 100644
index 00000000..99229c2f
--- /dev/null
+++ b/src/search/passes/simplify_boolean.h
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+#pragma once
+
+#include <iostream>
+#include <memory>
+
+#include "search/ir.h"
+#include "search/ir_pass.h"
+
+namespace kqir {
+
+struct SimplifyBoolean : Visitor {
+ std::unique_ptr<Node> Visit(std::unique_ptr<OrExpr> node) override {
+ node = Node::MustAs<OrExpr>(Visitor::Visit(std::move(node)));
+
+ for (auto iter = node->inners.begin(); iter != node->inners.end();) {
+ if (auto v = Node::As<BoolLiteral>(std::move(*iter))) {
+ if (!v->val) {
+ iter = node->inners.erase(iter);
+ } else {
+ return v;
+ }
+ } else {
+ ++iter;
+ }
+ }
+
+ if (node->inners.size() == 0) {
+ return std::make_unique<BoolLiteral>(false);
+ } else if (node->inners.size() == 1) {
+ return std::move(node->inners[0]);
+ }
+
+ return node;
+ }
+
+ std::unique_ptr<Node> Visit(std::unique_ptr<AndExpr> node) override {
+ node = Node::MustAs<AndExpr>(Visitor::Visit(std::move(node)));
+
+ for (auto iter = node->inners.begin(); iter != node->inners.end();) {
+ if (auto v = Node::As<BoolLiteral>(std::move(*iter))) {
+ if (v->val) {
+ iter = node->inners.erase(iter);
+ } else {
+ return v;
+ }
+ } else {
+ ++iter;
+ }
+ }
+
+ if (node->inners.size() == 0) {
+ return std::make_unique<BoolLiteral>(true);
+ } else if (node->inners.size() == 1) {
+ return std::move(node->inners[0]);
+ }
+
+ return node;
+ }
+
+ std::unique_ptr<Node> Visit(std::unique_ptr<NotExpr> node) override {
+ node = Node::MustAs<NotExpr>(Visitor::Visit(std::move(node)));
+
+ if (auto v = Node::As<BoolLiteral>(std::move(node->inner))) {
+ v->val = !v->val;
+ return v;
+ }
+
+ return node;
+ }
+};
+
+} // namespace kqir
diff --git a/tests/cppunit/ir_pass_test.cc b/tests/cppunit/ir_pass_test.cc
new file mode 100644
index 00000000..9f9af8f6
--- /dev/null
+++ b/tests/cppunit/ir_pass_test.cc
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+#include "search/ir_pass.h"
+
+#include "gtest/gtest.h"
+#include "search/passes/simplify_boolean.h"
+#include "search/sql_transformer.h"
+
+using namespace kqir;
+
+static auto Parse(const std::string& in) { return
sql::ParseToIR(peg::string_input(in, "test")); }
+
+TEST(IRPassTest, Simple) {
+ auto ir = *Parse("select a from b where not c = 1 or d hastag \"x\" and 2 <=
e order by e asc limit 0, 10");
+
+ auto original = ir->Dump();
+
+ Visitor visitor;
+ auto ir2 = visitor.Transform(std::move(ir));
+ ASSERT_EQ(original, ir2->Dump());
+}
+
+TEST(IRPassTest, SimplifyBoolean) {
+ SimplifyBoolean sb;
+ ASSERT_EQ(sb.Transform(*Parse("select a from b where not false"))->Dump(),
"select a from b where true");
+ ASSERT_EQ(sb.Transform(*Parse("select a from b where not not
false"))->Dump(), "select a from b where false");
+ ASSERT_EQ(sb.Transform(*Parse("select a from b where true and
true"))->Dump(), "select a from b where true");
+ ASSERT_EQ(sb.Transform(*Parse("select a from b where true and
false"))->Dump(), "select a from b where false");
+ ASSERT_EQ(sb.Transform(*Parse("select a from b where false and
true"))->Dump(), "select a from b where false");
+ ASSERT_EQ(sb.Transform(*Parse("select a from b where true and false and
true"))->Dump(),
+ "select a from b where false");
+ ASSERT_EQ(sb.Transform(*Parse("select a from b where true and true and
true"))->Dump(), "select a from b where true");
+ ASSERT_EQ(sb.Transform(*Parse("select a from b where x > 1 and
false"))->Dump(), "select a from b where false");
+ ASSERT_EQ(sb.Transform(*Parse("select a from b where x > 1 and
true"))->Dump(), "select a from b where x > 1");
+ ASSERT_EQ(sb.Transform(*Parse("select a from b where x > 1 and true and y <
10"))->Dump(),
+ "select a from b where (and x > 1, y < 10)");
+ ASSERT_EQ(sb.Transform(*Parse("select a from b where not (false and (not
true))"))->Dump(),
+ "select a from b where true");
+ ASSERT_EQ(sb.Transform(*Parse("select a from b where true or
true"))->Dump(), "select a from b where true");
+ ASSERT_EQ(sb.Transform(*Parse("select a from b where true or
false"))->Dump(), "select a from b where true");
+ ASSERT_EQ(sb.Transform(*Parse("select a from b where false or
true"))->Dump(), "select a from b where true");
+ ASSERT_EQ(sb.Transform(*Parse("select a from b where true or false or
true"))->Dump(), "select a from b where true");
+ ASSERT_EQ(sb.Transform(*Parse("select a from b where true or false or
true"))->Dump(), "select a from b where true");
+ ASSERT_EQ(sb.Transform(*Parse("select a from b where not ((x < 1 or true)
and (y > 2 and true))"))->Dump(),
+ "select a from b where not y > 2");
+}