This is an automated email from the ASF dual-hosted git repository.

twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new 899216db Optimize numeric comparison via interval analysis in KQIR 
(#2257)
899216db is described below

commit 899216db7316083b5de9cef03dd0e818ededed96
Author: Twice <[email protected]>
AuthorDate: Sun Apr 21 15:17:24 2024 +0900

    Optimize numeric comparison via interval analysis in KQIR (#2257)
    
    Co-authored-by: hulk <[email protected]>
---
 src/search/ir.h                       |   1 +
 src/search/ir_pass.h                  |   2 +
 src/search/passes/interval_analysis.h | 115 ++++++++++++++++++++++++++++------
 src/search/passes/manager.h           |   1 +
 tests/cppunit/ir_pass_test.cc         |  32 ++++++++++
 5 files changed, 133 insertions(+), 18 deletions(-)

diff --git a/src/search/ir.h b/src/search/ir.h
index d7da716a..3acc5ae8 100644
--- a/src/search/ir.h
+++ b/src/search/ir.h
@@ -85,6 +85,7 @@ struct FieldRef : Ref {
   const FieldInfo *info = nullptr;
 
   explicit FieldRef(std::string name) : name(std::move(name)) {}
+  FieldRef(std::string name, const FieldInfo *info) : name(std::move(name)), 
info(info) {}
 
   std::string_view Name() const override { return "FieldRef"; }
   std::string Dump() const override { return name; }
diff --git a/src/search/ir_pass.h b/src/search/ir_pass.h
index 5fa57b1a..2068a45a 100644
--- a/src/search/ir_pass.h
+++ b/src/search/ir_pass.h
@@ -28,6 +28,8 @@ namespace kqir {
 struct Pass {
   virtual std::unique_ptr<Node> Transform(std::unique_ptr<Node> node) = 0;
 
+  virtual void Reset() {}
+
   virtual ~Pass() = default;
 };
 
diff --git a/src/search/passes/interval_analysis.h 
b/src/search/passes/interval_analysis.h
index 3010ca73..59959979 100644
--- a/src/search/passes/interval_analysis.h
+++ b/src/search/passes/interval_analysis.h
@@ -21,6 +21,7 @@
 #pragma once
 
 #include <algorithm>
+#include <cmath>
 #include <memory>
 #include <set>
 #include <type_traits>
@@ -33,61 +34,139 @@
 namespace kqir {
 
 struct IntervalAnalysis : Visitor {
-  std::map<Node *, std::pair<std::string, IntervalSet>> result;
+  struct IntervalInfo {
+    std::string field_name;
+    const FieldInfo *field_info;
+    IntervalSet intervals;
+  };
+
+  std::map<Node *, IntervalInfo> result;
+
+  void Reset() override { result.clear(); }
 
   template <typename T>
   std::unique_ptr<Node> VisitImpl(std::unique_ptr<T> node) {
     node = Node::MustAs<T>(Visitor::Visit(std::move(node)));
 
-    std::map<std::string, std::pair<IntervalSet, std::set<Node *>>> 
interval_map;
+    struct LocalIntervalInfo {
+      IntervalSet intervals;
+      std::set<Node *> nodes;
+      const FieldInfo *field;
+    };
+
+    std::map<std::string, LocalIntervalInfo> interval_map;
     for (const auto &n : node->inners) {
       IntervalSet new_interval;
+      const FieldInfo *new_field_info = nullptr;
       std::string new_field;
 
       if (auto v = dynamic_cast<NumericCompareExpr *>(n.get())) {
         new_interval = IntervalSet(v->op, v->num->val);
         new_field = v->field->name;
+        new_field_info = v->field->info;
       } else if (auto iter = result.find(n.get()); iter != result.end()) {
-        new_interval = iter->second.second;
-        new_field = iter->second.first;
+        new_interval = iter->second.intervals;
+        new_field = iter->second.field_name;
+        new_field_info = iter->second.field_info;
       } else {
         continue;
       }
 
       if (auto iter = interval_map.find(new_field); iter != 
interval_map.end()) {
         if constexpr (std::is_same_v<T, OrExpr>) {
-          iter->second.first = iter->second.first | new_interval;
+          iter->second.intervals = iter->second.intervals | new_interval;
         } else if constexpr (std::is_same_v<T, AndExpr>) {
-          iter->second.first = iter->second.first & new_interval;
+          iter->second.intervals = iter->second.intervals & new_interval;
         } else {
           static_assert(AlwaysFalse<T>);
         }
-        iter->second.second.emplace(n.get());
+        iter->second.nodes.emplace(n.get());
+        iter->second.field = new_field_info;
       } else {
-        interval_map.emplace(new_field, std::make_pair(new_interval, 
std::set<Node *>{n.get()}));
+        interval_map.emplace(new_field, LocalIntervalInfo{new_interval, 
std::set<Node *>{n.get()}, new_field_info});
       }
     }
 
     if (interval_map.size() == 1) {
       const auto &elem = *interval_map.begin();
-      result.emplace(node.get(), std::make_pair(elem.first, 
elem.second.first));
+      result.emplace(node.get(), IntervalInfo{elem.first, elem.second.field, 
elem.second.intervals});
     }
 
     for (const auto &[field, info] : interval_map) {
-      if (info.first.IsEmpty() || info.first.IsFull()) {
-        auto iter = std::remove_if(node->inners.begin(), node->inners.end(),
-                                   [&info = info](const auto &n) { return 
info.second.count(n.get()) == 1; });
-        node->inners.erase(iter, node->inners.end());
+      auto iter = std::remove_if(node->inners.begin(), node->inners.end(),
+                                 [&info = info](const auto &n) { return 
info.nodes.count(n.get()) == 1; });
+      node->inners.erase(iter, node->inners.end());
+
+      auto field_node = std::make_unique<FieldRef>(field, info.field);
+      node->inners.emplace_back(GenerateFromInterval(info.intervals, 
field_node.get()));
+    }
+
+    return node;
+  }
+
+  static std::unique_ptr<QueryExpr> GenerateFromInterval(const IntervalSet 
&intervals, FieldRef *field) {
+    if (intervals.IsEmpty()) {
+      return std::make_unique<BoolLiteral>(false);
+    }
+
+    if (intervals.IsFull()) {
+      return std::make_unique<BoolLiteral>(true);
+    }
+
+    std::vector<std::unique_ptr<QueryExpr>> exprs;
+
+    if (intervals.intervals.size() > 1 && 
std::isinf(intervals.intervals.front().first) &&
+        std::isinf(intervals.intervals.back().second)) {
+      bool is_all_ne = true;
+      auto iter = intervals.intervals.begin();
+      auto last = iter->second;
+      ++iter;
+      while (iter != intervals.intervals.end()) {
+        if (iter->first != IntervalSet::NextNum(last)) {
+          is_all_ne = false;
+          break;
+        }
+
+        last = iter->second;
+        ++iter;
       }
 
-      if (info.first.IsEmpty()) {
-        node->inners.emplace_back(std::make_unique<BoolLiteral>(false));
-      } else if (info.first.IsFull()) {
-        node->inners.emplace_back(std::make_unique<BoolLiteral>(true));
+      if (is_all_ne) {
+        for (auto i = intervals.intervals.begin(); i != 
intervals.intervals.end() && !std::isinf(i->second); ++i) {
+          
exprs.emplace_back(std::make_unique<NumericCompareExpr>(NumericCompareExpr::NE, 
field->CloneAs<FieldRef>(),
+                                                                  
std::make_unique<NumericLiteral>(i->second)));
+        }
+
+        return std::make_unique<AndExpr>(std::move(exprs));
       }
     }
 
-    return node;
+    for (auto [l, r] : intervals.intervals) {
+      if (std::isinf(l)) {
+        
exprs.emplace_back(std::make_unique<NumericCompareExpr>(NumericCompareExpr::LT, 
field->CloneAs<FieldRef>(),
+                                                                
std::make_unique<NumericLiteral>(r)));
+      } else if (std::isinf(r)) {
+        
exprs.emplace_back(std::make_unique<NumericCompareExpr>(NumericCompareExpr::GET,
 field->CloneAs<FieldRef>(),
+                                                                
std::make_unique<NumericLiteral>(l)));
+      } else if (r == IntervalSet::NextNum(l)) {
+        
exprs.emplace_back(std::make_unique<NumericCompareExpr>(NumericCompareExpr::EQ, 
field->CloneAs<FieldRef>(),
+                                                                
std::make_unique<NumericLiteral>(l)));
+      } else {
+        std::vector<std::unique_ptr<QueryExpr>> sub_expr;
+        
sub_expr.emplace_back(std::make_unique<NumericCompareExpr>(NumericCompareExpr::GET,
 field->CloneAs<FieldRef>(),
+                                                                   
std::make_unique<NumericLiteral>(l)));
+        
sub_expr.emplace_back(std::make_unique<NumericCompareExpr>(NumericCompareExpr::LT,
 field->CloneAs<FieldRef>(),
+                                                                   
std::make_unique<NumericLiteral>(r)));
+
+        exprs.emplace_back(std::make_unique<AndExpr>(std::move(sub_expr)));
+      }
+    }
+
+    if (exprs.size() == 1) {
+      return std::move(exprs.front());
+    } else {
+      return std::make_unique<OrExpr>(std::move(exprs));
+    }
   }
 
   std::unique_ptr<Node> Visit(std::unique_ptr<OrExpr> node) override { return 
VisitImpl(std::move(node)); }
diff --git a/src/search/passes/manager.h b/src/search/passes/manager.h
index 094faa23..480e27a7 100644
--- a/src/search/passes/manager.h
+++ b/src/search/passes/manager.h
@@ -36,6 +36,7 @@ using PassSequence = std::vector<std::unique_ptr<Pass>>;
 struct PassManager {
   static std::unique_ptr<Node> Execute(const PassSequence &seq, 
std::unique_ptr<Node> node) {
     for (auto &pass : seq) {
+      pass->Reset();
       node = pass->Transform(std::move(node));
     }
     return node;
diff --git a/tests/cppunit/ir_pass_test.cc b/tests/cppunit/ir_pass_test.cc
index f40dd8a4..bfb630a5 100644
--- a/tests/cppunit/ir_pass_test.cc
+++ b/tests/cppunit/ir_pass_test.cc
@@ -21,6 +21,8 @@
 #include "search/ir_pass.h"
 
 #include "gtest/gtest.h"
+#include "search/interval.h"
+#include "search/ir_sema_checker.h"
 #include "search/passes/interval_analysis.h"
 #include "search/passes/lower_to_plan.h"
 #include "search/passes/manager.h"
@@ -130,4 +132,34 @@ TEST(IRPassTest, IntervalAnalysis) {
             "select * from a where false");
   ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where (a > 
3 or a < 1) and a = 2"))->Dump(),
             "select * from a where false");
+  ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where b = 
1 and (a = 1 or a != 1)"))->Dump(),
+            "select * from a where b = 1");
+  ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a = 
1 or b = 1 or a != 1"))->Dump(),
+            "select * from a where true");
+  ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where (a < 
3 or a > 1) and b >= 1"))->Dump(),
+            "select * from a where b >= 1");
+  ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a != 
1 or a != 2"))->Dump(),
+            "select * from a where true");
+  ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a = 
1 and a = 2"))->Dump(),
+            "select * from a where false");
+
+  ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a < 
1 and a < 3"))->Dump(),
+            "select * from a where a < 1");
+  ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a < 
1 or a < 3"))->Dump(),
+            "select * from a where a < 3");
+  ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a = 
1 and a < 3"))->Dump(),
+            "select * from a where a = 1");
+  ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a = 
1 or a < 3"))->Dump(),
+            "select * from a where a < 3");
+  ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a = 
1 or a = 3"))->Dump(),
+            "select * from a where (or a = 1, a = 3)");
+  ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a != 
1"))->Dump(),
+            "select * from a where a != 1");
+  ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a != 
1 and a != 2"))->Dump(),
+            "select * from a where (and a != 1, a != 2)");
+  ASSERT_EQ(
+      PassManager::Execute(ia_passes, *Parse("select * from a where a >= 0 and 
a >= 1 and a < 4 and a != 2"))->Dump(),
+      fmt::format("select * from a where (or (and a >= 1, a < 2), (and a >= 
{}, a < 4))", IntervalSet::NextNum(2)));
+  ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a != 
1 and b > 1 and b = 2"))->Dump(),
+            "select * from a where (and a != 1, b = 2)");
 }

Reply via email to