This is an automated email from the ASF dual-hosted git repository.
twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git
The following commit(s) were added to refs/heads/unstable by this push:
new bc1fc9fc Add a univariate interval analysis pass for KQIR (#2256)
bc1fc9fc is described below
commit bc1fc9fc6611b11009d35dac131dcc0142f9affd
Author: Twice <[email protected]>
AuthorDate: Sat Apr 20 12:16:08 2024 +0900
Add a univariate interval analysis pass for KQIR (#2256)
---
src/search/interval.h | 185 +++++++++++++++++++++++++++++++
src/search/ir_pass.h | 2 +
src/search/ir_plan.h | 9 +-
src/search/passes/interval_analysis.h | 98 ++++++++++++++++
src/search/passes/manager.h | 25 +++--
src/search/passes/simplify_and_or_expr.h | 12 +-
tests/cppunit/interval_test.cc | 76 +++++++++++++
tests/cppunit/ir_pass_test.cc | 20 +++-
8 files changed, 403 insertions(+), 24 deletions(-)
diff --git a/src/search/interval.h b/src/search/interval.h
new file mode 100644
index 00000000..d6ebbf5c
--- /dev/null
+++ b/src/search/interval.h
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+#pragma once
+
+#include <cmath>
+#include <limits>
+#include <map>
+#include <string>
+
+#include "fmt/format.h"
+#include "search/ir.h"
+#include "string_util.h"
+
+namespace kqir {
+
+struct Interval {
+ double l, r; // [l, r)
+
+ Interval(double l, double r) : l(l), r(r) {}
+
+ bool IsEmpty() const { return l >= r; }
+
+ bool operator==(const Interval &other) const { return l == other.l && r ==
other.r; }
+ bool operator!=(const Interval &other) const { return !(*this == other); }
+
+ std::string ToString() const { return fmt::format("[{}, {})", l, r); }
+};
+
+template <typename Iter1, typename Iter2, typename F>
+void ForEachMerged(Iter1 first1, Iter1 last1, Iter2 first2, Iter2 last2, F
&&f) {
+ while (first1 != last1) {
+ if (first2 == last2) {
+ std::for_each(first1, last1, std::forward<F>(f));
+ return;
+ }
+
+ if (*first2 < *first1) {
+ std::forward<F>(f)(*first2);
+ ++first2;
+ } else {
+ std::forward<F>(f)(*first1);
+ ++first1;
+ }
+ }
+ std::for_each(first2, last2, std::forward<F>(f));
+}
+
+struct IntervalSet {
+ // NOTE: element must be sorted in this vector
+ // but we don't need to use map here
+ using DataType = std::vector<std::pair<double, double>>;
+ DataType intervals;
+
+ static inline const double inf = std::numeric_limits<double>::infinity();
+ static inline const double minf = -inf;
+
+ static double NextNum(double val) { return std::nextafter(val, inf); }
+
+ static double PrevNum(double val) { return std::nextafter(val, minf); }
+
+ explicit IntervalSet() = default;
+
+ struct Full {};
+ static constexpr const Full full{};
+
+ explicit IntervalSet(Full) { intervals.emplace_back(minf, inf); }
+
+ explicit IntervalSet(Interval range) {
+ if (!range.IsEmpty()) intervals.emplace_back(range.l, range.r);
+ }
+
+ IntervalSet(NumericCompareExpr::Op op, double val) {
+ if (op == NumericCompareExpr::EQ) {
+ intervals.emplace_back(val, NextNum(val));
+ } else if (op == NumericCompareExpr::NE) {
+ intervals.emplace_back(minf, val);
+ intervals.emplace_back(NextNum(val), inf);
+ } else if (op == NumericCompareExpr::LT) {
+ intervals.emplace_back(minf, val);
+ } else if (op == NumericCompareExpr::GT) {
+ intervals.emplace_back(NextNum(val), inf);
+ } else if (op == NumericCompareExpr::LET) {
+ intervals.emplace_back(minf, NextNum(val));
+ } else if (op == NumericCompareExpr::GET) {
+ intervals.emplace_back(val, inf);
+ }
+ }
+
+ bool operator==(const IntervalSet &other) const { return intervals ==
other.intervals; }
+ bool operator!=(const IntervalSet &other) const { return intervals !=
other.intervals; }
+
+ std::string ToString() const {
+ if (IsEmpty()) return "empty set";
+ return util::StringJoin(
+ intervals, [](const auto &i) { return Interval(i.first,
i.second).ToString(); }, " or ");
+ }
+
+ friend std::ostream &operator<<(std::ostream &os, const IntervalSet &is) {
return os << is.ToString(); }
+
+ bool IsEmpty() const { return intervals.empty(); }
+ bool IsFull() const {
+ if (intervals.size() != 1) return false;
+
+ const auto &v = *intervals.begin();
+ return std::isinf(v.first) && std::isinf(v.second) && v.first * v.second <
0;
+ }
+
+ friend IntervalSet operator&(const IntervalSet &l, const IntervalSet &r) {
+ if (l.IsEmpty() || r.IsEmpty()) {
+ return IntervalSet();
+ }
+
+ return ~(~l | ~r);
+ }
+
+ friend IntervalSet operator|(const IntervalSet &l, const IntervalSet &r) {
+ if (l.IsEmpty()) {
+ return r;
+ }
+
+ if (r.IsEmpty()) {
+ return l;
+ }
+
+ IntervalSet result;
+ ForEachMerged(l.intervals.begin(), l.intervals.end(), r.intervals.begin(),
r.intervals.end(),
+ [&result](const auto &v) {
+ if (result.IsEmpty() || result.intervals.rbegin()->second
< v.first) {
+ result.intervals.emplace_back(v.first, v.second);
+ } else {
+ result.intervals.rbegin()->second =
std::max(result.intervals.rbegin()->second, v.second);
+ }
+ });
+
+ return result;
+ }
+
+ friend IntervalSet operator~(const IntervalSet &v) {
+ if (v.IsEmpty()) {
+ return IntervalSet(full);
+ }
+
+ IntervalSet result;
+
+ auto iter = v.intervals.begin();
+ if (!std::isinf(iter->first)) {
+ result.intervals.emplace_back(minf, iter->first);
+ }
+
+ double last = iter->second;
+ ++iter;
+ while (iter != v.intervals.end()) {
+ result.intervals.emplace_back(last, iter->first);
+
+ last = iter->second;
+ ++iter;
+ }
+
+ if (!std::isinf(last)) {
+ result.intervals.emplace_back(last, inf);
+ }
+
+ return result;
+ }
+};
+
+} // namespace kqir
diff --git a/src/search/ir_pass.h b/src/search/ir_pass.h
index 924e8c75..5fa57b1a 100644
--- a/src/search/ir_pass.h
+++ b/src/search/ir_pass.h
@@ -27,6 +27,8 @@ namespace kqir {
struct Pass {
virtual std::unique_ptr<Node> Transform(std::unique_ptr<Node> node) = 0;
+
+ virtual ~Pass() = default;
};
struct Visitor : Pass {
diff --git a/src/search/ir_plan.h b/src/search/ir_plan.h
index da805846..94f801f4 100644
--- a/src/search/ir_plan.h
+++ b/src/search/ir_plan.h
@@ -24,6 +24,7 @@
#include <memory>
#include "ir.h"
+#include "search/interval.h"
#include "search/ir_sema_checker.h"
#include "string_util.h"
@@ -57,14 +58,6 @@ struct FieldScan : PlanOperator {
explicit FieldScan(std::unique_ptr<FieldRef> field) :
field(std::move(field)) {}
};
-struct Interval {
- double l, r; // [l, r)
-
- explicit Interval(double l, double r) : l(l), r(r) {}
-
- std::string ToString() const { return fmt::format("[{}, {})", l, r); }
-};
-
struct NumericFieldScan : FieldScan {
Interval range;
diff --git a/src/search/passes/interval_analysis.h
b/src/search/passes/interval_analysis.h
new file mode 100644
index 00000000..3010ca73
--- /dev/null
+++ b/src/search/passes/interval_analysis.h
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+#pragma once
+
+#include <algorithm>
+#include <memory>
+#include <set>
+#include <type_traits>
+
+#include "search/interval.h"
+#include "search/ir.h"
+#include "search/ir_pass.h"
+#include "type_util.h"
+
+namespace kqir {
+
+struct IntervalAnalysis : Visitor {
+ std::map<Node *, std::pair<std::string, IntervalSet>> result;
+
+ template <typename T>
+ std::unique_ptr<Node> VisitImpl(std::unique_ptr<T> node) {
+ node = Node::MustAs<T>(Visitor::Visit(std::move(node)));
+
+ std::map<std::string, std::pair<IntervalSet, std::set<Node *>>>
interval_map;
+ for (const auto &n : node->inners) {
+ IntervalSet new_interval;
+ std::string new_field;
+
+ if (auto v = dynamic_cast<NumericCompareExpr *>(n.get())) {
+ new_interval = IntervalSet(v->op, v->num->val);
+ new_field = v->field->name;
+ } else if (auto iter = result.find(n.get()); iter != result.end()) {
+ new_interval = iter->second.second;
+ new_field = iter->second.first;
+ } else {
+ continue;
+ }
+
+ if (auto iter = interval_map.find(new_field); iter !=
interval_map.end()) {
+ if constexpr (std::is_same_v<T, OrExpr>) {
+ iter->second.first = iter->second.first | new_interval;
+ } else if constexpr (std::is_same_v<T, AndExpr>) {
+ iter->second.first = iter->second.first & new_interval;
+ } else {
+ static_assert(AlwaysFalse<T>);
+ }
+ iter->second.second.emplace(n.get());
+ } else {
+ interval_map.emplace(new_field, std::make_pair(new_interval,
std::set<Node *>{n.get()}));
+ }
+ }
+
+ if (interval_map.size() == 1) {
+ const auto &elem = *interval_map.begin();
+ result.emplace(node.get(), std::make_pair(elem.first,
elem.second.first));
+ }
+
+ for (const auto &[field, info] : interval_map) {
+ if (info.first.IsEmpty() || info.first.IsFull()) {
+ auto iter = std::remove_if(node->inners.begin(), node->inners.end(),
+ [&info = info](const auto &n) { return
info.second.count(n.get()) == 1; });
+ node->inners.erase(iter, node->inners.end());
+ }
+
+ if (info.first.IsEmpty()) {
+ node->inners.emplace_back(std::make_unique<BoolLiteral>(false));
+ } else if (info.first.IsFull()) {
+ node->inners.emplace_back(std::make_unique<BoolLiteral>(true));
+ }
+ }
+
+ return node;
+ }
+
+ std::unique_ptr<Node> Visit(std::unique_ptr<OrExpr> node) override { return
VisitImpl(std::move(node)); }
+
+ std::unique_ptr<Node> Visit(std::unique_ptr<AndExpr> node) override { return
VisitImpl(std::move(node)); }
+};
+
+} // namespace kqir
diff --git a/src/search/passes/manager.h b/src/search/passes/manager.h
index 5d11e670..094faa23 100644
--- a/src/search/passes/manager.h
+++ b/src/search/passes/manager.h
@@ -31,21 +31,24 @@
namespace kqir {
+using PassSequence = std::vector<std::unique_ptr<Pass>>;
+
struct PassManager {
- template <typename... PN>
- static std::unique_ptr<Node> Execute(std::unique_ptr<Node> node) {
- return executeImpl<PN...>(std::move(node),
std::make_index_sequence<sizeof...(PN)>{});
+ static std::unique_ptr<Node> Execute(const PassSequence &seq,
std::unique_ptr<Node> node) {
+ for (auto &pass : seq) {
+ node = pass->Transform(std::move(node));
+ }
+ return node;
}
- static constexpr auto Default = Execute<SimplifyAndOrExpr, PushDownNotExpr,
SimplifyBoolean>;
-
- private:
- template <typename... PN, size_t... I>
- static std::unique_ptr<Node> executeImpl(std::unique_ptr<Node> node,
std::index_sequence<I...>) {
- std::tuple<PN...> passes;
-
- return std::move(((node = std::get<I>(passes).Transform(std::move(node))),
...));
+ template <typename... Passes>
+ static PassSequence GeneratePasses() {
+ PassSequence result;
+ (result.push_back(std::make_unique<Passes>()), ...);
+ return result;
}
+
+ static PassSequence ExprPasses() { return GeneratePasses<SimplifyAndOrExpr,
PushDownNotExpr, SimplifyBoolean>(); }
};
} // namespace kqir
diff --git a/src/search/passes/simplify_and_or_expr.h
b/src/search/passes/simplify_and_or_expr.h
index 0ddaed61..22ac7b57 100644
--- a/src/search/passes/simplify_and_or_expr.h
+++ b/src/search/passes/simplify_and_or_expr.h
@@ -42,7 +42,11 @@ struct SimplifyAndOrExpr : Visitor {
}
}
- return std::make_unique<OrExpr>(std::move(merged_nodes));
+ if (merged_nodes.size() == 1) {
+ return std::move(merged_nodes.front());
+ } else {
+ return std::make_unique<OrExpr>(std::move(merged_nodes));
+ }
}
std::unique_ptr<Node> Visit(std::unique_ptr<AndExpr> node) override {
@@ -59,7 +63,11 @@ struct SimplifyAndOrExpr : Visitor {
}
}
- return std::make_unique<AndExpr>(std::move(merged_nodes));
+ if (merged_nodes.size() == 1) {
+ return std::move(merged_nodes.front());
+ } else {
+ return std::make_unique<AndExpr>(std::move(merged_nodes));
+ }
}
};
diff --git a/tests/cppunit/interval_test.cc b/tests/cppunit/interval_test.cc
new file mode 100644
index 00000000..5090c4c8
--- /dev/null
+++ b/tests/cppunit/interval_test.cc
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ */
+
+#include "search/interval.h"
+
+#include <gtest/gtest.h>
+
+#include "search/ir.h"
+
+using namespace kqir;
+
+TEST(IntervalSet, Simple) {
+ ASSERT_TRUE(IntervalSet().IsEmpty());
+ ASSERT_TRUE(!IntervalSet().IsFull());
+ ASSERT_TRUE(IntervalSet(IntervalSet::full).IsFull());
+ ASSERT_TRUE(!IntervalSet(IntervalSet::full).IsEmpty());
+ ASSERT_TRUE((~IntervalSet()).IsFull());
+ ASSERT_TRUE((~IntervalSet(IntervalSet::full)).IsEmpty());
+
+ ASSERT_EQ(IntervalSet(Interval(1, 2)) | IntervalSet(Interval(2, 4)),
IntervalSet(Interval(1, 4)));
+ ASSERT_EQ((IntervalSet(Interval(1, 2)) | IntervalSet(Interval(2,
4))).intervals, (IntervalSet::DataType{{1, 4}}));
+ ASSERT_EQ((IntervalSet(Interval(1, 2)) | IntervalSet(Interval(3,
4))).intervals,
+ (IntervalSet::DataType{{1, 2}, {3, 4}}));
+ ASSERT_EQ((IntervalSet(Interval(1, 4)) | IntervalSet(Interval(2,
3))).intervals, (IntervalSet::DataType{{1, 4}}));
+ ASSERT_EQ((IntervalSet(Interval(2, 3)) | IntervalSet(Interval(1,
4))).intervals, (IntervalSet::DataType{{1, 4}}));
+ ASSERT_EQ((IntervalSet(NumericCompareExpr::GET, 1) |
IntervalSet(NumericCompareExpr::LT, 4)).intervals,
+ (IntervalSet::DataType{{IntervalSet::minf, IntervalSet::inf}}));
+ ASSERT_EQ((IntervalSet(NumericCompareExpr::GET, 1) |
IntervalSet(NumericCompareExpr::NE, 4)).intervals,
+ (IntervalSet::DataType{{IntervalSet::minf, IntervalSet::inf}}));
+ ASSERT_EQ((IntervalSet(NumericCompareExpr::GET, 4) |
IntervalSet(NumericCompareExpr::LT, 1)).intervals,
+ (IntervalSet::DataType{{IntervalSet::minf, 1}, {4,
IntervalSet::inf}}));
+ ASSERT_EQ((IntervalSet(NumericCompareExpr::GET, 4) |
IntervalSet(NumericCompareExpr::NE, 1)).intervals,
+ (IntervalSet::DataType{{IntervalSet::minf, 1},
{IntervalSet::NextNum(1), IntervalSet::inf}}));
+ ASSERT_EQ((IntervalSet(NumericCompareExpr::GET, 1) &
IntervalSet(NumericCompareExpr::LT, 4)).intervals,
+ (IntervalSet::DataType{{1, 4}}));
+ ASSERT_EQ((IntervalSet(NumericCompareExpr::GET, 1) &
IntervalSet(NumericCompareExpr::NE, 4)).intervals,
+ (IntervalSet::DataType{{1, 4}, {IntervalSet::NextNum(4),
IntervalSet::inf}}));
+
+ ASSERT_EQ(IntervalSet(IntervalSet::full) & IntervalSet(IntervalSet::full),
IntervalSet(IntervalSet::full));
+ ASSERT_EQ(IntervalSet(IntervalSet::full) | IntervalSet(IntervalSet::full),
IntervalSet(IntervalSet::full));
+
+ ASSERT_EQ((IntervalSet({1, 5}) | IntervalSet({7, 10})) & IntervalSet({2, 8}),
+ IntervalSet({2, 5}) | IntervalSet({7, 8}));
+ ASSERT_EQ(~IntervalSet({2, 8}), IntervalSet({IntervalSet::minf, 2}) |
IntervalSet({8, IntervalSet::inf}));
+
+ for (auto i = 0; i < 1000; ++i) {
+ auto gen = [] { return static_cast<double>(rand()) / 100; };
+ auto geni = [&gen] { return IntervalSet({gen(), gen()}); };
+ auto l = geni(), r = geni();
+ for (int j = 0; j < i % 10; ++j) {
+ l = l | geni();
+ }
+ for (int j = 0; j < i % 7; ++j) {
+ r = r | geni();
+ }
+ ASSERT_EQ(~l | ~r, ~(l & r));
+ ASSERT_EQ(~l & ~r, ~(l | r));
+ }
+}
diff --git a/tests/cppunit/ir_pass_test.cc b/tests/cppunit/ir_pass_test.cc
index 0f0952bf..f40dd8a4 100644
--- a/tests/cppunit/ir_pass_test.cc
+++ b/tests/cppunit/ir_pass_test.cc
@@ -21,6 +21,7 @@
#include "search/ir_pass.h"
#include "gtest/gtest.h"
+#include "search/passes/interval_analysis.h"
#include "search/passes/lower_to_plan.h"
#include "search/passes/manager.h"
#include "search/passes/push_down_not_expr.h"
@@ -100,9 +101,11 @@ TEST(IRPassTest, PushDownNotExpr) {
}
TEST(IRPassTest, Manager) {
- ASSERT_EQ(
- PassManager::Default(*Parse("select * from a where not (x > 1 or (y < 2
or z = 3)) and (true or x = 1)"))->Dump(),
- "select * from a where (and x <= 1, y >= 2, z != 3)");
+ auto expr_passes = PassManager::ExprPasses();
+ ASSERT_EQ(PassManager::Execute(expr_passes,
+ *Parse("select * from a where not (x > 1 or
(y < 2 or z = 3)) and (true or x = 1)"))
+ ->Dump(),
+ "select * from a where (and x <= 1, y >= 2, z != 3)");
}
TEST(IRPassTest, LowerToPlan) {
@@ -117,3 +120,14 @@ TEST(IRPassTest, LowerToPlan) {
ASSERT_EQ(ltp.Transform(*Parse("select a from b where c = 1 order by d limit
1"))->Dump(),
"project a: (limit 0, 1: (sort d, asc: (filter c = 1: full-scan
b)))");
}
+
+TEST(IRPassTest, IntervalAnalysis) {
+ auto ia_passes = PassManager::GeneratePasses<IntervalAnalysis,
SimplifyAndOrExpr, SimplifyBoolean>();
+
+ ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a >
1 or a < 3"))->Dump(),
+ "select * from a where true");
+ ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a <
1 and a > 3"))->Dump(),
+ "select * from a where false");
+ ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where (a >
3 or a < 1) and a = 2"))->Dump(),
+ "select * from a where false");
+}