This is an automated email from the ASF dual-hosted git repository.

twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new 35c9e937 Add IR visitor and boolean simplification pass (#2211)
35c9e937 is described below

commit 35c9e937341a82e494be678bc56e518ab71fff2e
Author: Twice <[email protected]>
AuthorDate: Sun Mar 31 18:25:01 2024 +0900

    Add IR visitor and boolean simplification pass (#2211)
---
 .github/workflows/kvrocks.yaml       |   4 +-
 src/search/ir.h                      |  11 +--
 src/search/ir_pass.h                 | 144 +++++++++++++++++++++++++++++++++++
 src/search/passes/simplify_boolean.h |  92 ++++++++++++++++++++++
 tests/cppunit/ir_pass_test.cc        |  64 ++++++++++++++++
 5 files changed, 308 insertions(+), 7 deletions(-)

diff --git a/.github/workflows/kvrocks.yaml b/.github/workflows/kvrocks.yaml
index 675a23b6..6c2b51c9 100644
--- a/.github/workflows/kvrocks.yaml
+++ b/.github/workflows/kvrocks.yaml
@@ -175,7 +175,7 @@ jobs:
             compiler: gcc
             with_openssl: -DENABLE_OPENSSL=ON
           - name: Ubuntu Clang with OpenSSL
-            os: ubuntu-20.04
+            os: ubuntu-22.04
             compiler: clang
             with_openssl: -DENABLE_OPENSSL=ON
           - name: Ubuntu GCC without luaJIT
@@ -191,7 +191,7 @@ jobs:
             compiler: gcc
             new_encoding: -DENABLE_NEW_ENCODING=TRUE
           - name: Ubuntu Clang with new encoding
-            os: ubuntu-20.04
+            os: ubuntu-22.04
             compiler: clang
             new_encoding: -DENABLE_NEW_ENCODING=TRUE
           - name: Ubuntu GCC with speedb enabled
diff --git a/src/search/ir.h b/src/search/ir.h
index 02f3766a..6e3dea28 100644
--- a/src/search/ir.h
+++ b/src/search/ir.h
@@ -55,16 +55,17 @@ struct Node {
     return std::unique_ptr<U>(new T(std::forward<Args>(args)...));
   }
 
-  template <typename T>
-  static std::unique_ptr<T> MustAs(std::unique_ptr<Node> &&original) {
+  template <typename T, typename U>
+  static std::unique_ptr<T> MustAs(std::unique_ptr<U> &&original) {
     auto casted = As<T>(std::move(original));
     CHECK(casted != nullptr);
     return casted;
   }
 
-  template <typename T>
-  static std::unique_ptr<T> As(std::unique_ptr<Node> &&original) {
-    auto casted = dynamic_cast<T *>(original.release());
+  template <typename T, typename U>
+  static std::unique_ptr<T> As(std::unique_ptr<U> &&original) {
+    auto casted = dynamic_cast<T *>(original.get());
+    if (casted) original.release();
     return std::unique_ptr<T>(casted);
   }
 };
diff --git a/src/search/ir_pass.h b/src/search/ir_pass.h
new file mode 100644
index 00000000..9a67530a
--- /dev/null
+++ b/src/search/ir_pass.h
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+#pragma once
+
+#include "ir.h"
+
+namespace kqir {
+
+struct Pass {
+  virtual std::unique_ptr<Node> Transform(std::unique_ptr<Node> node) = 0;
+};
+
+struct Visitor : Pass {
+  std::unique_ptr<Node> Transform(std::unique_ptr<Node> node) override {
+    if (auto v = Node::As<SearchStmt>(std::move(node))) {
+      return Visit(std::move(v));
+    } else if (auto v = Node::As<SelectExpr>(std::move(node))) {
+      return Visit(std::move(v));
+    } else if (auto v = Node::As<IndexRef>(std::move(node))) {
+      return Visit(std::move(v));
+    } else if (auto v = Node::As<Limit>(std::move(node))) {
+      return Visit(std::move(v));
+    } else if (auto v = Node::As<SortBy>(std::move(node))) {
+      return Visit(std::move(v));
+    } else if (auto v = Node::As<AndExpr>(std::move(node))) {
+      return Visit(std::move(v));
+    } else if (auto v = Node::As<OrExpr>(std::move(node))) {
+      return Visit(std::move(v));
+    } else if (auto v = Node::As<NotExpr>(std::move(node))) {
+      return Visit(std::move(v));
+    } else if (auto v = Node::As<NumericCompareExpr>(std::move(node))) {
+      return Visit(std::move(v));
+    } else if (auto v = Node::As<NumericLiteral>(std::move(node))) {
+      return Visit(std::move(v));
+    } else if (auto v = Node::As<FieldRef>(std::move(node))) {
+      return Visit(std::move(v));
+    } else if (auto v = Node::As<TagContainExpr>(std::move(node))) {
+      return Visit(std::move(v));
+    } else if (auto v = Node::As<StringLiteral>(std::move(node))) {
+      return Visit(std::move(v));
+    } else if (auto v = Node::As<BoolLiteral>(std::move(node))) {
+      return Visit(std::move(v));
+    }
+
+    __builtin_unreachable();
+  }
+
+  template <typename T>
+  std::unique_ptr<T> VisitAs(std::unique_ptr<T> n) {
+    return Node::MustAs<T>(Visit(std::move(n)));
+  }
+
+  template <typename T>
+  std::unique_ptr<T> TransformAs(std::unique_ptr<Node> n) {
+    return Node::MustAs<T>(Transform(std::move(n)));
+  }
+
+  virtual std::unique_ptr<Node> Visit(std::unique_ptr<SearchStmt> node) {
+    node->index = VisitAs<IndexRef>(std::move(node->index));
+    node->select_expr = VisitAs<SelectExpr>(std::move(node->select_expr));
+    if (node->query_expr) node->query_expr = 
TransformAs<QueryExpr>(std::move(node->query_expr));
+    if (node->sort_by) node->sort_by = 
VisitAs<SortBy>(std::move(node->sort_by));
+    if (node->limit) node->limit = VisitAs<Limit>(std::move(node->limit));
+    return node;
+  }
+
+  virtual std::unique_ptr<Node> Visit(std::unique_ptr<SelectExpr> node) {
+    for (auto &n : node->fields) {
+      n = VisitAs<FieldRef>(std::move(n));
+    }
+
+    return node;
+  }
+
+  virtual std::unique_ptr<Node> Visit(std::unique_ptr<IndexRef> node) { return 
node; }
+
+  virtual std::unique_ptr<Node> Visit(std::unique_ptr<FieldRef> node) { return 
node; }
+
+  virtual std::unique_ptr<Node> Visit(std::unique_ptr<BoolLiteral> node) { 
return node; }
+
+  virtual std::unique_ptr<Node> Visit(std::unique_ptr<StringLiteral> node) { 
return node; }
+
+  virtual std::unique_ptr<Node> Visit(std::unique_ptr<NumericLiteral> node) { 
return node; }
+
+  virtual std::unique_ptr<Node> Visit(std::unique_ptr<NumericCompareExpr> 
node) {
+    node->field = VisitAs<FieldRef>(std::move(node->field));
+    node->num = VisitAs<NumericLiteral>(std::move(node->num));
+    return node;
+  }
+
+  virtual std::unique_ptr<Node> Visit(std::unique_ptr<TagContainExpr> node) {
+    node->field = VisitAs<FieldRef>(std::move(node->field));
+    node->tag = VisitAs<StringLiteral>(std::move(node->tag));
+    return node;
+  }
+
+  virtual std::unique_ptr<Node> Visit(std::unique_ptr<AndExpr> node) {
+    for (auto &n : node->inners) {
+      n = TransformAs<QueryExpr>(std::move(n));
+    }
+
+    return node;
+  }
+
+  virtual std::unique_ptr<Node> Visit(std::unique_ptr<OrExpr> node) {
+    for (auto &n : node->inners) {
+      n = TransformAs<QueryExpr>(std::move(n));
+    }
+
+    return node;
+  }
+
+  virtual std::unique_ptr<Node> Visit(std::unique_ptr<NotExpr> node) {
+    node->inner = TransformAs<QueryExpr>(std::move(node->inner));
+    return node;
+  }
+
+  virtual std::unique_ptr<Node> Visit(std::unique_ptr<Limit> node) { return 
node; }
+
+  virtual std::unique_ptr<Node> Visit(std::unique_ptr<SortBy> node) {
+    node->field = VisitAs<FieldRef>(std::move(node->field));
+    return node;
+  }
+};
+
+}  // namespace kqir
diff --git a/src/search/passes/simplify_boolean.h 
b/src/search/passes/simplify_boolean.h
new file mode 100644
index 00000000..99229c2f
--- /dev/null
+++ b/src/search/passes/simplify_boolean.h
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+#pragma once
+
+#include <iostream>
+#include <memory>
+
+#include "search/ir.h"
+#include "search/ir_pass.h"
+
+namespace kqir {
+
+struct SimplifyBoolean : Visitor {
+  std::unique_ptr<Node> Visit(std::unique_ptr<OrExpr> node) override {
+    node = Node::MustAs<OrExpr>(Visitor::Visit(std::move(node)));
+
+    for (auto iter = node->inners.begin(); iter != node->inners.end();) {
+      if (auto v = Node::As<BoolLiteral>(std::move(*iter))) {
+        if (!v->val) {
+          iter = node->inners.erase(iter);
+        } else {
+          return v;
+        }
+      } else {
+        ++iter;
+      }
+    }
+
+    if (node->inners.size() == 0) {
+      return std::make_unique<BoolLiteral>(false);
+    } else if (node->inners.size() == 1) {
+      return std::move(node->inners[0]);
+    }
+
+    return node;
+  }
+
+  std::unique_ptr<Node> Visit(std::unique_ptr<AndExpr> node) override {
+    node = Node::MustAs<AndExpr>(Visitor::Visit(std::move(node)));
+
+    for (auto iter = node->inners.begin(); iter != node->inners.end();) {
+      if (auto v = Node::As<BoolLiteral>(std::move(*iter))) {
+        if (v->val) {
+          iter = node->inners.erase(iter);
+        } else {
+          return v;
+        }
+      } else {
+        ++iter;
+      }
+    }
+
+    if (node->inners.size() == 0) {
+      return std::make_unique<BoolLiteral>(true);
+    } else if (node->inners.size() == 1) {
+      return std::move(node->inners[0]);
+    }
+
+    return node;
+  }
+
+  std::unique_ptr<Node> Visit(std::unique_ptr<NotExpr> node) override {
+    node = Node::MustAs<NotExpr>(Visitor::Visit(std::move(node)));
+
+    if (auto v = Node::As<BoolLiteral>(std::move(node->inner))) {
+      v->val = !v->val;
+      return v;
+    }
+
+    return node;
+  }
+};
+
+}  // namespace kqir
diff --git a/tests/cppunit/ir_pass_test.cc b/tests/cppunit/ir_pass_test.cc
new file mode 100644
index 00000000..9f9af8f6
--- /dev/null
+++ b/tests/cppunit/ir_pass_test.cc
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+#include "search/ir_pass.h"
+
+#include "gtest/gtest.h"
+#include "search/passes/simplify_boolean.h"
+#include "search/sql_transformer.h"
+
+using namespace kqir;
+
+static auto Parse(const std::string& in) { return 
sql::ParseToIR(peg::string_input(in, "test")); }
+
+TEST(IRPassTest, Simple) {
+  auto ir = *Parse("select a from b where not c = 1 or d hastag \"x\" and 2 <= 
e order by e asc limit 0, 10");
+
+  auto original = ir->Dump();
+
+  Visitor visitor;
+  auto ir2 = visitor.Transform(std::move(ir));
+  ASSERT_EQ(original, ir2->Dump());
+}
+
+TEST(IRPassTest, SimplifyBoolean) {
+  SimplifyBoolean sb;
+  ASSERT_EQ(sb.Transform(*Parse("select a from b where not false"))->Dump(), 
"select a from b where true");
+  ASSERT_EQ(sb.Transform(*Parse("select a from b where not not 
false"))->Dump(), "select a from b where false");
+  ASSERT_EQ(sb.Transform(*Parse("select a from b where true and 
true"))->Dump(), "select a from b where true");
+  ASSERT_EQ(sb.Transform(*Parse("select a from b where true and 
false"))->Dump(), "select a from b where false");
+  ASSERT_EQ(sb.Transform(*Parse("select a from b where false and 
true"))->Dump(), "select a from b where false");
+  ASSERT_EQ(sb.Transform(*Parse("select a from b where true and false and 
true"))->Dump(),
+            "select a from b where false");
+  ASSERT_EQ(sb.Transform(*Parse("select a from b where true and true and 
true"))->Dump(), "select a from b where true");
+  ASSERT_EQ(sb.Transform(*Parse("select a from b where x > 1 and 
false"))->Dump(), "select a from b where false");
+  ASSERT_EQ(sb.Transform(*Parse("select a from b where x > 1 and 
true"))->Dump(), "select a from b where x > 1");
+  ASSERT_EQ(sb.Transform(*Parse("select a from b where x > 1 and true and y < 
10"))->Dump(),
+            "select a from b where (and x > 1, y < 10)");
+  ASSERT_EQ(sb.Transform(*Parse("select a from b where not (false and (not 
true))"))->Dump(),
+            "select a from b where true");
+  ASSERT_EQ(sb.Transform(*Parse("select a from b where true or 
true"))->Dump(), "select a from b where true");
+  ASSERT_EQ(sb.Transform(*Parse("select a from b where true or 
false"))->Dump(), "select a from b where true");
+  ASSERT_EQ(sb.Transform(*Parse("select a from b where false or 
true"))->Dump(), "select a from b where true");
+  ASSERT_EQ(sb.Transform(*Parse("select a from b where true or false or 
true"))->Dump(), "select a from b where true");
+  ASSERT_EQ(sb.Transform(*Parse("select a from b where true or false or 
true"))->Dump(), "select a from b where true");
+  ASSERT_EQ(sb.Transform(*Parse("select a from b where not ((x < 1 or true) 
and (y > 2 and true))"))->Dump(),
+            "select a from b where not y > 2");
+}

Reply via email to