This is an automated email from the ASF dual-hosted git repository.
twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git
The following commit(s) were added to refs/heads/unstable by this push:
new 35348467 Add NotExpr pushing pass for KQIR (#2218)
35348467 is described below
commit 35348467a639bf786ea974057aad3ec338ba0850
Author: Twice <[email protected]>
AuthorDate: Wed Apr 3 23:37:45 2024 +0900
Add NotExpr pushing pass for KQIR (#2218)
---
src/search/passes/push_down_not_expr.h | 59 ++++++++++++++++++++++++++++++++++
tests/cppunit/ir_pass_test.cc | 16 +++++++++
2 files changed, 75 insertions(+)
diff --git a/src/search/passes/push_down_not_expr.h
b/src/search/passes/push_down_not_expr.h
new file mode 100644
index 00000000..3c286c09
--- /dev/null
+++ b/src/search/passes/push_down_not_expr.h
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+#pragma once
+
+#include <memory>
+
+#include "search/ir.h"
+#include "search/ir_pass.h"
+
+namespace kqir {
+
+struct PushDownNotExpr : Visitor {
+ std::unique_ptr<Node> Visit(std::unique_ptr<NotExpr> node) override {
+ std::unique_ptr<Node> res;
+
+ if (auto v = Node::As<NumericCompareExpr>(std::move(node->inner))) {
+ v->op = v->Negative(v->op);
+ return v;
+ } else if (auto v = Node::As<TagContainExpr>(std::move(node->inner))) {
+ return std::make_unique<NotExpr>(std::move(v));
+ } else if (auto v = Node::As<AndExpr>(std::move(node->inner))) {
+ std::vector<std::unique_ptr<QueryExpr>> nodes;
+ for (auto& n : v->inners) {
+ nodes.push_back(std::make_unique<NotExpr>(std::move(n)));
+ }
+ res = std::make_unique<OrExpr>(std::move(nodes));
+ } else if (auto v = Node::As<OrExpr>(std::move(node->inner))) {
+ std::vector<std::unique_ptr<QueryExpr>> nodes;
+ for (auto& n : v->inners) {
+ nodes.push_back(std::make_unique<NotExpr>(std::move(n)));
+ }
+ res = std::make_unique<AndExpr>(std::move(nodes));
+ } else if (auto v = Node::As<NotExpr>(std::move(node->inner))) {
+ res = std::move(v->inner);
+ }
+
+ return Visitor::Transform(std::move(res));
+ }
+};
+
+} // namespace kqir
diff --git a/tests/cppunit/ir_pass_test.cc b/tests/cppunit/ir_pass_test.cc
index 76886dae..7ac7dfd0 100644
--- a/tests/cppunit/ir_pass_test.cc
+++ b/tests/cppunit/ir_pass_test.cc
@@ -21,6 +21,7 @@
#include "search/ir_pass.h"
#include "gtest/gtest.h"
+#include "search/passes/push_down_not_expr.h"
#include "search/passes/simplify_and_or_expr.h"
#include "search/passes/simplify_boolean.h"
#include "search/sql_transformer.h"
@@ -80,3 +81,18 @@ TEST(IRPassTest, SimplifyAndOrExpr) {
ASSERT_EQ(saoe.Transform(*Parse("select a from b where x > 1 or (y < 2 or z
= 3)"))->Dump(),
"select a from b where (or x > 1, y < 2, z = 3)");
}
+
+TEST(IRPassTest, PushDownNotExpr) {
+ PushDownNotExpr pdne;
+
+ ASSERT_EQ(pdne.Transform(*Parse("select * from a where not a > 1"))->Dump(),
"select * from a where a <= 1");
+ ASSERT_EQ(pdne.Transform(*Parse("select * from a where not a hastag
\"\""))->Dump(),
+ "select * from a where not a hastag \"\"");
+ ASSERT_EQ(pdne.Transform(*Parse("select * from a where not not a >
1"))->Dump(), "select * from a where a > 1");
+ ASSERT_EQ(pdne.Transform(*Parse("select * from a where not (a > 1 and b <=
3)"))->Dump(),
+ "select * from a where (or a <= 1, b > 3)");
+ ASSERT_EQ(pdne.Transform(*Parse("select * from a where not (a > 1 or b <=
3)"))->Dump(),
+ "select * from a where (and a <= 1, b > 3)");
+ ASSERT_EQ(pdne.Transform(*Parse("select * from a where not (not a > 1 or (b
< 3 and c hastag \"\"))"))->Dump(),
+ "select * from a where (and a > 1, (or b >= 3, not c hastag
\"\"))");
+}