This is an automated email from the ASF dual-hosted git repository.

twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new dd4a2686 Add a simple AND/OR expression simplification pass for KQIR 
(#2216)
dd4a2686 is described below

commit dd4a26867c356dad1f8478f879eeebf4b0607854
Author: Twice <[email protected]>
AuthorDate: Wed Apr 3 08:47:48 2024 +0900

    Add a simple AND/OR expression simplification pass for KQIR (#2216)
---
 .../{simplify_boolean.h => simplify_and_or_expr.h} | 56 ++++++----------------
 src/search/passes/simplify_boolean.h               |  1 -
 tests/cppunit/ir_pass_test.cc                      | 18 +++++++
 3 files changed, 33 insertions(+), 42 deletions(-)

diff --git a/src/search/passes/simplify_boolean.h 
b/src/search/passes/simplify_and_or_expr.h
similarity index 51%
copy from src/search/passes/simplify_boolean.h
copy to src/search/passes/simplify_and_or_expr.h
index 99229c2f..0ddaed61 100644
--- a/src/search/passes/simplify_boolean.h
+++ b/src/search/passes/simplify_and_or_expr.h
@@ -20,7 +20,6 @@
 
 #pragma once
 
-#include <iostream>
 #include <memory>
 
 #include "search/ir.h"
@@ -28,64 +27,39 @@
 
 namespace kqir {
 
-struct SimplifyBoolean : Visitor {
+struct SimplifyAndOrExpr : Visitor {
   std::unique_ptr<Node> Visit(std::unique_ptr<OrExpr> node) override {
     node = Node::MustAs<OrExpr>(Visitor::Visit(std::move(node)));
 
-    for (auto iter = node->inners.begin(); iter != node->inners.end();) {
-      if (auto v = Node::As<BoolLiteral>(std::move(*iter))) {
-        if (!v->val) {
-          iter = node->inners.erase(iter);
-        } else {
-          return v;
+    std::vector<std::unique_ptr<QueryExpr>> merged_nodes;
+    for (auto &n : node->inners) {
+      if (auto v = Node::As<OrExpr>(std::move(n))) {
+        for (auto &m : v->inners) {
+          merged_nodes.push_back(std::move(m));
         }
       } else {
-        ++iter;
+        merged_nodes.push_back(std::move(n));
       }
     }
 
-    if (node->inners.size() == 0) {
-      return std::make_unique<BoolLiteral>(false);
-    } else if (node->inners.size() == 1) {
-      return std::move(node->inners[0]);
-    }
-
-    return node;
+    return std::make_unique<OrExpr>(std::move(merged_nodes));
   }
 
   std::unique_ptr<Node> Visit(std::unique_ptr<AndExpr> node) override {
     node = Node::MustAs<AndExpr>(Visitor::Visit(std::move(node)));
 
-    for (auto iter = node->inners.begin(); iter != node->inners.end();) {
-      if (auto v = Node::As<BoolLiteral>(std::move(*iter))) {
-        if (v->val) {
-          iter = node->inners.erase(iter);
-        } else {
-          return v;
+    std::vector<std::unique_ptr<QueryExpr>> merged_nodes;
+    for (auto &n : node->inners) {
+      if (auto v = Node::As<AndExpr>(std::move(n))) {
+        for (auto &m : v->inners) {
+          merged_nodes.push_back(std::move(m));
         }
       } else {
-        ++iter;
+        merged_nodes.push_back(std::move(n));
       }
     }
 
-    if (node->inners.size() == 0) {
-      return std::make_unique<BoolLiteral>(true);
-    } else if (node->inners.size() == 1) {
-      return std::move(node->inners[0]);
-    }
-
-    return node;
-  }
-
-  std::unique_ptr<Node> Visit(std::unique_ptr<NotExpr> node) override {
-    node = Node::MustAs<NotExpr>(Visitor::Visit(std::move(node)));
-
-    if (auto v = Node::As<BoolLiteral>(std::move(node->inner))) {
-      v->val = !v->val;
-      return v;
-    }
-
-    return node;
+    return std::make_unique<AndExpr>(std::move(merged_nodes));
   }
 };
 
diff --git a/src/search/passes/simplify_boolean.h 
b/src/search/passes/simplify_boolean.h
index 99229c2f..79281e8e 100644
--- a/src/search/passes/simplify_boolean.h
+++ b/src/search/passes/simplify_boolean.h
@@ -20,7 +20,6 @@
 
 #pragma once
 
-#include <iostream>
 #include <memory>
 
 #include "search/ir.h"
diff --git a/tests/cppunit/ir_pass_test.cc b/tests/cppunit/ir_pass_test.cc
index 9f9af8f6..76886dae 100644
--- a/tests/cppunit/ir_pass_test.cc
+++ b/tests/cppunit/ir_pass_test.cc
@@ -21,6 +21,7 @@
 #include "search/ir_pass.h"
 
 #include "gtest/gtest.h"
+#include "search/passes/simplify_and_or_expr.h"
 #include "search/passes/simplify_boolean.h"
 #include "search/sql_transformer.h"
 
@@ -62,3 +63,20 @@ TEST(IRPassTest, SimplifyBoolean) {
   ASSERT_EQ(sb.Transform(*Parse("select a from b where not ((x < 1 or true) 
and (y > 2 and true))"))->Dump(),
             "select a from b where not y > 2");
 }
+
+TEST(IRPassTest, SimplifyAndOrExpr) {
+  SimplifyAndOrExpr saoe;
+
+  ASSERT_EQ(Parse("select a from b where true and (false and 
true)").GetValue()->Dump(),
+            "select a from b where (and true, (and false, true))");
+  ASSERT_EQ(saoe.Transform(*Parse("select a from b where true and (false and 
true)"))->Dump(),
+            "select a from b where (and true, false, true)");
+  ASSERT_EQ(saoe.Transform(*Parse("select a from b where true or (false or 
true)"))->Dump(),
+            "select a from b where (or true, false, true)");
+  ASSERT_EQ(saoe.Transform(*Parse("select a from b where true and (false or 
true)"))->Dump(),
+            "select a from b where (and true, (or false, true))");
+  ASSERT_EQ(saoe.Transform(*Parse("select a from b where true or (false and 
true)"))->Dump(),
+            "select a from b where (or true, (and false, true))");
+  ASSERT_EQ(saoe.Transform(*Parse("select a from b where x > 1 or (y < 2 or z 
= 3)"))->Dump(),
+            "select a from b where (or x > 1, y < 2, z = 3)");
+}

Reply via email to