This is an automated email from the ASF dual-hosted git repository.
twice pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks.git
The following commit(s) were added to refs/heads/unstable by this push:
new 899216db Optimize numeric comparison via interval analysis in KQIR
(#2257)
899216db is described below
commit 899216db7316083b5de9cef03dd0e818ededed96
Author: Twice <[email protected]>
AuthorDate: Sun Apr 21 15:17:24 2024 +0900
Optimize numeric comparison via interval analysis in KQIR (#2257)
Co-authored-by: hulk <[email protected]>
---
src/search/ir.h | 1 +
src/search/ir_pass.h | 2 +
src/search/passes/interval_analysis.h | 115 ++++++++++++++++++++++++++++------
src/search/passes/manager.h | 1 +
tests/cppunit/ir_pass_test.cc | 32 ++++++++++
5 files changed, 133 insertions(+), 18 deletions(-)
diff --git a/src/search/ir.h b/src/search/ir.h
index d7da716a..3acc5ae8 100644
--- a/src/search/ir.h
+++ b/src/search/ir.h
@@ -85,6 +85,7 @@ struct FieldRef : Ref {
const FieldInfo *info = nullptr;
explicit FieldRef(std::string name) : name(std::move(name)) {}
+ FieldRef(std::string name, const FieldInfo *info) : name(std::move(name)),
info(info) {}
std::string_view Name() const override { return "FieldRef"; }
std::string Dump() const override { return name; }
diff --git a/src/search/ir_pass.h b/src/search/ir_pass.h
index 5fa57b1a..2068a45a 100644
--- a/src/search/ir_pass.h
+++ b/src/search/ir_pass.h
@@ -28,6 +28,8 @@ namespace kqir {
struct Pass {
virtual std::unique_ptr<Node> Transform(std::unique_ptr<Node> node) = 0;
+ virtual void Reset() {}
+
virtual ~Pass() = default;
};
diff --git a/src/search/passes/interval_analysis.h
b/src/search/passes/interval_analysis.h
index 3010ca73..59959979 100644
--- a/src/search/passes/interval_analysis.h
+++ b/src/search/passes/interval_analysis.h
@@ -21,6 +21,7 @@
#pragma once
#include <algorithm>
+#include <cmath>
#include <memory>
#include <set>
#include <type_traits>
@@ -33,61 +34,139 @@
namespace kqir {
struct IntervalAnalysis : Visitor {
- std::map<Node *, std::pair<std::string, IntervalSet>> result;
+ struct IntervalInfo {
+ std::string field_name;
+ const FieldInfo *field_info;
+ IntervalSet intervals;
+ };
+
+ std::map<Node *, IntervalInfo> result;
+
+ void Reset() override { result.clear(); }
template <typename T>
std::unique_ptr<Node> VisitImpl(std::unique_ptr<T> node) {
node = Node::MustAs<T>(Visitor::Visit(std::move(node)));
- std::map<std::string, std::pair<IntervalSet, std::set<Node *>>>
interval_map;
+ struct LocalIntervalInfo {
+ IntervalSet intervals;
+ std::set<Node *> nodes;
+ const FieldInfo *field;
+ };
+
+ std::map<std::string, LocalIntervalInfo> interval_map;
for (const auto &n : node->inners) {
IntervalSet new_interval;
+ const FieldInfo *new_field_info = nullptr;
std::string new_field;
if (auto v = dynamic_cast<NumericCompareExpr *>(n.get())) {
new_interval = IntervalSet(v->op, v->num->val);
new_field = v->field->name;
+ new_field_info = v->field->info;
} else if (auto iter = result.find(n.get()); iter != result.end()) {
- new_interval = iter->second.second;
- new_field = iter->second.first;
+ new_interval = iter->second.intervals;
+ new_field = iter->second.field_name;
+ new_field_info = iter->second.field_info;
} else {
continue;
}
if (auto iter = interval_map.find(new_field); iter !=
interval_map.end()) {
if constexpr (std::is_same_v<T, OrExpr>) {
- iter->second.first = iter->second.first | new_interval;
+ iter->second.intervals = iter->second.intervals | new_interval;
} else if constexpr (std::is_same_v<T, AndExpr>) {
- iter->second.first = iter->second.first & new_interval;
+ iter->second.intervals = iter->second.intervals & new_interval;
} else {
static_assert(AlwaysFalse<T>);
}
- iter->second.second.emplace(n.get());
+ iter->second.nodes.emplace(n.get());
+ iter->second.field = new_field_info;
} else {
- interval_map.emplace(new_field, std::make_pair(new_interval,
std::set<Node *>{n.get()}));
+ interval_map.emplace(new_field, LocalIntervalInfo{new_interval,
std::set<Node *>{n.get()}, new_field_info});
}
}
if (interval_map.size() == 1) {
const auto &elem = *interval_map.begin();
- result.emplace(node.get(), std::make_pair(elem.first,
elem.second.first));
+ result.emplace(node.get(), IntervalInfo{elem.first, elem.second.field,
elem.second.intervals});
}
for (const auto &[field, info] : interval_map) {
- if (info.first.IsEmpty() || info.first.IsFull()) {
- auto iter = std::remove_if(node->inners.begin(), node->inners.end(),
- [&info = info](const auto &n) { return
info.second.count(n.get()) == 1; });
- node->inners.erase(iter, node->inners.end());
+ auto iter = std::remove_if(node->inners.begin(), node->inners.end(),
+ [&info = info](const auto &n) { return
info.nodes.count(n.get()) == 1; });
+ node->inners.erase(iter, node->inners.end());
+
+ auto field_node = std::make_unique<FieldRef>(field, info.field);
+ node->inners.emplace_back(GenerateFromInterval(info.intervals,
field_node.get()));
+ }
+
+ return node;
+ }
+
+ static std::unique_ptr<QueryExpr> GenerateFromInterval(const IntervalSet
&intervals, FieldRef *field) {
+ if (intervals.IsEmpty()) {
+ return std::make_unique<BoolLiteral>(false);
+ }
+
+ if (intervals.IsFull()) {
+ return std::make_unique<BoolLiteral>(true);
+ }
+
+ std::vector<std::unique_ptr<QueryExpr>> exprs;
+
+ if (intervals.intervals.size() > 1 &&
std::isinf(intervals.intervals.front().first) &&
+ std::isinf(intervals.intervals.back().second)) {
+ bool is_all_ne = true;
+ auto iter = intervals.intervals.begin();
+ auto last = iter->second;
+ ++iter;
+ while (iter != intervals.intervals.end()) {
+ if (iter->first != IntervalSet::NextNum(last)) {
+ is_all_ne = false;
+ break;
+ }
+
+ last = iter->second;
+ ++iter;
}
- if (info.first.IsEmpty()) {
- node->inners.emplace_back(std::make_unique<BoolLiteral>(false));
- } else if (info.first.IsFull()) {
- node->inners.emplace_back(std::make_unique<BoolLiteral>(true));
+ if (is_all_ne) {
+ for (auto i = intervals.intervals.begin(); i !=
intervals.intervals.end() && !std::isinf(i->second); ++i) {
+
exprs.emplace_back(std::make_unique<NumericCompareExpr>(NumericCompareExpr::NE,
field->CloneAs<FieldRef>(),
+
std::make_unique<NumericLiteral>(i->second)));
+ }
+
+ return std::make_unique<AndExpr>(std::move(exprs));
}
}
- return node;
+ for (auto [l, r] : intervals.intervals) {
+ if (std::isinf(l)) {
+
exprs.emplace_back(std::make_unique<NumericCompareExpr>(NumericCompareExpr::LT,
field->CloneAs<FieldRef>(),
+
std::make_unique<NumericLiteral>(r)));
+ } else if (std::isinf(r)) {
+
exprs.emplace_back(std::make_unique<NumericCompareExpr>(NumericCompareExpr::GET,
field->CloneAs<FieldRef>(),
+
std::make_unique<NumericLiteral>(l)));
+ } else if (r == IntervalSet::NextNum(l)) {
+
exprs.emplace_back(std::make_unique<NumericCompareExpr>(NumericCompareExpr::EQ,
field->CloneAs<FieldRef>(),
+
std::make_unique<NumericLiteral>(l)));
+ } else {
+ std::vector<std::unique_ptr<QueryExpr>> sub_expr;
+
sub_expr.emplace_back(std::make_unique<NumericCompareExpr>(NumericCompareExpr::GET,
field->CloneAs<FieldRef>(),
+
std::make_unique<NumericLiteral>(l)));
+
sub_expr.emplace_back(std::make_unique<NumericCompareExpr>(NumericCompareExpr::LT,
field->CloneAs<FieldRef>(),
+
std::make_unique<NumericLiteral>(r)));
+
+ exprs.emplace_back(std::make_unique<AndExpr>(std::move(sub_expr)));
+ }
+ }
+
+ if (exprs.size() == 1) {
+ return std::move(exprs.front());
+ } else {
+ return std::make_unique<OrExpr>(std::move(exprs));
+ }
}
std::unique_ptr<Node> Visit(std::unique_ptr<OrExpr> node) override { return
VisitImpl(std::move(node)); }
diff --git a/src/search/passes/manager.h b/src/search/passes/manager.h
index 094faa23..480e27a7 100644
--- a/src/search/passes/manager.h
+++ b/src/search/passes/manager.h
@@ -36,6 +36,7 @@ using PassSequence = std::vector<std::unique_ptr<Pass>>;
struct PassManager {
static std::unique_ptr<Node> Execute(const PassSequence &seq,
std::unique_ptr<Node> node) {
for (auto &pass : seq) {
+ pass->Reset();
node = pass->Transform(std::move(node));
}
return node;
diff --git a/tests/cppunit/ir_pass_test.cc b/tests/cppunit/ir_pass_test.cc
index f40dd8a4..bfb630a5 100644
--- a/tests/cppunit/ir_pass_test.cc
+++ b/tests/cppunit/ir_pass_test.cc
@@ -21,6 +21,8 @@
#include "search/ir_pass.h"
#include "gtest/gtest.h"
+#include "search/interval.h"
+#include "search/ir_sema_checker.h"
#include "search/passes/interval_analysis.h"
#include "search/passes/lower_to_plan.h"
#include "search/passes/manager.h"
@@ -130,4 +132,34 @@ TEST(IRPassTest, IntervalAnalysis) {
"select * from a where false");
ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where (a >
3 or a < 1) and a = 2"))->Dump(),
"select * from a where false");
+ ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where b =
1 and (a = 1 or a != 1)"))->Dump(),
+ "select * from a where b = 1");
+ ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a =
1 or b = 1 or a != 1"))->Dump(),
+ "select * from a where true");
+ ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where (a <
3 or a > 1) and b >= 1"))->Dump(),
+ "select * from a where b >= 1");
+ ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a !=
1 or a != 2"))->Dump(),
+ "select * from a where true");
+ ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a =
1 and a = 2"))->Dump(),
+ "select * from a where false");
+
+ ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a <
1 and a < 3"))->Dump(),
+ "select * from a where a < 1");
+ ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a <
1 or a < 3"))->Dump(),
+ "select * from a where a < 3");
+ ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a =
1 and a < 3"))->Dump(),
+ "select * from a where a = 1");
+ ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a =
1 or a < 3"))->Dump(),
+ "select * from a where a < 3");
+ ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a =
1 or a = 3"))->Dump(),
+ "select * from a where (or a = 1, a = 3)");
+ ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a !=
1"))->Dump(),
+ "select * from a where a != 1");
+ ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a !=
1 and a != 2"))->Dump(),
+ "select * from a where (and a != 1, a != 2)");
+ ASSERT_EQ(
+ PassManager::Execute(ia_passes, *Parse("select * from a where a >= 0 and
a >= 1 and a < 4 and a != 2"))->Dump(),
+ fmt::format("select * from a where (or (and a >= 1, a < 2), (and a >=
{}, a < 4))", IntervalSet::NextNum(2)));
+ ASSERT_EQ(PassManager::Execute(ia_passes, *Parse("select * from a where a !=
1 and b > 1 and b = 2"))->Dump(),
+ "select * from a where (and a != 1, b = 2)");
}