This is an automated email from the ASF dual-hosted git repository.

twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git


The following commit(s) were added to refs/heads/unstable by this push:
     new bc1fc9fc Add a univariate interval analysis pass for KQIR (#2256)
bc1fc9fc is described below

commit bc1fc9fc6611b11009d35dac131dcc0142f9affd
Author: Twice <[email protected]>
AuthorDate: Sat Apr 20 12:16:08 2024 +0900

    Add a univariate interval analysis pass for KQIR (#2256)
---
 src/search/interval.h                    | 185 +++++++++++++++++++++++++++++++
 src/search/ir_pass.h                     |   2 +
 src/search/ir_plan.h                     |   9 +-
 src/search/passes/interval_analysis.h    |  98 ++++++++++++++++
 src/search/passes/manager.h              |  25 +++--
 src/search/passes/simplify_and_or_expr.h |  12 +-
 tests/cppunit/interval_test.cc           |  76 +++++++++++++
 tests/cppunit/ir_pass_test.cc            |  20 +++-
 8 files changed, 403 insertions(+), 24 deletions(-)

diff --git a/src/search/interval.h b/src/search/interval.h
new file mode 100644
index 00000000..d6ebbf5c
--- /dev/null
+++ b/src/search/interval.h
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+#pragma once
+
+#include <cmath>
+#include <limits>
+#include <map>
+#include <string>
+
+#include "fmt/format.h"
+#include "search/ir.h"
+#include "string_util.h"
+
+namespace kqir {
+
+struct Interval {
+  double l, r;  // [l, r)
+
+  Interval(double l, double r) : l(l), r(r) {}
+
+  bool IsEmpty() const { return l >= r; }
+
+  bool operator==(const Interval &other) const { return l == other.l && r == 
other.r; }
+  bool operator!=(const Interval &other) const { return !(*this == other); }
+
+  std::string ToString() const { return fmt::format("[{}, {})", l, r); }
+};
+
+template <typename Iter1, typename Iter2, typename F>
+void ForEachMerged(Iter1 first1, Iter1 last1, Iter2 first2, Iter2 last2, F 
&&f) {
+  while (first1 != last1) {
+    if (first2 == last2) {
+      std::for_each(first1, last1, std::forward<F>(f));
+      return;
+    }
+
+    if (*first2 < *first1) {
+      std::forward<F>(f)(*first2);
+      ++first2;
+    } else {
+      std::forward<F>(f)(*first1);
+      ++first1;
+    }
+  }
+  std::for_each(first2, last2, std::forward<F>(f));
+}
+
+struct IntervalSet {
+  // NOTE: element must be sorted in this vector
+  // but we don't need to use map here
+  using DataType = std::vector<std::pair<double, double>>;
+  DataType intervals;
+
+  static inline const double inf = std::numeric_limits<double>::infinity();
+  static inline const double minf = -inf;
+
+  static double NextNum(double val) { return std::nextafter(val, inf); }
+
+  static double PrevNum(double val) { return std::nextafter(val, minf); }
+
+  explicit IntervalSet() = default;
+
+  struct Full {};
+  static constexpr const Full full{};
+
+  explicit IntervalSet(Full) { intervals.emplace_back(minf, inf); }
+
+  explicit IntervalSet(Interval range) {
+    if (!range.IsEmpty()) intervals.emplace_back(range.l, range.r);
+  }
+
+  IntervalSet(NumericCompareExpr::Op op, double val) {
+    if (op == NumericCompareExpr::EQ) {
+      intervals.emplace_back(val, NextNum(val));
+    } else if (op == NumericCompareExpr::NE) {
+      intervals.emplace_back(minf, val);
+      intervals.emplace_back(NextNum(val), inf);
+    } else if (op == NumericCompareExpr::LT) {
+      intervals.emplace_back(minf, val);
+    } else if (op == NumericCompareExpr::GT) {
+      intervals.emplace_back(NextNum(val), inf);
+    } else if (op == NumericCompareExpr::LET) {
+      intervals.emplace_back(minf, NextNum(val));
+    } else if (op == NumericCompareExpr::GET) {
+      intervals.emplace_back(val, inf);
+    }
+  }
+
+  bool operator==(const IntervalSet &other) const { return intervals == 
other.intervals; }
+  bool operator!=(const IntervalSet &other) const { return intervals != 
other.intervals; }
+
+  std::string ToString() const {
+    if (IsEmpty()) return "empty set";
+    return util::StringJoin(
+        intervals, [](const auto &i) { return Interval(i.first, 
i.second).ToString(); }, " or ");
+  }
+
+  friend std::ostream &operator<<(std::ostream &os, const IntervalSet &is) { 
return os << is.ToString(); }
+
+  bool IsEmpty() const { return intervals.empty(); }
+  bool IsFull() const {
+    if (intervals.size() != 1) return false;
+
+    const auto &v = *intervals.begin();
+    return std::isinf(v.first) && std::isinf(v.second) && v.first * v.second < 
0;
+  }
+
+  friend IntervalSet operator&(const IntervalSet &l, const IntervalSet &r) {
+    if (l.IsEmpty() || r.IsEmpty()) {
+      return IntervalSet();
+    }
+
+    return ~(~l | ~r);
+  }
+
+  friend IntervalSet operator|(const IntervalSet &l, const IntervalSet &r) {
+    if (l.IsEmpty()) {
+      return r;
+    }
+
+    if (r.IsEmpty()) {
+      return l;
+    }
+
+    IntervalSet result;
+    ForEachMerged(l.intervals.begin(), l.intervals.end(), r.intervals.begin(), 
r.intervals.end(),
+                  [&result](const auto &v) {
+                    if (result.IsEmpty() || result.intervals.rbegin()->second 
< v.first) {
+                      result.intervals.emplace_back(v.first, v.second);
+                    } else {
+                      result.intervals.rbegin()->second = 
std::max(result.intervals.rbegin()->second, v.second);
+                    }
+                  });
+
+    return result;
+  }
+
+  friend IntervalSet operator~(const IntervalSet &v) {
+    if (v.IsEmpty()) {
+      return IntervalSet(full);
+    }
+
+    IntervalSet result;
+
+    auto iter = v.intervals.begin();
+    if (!std::isinf(iter->first)) {
+      result.intervals.emplace_back(minf, iter->first);
+    }
+
+    double last = iter->second;
+    ++iter;
+    while (iter != v.intervals.end()) {
+      result.intervals.emplace_back(last, iter->first);
+
+      last = iter->second;
+      ++iter;
+    }
+
+    if (!std::isinf(last)) {
+      result.intervals.emplace_back(last, inf);
+    }
+
+    return result;
+  }
+};
+
+}  // namespace kqir
diff --git a/src/search/ir_pass.h b/src/search/ir_pass.h
index 924e8c75..5fa57b1a 100644
--- a/src/search/ir_pass.h
+++ b/src/search/ir_pass.h
@@ -27,6 +27,8 @@ namespace kqir {
 
 struct Pass {
   virtual std::unique_ptr<Node> Transform(std::unique_ptr<Node> node) = 0;
+
+  virtual ~Pass() = default;
 };
 
 struct Visitor : Pass {
diff --git a/src/search/ir_plan.h b/src/search/ir_plan.h
index da805846..94f801f4 100644
--- a/src/search/ir_plan.h
+++ b/src/search/ir_plan.h
@@ -24,6 +24,7 @@
 #include <memory>
 
 #include "ir.h"
+#include "search/interval.h"
 #include "search/ir_sema_checker.h"
 #include "string_util.h"
 
@@ -57,14 +58,6 @@ struct FieldScan : PlanOperator {
   explicit FieldScan(std::unique_ptr<FieldRef> field) : 
field(std::move(field)) {}
 };
 
-struct Interval {
-  double l, r;  // [l, r)
-
-  explicit Interval(double l, double r) : l(l), r(r) {}
-
-  std::string ToString() const { return fmt::format("[{}, {})", l, r); }
-};
-
 struct NumericFieldScan : FieldScan {
   Interval range;
 
diff --git a/src/search/passes/interval_analysis.h 
b/src/search/passes/interval_analysis.h
new file mode 100644
index 00000000..3010ca73
--- /dev/null
+++ b/src/search/passes/interval_analysis.h
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+#pragma once
+
+#include <algorithm>
+#include <memory>
+#include <set>
+#include <type_traits>
+
+#include "search/interval.h"
+#include "search/ir.h"
+#include "search/ir_pass.h"
+#include "type_util.h"
+
+namespace kqir {
+
+struct IntervalAnalysis : Visitor {
+  std::map<Node *, std::pair<std::string, IntervalSet>> result;
+
+  template <typename T>
+  std::unique_ptr<Node> VisitImpl(std::unique_ptr<T> node) {
+    node = Node::MustAs<T>(Visitor::Visit(std::move(node)));
+
+    std::map<std::string, std::pair<IntervalSet, std::set<Node *>>> 
interval_map;
+    for (const auto &n : node->inners) {
+      IntervalSet new_interval;
+      std::string new_field;
+
+      if (auto v = dynamic_cast<NumericCompareExpr *>(n.get())) {
+        new_interval = IntervalSet(v->op, v->num->val);
+        new_field = v->field->name;
+      } else if (auto iter = result.find(n.get()); iter != result.end()) {
+        new_interval = iter->second.second;
+        new_field = iter->second.first;
+      } else {
+        continue;
+      }
+
+      if (auto iter = interval_map.find(new_field); iter != 
interval_map.end()) {
+        if constexpr (std::is_same_v<T, OrExpr>) {
+          iter->second.first = iter->second.first | new_interval;
+        } else if constexpr (std::is_same_v<T, AndExpr>) {
+          iter->second.first = iter->second.first & new_interval;
+        } else {
+          static_assert(AlwaysFalse<T>);
+        }
+        iter->second.second.emplace(n.get());
+      } else {
+        interval_map.emplace(new_field, std::make_pair(new_interval, 
std::set<Node *>{n.get()}));
+      }
+    }
+
+    if (interval_map.size() == 1) {
+      const auto &elem = *interval_map.begin();
+      result.emplace(node.get(), std::make_pair(elem.first, 
elem.second.first));
+    }
+
+    for (const auto &[field, info] : interval_map) {
+      if (info.first.IsEmpty() || info.first.IsFull()) {
+        auto iter = std::remove_if(node->inners.begin(), node->inners.end(),
+                                   [&info = info](const auto &n) { return 
info.second.count(n.get()) == 1; });
+        node->inners.erase(iter, node->inners.end());
+      }
+
+      if (info.first.IsEmpty()) {
+        node->inners.emplace_back(std::make_unique<BoolLiteral>(false));
+      } else if (info.first.IsFull()) {
+        node->inners.emplace_back(std::make_unique<BoolLiteral>(true));
+      }
+    }
+
+    return node;
+  }
+
+  std::unique_ptr<Node> Visit(std::unique_ptr<OrExpr> node) override { return 
VisitImpl(std::move(node)); }
+
+  std::unique_ptr<Node> Visit(std::unique_ptr<AndExpr> node) override { return 
VisitImpl(std::move(node)); }
+};
+
+}  // namespace kqir
diff --git a/src/search/passes/manager.h b/src/search/passes/manager.h
index 5d11e670..094faa23 100644
--- a/src/search/passes/manager.h
+++ b/src/search/passes/manager.h
@@ -31,21 +31,24 @@
 
 namespace kqir {
 
+using PassSequence = std::vector<std::unique_ptr<Pass>>;
+
 struct PassManager {
-  template <typename... PN>
-  static std::unique_ptr<Node> Execute(std::unique_ptr<Node> node) {
-    return executeImpl<PN...>(std::move(node), 
std::make_index_sequence<sizeof...(PN)>{});
+  static std::unique_ptr<Node> Execute(const PassSequence &seq, 
std::unique_ptr<Node> node) {
+    for (auto &pass : seq) {
+      node = pass->Transform(std::move(node));
+    }
+    return node;
   }
 
-  static constexpr auto Default = Execute<SimplifyAndOrExpr, PushDownNotExpr, 
SimplifyBoolean>;
-
- private:
-  template <typename... PN, size_t... I>
-  static std::unique_ptr<Node> executeImpl(std::unique_ptr<Node> node, 
std::index_sequence<I...>) {
-    std::tuple<PN...> passes;
-
-    return std::move(((node = std::get<I>(passes).Transform(std::move(node))), 
...));
+  template <typename... Passes>
+  static PassSequence GeneratePasses() {
+    PassSequence result;
+    (result.push_back(std::make_unique<Passes>()), ...);
+    return result;
   }
+
+  static PassSequence ExprPasses() { return GeneratePasses<SimplifyAndOrExpr, 
PushDownNotExpr, SimplifyBoolean>(); }
 };
 
 }  // namespace kqir
diff --git a/src/search/passes/simplify_and_or_expr.h 
b/src/search/passes/simplify_and_or_expr.h
index 0ddaed61..22ac7b57 100644
--- a/src/search/passes/simplify_and_or_expr.h
+++ b/src/search/passes/simplify_and_or_expr.h
@@ -42,7 +42,11 @@ struct SimplifyAndOrExpr : Visitor {
       }
     }
 
-    return std::make_unique<OrExpr>(std::move(merged_nodes));
+    if (merged_nodes.size() == 1) {
+      return std::move(merged_nodes.front());
+    } else {
+      return std::make_unique<OrExpr>(std::move(merged_nodes));
+    }
   }
 
   std::unique_ptr<Node> Visit(std::unique_ptr<AndExpr> node) override {
@@ -59,7 +63,11 @@ struct SimplifyAndOrExpr : Visitor {
       }
     }
 
-    return std::make_unique<AndExpr>(std::move(merged_nodes));
+    if (merged_nodes.size() == 1) {
+      return std::move(merged_nodes.front());
+    } else {
+      return std::make_unique<AndExpr>(std::move(merged_nodes));
+    }
   }
 };
 
diff --git a/tests/cppunit/interval_test.cc b/tests/cppunit/interval_test.cc
new file mode 100644
index 00000000..5090c4c8
--- /dev/null
+++ b/tests/cppunit/interval_test.cc
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+#include "search/interval.h"
+
+#include <gtest/gtest.h>
+
+#include "search/ir.h"
+
+using namespace kqir;
+
+TEST(IntervalSet, Simple) {
+  ASSERT_TRUE(IntervalSet().IsEmpty());
+  ASSERT_TRUE(!IntervalSet().IsFull());
+  ASSERT_TRUE(IntervalSet(IntervalSet::full).IsFull());
+  ASSERT_TRUE(!IntervalSet(IntervalSet::full).IsEmpty());
+  ASSERT_TRUE((~IntervalSet()).IsFull());
+  ASSERT_TRUE((~IntervalSet(IntervalSet::full)).IsEmpty());
+
+  ASSERT_EQ(IntervalSet(Interval(1, 2)) | IntervalSet(Interval(2, 4)), 
IntervalSet(Interval(1, 4)));
+  ASSERT_EQ((IntervalSet(Interval(1, 2)) | IntervalSet(Interval(2, 
4))).intervals, (IntervalSet::DataType{{1, 4}}));
+  ASSERT_EQ((IntervalSet(Interval(1, 2)) | IntervalSet(Interval(3, 
4))).intervals,
+            (IntervalSet::DataType{{1, 2}, {3, 4}}));
+  ASSERT_EQ((IntervalSet(Interval(1, 4)) | IntervalSet(Interval(2, 
3))).intervals, (IntervalSet::DataType{{1, 4}}));
+  ASSERT_EQ((IntervalSet(Interval(2, 3)) | IntervalSet(Interval(1, 
4))).intervals, (IntervalSet::DataType{{1, 4}}));
+  ASSERT_EQ((IntervalSet(NumericCompareExpr::GET, 1) | 
IntervalSet(NumericCompareExpr::LT, 4)).intervals,
+            (IntervalSet::DataType{{IntervalSet::minf, IntervalSet::inf}}));
+  ASSERT_EQ((IntervalSet(NumericCompareExpr::GET, 1) | 
IntervalSet(NumericCompareExpr::NE, 4)).intervals,
+            (IntervalSet::DataType{{IntervalSet::minf, IntervalSet::inf}}));
+  ASSERT_EQ((IntervalSet(NumericCompareExpr::GET, 4) | 
IntervalSet(NumericCompareExpr::LT, 1)).intervals,
+            (IntervalSet::DataType{{IntervalSet::minf, 1}, {4, 
IntervalSet::inf}}));
+  ASSERT_EQ((IntervalSet(NumericCompareExpr::GET, 4) | 
IntervalSet(NumericCompareExpr::NE, 1)).intervals,
+            (IntervalSet::DataType{{IntervalSet::minf, 1}, 
{IntervalSet::NextNum(1), IntervalSet::inf}}));
+  ASSERT_EQ((IntervalSet(NumericCompareExpr::GET, 1) & 
IntervalSet(NumericCompareExpr::LT, 4)).intervals,
+            (IntervalSet::DataType{{1, 4}}));
+  ASSERT_EQ((IntervalSet(NumericCompareExpr::GET, 1) & 
IntervalSet(NumericCompareExpr::NE, 4)).intervals,
+            (IntervalSet::DataType{{1, 4}, {IntervalSet::NextNum(4), 
IntervalSet::inf}}));
+
+  ASSERT_EQ(IntervalSet(IntervalSet::full) & IntervalSet(IntervalSet::full), 
IntervalSet(IntervalSet::full));
+  ASSERT_EQ(IntervalSet(IntervalSet::full) | IntervalSet(IntervalSet::full), 
IntervalSet(IntervalSet::full));
+
+  ASSERT_EQ((IntervalSet({1, 5}) | IntervalSet({7, 10})) & IntervalSet({2, 8}),
+            IntervalSet({2, 5}) | IntervalSet({7, 8}));
+  ASSERT_EQ(~IntervalSet({2, 8}), IntervalSet({IntervalSet::minf, 2}) | 
IntervalSet({8, IntervalSet::inf}));
+
+  for (auto i = 0; i < 1000; ++i) {
+    auto gen = [] { return static_cast<double>(rand()) / 100; };
+    auto geni = [&gen] { return IntervalSet({gen(), gen()}); };
+    auto l = geni(), r = geni();
+    for (int j = 0; j < i % 10; ++j) {
+      l = l | geni();
+    }
+    for (int j = 0; j < i % 7; ++j) {
+      r = r | geni();
+    }
+    ASSERT_EQ(~l | ~r, ~(l & r));
+    ASSERT_EQ(~l & ~r, ~(l | r));
+  }
+}
diff --git a/tests/cppunit/ir_pass_test.cc b/tests/cppunit/ir_pass_test.cc
index 0f0952bf..f40dd8a4 100644
--- a/tests/cppunit/ir_pass_test.cc
+++ b/tests/cppunit/ir_pass_test.cc
@@ -21,6 +21,7 @@
 #include "search/ir_pass.h"
 
 #include "gtest/gtest.h"
+#include "search/passes/interval_analysis.h"
 #include "search/passes/lower_to_plan.h"
 #include "search/passes/manager.h"
 #include "search/passes/push_down_not_expr.h"
@@ -100,9 +101,11 @@ TEST(IRPassTest, PushDownNotExpr) {
 }
 
 TEST(IRPassTest, Manager) {
-  ASSERT_EQ(
-      PassManager::Default(*Parse("select * from a where not (x > 1 or (y < 2 
or z = 3)) and (true or x = 1)"))->Dump(),
-      "select * from a where (and x <= 1, y >= 2, z != 3)");
+  auto expr_passes = PassManager::ExprPasses();
+  ASSERT_EQ(PassManager::Execute(expr_passes,
+                                 *Parse("select * from a where not (x > 1 or 
(y < 2 or z = 3)) and (true or x = 1)"))
+                ->Dump(),
+            "select * from a where (and x <= 1, y >= 2, z != 3)");
 }
 
 TEST(IRPassTest, LowerToPlan) {
@@ -117,3 +120,14 @@ TEST(IRPassTest, LowerToPlan) {
   ASSERT_EQ(ltp.Transform(*Parse("select a from b where c = 1 order by d limit 
1"))->Dump(),
             "project a: (limit 0, 1: (sort d, asc: (filter c = 1: full-scan 
b)))");
 }
+
+TEST(IRPassTest, IntervalAnalysis) {
+  auto ia_passes = PassManager::GeneratePasses<IntervalAnalysis, 
SimplifyAndOrExpr, SimplifyBoolean>();
+
+  ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a > 
1 or a < 3"))->Dump(),
+            "select * from a where true");
+  ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a < 
1 and a > 3"))->Dump(),
+            "select * from a where false");
+  ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where (a > 
3 or a < 1) and a = 2"))->Dump(),
+            "select * from a where false");
+}

Reply via email to