This is an automated email from the ASF dual-hosted git repository.
yangsiyu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new acc243383ca [feature](bm25) support score range filter pushdown
(min_score semantics) (#60100)
acc243383ca is described below
commit acc243383caa284133f04e05060bde5fe7b4dbf9
Author: zzzxl <[email protected]>
AuthorDate: Wed Feb 25 14:25:58 2026 +0800
[feature](bm25) support score range filter pushdown (min_score semantics)
(#60100)
---
be/src/olap/collection_similarity.cpp | 96 ++--
be/src/olap/collection_similarity.h | 22 +-
be/src/olap/rowset/segment_v2/segment_iterator.cpp | 31 +-
be/src/pipeline/exec/olap_scan_operator.cpp | 11 +
be/src/vec/exec/scan/olap_scanner.cpp | 2 -
be/src/vec/exec/scan/scanner_scheduler.cpp | 4 +
be/src/vec/exprs/score_runtime.h | 22 +
be/test/olap/collection_similarity_test.cpp | 541 +++++++++++++++++++++
be/test/vec/exprs/score_runtime_test.cpp | 259 ++++++++++
.../glue/translator/PhysicalPlanTranslator.java | 3 +
.../LogicalOlapScanToPhysicalOlapScan.java | 1 +
.../rewrite/PushDownScoreTopNIntoOlapScan.java | 176 ++++++-
.../rewrite/PushDownVectorTopNIntoOlapScan.java | 2 +-
.../trees/expressions/ComparisonPredicate.java | 18 +
.../doris/nereids/trees/plans/ScoreRangeInfo.java | 79 +++
.../trees/plans/logical/LogicalOlapScan.java | 61 ++-
.../physical/PhysicalLazyMaterializeOlapScan.java | 1 +
.../trees/plans/physical/PhysicalOlapScan.java | 26 +-
.../org/apache/doris/planner/OlapScanNode.java | 9 +
.../translator/PhysicalPlanTranslatorTest.java | 2 +-
.../postprocess/MergeProjectPostProcessTest.java | 2 +-
.../PushDownFilterThroughProjectTest.java | 4 +-
.../doris/nereids/trees/plans/PlanEqualsTest.java | 6 +-
gensrc/thrift/PlanNodes.thrift | 7 +
.../test_bm25_score_range_filter.out | 33 ++
.../inverted_index_p0/test_bm25_score.groovy | 6 +-
.../test_bm25_score_range_filter.groovy | 159 ++++++
27 files changed, 1471 insertions(+), 112 deletions(-)
diff --git a/be/src/olap/collection_similarity.cpp
b/be/src/olap/collection_similarity.cpp
index 1223e20c0bd..415fce0ebe0 100644
--- a/be/src/olap/collection_similarity.cpp
+++ b/be/src/olap/collection_similarity.cpp
@@ -29,25 +29,32 @@ void CollectionSimilarity::collect(segment_v2::rowid_t
row_id, float score) {
void CollectionSimilarity::get_bm25_scores(roaring::Roaring* row_bitmap,
vectorized::IColumn::MutablePtr&
scores,
-
std::unique_ptr<std::vector<uint64_t>>& row_ids) const {
- size_t num_results = row_bitmap->cardinality();
- auto score_column = vectorized::ColumnFloat32::create(num_results);
- auto& score_data = score_column->get_data();
+
std::unique_ptr<std::vector<uint64_t>>& row_ids,
+ const ScoreRangeFilterPtr& filter)
const {
+ std::vector<float> filtered_scores;
+ filtered_scores.reserve(row_bitmap->cardinality());
- row_ids->resize(num_results);
+ roaring::Roaring new_bitmap;
- int32_t i = 0;
for (uint32_t row_id : *row_bitmap) {
- (*row_ids)[i] = row_id;
auto it = _bm25_scores.find(row_id);
- if (it != _bm25_scores.end()) {
- score_data[i] = it->second;
- } else {
- score_data[i] = 0.0;
+ float score = (it != _bm25_scores.end()) ? it->second : 0.0F;
+ if (filter && !filter->pass(score)) {
+ continue;
}
- i++;
+ row_ids->push_back(row_id);
+ filtered_scores.push_back(score);
+ new_bitmap.add(row_id);
}
+ size_t num_results = row_ids->size();
+ auto score_column = vectorized::ColumnFloat32::create(num_results);
+ if (num_results > 0) {
+ memcpy(score_column->get_data().data(), filtered_scores.data(),
+ num_results * sizeof(float));
+ }
+
+ *row_bitmap = std::move(new_bitmap);
auto null_map = vectorized::ColumnUInt8::create(num_results, 0);
scores = vectorized::ColumnNullable::create(std::move(score_column),
std::move(null_map));
}
@@ -55,23 +62,14 @@ void
CollectionSimilarity::get_bm25_scores(roaring::Roaring* row_bitmap,
void CollectionSimilarity::get_topn_bm25_scores(roaring::Roaring* row_bitmap,
vectorized::IColumn::MutablePtr& scores,
std::unique_ptr<std::vector<uint64_t>>& row_ids,
- OrderType order_type, size_t
top_k) const {
+ OrderType order_type, size_t
top_k,
+ const ScoreRangeFilterPtr&
filter) const {
std::vector<std::pair<uint32_t, float>> top_k_results;
if (order_type == OrderType::DESC) {
- find_top_k_scores<OrderType::DESC>(
- row_bitmap, _bm25_scores, top_k,
- [](const ScoreMapIterator& a, const ScoreMapIterator& b) {
- return a->second > b->second;
- },
- top_k_results);
+ find_top_k_scores<OrderType::DESC>(row_bitmap, _bm25_scores, top_k,
top_k_results, filter);
} else {
- find_top_k_scores<OrderType::ASC>(
- row_bitmap, _bm25_scores, top_k,
- [](const ScoreMapIterator& a, const ScoreMapIterator& b) {
- return a->second < b->second;
- },
- top_k_results);
+ find_top_k_scores<OrderType::ASC>(row_bitmap, _bm25_scores, top_k,
top_k_results, filter);
}
size_t num_results = top_k_results.size();
@@ -92,46 +90,65 @@ void
CollectionSimilarity::get_topn_bm25_scores(roaring::Roaring* row_bitmap,
scores = vectorized::ColumnNullable::create(std::move(score_column),
std::move(null_map));
}
-template <OrderType order, typename Compare>
-void CollectionSimilarity::find_top_k_scores(
- const roaring::Roaring* row_bitmap, const ScoreMap& all_scores, size_t
top_k, Compare comp,
- std::vector<std::pair<uint32_t, float>>& top_k_results) const {
+template <OrderType order>
+void CollectionSimilarity::find_top_k_scores(const roaring::Roaring*
row_bitmap,
+ const ScoreMap& all_scores,
size_t top_k,
+ std::vector<std::pair<uint32_t,
float>>& top_k_results,
+ const ScoreRangeFilterPtr&
filter) const {
if (top_k <= 0) {
return;
}
- std::priority_queue<ScoreMapIterator, std::vector<ScoreMapIterator>,
Compare> top_k_heap(comp);
+ auto pair_comp = [](const std::pair<uint32_t, float>& a, const
std::pair<uint32_t, float>& b) {
+ if constexpr (order == OrderType::DESC) {
+ return a.second > b.second;
+ } else {
+ return a.second < b.second;
+ }
+ };
+
+ std::priority_queue<std::pair<uint32_t, float>,
std::vector<std::pair<uint32_t, float>>,
+ decltype(pair_comp)>
+ top_k_heap(pair_comp);
std::vector<uint32_t> zero_score_ids;
+
for (uint32_t row_id : *row_bitmap) {
auto it = all_scores.find(row_id);
- if (it == all_scores.end()) {
+ float score = (it != all_scores.end()) ? it->second : 0.0F;
+
+ if (filter && !filter->pass(score)) {
+ continue;
+ }
+
+ if (score == 0.0F) {
zero_score_ids.push_back(row_id);
continue;
}
+
if (top_k_heap.size() < top_k) {
- top_k_heap.push(it);
- } else if (comp(it, top_k_heap.top())) {
+ top_k_heap.emplace(row_id, score);
+ } else if (pair_comp({row_id, score}, top_k_heap.top())) {
top_k_heap.pop();
- top_k_heap.push(it);
+ top_k_heap.emplace(row_id, score);
}
}
- top_k_results.reserve(top_k_heap.size());
+ top_k_results.reserve(top_k);
while (!top_k_heap.empty()) {
- auto top = top_k_heap.top();
- top_k_results.push_back({top->first, top->second});
+ top_k_results.push_back(top_k_heap.top());
top_k_heap.pop();
}
+ std::ranges::reverse(top_k_results);
if constexpr (order == OrderType::DESC) {
- std::ranges::reverse(top_k_results);
-
+ // DESC: high scores first, then zeros at the end
size_t remaining = top_k - top_k_results.size();
for (size_t i = 0; i < remaining && i < zero_score_ids.size(); ++i) {
top_k_results.emplace_back(zero_score_ids[i], 0.0F);
}
} else {
+ // ASC: zeros first, then low scores
std::vector<std::pair<uint32_t, float>> final_results;
final_results.reserve(top_k);
@@ -140,7 +157,6 @@ void CollectionSimilarity::find_top_k_scores(
final_results.emplace_back(zero_score_ids[i], 0.0F);
}
- std::ranges::reverse(top_k_results);
size_t remaining = top_k - final_results.size();
for (size_t i = 0; i < remaining && i < top_k_results.size(); ++i) {
final_results.push_back(top_k_results[i]);
diff --git a/be/src/olap/collection_similarity.h
b/be/src/olap/collection_similarity.h
index 2572660a1fb..2ae7b06921e 100644
--- a/be/src/olap/collection_similarity.h
+++ b/be/src/olap/collection_similarity.h
@@ -17,6 +17,7 @@
#pragma once
+#include "gen_cpp/Opcodes_types.h"
#include "rowset/segment_v2/common.h"
#include "vec/columns/column.h"
@@ -31,6 +32,16 @@ enum class OrderType {
DESC,
};
+struct ScoreRangeFilter {
+ TExprOpcode::type op;
+ double threshold;
+
+ bool pass(float score) const {
+ return (op == TExprOpcode::GT) ? (score > threshold) : (score >=
threshold);
+ }
+};
+using ScoreRangeFilterPtr = std::shared_ptr<ScoreRangeFilter>;
+
class CollectionSimilarity {
public:
CollectionSimilarity() { _bm25_scores.reserve(1024); }
@@ -39,17 +50,18 @@ public:
void collect(segment_v2::rowid_t row_id, float score);
void get_bm25_scores(roaring::Roaring* row_bitmap,
vectorized::IColumn::MutablePtr& scores,
- std::unique_ptr<std::vector<uint64_t>>& row_ids)
const;
+ std::unique_ptr<std::vector<uint64_t>>& row_ids,
+ const ScoreRangeFilterPtr& filter = nullptr) const;
void get_topn_bm25_scores(roaring::Roaring* row_bitmap,
vectorized::IColumn::MutablePtr& scores,
std::unique_ptr<std::vector<uint64_t>>& row_ids,
OrderType order_type,
- size_t top_k) const;
+ size_t top_k, const ScoreRangeFilterPtr& filter
= nullptr) const;
private:
- template <OrderType order, typename Compare>
+ template <OrderType order>
void find_top_k_scores(const roaring::Roaring* row_bitmap, const ScoreMap&
all_scores,
- size_t top_k, Compare comp,
- std::vector<std::pair<uint32_t, float>>&
top_k_results) const;
+ size_t top_k, std::vector<std::pair<uint32_t,
float>>& top_k_results,
+ const ScoreRangeFilterPtr& filter) const;
ScoreMap _bm25_scores;
};
diff --git a/be/src/olap/rowset/segment_v2/segment_iterator.cpp
b/be/src/olap/rowset/segment_v2/segment_iterator.cpp
index c4f3f7300c0..3f62eec7ef8 100644
--- a/be/src/olap/rowset/segment_v2/segment_iterator.cpp
+++ b/be/src/olap/rowset/segment_v2/segment_iterator.cpp
@@ -2687,16 +2687,19 @@ Status
SegmentIterator::_check_output_block(vectorized::Block* block) {
idx, block->columns(), _schema->num_column_ids(),
_virtual_column_exprs.size());
} else if (vectorized::check_and_get_column<vectorized::ColumnNothing>(
entry.column.get())) {
- std::vector<std::string> vcid_to_idx;
- for (const auto& pair : _vir_cid_to_idx_in_block) {
- vcid_to_idx.push_back(fmt::format("{}-{}", pair.first,
pair.second));
+ if (rows > 0) {
+ std::vector<std::string> vcid_to_idx;
+ for (const auto& pair : _vir_cid_to_idx_in_block) {
+ vcid_to_idx.push_back(fmt::format("{}-{}", pair.first,
pair.second));
+ }
+ std::string vir_cid_to_idx_in_block_msg =
+ fmt::format("_vir_cid_to_idx_in_block:[{}]",
fmt::join(vcid_to_idx, ","));
+ return Status::InternalError(
+ "Column in idx {} is nothing, block columns {},
normal_columns {}, "
+ "vir_cid_to_idx_in_block_msg {}",
+ idx, block->columns(), _schema->num_column_ids(),
+ vir_cid_to_idx_in_block_msg);
}
- std::string vir_cid_to_idx_in_block_msg =
- fmt::format("_vir_cid_to_idx_in_block:[{}]",
fmt::join(vcid_to_idx, ","));
- return Status::InternalError(
- "Column in idx {} is nothing, block columns {},
normal_columns {}, "
- "vir_cid_to_idx_in_block_msg {}",
- idx, block->columns(), _schema->num_column_ids(),
vir_cid_to_idx_in_block_msg);
} else if (entry.column->size() != rows) {
return Status::InternalError(
"Unmatched size {}, expected {}, column: {}, type: {},
idx_in_block: {}, "
@@ -3155,6 +3158,12 @@ void
SegmentIterator::_prepare_score_column_materialization() {
return;
}
+ ScoreRangeFilterPtr filter;
+ if (_score_runtime->has_score_range_filter()) {
+ const auto& range_info = _score_runtime->get_score_range_info();
+ filter = std::make_shared<ScoreRangeFilter>(range_info->op,
range_info->threshold);
+ }
+
vectorized::IColumn::MutablePtr result_column;
auto result_row_ids = std::make_unique<std::vector<uint64_t>>();
if (_score_runtime->get_limit() > 0 && _col_predicates.empty() &&
@@ -3162,10 +3171,10 @@ void
SegmentIterator::_prepare_score_column_materialization() {
OrderType order_type = _score_runtime->is_asc() ? OrderType::ASC :
OrderType::DESC;
_index_query_context->collection_similarity->get_topn_bm25_scores(
&_row_bitmap, result_column, result_row_ids, order_type,
- _score_runtime->get_limit());
+ _score_runtime->get_limit(), filter);
} else {
_index_query_context->collection_similarity->get_bm25_scores(&_row_bitmap,
result_column,
-
result_row_ids);
+
result_row_ids, filter);
}
const size_t dst_col_idx = _score_runtime->get_dest_column_idx();
auto* column_iter =
_column_iterators[_schema->column_id(dst_col_idx)].get();
diff --git a/be/src/pipeline/exec/olap_scan_operator.cpp
b/be/src/pipeline/exec/olap_scan_operator.cpp
index 03e5cf5a995..07c0522a4b5 100644
--- a/be/src/pipeline/exec/olap_scan_operator.cpp
+++ b/be/src/pipeline/exec/olap_scan_operator.cpp
@@ -79,6 +79,17 @@ Status OlapScanLocalState::init(RuntimeState* state,
LocalStateInfo& info) {
segment_v2::AnnTopNRuntime::create_shared(asc, limit,
ordering_expr_ctx);
}
+ // Parse score range filtering parameters and set to ScoreRuntime
+ if (olap_scan_node.__isset.score_range_info) {
+ const auto& score_range_info = olap_scan_node.score_range_info;
+ if (score_range_info.__isset.op && score_range_info.__isset.threshold)
{
+ if (_score_runtime) {
+ _score_runtime->set_score_range_info(score_range_info.op,
+
score_range_info.threshold);
+ }
+ }
+ }
+
RETURN_IF_ERROR(Base::init(state, info));
RETURN_IF_ERROR(_sync_cloud_tablets(state));
return Status::OK();
diff --git a/be/src/vec/exec/scan/olap_scanner.cpp
b/be/src/vec/exec/scan/olap_scanner.cpp
index 334838586c8..3778efc3019 100644
--- a/be/src/vec/exec/scan/olap_scanner.cpp
+++ b/be/src/vec/exec/scan/olap_scanner.cpp
@@ -153,8 +153,6 @@ Status OlapScanner::prepare() {
_slot_id_to_index_in_block = local_state->_slot_id_to_index_in_block;
_slot_id_to_col_type = local_state->_slot_id_to_col_type;
- _score_runtime = local_state->_score_runtime;
-
_score_runtime = local_state->_score_runtime;
// All scanners share the same ann_topn_runtime.
_ann_topn_runtime = local_state->_ann_topn_runtime;
diff --git a/be/src/vec/exec/scan/scanner_scheduler.cpp
b/be/src/vec/exec/scan/scanner_scheduler.cpp
index cb021e419e6..e28495802ae 100644
--- a/be/src/vec/exec/scan/scanner_scheduler.cpp
+++ b/be/src/vec/exec/scan/scanner_scheduler.cpp
@@ -382,6 +382,10 @@ void
ScannerScheduler::_make_sure_virtual_col_is_materialized(
return;
}
+ if (free_block->rows() == 0) {
+ return;
+ }
+
size_t idx = 0;
for (const auto& entry : *free_block) {
// Virtual column must be materialized on the end of SegmentIterator's
next batch method.
diff --git a/be/src/vec/exprs/score_runtime.h b/be/src/vec/exprs/score_runtime.h
index 99abb74ad73..5a951c682b8 100644
--- a/be/src/vec/exprs/score_runtime.h
+++ b/be/src/vec/exprs/score_runtime.h
@@ -17,6 +17,10 @@
#pragma once
+#include <gen_cpp/Exprs_types.h>
+
+#include <optional>
+
#include "vec/exprs/vexpr_context.h"
#include "vec/exprs/virtual_slot_ref.h"
@@ -27,6 +31,12 @@ class ScoreRuntime {
ENABLE_FACTORY_CREATOR(ScoreRuntime);
public:
+ // Score range filtering info for predicates like score() > 0.5
+ struct ScoreRangeInfo {
+ TExprOpcode::type op;
+ double threshold;
+ };
+
ScoreRuntime(VExprContextSPtr order_by_expr_ctx, bool asc, size_t limit)
: _order_by_expr_ctx(std::move(order_by_expr_ctx)), _asc(asc),
_limit(limit) {};
@@ -50,6 +60,15 @@ public:
bool is_asc() const { return _asc; }
size_t get_limit() const { return _limit; }
+ // Score range filtering methods
+ void set_score_range_info(TExprOpcode::type op, double threshold) {
+ _score_range_info = ScoreRangeInfo {.op = op, .threshold = threshold};
+ }
+
+ const std::optional<ScoreRangeInfo>& get_score_range_info() const { return
_score_range_info; }
+
+ bool has_score_range_filter() const { return
_score_range_info.has_value(); }
+
private:
VExprContextSPtr _order_by_expr_ctx;
const bool _asc = false;
@@ -57,6 +76,9 @@ private:
std::string _name = "score_runtime";
size_t _dest_column_idx = -1;
+
+ // Score range filtering info (e.g., score() > 0.5)
+ std::optional<ScoreRangeInfo> _score_range_info;
};
using ScoreRuntimeSPtr = std::shared_ptr<ScoreRuntime>;
diff --git a/be/test/olap/collection_similarity_test.cpp
b/be/test/olap/collection_similarity_test.cpp
index 061f48f26a5..2cb5800510b 100644
--- a/be/test/olap/collection_similarity_test.cpp
+++ b/be/test/olap/collection_similarity_test.cpp
@@ -222,4 +222,545 @@ TEST_F(CollectionSimilarityTest,
GetBm25ScoresEmptyBitmapTest) {
EXPECT_EQ(row_ids->size(), 0);
}
+// Tests for ScoreRangeFilter
+
+TEST_F(CollectionSimilarityTest, GetBm25ScoresWithFilterGTTest) {
+ similarity->collect(1, 0.3f);
+ similarity->collect(2, 0.6f);
+ similarity->collect(3, 0.9f);
+
+ roaring::Roaring bitmap = create_bitmap({1, 2, 3, 4}); // 4 has score 0
+ vectorized::IColumn::MutablePtr scores;
+ std::unique_ptr<std::vector<uint64_t>> row_ids =
std::make_unique<std::vector<uint64_t>>();
+
+ auto filter = std::make_shared<ScoreRangeFilter>(TExprOpcode::GT, 0.5);
+ similarity->get_bm25_scores(&bitmap, scores, row_ids, filter);
+
+ verify_scores(scores, {0.6f, 0.9f});
+ verify_row_ids(row_ids, {2, 3});
+ EXPECT_EQ(bitmap.cardinality(), 2);
+}
+
+TEST_F(CollectionSimilarityTest, GetBm25ScoresWithFilterGETest) {
+ similarity->collect(1, 0.3f);
+ similarity->collect(2, 0.5f);
+ similarity->collect(3, 0.9f);
+
+ roaring::Roaring bitmap = create_bitmap({1, 2, 3, 4});
+ vectorized::IColumn::MutablePtr scores;
+ std::unique_ptr<std::vector<uint64_t>> row_ids =
std::make_unique<std::vector<uint64_t>>();
+
+ auto filter = std::make_shared<ScoreRangeFilter>(TExprOpcode::GE, 0.5);
+ similarity->get_bm25_scores(&bitmap, scores, row_ids, filter);
+
+ verify_scores(scores, {0.5f, 0.9f});
+ verify_row_ids(row_ids, {2, 3});
+}
+
+TEST_F(CollectionSimilarityTest, GetBm25ScoresWithFilterZeroThresholdTest) {
+ similarity->collect(1, 0.3f);
+ similarity->collect(2, 0.6f);
+
+ roaring::Roaring bitmap = create_bitmap({1, 2, 3}); // 3 has score 0
+ vectorized::IColumn::MutablePtr scores;
+ std::unique_ptr<std::vector<uint64_t>> row_ids =
std::make_unique<std::vector<uint64_t>>();
+
+ // GT 0: should exclude score=0
+ auto filter = std::make_shared<ScoreRangeFilter>(TExprOpcode::GT, 0.0);
+ similarity->get_bm25_scores(&bitmap, scores, row_ids, filter);
+
+ verify_scores(scores, {0.3f, 0.6f});
+ verify_row_ids(row_ids, {1, 2});
+}
+
+TEST_F(CollectionSimilarityTest, GetBm25ScoresWithFilterGEZeroTest) {
+ similarity->collect(1, 0.3f);
+ similarity->collect(2, 0.6f);
+
+ roaring::Roaring bitmap = create_bitmap({1, 2, 3}); // 3 has score 0
+ vectorized::IColumn::MutablePtr scores;
+ std::unique_ptr<std::vector<uint64_t>> row_ids =
std::make_unique<std::vector<uint64_t>>();
+
+ // GE 0: should include score=0
+ auto filter = std::make_shared<ScoreRangeFilter>(TExprOpcode::GE, 0.0);
+ similarity->get_bm25_scores(&bitmap, scores, row_ids, filter);
+
+ verify_scores(scores, {0.3f, 0.6f, 0.0f});
+ verify_row_ids(row_ids, {1, 2, 3});
+}
+
+TEST_F(CollectionSimilarityTest, GetTopnBm25ScoresWithFilterDescTest) {
+ similarity->collect(1, 0.3f);
+ similarity->collect(2, 0.6f);
+ similarity->collect(3, 0.9f);
+ similarity->collect(4, 0.4f);
+
+ roaring::Roaring bitmap = create_bitmap({1, 2, 3, 4, 5}); // 5 has score 0
+ vectorized::IColumn::MutablePtr scores;
+ std::unique_ptr<std::vector<uint64_t>> row_ids =
std::make_unique<std::vector<uint64_t>>();
+
+ auto filter = std::make_shared<ScoreRangeFilter>(TExprOpcode::GT, 0.35);
+ similarity->get_topn_bm25_scores(&bitmap, scores, row_ids,
OrderType::DESC, 3, filter);
+
+ // Only 0.9, 0.6, 0.4 pass filter (> 0.35), top 3 DESC
+ verify_scores(scores, {0.9f, 0.6f, 0.4f});
+ verify_row_ids(row_ids, {3, 2, 4});
+}
+
+TEST_F(CollectionSimilarityTest, GetTopnBm25ScoresWithFilterAscTest) {
+ similarity->collect(1, 0.3f);
+ similarity->collect(2, 0.6f);
+ similarity->collect(3, 0.9f);
+
+ roaring::Roaring bitmap = create_bitmap({1, 2, 3, 4, 5}); // 4,5 have
score 0
+ vectorized::IColumn::MutablePtr scores;
+ std::unique_ptr<std::vector<uint64_t>> row_ids =
std::make_unique<std::vector<uint64_t>>();
+
+ // GE 0: includes zeros, ASC order puts zeros first
+ auto filter = std::make_shared<ScoreRangeFilter>(TExprOpcode::GE, 0.0);
+ similarity->get_topn_bm25_scores(&bitmap, scores, row_ids, OrderType::ASC,
3, filter);
+
+ // ASC: zeros first, then lowest scores
+ verify_scores(scores, {0.0f, 0.0f, 0.3f});
+ verify_row_ids(row_ids, {4, 5, 1});
+}
+
+TEST_F(CollectionSimilarityTest,
GetTopnBm25ScoresWithFilterExcludeZerosAscTest) {
+ similarity->collect(1, 0.3f);
+ similarity->collect(2, 0.6f);
+ similarity->collect(3, 0.9f);
+
+ roaring::Roaring bitmap = create_bitmap({1, 2, 3, 4, 5}); // 4,5 have
score 0
+ vectorized::IColumn::MutablePtr scores;
+ std::unique_ptr<std::vector<uint64_t>> row_ids =
std::make_unique<std::vector<uint64_t>>();
+
+ // GT 0: excludes zeros
+ auto filter = std::make_shared<ScoreRangeFilter>(TExprOpcode::GT, 0.0);
+ similarity->get_topn_bm25_scores(&bitmap, scores, row_ids, OrderType::ASC,
3, filter);
+
+ verify_scores(scores, {0.3f, 0.6f, 0.9f});
+ verify_row_ids(row_ids, {1, 2, 3});
+}
+
+TEST_F(CollectionSimilarityTest, GetTopnBm25ScoresWithFilterAllFilteredTest) {
+ similarity->collect(1, 0.3f);
+ similarity->collect(2, 0.4f);
+
+ roaring::Roaring bitmap = create_bitmap({1, 2, 3});
+ vectorized::IColumn::MutablePtr scores;
+ std::unique_ptr<std::vector<uint64_t>> row_ids =
std::make_unique<std::vector<uint64_t>>();
+
+ // Filter threshold too high, all filtered out
+ auto filter = std::make_shared<ScoreRangeFilter>(TExprOpcode::GT, 0.5);
+ similarity->get_topn_bm25_scores(&bitmap, scores, row_ids,
OrderType::DESC, 3, filter);
+
+ EXPECT_EQ(scores->size(), 0);
+ EXPECT_EQ(row_ids->size(), 0);
+}
+
+TEST_F(CollectionSimilarityTest, LargeDataGetBm25ScoresBasicTest) {
+ constexpr size_t NUM_ROWS = 100000;
+
+ for (size_t i = 0; i < NUM_ROWS; ++i) {
+ float score = static_cast<float>(i) / static_cast<float>(NUM_ROWS);
+ similarity->collect(static_cast<uint32_t>(i), score);
+ }
+
+ std::vector<uint32_t> all_ids;
+ all_ids.reserve(NUM_ROWS);
+ for (size_t i = 0; i < NUM_ROWS; ++i) {
+ all_ids.push_back(static_cast<uint32_t>(i));
+ }
+ roaring::Roaring bitmap = create_bitmap(all_ids);
+
+ vectorized::IColumn::MutablePtr scores;
+ std::unique_ptr<std::vector<uint64_t>> row_ids =
std::make_unique<std::vector<uint64_t>>();
+
+ similarity->get_bm25_scores(&bitmap, scores, row_ids);
+
+ EXPECT_EQ(scores->size(), NUM_ROWS);
+ EXPECT_EQ(row_ids->size(), NUM_ROWS);
+ EXPECT_EQ(bitmap.cardinality(), NUM_ROWS);
+}
+
+TEST_F(CollectionSimilarityTest, LargeDataGetBm25ScoresWithGTFilterTest) {
+ constexpr size_t NUM_ROWS = 100000;
+ constexpr double THRESHOLD = 0.5;
+
+ for (size_t i = 0; i < NUM_ROWS; ++i) {
+ float score = static_cast<float>(i) / static_cast<float>(NUM_ROWS);
+ similarity->collect(static_cast<uint32_t>(i), score);
+ }
+
+ std::vector<uint32_t> all_ids;
+ all_ids.reserve(NUM_ROWS);
+ for (size_t i = 0; i < NUM_ROWS; ++i) {
+ all_ids.push_back(static_cast<uint32_t>(i));
+ }
+ roaring::Roaring bitmap = create_bitmap(all_ids);
+
+ vectorized::IColumn::MutablePtr scores;
+ std::unique_ptr<std::vector<uint64_t>> row_ids =
std::make_unique<std::vector<uint64_t>>();
+
+ auto filter = std::make_shared<ScoreRangeFilter>(TExprOpcode::GT,
THRESHOLD);
+ similarity->get_bm25_scores(&bitmap, scores, row_ids, filter);
+
+ size_t expected_count = 0;
+ for (size_t i = 0; i < NUM_ROWS; ++i) {
+ float score = static_cast<float>(i) / static_cast<float>(NUM_ROWS);
+ if (score > THRESHOLD) {
+ expected_count++;
+ }
+ }
+
+ EXPECT_EQ(scores->size(), expected_count);
+ EXPECT_EQ(row_ids->size(), expected_count);
+ EXPECT_EQ(bitmap.cardinality(), expected_count);
+
+ auto* nullable_column =
dynamic_cast<vectorized::ColumnNullable*>(scores.get());
+ ASSERT_NE(nullable_column, nullptr);
+ const auto* float_column =
+ dynamic_cast<const
vectorized::ColumnFloat32*>(&nullable_column->get_nested_column());
+ ASSERT_NE(float_column, nullptr);
+ const auto& data = float_column->get_data();
+ for (size_t i = 0; i < data.size(); ++i) {
+ EXPECT_GT(data[i], THRESHOLD);
+ }
+}
+
+TEST_F(CollectionSimilarityTest, LargeDataGetBm25ScoresWithGEFilterTest) {
+ constexpr size_t NUM_ROWS = 100000;
+ constexpr double THRESHOLD = 0.5;
+
+ for (size_t i = 0; i < NUM_ROWS; ++i) {
+ float score = static_cast<float>(i) / static_cast<float>(NUM_ROWS);
+ similarity->collect(static_cast<uint32_t>(i), score);
+ }
+
+ std::vector<uint32_t> all_ids;
+ all_ids.reserve(NUM_ROWS);
+ for (size_t i = 0; i < NUM_ROWS; ++i) {
+ all_ids.push_back(static_cast<uint32_t>(i));
+ }
+ roaring::Roaring bitmap = create_bitmap(all_ids);
+
+ vectorized::IColumn::MutablePtr scores;
+ std::unique_ptr<std::vector<uint64_t>> row_ids =
std::make_unique<std::vector<uint64_t>>();
+
+ auto filter = std::make_shared<ScoreRangeFilter>(TExprOpcode::GE,
THRESHOLD);
+ similarity->get_bm25_scores(&bitmap, scores, row_ids, filter);
+
+ size_t expected_count = 0;
+ for (size_t i = 0; i < NUM_ROWS; ++i) {
+ float score = static_cast<float>(i) / static_cast<float>(NUM_ROWS);
+ if (score >= THRESHOLD) {
+ expected_count++;
+ }
+ }
+
+ EXPECT_EQ(scores->size(), expected_count);
+ EXPECT_EQ(row_ids->size(), expected_count);
+
+ auto* nullable_column =
dynamic_cast<vectorized::ColumnNullable*>(scores.get());
+ ASSERT_NE(nullable_column, nullptr);
+ const auto* float_column =
+ dynamic_cast<const
vectorized::ColumnFloat32*>(&nullable_column->get_nested_column());
+ ASSERT_NE(float_column, nullptr);
+ const auto& data = float_column->get_data();
+ for (size_t i = 0; i < data.size(); ++i) {
+ EXPECT_GE(data[i], THRESHOLD);
+ }
+}
+
+TEST_F(CollectionSimilarityTest, LargeDataTopNDescTest) {
+ constexpr size_t NUM_ROWS = 100000;
+ constexpr size_t TOP_K = 100;
+
+ for (size_t i = 0; i < NUM_ROWS; ++i) {
+ float score = static_cast<float>(i) / static_cast<float>(NUM_ROWS);
+ similarity->collect(static_cast<uint32_t>(i), score);
+ }
+
+ std::vector<uint32_t> all_ids;
+ all_ids.reserve(NUM_ROWS);
+ for (size_t i = 0; i < NUM_ROWS; ++i) {
+ all_ids.push_back(static_cast<uint32_t>(i));
+ }
+ roaring::Roaring bitmap = create_bitmap(all_ids);
+
+ vectorized::IColumn::MutablePtr scores;
+ std::unique_ptr<std::vector<uint64_t>> row_ids =
std::make_unique<std::vector<uint64_t>>();
+
+ similarity->get_topn_bm25_scores(&bitmap, scores, row_ids,
OrderType::DESC, TOP_K);
+
+ EXPECT_EQ(scores->size(), TOP_K);
+ EXPECT_EQ(row_ids->size(), TOP_K);
+
+ auto* nullable_column =
dynamic_cast<vectorized::ColumnNullable*>(scores.get());
+ ASSERT_NE(nullable_column, nullptr);
+ const auto* float_column =
+ dynamic_cast<const
vectorized::ColumnFloat32*>(&nullable_column->get_nested_column());
+ ASSERT_NE(float_column, nullptr);
+ const auto& data = float_column->get_data();
+
+ for (size_t i = 1; i < data.size(); ++i) {
+ EXPECT_GE(data[i - 1], data[i]) << "DESC order violated at index " <<
i;
+ }
+
+ float expected_max = static_cast<float>(NUM_ROWS - 1) /
static_cast<float>(NUM_ROWS);
+ EXPECT_FLOAT_EQ(data[0], expected_max);
+}
+
+TEST_F(CollectionSimilarityTest, LargeDataTopNAscTest) {
+ constexpr size_t NUM_ROWS = 100000;
+ constexpr size_t TOP_K = 100;
+
+ for (size_t i = 1; i <= NUM_ROWS; ++i) {
+ float score = static_cast<float>(i) / static_cast<float>(NUM_ROWS + 1);
+ similarity->collect(static_cast<uint32_t>(i), score);
+ }
+
+ std::vector<uint32_t> all_ids;
+ all_ids.reserve(NUM_ROWS);
+ for (size_t i = 1; i <= NUM_ROWS; ++i) {
+ all_ids.push_back(static_cast<uint32_t>(i));
+ }
+ roaring::Roaring bitmap = create_bitmap(all_ids);
+
+ vectorized::IColumn::MutablePtr scores;
+ std::unique_ptr<std::vector<uint64_t>> row_ids =
std::make_unique<std::vector<uint64_t>>();
+
+ similarity->get_topn_bm25_scores(&bitmap, scores, row_ids, OrderType::ASC,
TOP_K);
+
+ EXPECT_EQ(scores->size(), TOP_K);
+ EXPECT_EQ(row_ids->size(), TOP_K);
+
+ auto* nullable_column =
dynamic_cast<vectorized::ColumnNullable*>(scores.get());
+ ASSERT_NE(nullable_column, nullptr);
+ const auto* float_column =
+ dynamic_cast<const
vectorized::ColumnFloat32*>(&nullable_column->get_nested_column());
+ ASSERT_NE(float_column, nullptr);
+ const auto& data = float_column->get_data();
+
+ for (size_t i = 1; i < data.size(); ++i) {
+ EXPECT_LE(data[i - 1], data[i]) << "ASC order violated at index " << i;
+ }
+
+ float expected_min = 1.0f / static_cast<float>(NUM_ROWS + 1);
+ EXPECT_FLOAT_EQ(data[0], expected_min);
+}
+
+TEST_F(CollectionSimilarityTest, LargeDataTopNDescWithFilterTest) {
+ constexpr size_t NUM_ROWS = 100000;
+ constexpr size_t TOP_K = 50;
+ constexpr double THRESHOLD = 0.8;
+
+ for (size_t i = 0; i < NUM_ROWS; ++i) {
+ float score = static_cast<float>(i) / static_cast<float>(NUM_ROWS);
+ similarity->collect(static_cast<uint32_t>(i), score);
+ }
+
+ std::vector<uint32_t> all_ids;
+ all_ids.reserve(NUM_ROWS);
+ for (size_t i = 0; i < NUM_ROWS; ++i) {
+ all_ids.push_back(static_cast<uint32_t>(i));
+ }
+ roaring::Roaring bitmap = create_bitmap(all_ids);
+
+ vectorized::IColumn::MutablePtr scores;
+ std::unique_ptr<std::vector<uint64_t>> row_ids =
std::make_unique<std::vector<uint64_t>>();
+
+ auto filter = std::make_shared<ScoreRangeFilter>(TExprOpcode::GT,
THRESHOLD);
+ similarity->get_topn_bm25_scores(&bitmap, scores, row_ids,
OrderType::DESC, TOP_K, filter);
+
+ EXPECT_EQ(scores->size(), TOP_K);
+ EXPECT_EQ(row_ids->size(), TOP_K);
+
+ auto* nullable_column =
dynamic_cast<vectorized::ColumnNullable*>(scores.get());
+ ASSERT_NE(nullable_column, nullptr);
+ const auto* float_column =
+ dynamic_cast<const
vectorized::ColumnFloat32*>(&nullable_column->get_nested_column());
+ ASSERT_NE(float_column, nullptr);
+ const auto& data = float_column->get_data();
+
+ for (size_t i = 0; i < data.size(); ++i) {
+ EXPECT_GT(data[i], THRESHOLD);
+ }
+ for (size_t i = 1; i < data.size(); ++i) {
+ EXPECT_GE(data[i - 1], data[i]) << "DESC order violated at index " <<
i;
+ }
+}
+
+TEST_F(CollectionSimilarityTest, LargeDataTopKExceedsFilteredCountTest) {
+ constexpr size_t NUM_ROWS = 10000;
+ constexpr size_t TOP_K = 5000;
+ constexpr double THRESHOLD = 0.9;
+
+ for (size_t i = 0; i < NUM_ROWS; ++i) {
+ float score = static_cast<float>(i) / static_cast<float>(NUM_ROWS);
+ similarity->collect(static_cast<uint32_t>(i), score);
+ }
+
+ std::vector<uint32_t> all_ids;
+ all_ids.reserve(NUM_ROWS);
+ for (size_t i = 0; i < NUM_ROWS; ++i) {
+ all_ids.push_back(static_cast<uint32_t>(i));
+ }
+ roaring::Roaring bitmap = create_bitmap(all_ids);
+
+ vectorized::IColumn::MutablePtr scores;
+ std::unique_ptr<std::vector<uint64_t>> row_ids =
std::make_unique<std::vector<uint64_t>>();
+
+ auto filter = std::make_shared<ScoreRangeFilter>(TExprOpcode::GT,
THRESHOLD);
+ similarity->get_topn_bm25_scores(&bitmap, scores, row_ids,
OrderType::DESC, TOP_K, filter);
+
+ size_t expected_count = 0;
+ for (size_t i = 0; i < NUM_ROWS; ++i) {
+ float score = static_cast<float>(i) / static_cast<float>(NUM_ROWS);
+ if (score > THRESHOLD) {
+ expected_count++;
+ }
+ }
+
+ EXPECT_EQ(scores->size(), expected_count);
+ EXPECT_EQ(row_ids->size(), expected_count);
+
+ auto* nullable_column =
dynamic_cast<vectorized::ColumnNullable*>(scores.get());
+ ASSERT_NE(nullable_column, nullptr);
+ const auto* float_column =
+ dynamic_cast<const
vectorized::ColumnFloat32*>(&nullable_column->get_nested_column());
+ ASSERT_NE(float_column, nullptr);
+ const auto& data = float_column->get_data();
+ for (size_t i = 0; i < data.size(); ++i) {
+ EXPECT_GT(data[i], THRESHOLD);
+ }
+}
+
+TEST_F(CollectionSimilarityTest, LargeDataSparseBitmapTest) {
+ constexpr size_t NUM_SCORED_ROWS = 1000;
+ constexpr size_t BITMAP_SIZE = 100000;
+ constexpr size_t TOP_K = 100;
+
+ for (size_t i = 0; i < NUM_SCORED_ROWS; ++i) {
+ float score = static_cast<float>(i + 1) /
static_cast<float>(NUM_SCORED_ROWS);
+ similarity->collect(static_cast<uint32_t>(i * 100), score); // 稀疏分布
+ }
+
+ std::vector<uint32_t> all_ids;
+ all_ids.reserve(BITMAP_SIZE);
+ for (size_t i = 0; i < BITMAP_SIZE; ++i) {
+ all_ids.push_back(static_cast<uint32_t>(i));
+ }
+ roaring::Roaring bitmap = create_bitmap(all_ids);
+
+ vectorized::IColumn::MutablePtr scores;
+ std::unique_ptr<std::vector<uint64_t>> row_ids =
std::make_unique<std::vector<uint64_t>>();
+
+ auto filter = std::make_shared<ScoreRangeFilter>(TExprOpcode::GT, 0.0);
+ similarity->get_topn_bm25_scores(&bitmap, scores, row_ids,
OrderType::DESC, TOP_K, filter);
+
+ EXPECT_EQ(scores->size(), TOP_K);
+ EXPECT_EQ(row_ids->size(), TOP_K);
+
+ auto* nullable_column =
dynamic_cast<vectorized::ColumnNullable*>(scores.get());
+ ASSERT_NE(nullable_column, nullptr);
+ const auto* float_column =
+ dynamic_cast<const
vectorized::ColumnFloat32*>(&nullable_column->get_nested_column());
+ ASSERT_NE(float_column, nullptr);
+ const auto& data = float_column->get_data();
+ for (size_t i = 0; i < data.size(); ++i) {
+ EXPECT_GT(data[i], 0.0f);
+ }
+}
+
+TEST_F(CollectionSimilarityTest, LargeDataScoreAccumulationTest) {
+ constexpr size_t NUM_ROWS = 10000;
+ constexpr size_t ACCUMULATIONS_PER_ROW = 5;
+
+ for (size_t acc = 0; acc < ACCUMULATIONS_PER_ROW; ++acc) {
+ for (size_t i = 0; i < NUM_ROWS; ++i) {
+ float score = 0.1f;
+ similarity->collect(static_cast<uint32_t>(i), score);
+ }
+ }
+
+ std::vector<uint32_t> all_ids;
+ all_ids.reserve(NUM_ROWS);
+ for (size_t i = 0; i < NUM_ROWS; ++i) {
+ all_ids.push_back(static_cast<uint32_t>(i));
+ }
+ roaring::Roaring bitmap = create_bitmap(all_ids);
+
+ vectorized::IColumn::MutablePtr scores;
+ std::unique_ptr<std::vector<uint64_t>> row_ids =
std::make_unique<std::vector<uint64_t>>();
+
+ similarity->get_bm25_scores(&bitmap, scores, row_ids);
+
+ EXPECT_EQ(scores->size(), NUM_ROWS);
+
+ auto* nullable_column =
dynamic_cast<vectorized::ColumnNullable*>(scores.get());
+ ASSERT_NE(nullable_column, nullptr);
+ const auto* float_column =
+ dynamic_cast<const
vectorized::ColumnFloat32*>(&nullable_column->get_nested_column());
+ ASSERT_NE(float_column, nullptr);
+ const auto& data = float_column->get_data();
+
+ float expected_score = 0.1f * ACCUMULATIONS_PER_ROW;
+ for (size_t i = 0; i < data.size(); ++i) {
+ EXPECT_FLOAT_EQ(data[i], expected_score);
+ }
+}
+
+TEST_F(CollectionSimilarityTest, LargeDataBoundaryThresholdTest) {
+ constexpr size_t NUM_ROWS = 10000;
+ constexpr double THRESHOLD = 0.5;
+
+ for (size_t i = 0; i < NUM_ROWS; ++i) {
+ float score = static_cast<float>(i) / static_cast<float>(NUM_ROWS);
+ similarity->collect(static_cast<uint32_t>(i), score);
+ }
+
+ similarity->collect(static_cast<uint32_t>(NUM_ROWS),
static_cast<float>(THRESHOLD));
+
+ std::vector<uint32_t> all_ids;
+ all_ids.reserve(NUM_ROWS + 1);
+ for (size_t i = 0; i <= NUM_ROWS; ++i) {
+ all_ids.push_back(static_cast<uint32_t>(i));
+ }
+ roaring::Roaring bitmap = create_bitmap(all_ids);
+
+ {
+ vectorized::IColumn::MutablePtr scores;
+ std::unique_ptr<std::vector<uint64_t>> row_ids =
std::make_unique<std::vector<uint64_t>>();
+ auto filter = std::make_shared<ScoreRangeFilter>(TExprOpcode::GT,
THRESHOLD);
+ similarity->get_bm25_scores(&bitmap, scores, row_ids, filter);
+
+ bool found_threshold = false;
+ for (size_t i = 0; i < row_ids->size(); ++i) {
+ if ((*row_ids)[i] == NUM_ROWS) {
+ found_threshold = true;
+ break;
+ }
+ }
+ EXPECT_FALSE(found_threshold) << "GT filter should not include score
== threshold";
+ }
+
+ {
+ vectorized::IColumn::MutablePtr scores;
+ std::unique_ptr<std::vector<uint64_t>> row_ids =
std::make_unique<std::vector<uint64_t>>();
+ roaring::Roaring bitmap2 = create_bitmap(all_ids);
+ auto filter = std::make_shared<ScoreRangeFilter>(TExprOpcode::GE,
THRESHOLD);
+ similarity->get_bm25_scores(&bitmap2, scores, row_ids, filter);
+
+ bool found_threshold = false;
+ for (size_t i = 0; i < row_ids->size(); ++i) {
+ if ((*row_ids)[i] == NUM_ROWS) {
+ found_threshold = true;
+ break;
+ }
+ }
+ EXPECT_TRUE(found_threshold) << "GE filter should include score ==
threshold";
+ }
+}
+
} // namespace doris
\ No newline at end of file
diff --git a/be/test/vec/exprs/score_runtime_test.cpp
b/be/test/vec/exprs/score_runtime_test.cpp
new file mode 100644
index 00000000000..c0466c40433
--- /dev/null
+++ b/be/test/vec/exprs/score_runtime_test.cpp
@@ -0,0 +1,259 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "vec/exprs/score_runtime.h"
+
+#include <gtest/gtest.h>
+#include <sys/resource.h>
+
+#include <memory>
+
+#include "gen_cpp/Exprs_types.h"
+#include "gen_cpp/Opcodes_types.h"
+#include "vec/exprs/vexpr_context.h"
+#include "vec/exprs/virtual_slot_ref.h"
+
+namespace doris::vectorized {
+
+namespace {
+
+class DummyExpr : public VExpr {
+public:
+ DummyExpr() { set_node_type(TExprNodeType::COMPOUND_PRED); }
+
+ const std::string& expr_name() const override {
+ static const std::string kName = "DummyExpr";
+ return kName;
+ }
+
+ Status execute(VExprContext*, Block*, int*) const override { return
Status::OK(); }
+ Status execute_column(VExprContext* context, const Block* block, Selector*
selector,
+ size_t count, ColumnPtr& result_column) const
override {
+ return Status::OK();
+ }
+};
+
+VExprContextSPtr make_virtual_slot_context(int column_id) {
+ TExprNode node;
+ node.node_type = TExprNodeType::SLOT_REF;
+ TTypeDesc type_desc;
+ TTypeNode type_node;
+ type_node.type = TTypeNodeType::SCALAR;
+ TScalarType scalar_type;
+ scalar_type.__set_type(TPrimitiveType::DOUBLE);
+ type_node.__set_scalar_type(scalar_type);
+ type_desc.types.push_back(type_node);
+ node.__set_type(type_desc);
+ node.num_children = 0;
+
+ TSlotRef slot_ref;
+ slot_ref.slot_id = -1;
+ node.__set_slot_ref(slot_ref);
+ node.__set_label("score");
+
+ auto vir_slot = std::make_shared<VirtualSlotRef>(node);
+ vir_slot->set_column_id(column_id);
+ vir_slot->_virtual_column_expr = std::make_shared<DummyExpr>();
+ return
std::make_shared<VExprContext>(std::static_pointer_cast<VExpr>(vir_slot));
+}
+
+VExprContextSPtr make_dummy_context() {
+ auto dummy = std::make_shared<DummyExpr>();
+ return
std::make_shared<VExprContext>(std::static_pointer_cast<VExpr>(dummy));
+}
+
+} // namespace
+
+class ScoreRuntimeTest : public testing::Test {};
+
+TEST_F(ScoreRuntimeTest, ConstructorSetsFields) {
+ auto ctx = make_virtual_slot_context(3);
+ auto runtime = ScoreRuntime::create_shared(ctx, true, 100);
+
+ EXPECT_TRUE(runtime->is_asc());
+ EXPECT_EQ(100, runtime->get_limit());
+}
+
+TEST_F(ScoreRuntimeTest, ConstructorDescOrder) {
+ auto ctx = make_virtual_slot_context(0);
+ auto runtime = ScoreRuntime::create_shared(ctx, false, 50);
+
+ EXPECT_FALSE(runtime->is_asc());
+ EXPECT_EQ(50, runtime->get_limit());
+}
+
+TEST_F(ScoreRuntimeTest, ConstructorZeroLimit) {
+ auto ctx = make_virtual_slot_context(0);
+ auto runtime = ScoreRuntime::create_shared(ctx, true, 0);
+
+ EXPECT_EQ(0, runtime->get_limit());
+}
+
+TEST_F(ScoreRuntimeTest, NoScoreRangeByDefault) {
+ auto ctx = make_virtual_slot_context(0);
+ auto runtime = ScoreRuntime::create_shared(ctx, false, 10);
+
+ EXPECT_FALSE(runtime->has_score_range_filter());
+ EXPECT_FALSE(runtime->get_score_range_info().has_value());
+}
+
+TEST_F(ScoreRuntimeTest, SetScoreRangeGT) {
+ auto ctx = make_virtual_slot_context(0);
+ auto runtime = ScoreRuntime::create_shared(ctx, false, 10);
+
+ runtime->set_score_range_info(TExprOpcode::GT, 0.5);
+
+ EXPECT_TRUE(runtime->has_score_range_filter());
+ ASSERT_TRUE(runtime->get_score_range_info().has_value());
+ EXPECT_EQ(TExprOpcode::GT, runtime->get_score_range_info()->op);
+ EXPECT_DOUBLE_EQ(0.5, runtime->get_score_range_info()->threshold);
+}
+
+TEST_F(ScoreRuntimeTest, SetScoreRangeGE) {
+ auto ctx = make_virtual_slot_context(0);
+ auto runtime = ScoreRuntime::create_shared(ctx, false, 10);
+
+ runtime->set_score_range_info(TExprOpcode::GE, 1.0);
+
+ ASSERT_TRUE(runtime->has_score_range_filter());
+ EXPECT_EQ(TExprOpcode::GE, runtime->get_score_range_info()->op);
+ EXPECT_DOUBLE_EQ(1.0, runtime->get_score_range_info()->threshold);
+}
+
+TEST_F(ScoreRuntimeTest, SetScoreRangeLT) {
+ auto ctx = make_virtual_slot_context(0);
+ auto runtime = ScoreRuntime::create_shared(ctx, false, 10);
+
+ runtime->set_score_range_info(TExprOpcode::LT, 3.14);
+
+ ASSERT_TRUE(runtime->has_score_range_filter());
+ EXPECT_EQ(TExprOpcode::LT, runtime->get_score_range_info()->op);
+ EXPECT_DOUBLE_EQ(3.14, runtime->get_score_range_info()->threshold);
+}
+
+TEST_F(ScoreRuntimeTest, SetScoreRangeLE) {
+ auto ctx = make_virtual_slot_context(0);
+ auto runtime = ScoreRuntime::create_shared(ctx, false, 10);
+
+ runtime->set_score_range_info(TExprOpcode::LE, 99.9);
+
+ ASSERT_TRUE(runtime->has_score_range_filter());
+ EXPECT_EQ(TExprOpcode::LE, runtime->get_score_range_info()->op);
+ EXPECT_DOUBLE_EQ(99.9, runtime->get_score_range_info()->threshold);
+}
+
+TEST_F(ScoreRuntimeTest, SetScoreRangeEQ) {
+ auto ctx = make_virtual_slot_context(0);
+ auto runtime = ScoreRuntime::create_shared(ctx, false, 10);
+
+ runtime->set_score_range_info(TExprOpcode::EQ, 0.0);
+
+ ASSERT_TRUE(runtime->has_score_range_filter());
+ EXPECT_EQ(TExprOpcode::EQ, runtime->get_score_range_info()->op);
+ EXPECT_DOUBLE_EQ(0.0, runtime->get_score_range_info()->threshold);
+}
+
+TEST_F(ScoreRuntimeTest, SetScoreRangeNegativeThreshold) {
+ auto ctx = make_virtual_slot_context(0);
+ auto runtime = ScoreRuntime::create_shared(ctx, false, 10);
+
+ runtime->set_score_range_info(TExprOpcode::GT, -1.5);
+
+ ASSERT_TRUE(runtime->has_score_range_filter());
+ EXPECT_DOUBLE_EQ(-1.5, runtime->get_score_range_info()->threshold);
+}
+
+TEST_F(ScoreRuntimeTest, OverwriteScoreRangeInfo) {
+ auto ctx = make_virtual_slot_context(0);
+ auto runtime = ScoreRuntime::create_shared(ctx, false, 10);
+
+ runtime->set_score_range_info(TExprOpcode::GT, 0.5);
+ EXPECT_EQ(TExprOpcode::GT, runtime->get_score_range_info()->op);
+ EXPECT_DOUBLE_EQ(0.5, runtime->get_score_range_info()->threshold);
+
+ runtime->set_score_range_info(TExprOpcode::LE, 2.0);
+ EXPECT_EQ(TExprOpcode::LE, runtime->get_score_range_info()->op);
+ EXPECT_DOUBLE_EQ(2.0, runtime->get_score_range_info()->threshold);
+}
+
+TEST_F(ScoreRuntimeTest, PrepareSuccessWithVirtualSlotRef) {
+ const int column_id = 5;
+ auto ctx = make_virtual_slot_context(column_id);
+ auto runtime = ScoreRuntime::create_shared(ctx, true, 10);
+
+ RuntimeState state;
+ RowDescriptor row_desc;
+ Status st = runtime->prepare(&state, row_desc);
+ ASSERT_TRUE(st.ok()) << st.to_string();
+ EXPECT_EQ(column_id, runtime->get_dest_column_idx());
+}
+
+TEST_F(ScoreRuntimeTest, PrepareSuccessColumnIdZero) {
+ auto ctx = make_virtual_slot_context(0);
+ auto runtime = ScoreRuntime::create_shared(ctx, false, 20);
+
+ RuntimeState state;
+ RowDescriptor row_desc;
+ Status st = runtime->prepare(&state, row_desc);
+ ASSERT_TRUE(st.ok()) << st.to_string();
+ EXPECT_EQ(0, runtime->get_dest_column_idx());
+}
+
+TEST_F(ScoreRuntimeTest, PrepareFailsWhenRootIsNotVirtualSlotRef) {
+ GTEST_FLAG_SET(death_test_style, "threadsafe");
+ auto ctx = make_dummy_context();
+ auto runtime = ScoreRuntime::create_shared(ctx, false, 10);
+
+ RuntimeState state;
+ RowDescriptor row_desc;
+ ASSERT_DEATH(({
+ struct rlimit core_limit;
+ core_limit.rlim_cur = 0;
+ core_limit.rlim_max = 0;
+ setrlimit(RLIMIT_CORE, &core_limit);
+ auto st = runtime->prepare(&state, row_desc);
+ }),
+ "VirtualSlotRef");
+}
+
+TEST_F(ScoreRuntimeTest, CreateSharedReturnsValidPtr) {
+ auto ctx = make_virtual_slot_context(0);
+ auto runtime = ScoreRuntime::create_shared(ctx, true, 1);
+ ASSERT_NE(nullptr, runtime);
+}
+
+TEST_F(ScoreRuntimeTest, FullWorkflow) {
+ const int column_id = 7;
+ auto ctx = make_virtual_slot_context(column_id);
+ auto runtime = ScoreRuntime::create_shared(ctx, false, 42);
+
+ EXPECT_FALSE(runtime->is_asc());
+ EXPECT_EQ(42, runtime->get_limit());
+ EXPECT_FALSE(runtime->has_score_range_filter());
+
+ RuntimeState state;
+ RowDescriptor row_desc;
+ ASSERT_TRUE(runtime->prepare(&state, row_desc).ok());
+ EXPECT_EQ(column_id, runtime->get_dest_column_idx());
+
+ runtime->set_score_range_info(TExprOpcode::GT, 0.8);
+ EXPECT_TRUE(runtime->has_score_range_filter());
+ EXPECT_EQ(TExprOpcode::GT, runtime->get_score_range_info()->op);
+ EXPECT_DOUBLE_EQ(0.8, runtime->get_score_range_info()->threshold);
+}
+
+} // namespace doris::vectorized
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
index 3f87d1a856b..4fd59cdf933 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
@@ -893,6 +893,9 @@ public class PhysicalPlanTranslator extends
DefaultPlanVisitor<PlanFragment, Pla
if (olapScan.getScoreLimit().isPresent()) {
olapScanNode.setScoreSortLimit(olapScan.getScoreLimit().get());
}
+ if (olapScan.getScoreRangeInfo().isPresent()) {
+ olapScanNode.setScoreRangeInfo(olapScan.getScoreRangeInfo().get());
+ }
// translate ann topn info
if (!olapScan.getAnnOrderKeys().isEmpty()) {
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalOlapScanToPhysicalOlapScan.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalOlapScanToPhysicalOlapScan.java
index 7404f01786f..20d08e257a8 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalOlapScanToPhysicalOlapScan.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalOlapScanToPhysicalOlapScan.java
@@ -68,6 +68,7 @@ public class LogicalOlapScanToPhysicalOlapScan extends
OneImplementationRuleFact
olapScan.getVirtualColumns(),
olapScan.getScoreOrderKeys(),
olapScan.getScoreLimit(),
+ olapScan.getScoreRangeInfo(),
olapScan.getAnnOrderKeys(),
olapScan.getAnnLimit(),
olapScan.getTableAlias())
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownScoreTopNIntoOlapScan.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownScoreTopNIntoOlapScan.java
index adb75495d46..8120cceffbe 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownScoreTopNIntoOlapScan.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownScoreTopNIntoOlapScan.java
@@ -21,19 +21,32 @@ import
org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Cast;
+import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.GreaterThan;
+import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
+import org.apache.doris.nereids.trees.expressions.LessThan;
+import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.Match;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Score;
+import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
+import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.FloatLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.ScoreRangeInfo;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.thrift.TExprOpcode;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Maps;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
@@ -41,6 +54,7 @@ import org.apache.logging.log4j.Logger;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.Set;
/**
* Push down score function into olap scan node.
@@ -56,15 +70,19 @@ import java.util.Optional;
* Score function.
* 3. The Filter node must contain at least one Match function.
*
+ * Additionally, this rule now supports score range predicates in WHERE clause:
+ * - score() > X, score() >= X, score() < X, score() <= X
+ * These predicates are extracted and pushed down to the scan node.
+ *
* Example:
* Before:
- * SELECT score() as score FROM table WHERE text_col MATCH 'query' ORDER BY
- * score DESC LIMIT 10
+ * SELECT score() as score FROM table WHERE text_col MATCH 'query' AND score()
> 0.5
+ * ORDER BY score DESC LIMIT 10
*
* After:
* The Score function is pushed down into the OlapScan node as a virtual
column,
* and the TopN information (order by, limit) is also pushed down to be used by
- * the storage engine.
+ * the storage engine. The score range predicate is also extracted.
*/
public class PushDownScoreTopNIntoOlapScan implements RewriteRuleFactory {
private static final Logger LOG =
LogManager.getLogger(PushDownScoreTopNIntoOlapScan.class);
@@ -107,12 +125,28 @@ public class PushDownScoreTopNIntoOlapScan implements
RewriteRuleFactory {
+ " for score() push down optimization");
}
- // 3. Requirement: WHERE clause must NOT contain any score() function.
- boolean hasScorePredicate = filter.getConjuncts().stream()
- .anyMatch(conjunct -> !conjunct.collect(e -> e instanceof
Score).isEmpty());
- if (hasScorePredicate) {
- throw new AnalysisException(
- "score() function can only be used in SELECT clause, not
in WHERE clause");
+ // 3. Check for score() predicates in WHERE clause and extract score
range info
+ List<Expression> scorePredicates = filter.getConjuncts().stream()
+ .filter(conjunct -> !conjunct.collect(e -> e instanceof
Score).isEmpty())
+ .collect(ImmutableList.toImmutableList());
+
+ Optional<ScoreRangeInfo> scoreRangeInfo = Optional.empty();
+ Expression extractedScorePredicate = null;
+ if (!scorePredicates.isEmpty()) {
+ if (scorePredicates.size() > 1) {
+ throw new AnalysisException(
+ "Only one score() range predicate is supported in
WHERE clause. "
+ + "Found " + scorePredicates.size() + "
predicates: " + scorePredicates);
+ }
+ Expression predicate = scorePredicates.get(0);
+ scoreRangeInfo = extractScoreRangeInfo(predicate);
+ if (!scoreRangeInfo.isPresent()) {
+ throw new AnalysisException(
+ "score() predicate in WHERE clause must be in the form
of 'score() > literal' "
+ + "or 'score() >= literal' (min_score
semantics). "
+ + "Operators <, <=, and = are not supported.");
+ }
+ extractedScorePredicate = predicate;
}
// 4. Requirement: TopN must have exactly one ordering expression.
@@ -157,7 +191,8 @@ public class PushDownScoreTopNIntoOlapScan implements
RewriteRuleFactory {
// topN info.
Plan newScan =
scan.withVirtualColumnsAndTopN(ImmutableList.of(scoreAlias),
ImmutableList.of(), Optional.empty(),
- topN.getOrderKeys(), Optional.of(topN.getLimit() +
topN.getOffset()));
+ topN.getOrderKeys(), Optional.of(topN.getLimit() +
topN.getOffset()),
+ scoreRangeInfo);
// Rebuild the plan tree above the new scan.
// We need to replace the original score() function with a reference
to the new
@@ -166,8 +201,25 @@ public class PushDownScoreTopNIntoOlapScan implements
RewriteRuleFactory {
replaceMap.put(scoreExpr, scoreAlias.toSlot());
replaceMap.put(scoreAlias, scoreAlias.toSlot());
+ // If we extracted a score predicate, remove it from the filter
+ // as it will be pushed down to the scan node
+ Set<Expression> newConjuncts;
+ if (extractedScorePredicate != null) {
+ final Expression predicateToRemove = extractedScorePredicate;
+ newConjuncts = filter.getConjuncts().stream()
+ .filter(c -> !c.equals(predicateToRemove))
+ .collect(ImmutableSet.toImmutableSet());
+ } else {
+ newConjuncts = filter.getConjuncts();
+ }
+
// The filter node remains, as the MATCH predicate is still needed.
- Plan newFilter = filter.withConjunctsAndChild(filter.getConjuncts(),
newScan);
+ Plan newFilter;
+ if (newConjuncts.isEmpty()) {
+ newFilter = newScan;
+ } else {
+ newFilter = filter.withConjunctsAndChild(newConjuncts, newScan);
+ }
// Rebuild project list with the replaced expressions.
List<NamedExpression> newProjections = ExpressionUtils
@@ -177,4 +229,106 @@ public class PushDownScoreTopNIntoOlapScan implements
RewriteRuleFactory {
// Rebuild the TopN node on top of the new project.
return topN.withChildren(newProject);
}
+
+ /**
+ * Extract score range info from a single score predicate.
+ * Only supports min_score semantics (similar to Elasticsearch):
+ * - score() > X or score() >= X
+ * - Reversed patterns: X < score() or X <= score()
+ *
+ * Note: < and <= are NOT supported because max_score filtering is rarely
needed.
+ * Note: EqualTo (=) is NOT supported.
+ */
+ private Optional<ScoreRangeInfo> extractScoreRangeInfo(Expression
predicate) {
+ if (!(predicate instanceof ComparisonPredicate)) {
+ if (!predicate.collect(e -> e instanceof Score).isEmpty()) {
+ throw new AnalysisException(
+ "score() predicate must be a top-level AND condition
in WHERE clause. "
+ + "Nesting score() inside OR or other compound
expressions is not supported. "
+ + "Invalid expression: " + predicate.toSql());
+ }
+ return Optional.empty();
+ }
+
+ ComparisonPredicate comp = (ComparisonPredicate) predicate;
+ Expression left = comp.left();
+ Expression right = comp.right();
+
+ if (isScoreExpression(left) && isNumericLiteral(right)) {
+ TExprOpcode op = getMinScoreOpcode(comp);
+ if (op != null) {
+ return Optional.of(new ScoreRangeInfo(op,
extractNumericValue(right)));
+ }
+ }
+
+ if (isScoreExpression(right) && isNumericLiteral(left)) {
+ TExprOpcode op = getReversedMinScoreOpcode(comp);
+ if (op != null) {
+ return Optional.of(new ScoreRangeInfo(op,
extractNumericValue(left)));
+ }
+ }
+
+ return Optional.empty();
+ }
+
+ /**
+ * Check if the expression is a Score function, possibly wrapped in Cast
expressions.
+ * The optimizer may wrap score() in Cast for type coercion (e.g., score()
>= 4.0 may become
+ * CAST(score() AS DECIMAL) >= 4.0).
+ */
+ private boolean isScoreExpression(Expression expr) {
+ if (expr instanceof Score) {
+ return true;
+ }
+ if (expr instanceof Cast) {
+ return isScoreExpression(((Cast) expr).child());
+ }
+ return false;
+ }
+
+ private boolean isNumericLiteral(Expression expr) {
+ return expr instanceof DoubleLiteral
+ || expr instanceof FloatLiteral
+ || expr instanceof IntegerLikeLiteral
+ || expr instanceof DecimalV3Literal;
+ }
+
+ private double extractNumericValue(Expression expr) {
+ if (expr instanceof DoubleLiteral) {
+ return ((DoubleLiteral) expr).getValue();
+ } else if (expr instanceof FloatLiteral) {
+ return ((FloatLiteral) expr).getValue();
+ } else if (expr instanceof IntegerLikeLiteral) {
+ return ((IntegerLikeLiteral) expr).getLongValue();
+ } else if (expr instanceof DecimalV3Literal) {
+ return ((DecimalV3Literal) expr).getDouble();
+ }
+ throw new IllegalArgumentException("Not a numeric literal: " + expr);
+ }
+
+ /**
+ * Get opcode for min_score patterns: score() > X or score() >= X
+ * Returns null for unsupported operators (< and <=)
+ */
+ private TExprOpcode getMinScoreOpcode(ComparisonPredicate comp) {
+ if (comp instanceof GreaterThan) {
+ return TExprOpcode.GT;
+ } else if (comp instanceof GreaterThanEqual) {
+ return TExprOpcode.GE;
+ }
+ return null;
+ }
+
+ /**
+ * Get the reversed opcode for min_score patterns like "0.5 < score()"
(equivalent to "score() > 0.5")
+ * Returns null for unsupported operators
+ */
+ private TExprOpcode getReversedMinScoreOpcode(ComparisonPredicate comp) {
+ if (comp instanceof LessThan) {
+ return TExprOpcode.GT;
+ } else if (comp instanceof LessThanEqual) {
+ return TExprOpcode.GE;
+ }
+ return null;
+ }
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownVectorTopNIntoOlapScan.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownVectorTopNIntoOlapScan.java
index 46ef4d0af2b..25efde5cce1 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownVectorTopNIntoOlapScan.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownVectorTopNIntoOlapScan.java
@@ -149,7 +149,7 @@ public class PushDownVectorTopNIntoOlapScan implements
RewriteRuleFactory {
Plan plan = scan.withVirtualColumnsAndTopN(
ImmutableList.of(orderKeyAlias),
topN.getOrderKeys(), Optional.of(topN.getLimit() +
topN.getOffset()),
- ImmutableList.of(), Optional.empty());
+ ImmutableList.of(), Optional.empty(), Optional.empty());
Map<Expression, Expression> replaceMap = Maps.newHashMap();
replaceMap.put(orderKeyAlias, orderKeyAlias.toSlot());
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ComparisonPredicate.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ComparisonPredicate.java
index 8c760ed55c7..8c29056ea96 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ComparisonPredicate.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ComparisonPredicate.java
@@ -19,6 +19,7 @@ package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.exceptions.UnboundException;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.Score;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.DataType;
@@ -68,5 +69,22 @@ public abstract class ComparisonPredicate extends
BinaryOperator {
throw new AnalysisException("comparison predicate could not
contains json type: " + this.toSql());
}
}
+ checkScoreComparisonType();
+ }
+
+ private void checkScoreComparisonType() {
+ Expression left = left();
+ Expression right = right();
+
+ if (left instanceof Score && !right.getDataType().isNumericType() &&
!right.getDataType().isNullType()) {
+ throw new AnalysisException(
+ "score() function can only be compared with numeric types,
but found: "
+ + right.getDataType() + " in " + this.toSql());
+ }
+ if (right instanceof Score && !left.getDataType().isNumericType() &&
!left.getDataType().isNullType()) {
+ throw new AnalysisException(
+ "score() function can only be compared with numeric types,
but found: "
+ + left.getDataType() + " in " + this.toSql());
+ }
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/ScoreRangeInfo.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/ScoreRangeInfo.java
new file mode 100644
index 00000000000..097eb6911eb
--- /dev/null
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/ScoreRangeInfo.java
@@ -0,0 +1,79 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.trees.plans;
+
+import org.apache.doris.thrift.TExprOpcode;
+import org.apache.doris.thrift.TScoreRangeInfo;
+
+import java.util.Objects;
+
+/**
+ * ScoreRangeInfo represents the score range filter parameters
+ * for BM25 range queries like "score() > 0.5".
+ * This is used to push down score range predicates to the storage engine.
+ */
+public class ScoreRangeInfo {
+ private final TExprOpcode op;
+ private final double threshold;
+
+ public ScoreRangeInfo(TExprOpcode op, double threshold) {
+ this.op = Objects.requireNonNull(op, "op cannot be null");
+ this.threshold = threshold;
+ }
+
+ public TExprOpcode getOp() {
+ return op;
+ }
+
+ public double getThreshold() {
+ return threshold;
+ }
+
+ /**
+ * Convert to Thrift representation for sending to BE.
+ */
+ public TScoreRangeInfo toThrift() {
+ TScoreRangeInfo tScoreRangeInfo = new TScoreRangeInfo();
+ tScoreRangeInfo.setOp(op);
+ tScoreRangeInfo.setThreshold(threshold);
+ return tScoreRangeInfo;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ ScoreRangeInfo that = (ScoreRangeInfo) o;
+ return Double.compare(that.threshold, threshold) == 0
+ && op == that.op;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(op, threshold);
+ }
+
+ @Override
+ public String toString() {
+ return "ScoreRangeInfo{op=" + op + ", threshold=" + threshold + '}';
+ }
+}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java
index a334b8ddcc5..bbcc44a38fc 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java
@@ -38,6 +38,7 @@ import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.PreAggStatus;
import org.apache.doris.nereids.trees.plans.RelationId;
+import org.apache.doris.nereids.trees.plans.ScoreRangeInfo;
import org.apache.doris.nereids.trees.plans.algebra.CatalogRelation;
import org.apache.doris.nereids.trees.plans.algebra.OlapScan;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
@@ -150,6 +151,8 @@ public class LogicalOlapScan extends LogicalCatalogRelation
implements OlapScan,
private final List<OrderKey> scoreOrderKeys;
private final Optional<Long> scoreLimit;
+ // Score range filter parameters for BM25 range queries like score() > 0.5
+ private final Optional<ScoreRangeInfo> scoreRangeInfo;
// use for ann push down
private final List<OrderKey> annOrderKeys;
private final Optional<Long> annLimit;
@@ -168,16 +171,20 @@ public class LogicalOlapScan extends
LogicalCatalogRelation implements OlapScan,
-1, false, PreAggStatus.unset(), ImmutableList.of(),
ImmutableList.of(),
Maps.newHashMap(), Optional.empty(), Optional.empty(), false,
ImmutableMap.of(),
ImmutableList.of(), ImmutableList.of(), ImmutableList.of(),
- ImmutableList.of(), Optional.empty(), ImmutableList.of(),
Optional.empty(), "");
+ ImmutableList.of(), Optional.empty(), Optional.empty(),
ImmutableList.of(), Optional.empty(), "");
}
+ /**
+ * Constructor for LogicalOlapScan.
+ */
public LogicalOlapScan(RelationId id, OlapTable table, List<String>
qualifier, List<Long> tabletIds,
List<String> hints, Optional<TableSample> tableSample,
Collection<Slot> operativeSlots) {
this(id, table, qualifier, Optional.empty(), Optional.empty(),
table.getPartitionIds(), false, tabletIds,
-1, false, PreAggStatus.unset(), ImmutableList.of(), hints,
Maps.newHashMap(), Optional.empty(),
tableSample, false, ImmutableMap.of(), ImmutableList.of(),
operativeSlots,
- ImmutableList.of(), ImmutableList.of(), Optional.empty(),
ImmutableList.of(), Optional.empty(), "");
+ ImmutableList.of(), ImmutableList.of(), Optional.empty(),
Optional.empty(),
+ ImmutableList.of(), Optional.empty(), "");
}
/**
@@ -190,7 +197,7 @@ public class LogicalOlapScan extends LogicalCatalogRelation
implements OlapScan,
specifiedPartitions, false, tabletIds,
-1, false, PreAggStatus.unset(), specifiedPartitions, hints,
Maps.newHashMap(), Optional.empty(),
tableSample, false, ImmutableMap.of(), ImmutableList.of(),
operativeSlots,
- ImmutableList.of(), ImmutableList.of(), Optional.empty(),
+ ImmutableList.of(), ImmutableList.of(), Optional.empty(),
Optional.empty(),
ImmutableList.of(), Optional.empty(), "");
}
@@ -205,7 +212,8 @@ public class LogicalOlapScan extends LogicalCatalogRelation
implements OlapScan,
selectedPartitionIds, false, tabletIds,
selectedIndexId, true, preAggStatus,
specifiedPartitions, hints, Maps.newHashMap(),
Optional.empty(), tableSample, true, ImmutableMap.of(),
- ImmutableList.of(), operativeSlots, ImmutableList.of(),
ImmutableList.of(), Optional.empty(),
+ ImmutableList.of(), operativeSlots, ImmutableList.of(),
ImmutableList.of(),
+ Optional.empty(), Optional.empty(),
ImmutableList.of(), Optional.empty(), "");
}
@@ -221,7 +229,7 @@ public class LogicalOlapScan extends LogicalCatalogRelation
implements OlapScan,
Optional<List<Slot>> cachedOutput, Optional<TableSample>
tableSample, boolean directMvScan,
Map<String, Set<List<String>>> colToSubPathsMap, List<Long>
specifiedTabletIds,
Collection<Slot> operativeSlots, List<NamedExpression>
virtualColumns,
- List<OrderKey> scoreOrderKeys, Optional<Long> scoreLimit,
+ List<OrderKey> scoreOrderKeys, Optional<Long> scoreLimit,
Optional<ScoreRangeInfo> scoreRangeInfo,
List<OrderKey> annOrderKeys, Optional<Long> annLimit, String
tableAlias) {
super(id, PlanType.LOGICAL_OLAP_SCAN, table, qualifier,
operativeSlots, virtualColumns, groupExpression,
logicalProperties, tableAlias);
@@ -257,6 +265,7 @@ public class LogicalOlapScan extends LogicalCatalogRelation
implements OlapScan,
this.subPathToSlotMap = Maps.newHashMap();
this.scoreOrderKeys = Utils.fastToImmutableList(scoreOrderKeys);
this.scoreLimit = scoreLimit;
+ this.scoreRangeInfo = scoreRangeInfo;
this.annOrderKeys = Utils.fastToImmutableList(annOrderKeys);
this.annLimit = annLimit;
}
@@ -328,6 +337,7 @@ public class LogicalOlapScan extends LogicalCatalogRelation
implements OlapScan,
&& Objects.equals(tableSample, that.tableSample)
&& Objects.equals(scoreOrderKeys, that.scoreOrderKeys)
&& Objects.equals(scoreLimit, that.scoreLimit)
+ && Objects.equals(scoreRangeInfo, that.scoreRangeInfo)
&& Objects.equals(annOrderKeys, that.annOrderKeys)
&& Objects.equals(annLimit, that.annLimit);
}
@@ -345,7 +355,7 @@ public class LogicalOlapScan extends LogicalCatalogRelation
implements OlapScan,
selectedIndexId, indexSelected, preAggStatus,
manuallySpecifiedPartitions,
hints, cacheSlotWithSlotName, cachedOutput, tableSample,
directMvScan,
colToSubPathsMap, manuallySpecifiedTabletIds, operativeSlots,
virtualColumns,
- scoreOrderKeys, scoreLimit, annOrderKeys, annLimit,
tableAlias);
+ scoreOrderKeys, scoreLimit, scoreRangeInfo, annOrderKeys,
annLimit, tableAlias);
}
@Override
@@ -356,7 +366,7 @@ public class LogicalOlapScan extends LogicalCatalogRelation
implements OlapScan,
selectedIndexId, indexSelected, preAggStatus,
manuallySpecifiedPartitions,
hints, cacheSlotWithSlotName, cachedOutput, tableSample,
directMvScan,
colToSubPathsMap, manuallySpecifiedTabletIds, operativeSlots,
virtualColumns,
- scoreOrderKeys, scoreLimit, annOrderKeys, annLimit,
tableAlias);
+ scoreOrderKeys, scoreLimit, scoreRangeInfo, annOrderKeys,
annLimit, tableAlias);
}
/**
@@ -369,7 +379,7 @@ public class LogicalOlapScan extends LogicalCatalogRelation
implements OlapScan,
selectedIndexId, indexSelected, preAggStatus,
manuallySpecifiedPartitions,
hints, cacheSlotWithSlotName, cachedOutput, tableSample,
directMvScan,
colToSubPathsMap, manuallySpecifiedTabletIds, operativeSlots,
virtualColumns,
- scoreOrderKeys, scoreLimit, annOrderKeys, annLimit,
tableAlias);
+ scoreOrderKeys, scoreLimit, scoreRangeInfo, annOrderKeys,
annLimit, tableAlias);
}
/**
@@ -383,7 +393,8 @@ public class LogicalOlapScan extends LogicalCatalogRelation
implements OlapScan,
selectedPartitionIds, partitionPruned, selectedTabletIds,
indexId, true, PreAggStatus.unset(),
manuallySpecifiedPartitions, hints, cacheSlotWithSlotName,
cachedOutput, tableSample, directMvScan, colToSubPathsMap,
manuallySpecifiedTabletIds,
- operativeSlots, virtualColumns, scoreOrderKeys, scoreLimit,
annOrderKeys, annLimit, tableAlias);
+ operativeSlots, virtualColumns, scoreOrderKeys, scoreLimit,
scoreRangeInfo,
+ annOrderKeys, annLimit, tableAlias);
}
/**
@@ -396,7 +407,7 @@ public class LogicalOlapScan extends LogicalCatalogRelation
implements OlapScan,
selectedIndexId, indexSelected, preAggStatus,
manuallySpecifiedPartitions,
hints, cacheSlotWithSlotName, cachedOutput, tableSample,
directMvScan,
colToSubPathsMap, manuallySpecifiedTabletIds, operativeSlots,
virtualColumns, scoreOrderKeys,
- scoreLimit, annOrderKeys, annLimit, tableAlias);
+ scoreLimit, scoreRangeInfo, annOrderKeys, annLimit,
tableAlias);
}
/**
@@ -409,7 +420,7 @@ public class LogicalOlapScan extends LogicalCatalogRelation
implements OlapScan,
selectedIndexId, indexSelected, preAggStatus,
manuallySpecifiedPartitions,
hints, cacheSlotWithSlotName, cachedOutput, tableSample,
directMvScan,
colToSubPathsMap, manuallySpecifiedTabletIds, operativeSlots,
virtualColumns,
- scoreOrderKeys, scoreLimit, annOrderKeys, annLimit,
tableAlias);
+ scoreOrderKeys, scoreLimit, scoreRangeInfo, annOrderKeys,
annLimit, tableAlias);
}
/**
@@ -422,7 +433,7 @@ public class LogicalOlapScan extends LogicalCatalogRelation
implements OlapScan,
selectedIndexId, indexSelected, preAggStatus,
manuallySpecifiedPartitions,
hints, cacheSlotWithSlotName, cachedOutput, tableSample,
directMvScan,
colToSubPathsMap, manuallySpecifiedTabletIds, operativeSlots,
virtualColumns,
- scoreOrderKeys, scoreLimit, annOrderKeys, annLimit,
tableAlias);
+ scoreOrderKeys, scoreLimit, scoreRangeInfo, annOrderKeys,
annLimit, tableAlias);
}
/**
@@ -435,7 +446,7 @@ public class LogicalOlapScan extends LogicalCatalogRelation
implements OlapScan,
selectedIndexId, indexSelected, preAggStatus,
manuallySpecifiedPartitions,
hints, cacheSlotWithSlotName, cachedOutput, tableSample,
directMvScan,
colToSubPathsMap, manuallySpecifiedTabletIds, operativeSlots,
virtualColumns,
- scoreOrderKeys, scoreLimit, annOrderKeys, annLimit,
tableAlias);
+ scoreOrderKeys, scoreLimit, scoreRangeInfo, annOrderKeys,
annLimit, tableAlias);
}
@Override
@@ -447,7 +458,7 @@ public class LogicalOlapScan extends LogicalCatalogRelation
implements OlapScan,
selectedIndexId, indexSelected, preAggStatus,
manuallySpecifiedPartitions,
hints, Maps.newHashMap(), Optional.empty(), tableSample,
directMvScan,
colToSubPathsMap, selectedTabletIds, operativeSlots,
virtualColumns, scoreOrderKeys,
- scoreLimit, annOrderKeys, annLimit, tableAlias);
+ scoreLimit, scoreRangeInfo, annOrderKeys, annLimit,
tableAlias);
}
@Override
@@ -458,7 +469,7 @@ public class LogicalOlapScan extends LogicalCatalogRelation
implements OlapScan,
selectedIndexId, indexSelected, preAggStatus,
manuallySpecifiedPartitions,
hints, cacheSlotWithSlotName, cachedOutput, tableSample,
directMvScan,
colToSubPathsMap, manuallySpecifiedTabletIds, operativeSlots,
virtualColumns,
- scoreOrderKeys, scoreLimit, annOrderKeys, annLimit,
tableAlias);
+ scoreOrderKeys, scoreLimit, scoreRangeInfo, annOrderKeys,
annLimit, tableAlias);
}
/**
@@ -478,20 +489,22 @@ public class LogicalOlapScan extends
LogicalCatalogRelation implements OlapScan,
selectedIndexId, indexSelected, preAggStatus,
manuallySpecifiedPartitions,
hints, cacheSlotWithSlotName, cachedOutput, tableSample,
directMvScan, colToSubPathsMap,
manuallySpecifiedTabletIds, operativeSlots, virtualColumns,
scoreOrderKeys, scoreLimit,
- annOrderKeys, annLimit, tableAlias);
+ scoreRangeInfo, annOrderKeys, annLimit, tableAlias);
}
/**
- * add virtual column to olap scan.
+ * Add virtual column to olap scan with optional score range info.
* @param virtualColumns generated virtual columns
- * @return scan with virtual columns
+ * @param scoreRangeInfo optional score range filter info for BM25 range
queries
+ * @return scan with virtual columns and optional score range info
*/
public LogicalOlapScan withVirtualColumnsAndTopN(
List<NamedExpression> virtualColumns,
List<OrderKey> annOrderKeys,
Optional<Long> annLimit,
List<OrderKey> scoreOrderKeys,
- Optional<Long> scoreLimit) {
+ Optional<Long> scoreLimit,
+ Optional<ScoreRangeInfo> scoreRangeInfo) {
LogicalProperties logicalProperties = getLogicalProperties();
List<Slot> output = Lists.newArrayList(logicalProperties.getOutput());
output.addAll(virtualColumns.stream().map(NamedExpression::toSlot).collect(Collectors.toList()));
@@ -502,7 +515,7 @@ public class LogicalOlapScan extends LogicalCatalogRelation
implements OlapScan,
selectedIndexId, indexSelected, preAggStatus,
manuallySpecifiedPartitions,
hints, cacheSlotWithSlotName, cachedOutput, tableSample,
directMvScan, colToSubPathsMap,
manuallySpecifiedTabletIds, operativeSlots, virtualColumns,
scoreOrderKeys, scoreLimit,
- annOrderKeys, annLimit, tableAlias);
+ scoreRangeInfo, annOrderKeys, annLimit, tableAlias);
}
@Override
@@ -675,6 +688,10 @@ public class LogicalOlapScan extends
LogicalCatalogRelation implements OlapScan,
return annLimit;
}
+ public Optional<ScoreRangeInfo> getScoreRangeInfo() {
+ return scoreRangeInfo;
+ }
+
private List<SlotReference> createSlotsVectorized(List<Column> columns) {
List<String> qualified = qualified();
SlotReference[] slots = new SlotReference[columns.size()];
@@ -837,7 +854,7 @@ public class LogicalOlapScan extends LogicalCatalogRelation
implements OlapScan,
selectedIndexId, indexSelected, preAggStatus,
manuallySpecifiedPartitions,
hints, cacheSlotWithSlotName, cachedOutput, tableSample,
directMvScan, colToSubPathsMap,
manuallySpecifiedTabletIds, operativeSlots, virtualColumns,
scoreOrderKeys, scoreLimit,
- annOrderKeys, annLimit, tableAlias);
+ scoreRangeInfo, annOrderKeys, annLimit, tableAlias);
}
private Map<Slot, Slot> constructReplaceMap(MTMV mtmv) {
@@ -878,7 +895,7 @@ public class LogicalOlapScan extends LogicalCatalogRelation
implements OlapScan,
selectedIndexId, indexSelected, preAggStatus,
manuallySpecifiedPartitions,
hints, cacheSlotWithSlotName, Optional.of(outputSlots),
tableSample, directMvScan, colToSubPathsMap,
manuallySpecifiedTabletIds, operativeSlots, virtualColumns,
scoreOrderKeys, scoreLimit,
- annOrderKeys, annLimit, tableAlias);
+ scoreRangeInfo, annOrderKeys, annLimit, tableAlias);
}
@Override
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalLazyMaterializeOlapScan.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalLazyMaterializeOlapScan.java
index aa1d67ce65b..cc5095644f3 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalLazyMaterializeOlapScan.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalLazyMaterializeOlapScan.java
@@ -59,6 +59,7 @@ public class PhysicalLazyMaterializeOlapScan extends
PhysicalOlapScan {
physicalOlapScan.getVirtualColumns(),
physicalOlapScan.getScoreOrderKeys(),
physicalOlapScan.getScoreLimit(),
+ physicalOlapScan.getScoreRangeInfo(),
physicalOlapScan.getAnnOrderKeys(),
physicalOlapScan.getAnnLimit()
);
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalOlapScan.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalOlapScan.java
index f52ab20558f..a2a4d1e152c 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalOlapScan.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalOlapScan.java
@@ -32,6 +32,7 @@ import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.PreAggStatus;
import org.apache.doris.nereids.trees.plans.RelationId;
+import org.apache.doris.nereids.trees.plans.ScoreRangeInfo;
import org.apache.doris.nereids.trees.plans.algebra.CatalogRelation;
import org.apache.doris.nereids.trees.plans.algebra.OlapScan;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
@@ -70,6 +71,7 @@ public class PhysicalOlapScan extends PhysicalCatalogRelation
implements OlapSca
private final List<OrderKey> scoreOrderKeys;
private final Optional<Long> scoreLimit;
+ private final Optional<ScoreRangeInfo> scoreRangeInfo;
// use for ann push down
private final List<OrderKey> annOrderKeys;
private final Optional<Long> annLimit;
@@ -82,14 +84,14 @@ public class PhysicalOlapScan extends
PhysicalCatalogRelation implements OlapSca
PreAggStatus preAggStatus, List<Slot> baseOutputs,
Optional<GroupExpression> groupExpression, LogicalProperties
logicalProperties,
Optional<TableSample> tableSample, List<Slot> operativeSlots,
List<NamedExpression> virtualColumns,
- List<OrderKey> scoreOrderKeys, Optional<Long> scoreLimit,
+ List<OrderKey> scoreOrderKeys, Optional<Long> scoreLimit,
Optional<ScoreRangeInfo> scoreRangeInfo,
List<OrderKey> annOrderKeys, Optional<Long> annLimit) {
this(id, olapTable, qualifier,
selectedIndexId, selectedTabletIds, selectedPartitionIds,
distributionSpec,
preAggStatus, baseOutputs,
groupExpression, logicalProperties, null,
null, tableSample, operativeSlots, virtualColumns,
scoreOrderKeys, scoreLimit,
- annOrderKeys, annLimit, "");
+ scoreRangeInfo, annOrderKeys, annLimit, "");
}
/**
@@ -102,12 +104,12 @@ public class PhysicalOlapScan extends
PhysicalCatalogRelation implements OlapSca
PhysicalProperties physicalProperties, Statistics statistics,
Optional<TableSample> tableSample,
Collection<Slot> operativeSlots, List<NamedExpression>
virtualColumns,
- List<OrderKey> scoreOrderKeys, Optional<Long> scoreLimit,
+ List<OrderKey> scoreOrderKeys, Optional<Long> scoreLimit,
Optional<ScoreRangeInfo> scoreRangeInfo,
List<OrderKey> annOrderKeys, Optional<Long> annLimit) {
this(id, olapTable, qualifier, selectedIndexId, selectedTabletIds,
selectedPartitionIds,
distributionSpec, preAggStatus, baseOutputs, groupExpression,
logicalProperties,
physicalProperties, statistics, tableSample, operativeSlots,
virtualColumns,
- scoreOrderKeys, scoreLimit, annOrderKeys, annLimit, "");
+ scoreOrderKeys, scoreLimit, scoreRangeInfo, annOrderKeys,
annLimit, "");
}
/**
@@ -120,7 +122,7 @@ public class PhysicalOlapScan extends
PhysicalCatalogRelation implements OlapSca
PhysicalProperties physicalProperties, Statistics statistics,
Optional<TableSample> tableSample,
Collection<Slot> operativeSlots, List<NamedExpression>
virtualColumns,
- List<OrderKey> scoreOrderKeys, Optional<Long> scoreLimit,
+ List<OrderKey> scoreOrderKeys, Optional<Long> scoreLimit,
Optional<ScoreRangeInfo> scoreRangeInfo,
List<OrderKey> annOrderKeys, Optional<Long> annLimit, String
tableAlias) {
super(id, PlanType.PHYSICAL_OLAP_SCAN, olapTable, qualifier,
groupExpression, logicalProperties, physicalProperties,
statistics, operativeSlots, tableAlias);
@@ -135,6 +137,7 @@ public class PhysicalOlapScan extends
PhysicalCatalogRelation implements OlapSca
this.virtualColumns = ImmutableList.copyOf(virtualColumns);
this.scoreOrderKeys = ImmutableList.copyOf(scoreOrderKeys);
this.scoreLimit = scoreLimit;
+ this.scoreRangeInfo = scoreRangeInfo;
this.annOrderKeys = ImmutableList.copyOf(annOrderKeys);
this.annLimit = annLimit;
}
@@ -200,6 +203,10 @@ public class PhysicalOlapScan extends
PhysicalCatalogRelation implements OlapSca
return annLimit;
}
+ public Optional<ScoreRangeInfo> getScoreRangeInfo() {
+ return scoreRangeInfo;
+ }
+
@Override
public String getFingerprint() {
String partitions = "";
@@ -276,6 +283,7 @@ public class PhysicalOlapScan extends
PhysicalCatalogRelation implements OlapSca
&& Objects.equals(virtualColumns, olapScan.virtualColumns)
&& Objects.equals(scoreOrderKeys, olapScan.scoreOrderKeys)
&& Objects.equals(scoreLimit, olapScan.scoreLimit)
+ && Objects.equals(scoreRangeInfo, olapScan.scoreRangeInfo)
&& Objects.equals(annOrderKeys, olapScan.annOrderKeys)
&& Objects.equals(annLimit, olapScan.annLimit);
}
@@ -295,7 +303,7 @@ public class PhysicalOlapScan extends
PhysicalCatalogRelation implements OlapSca
return new PhysicalOlapScan(relationId, getTable(), qualifier,
selectedIndexId, selectedTabletIds,
selectedPartitionIds, distributionSpec, preAggStatus,
baseOutputs,
groupExpression, getLogicalProperties(), null, null,
tableSample, operativeSlots, virtualColumns,
- scoreOrderKeys, scoreLimit, annOrderKeys, annLimit,
tableAlias);
+ scoreOrderKeys, scoreLimit, scoreRangeInfo, annOrderKeys,
annLimit, tableAlias);
}
@Override
@@ -304,7 +312,7 @@ public class PhysicalOlapScan extends
PhysicalCatalogRelation implements OlapSca
return new PhysicalOlapScan(relationId, getTable(), qualifier,
selectedIndexId, selectedTabletIds,
selectedPartitionIds, distributionSpec, preAggStatus,
baseOutputs, groupExpression,
logicalProperties.get(), null, null, tableSample,
operativeSlots, virtualColumns,
- scoreOrderKeys, scoreLimit, annOrderKeys, annLimit,
tableAlias);
+ scoreOrderKeys, scoreLimit, scoreRangeInfo, annOrderKeys,
annLimit, tableAlias);
}
@Override
@@ -313,7 +321,7 @@ public class PhysicalOlapScan extends
PhysicalCatalogRelation implements OlapSca
return new PhysicalOlapScan(relationId, getTable(), qualifier,
selectedIndexId, selectedTabletIds,
selectedPartitionIds, distributionSpec, preAggStatus,
baseOutputs, groupExpression,
getLogicalProperties(), physicalProperties, statistics,
tableSample, operativeSlots,
- virtualColumns, scoreOrderKeys, scoreLimit, annOrderKeys,
annLimit, tableAlias);
+ virtualColumns, scoreOrderKeys, scoreLimit, scoreRangeInfo,
annOrderKeys, annLimit, tableAlias);
}
@Override
@@ -340,7 +348,7 @@ public class PhysicalOlapScan extends
PhysicalCatalogRelation implements OlapSca
selectedPartitionIds, distributionSpec, preAggStatus,
baseOutputs,
groupExpression, getLogicalProperties(),
getPhysicalProperties(), statistics,
tableSample, operativeSlots, virtualColumns, scoreOrderKeys,
scoreLimit,
- annOrderKeys, annLimit, tableAlias);
+ scoreRangeInfo, annOrderKeys, annLimit, tableAlias);
}
@Override
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/planner/OlapScanNode.java
b/fe/fe-core/src/main/java/org/apache/doris/planner/OlapScanNode.java
index 1a20ecb3578..0c4f00fc323 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/planner/OlapScanNode.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/planner/OlapScanNode.java
@@ -53,6 +53,7 @@ import org.apache.doris.common.UserException;
import org.apache.doris.common.util.DebugUtil;
import org.apache.doris.info.PartitionNamesInfo;
import org.apache.doris.nereids.glue.translator.PlanTranslatorContext;
+import org.apache.doris.nereids.trees.plans.ScoreRangeInfo;
import org.apache.doris.planner.normalize.Normalizer;
import org.apache.doris.planner.normalize.PartitionRangePredicateNormalizer;
import org.apache.doris.qe.ConnectContext;
@@ -187,6 +188,7 @@ public class OlapScanNode extends ScanNode {
private SortInfo scoreSortInfo = null;
private long scoreSortLimit = -1;
+ private ScoreRangeInfo scoreRangeInfo = null;
// cached for prepared statement to quickly prune partition
// only used in short circuit plan at present
@@ -272,6 +274,10 @@ public class OlapScanNode extends ScanNode {
this.scoreSortLimit = scoreSortLimit;
}
+ public void setScoreRangeInfo(ScoreRangeInfo scoreRangeInfo) {
+ this.scoreRangeInfo = scoreRangeInfo;
+ }
+
public void setAnnSortInfo(SortInfo annSortInfo) {
this.annSortInfo = annSortInfo;
}
@@ -1159,6 +1165,9 @@ public class OlapScanNode extends ScanNode {
if (scoreSortLimit != -1) {
msg.olap_scan_node.setScoreSortLimit(scoreSortLimit);
}
+ if (scoreRangeInfo != null) {
+ msg.olap_scan_node.setScoreRangeInfo(scoreRangeInfo.toThrift());
+ }
if (annSortInfo != null) {
TSortInfo tAnnSortInfo = new TSortInfo(
Expr.treesToThrift(annSortInfo.getOrderingExprs()),
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslatorTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslatorTest.java
index fdf1c726cfc..288c5f65a2b 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslatorTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslatorTest.java
@@ -90,7 +90,7 @@ public class PhysicalPlanTranslatorTest extends
TestWithFeService {
Collections.emptyList(), Collections.emptyList(), null,
PreAggStatus.on(),
ImmutableList.of(), Optional.empty(), t1Properties,
Optional.empty(),
ImmutableList.of(), ImmutableList.of(), ImmutableList.of(),
Optional.empty(),
- ImmutableList.of(), Optional.empty());
+ Optional.empty(), ImmutableList.of(), Optional.empty());
Literal t1FilterRight = new IntegerLiteral(1);
Expression t1FilterExpr = new GreaterThan(col1, t1FilterRight);
PhysicalFilter<PhysicalOlapScan> filter =
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/MergeProjectPostProcessTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/MergeProjectPostProcessTest.java
index 3177cd97316..2b561837213 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/MergeProjectPostProcessTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/MergeProjectPostProcessTest.java
@@ -80,7 +80,7 @@ public class MergeProjectPostProcessTest {
Collections.emptyList(), Collections.emptyList(), null,
PreAggStatus.on(), ImmutableList.of(),
Optional.empty(), t1Properties, Optional.empty(),
ImmutableList.of(),
ImmutableList.of(), ImmutableList.of(), Optional.empty(),
- ImmutableList.of(), Optional.empty());
+ Optional.empty(), ImmutableList.of(), Optional.empty());
Alias x = new Alias(a, "x");
List<NamedExpression> projList3 = Lists.newArrayList(x, b, c);
PhysicalProject proj3 = new PhysicalProject(projList3, placeHolder,
scan);
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/PushDownFilterThroughProjectTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/PushDownFilterThroughProjectTest.java
index 76e9bb0722f..979bf1abd04 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/PushDownFilterThroughProjectTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/PushDownFilterThroughProjectTest.java
@@ -94,7 +94,7 @@ public class PushDownFilterThroughProjectTest {
qualifier, 0L, Collections.emptyList(),
Collections.emptyList(), null,
PreAggStatus.on(), ImmutableList.of(), Optional.empty(),
t1Properties,
Optional.empty(), ImmutableList.of(), ImmutableList.of(),
ImmutableList.of(), Optional.empty(),
- ImmutableList.of(), Optional.empty());
+ Optional.empty(), ImmutableList.of(), Optional.empty());
Alias x = new Alias(a, "x");
List<NamedExpression> projList3 = Lists.newArrayList(x, b, c);
PhysicalProject proj3 = new PhysicalProject(projList3, placeHolder,
scan);
@@ -134,7 +134,7 @@ public class PushDownFilterThroughProjectTest {
qualifier, 0L, Collections.emptyList(),
Collections.emptyList(), null,
PreAggStatus.on(), ImmutableList.of(), Optional.empty(),
t1Properties,
Optional.empty(), new ArrayList<>(), ImmutableList.of(),
ImmutableList.of(), Optional.empty(),
- ImmutableList.of(), Optional.empty());
+ Optional.empty(), ImmutableList.of(), Optional.empty());
Alias x = new Alias(a, "x");
List<NamedExpression> projList3 = Lists.newArrayList(x, b, c);
PhysicalProject proj3 = new PhysicalProject(projList3, placeHolder,
scan);
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanEqualsTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanEqualsTest.java
index 59c6eac3593..6f42d6a5839 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanEqualsTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanEqualsTest.java
@@ -342,14 +342,14 @@ class PlanEqualsTest {
PreAggStatus.on(), ImmutableList.of(), Optional.empty(),
logicalProperties,
Optional.empty(),
ImmutableList.of(), ImmutableList.of(), ImmutableList.of(),
Optional.empty(),
- ImmutableList.of(), Optional.empty());
+ Optional.empty(), ImmutableList.of(), Optional.empty());
PhysicalOlapScan expected = new PhysicalOlapScan(id, olapTable,
Lists.newArrayList("a"),
1L, selectedTabletId, olapTable.getPartitionIds(),
distributionSpecHash,
PreAggStatus.on(), ImmutableList.of(), Optional.empty(),
logicalProperties,
Optional.empty(),
ImmutableList.of(), ImmutableList.of(), ImmutableList.of(),
Optional.empty(),
- ImmutableList.of(), Optional.empty());
+ Optional.empty(), ImmutableList.of(), Optional.empty());
Assertions.assertEquals(expected, actual);
PhysicalOlapScan unexpected = new PhysicalOlapScan(id, olapTable,
Lists.newArrayList("b"),
@@ -357,7 +357,7 @@ class PlanEqualsTest {
PreAggStatus.on(), ImmutableList.of(), Optional.empty(),
logicalProperties,
Optional.empty(),
ImmutableList.of(), ImmutableList.of(), ImmutableList.of(),
Optional.empty(),
- ImmutableList.of(), Optional.empty());
+ Optional.empty(), ImmutableList.of(), Optional.empty());
Assertions.assertNotEquals(unexpected, actual);
}
diff --git a/gensrc/thrift/PlanNodes.thrift b/gensrc/thrift/PlanNodes.thrift
index 1369c6e1045..7b281dcf712 100644
--- a/gensrc/thrift/PlanNodes.thrift
+++ b/gensrc/thrift/PlanNodes.thrift
@@ -862,6 +862,12 @@ enum TPushAggOp {
COUNT_ON_INDEX = 4
}
+struct TScoreRangeInfo {
+ // Score range filter parameters for BM25 range queries like score() > 0.5
+ 1: optional Opcodes.TExprOpcode op
+ 2: optional double threshold
+}
+
struct TOlapScanNode {
1: required Types.TTupleId tuple_id
2: required list<string> key_column_name
@@ -887,6 +893,7 @@ struct TOlapScanNode {
20: optional i64 score_sort_limit
21: optional TSortInfo ann_sort_info
22: optional i64 ann_sort_limit
+ 23: optional TScoreRangeInfo score_range_info
}
struct TEqJoinCondition {
diff --git
a/regression-test/data/inverted_index_p0/test_bm25_score_range_filter.out
b/regression-test/data/inverted_index_p0/test_bm25_score_range_filter.out
new file mode 100644
index 00000000000..d84baf2af71
--- /dev/null
+++ b/regression-test/data/inverted_index_p0/test_bm25_score_range_filter.out
@@ -0,0 +1,33 @@
+-- This file is automatically generated. You should know what you did if you
want to edit this
+-- !score_gt --
+3 apple apple apple 0.5016066
+
+-- !score_ge --
+3 apple apple apple 0.5016066
+
+-- !score_gt_zero --
+3 apple apple apple 0.5016066
+6 apple 0.4399566
+2 apple apple banana 0.4363004
+8 apple apple banana cherry 0.3967367
+7 apple banana 0.3662894
+1 apple banana cherry 0.3137539
+
+-- !score_asc --
+1 apple banana cherry 0.3137539
+7 apple banana 0.3662894
+8 apple apple banana cherry 0.3967367
+2 apple apple banana 0.4363004
+6 apple 0.4399566
+
+-- !score_high_threshold --
+
+-- !score_reversed --
+3 apple apple apple 0.5016066
+
+-- !score_match_all --
+7 apple banana 0.9206117
+2 apple apple banana 0.9111184
+8 apple apple banana cherry 0.8119956
+1 apple banana cherry 0.7885718
+
diff --git a/regression-test/suites/inverted_index_p0/test_bm25_score.groovy
b/regression-test/suites/inverted_index_p0/test_bm25_score.groovy
index 36cb8a8270c..cdbec257922 100644
--- a/regression-test/suites/inverted_index_p0/test_bm25_score.groovy
+++ b/regression-test/suites/inverted_index_p0/test_bm25_score.groovy
@@ -156,10 +156,8 @@ suite("test_bm25_score", "p0") {
exception "score() function requires WHERE clause with MATCH
function, ORDER BY and LIMIT for optimization"
}
- test {
- sql """ select *, score() from test_bm25_score where request
match_any 'button.03.gif' and score() > 0.5 order by score() limit 10; """
- exception "score() function can only be used in SELECT clause, not
in WHERE clause"
- }
+ // score() > 0.5 with ORDER BY and LIMIT is now supported after score
range filter pushdown feature
+ sql """ select *, score() from test_bm25_score where request match_any
'button.03.gif' and score() > 0.5 order by score() limit 10; """
test {
sql """ select *, score() as score from test_bm25_score where
request match_any 'button.03.gif' and score() > 0.5; """
diff --git
a/regression-test/suites/inverted_index_p0/test_bm25_score_range_filter.groovy
b/regression-test/suites/inverted_index_p0/test_bm25_score_range_filter.groovy
new file mode 100644
index 00000000000..240b67ef24a
--- /dev/null
+++
b/regression-test/suites/inverted_index_p0/test_bm25_score_range_filter.groovy
@@ -0,0 +1,159 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("test_bm25_score_range_filter", "p0") {
+ def tableName = "test_bm25_score_range_filter"
+
+ sql "DROP TABLE IF EXISTS ${tableName}"
+
+ sql """
+ CREATE TABLE ${tableName} (
+ `id` int(11) NULL,
+ `content` text NULL,
+ INDEX content_idx (`content`) USING INVERTED PROPERTIES("parser" =
"english", "support_phrase" = "true")
+ ) ENGINE=OLAP
+ DUPLICATE KEY(`id`)
+ DISTRIBUTED BY HASH(`id`) BUCKETS 1
+ PROPERTIES (
+ "replication_allocation" = "tag.location.default: 1"
+ );
+ """
+
+ sql """ INSERT INTO ${tableName} VALUES (1, 'apple banana cherry'); """
+ sql """ INSERT INTO ${tableName} VALUES (2, 'apple apple banana'); """
+ sql """ INSERT INTO ${tableName} VALUES (3, 'apple apple apple'); """
+ sql """ INSERT INTO ${tableName} VALUES (4, 'banana cherry date'); """
+ sql """ INSERT INTO ${tableName} VALUES (5, 'cherry date elderberry'); """
+ sql """ INSERT INTO ${tableName} VALUES (6, 'apple'); """
+ sql """ INSERT INTO ${tableName} VALUES (7, 'apple banana'); """
+ sql """ INSERT INTO ${tableName} VALUES (8, 'apple apple banana cherry');
"""
+
+ sql "sync"
+ sql """ set enable_common_expr_pushdown = true; """
+
+ // Test 1: Basic score range filter with GT (>)
+ qt_score_gt """
+ SELECT id, content, score() as s
+ FROM ${tableName}
+ WHERE content match_any 'apple' AND score() > 0.5
+ ORDER BY s DESC
+ LIMIT 10;
+ """
+
+ // Test 2: Score range filter with GE (>=)
+ qt_score_ge """
+ SELECT id, content, score() as s
+ FROM ${tableName}
+ WHERE content match_any 'apple' AND score() >= 0.5
+ ORDER BY s DESC
+ LIMIT 10;
+ """
+
+ // Test 3: Score range filter with zero threshold
+ qt_score_gt_zero """
+ SELECT id, content, score() as s
+ FROM ${tableName}
+ WHERE content match_any 'apple' AND score() > 0
+ ORDER BY s DESC
+ LIMIT 10;
+ """
+
+ // Test 4: Score range filter with ASC order
+ qt_score_asc """
+ SELECT id, content, score() as s
+ FROM ${tableName}
+ WHERE content match_any 'apple' AND score() > 0.3
+ ORDER BY s ASC
+ LIMIT 5;
+ """
+
+ // Test 5: Score range filter with high threshold (may filter all)
+ qt_score_high_threshold """
+ SELECT id, content, score() as s
+ FROM ${tableName}
+ WHERE content match_any 'apple' AND score() > 10.0
+ ORDER BY s DESC
+ LIMIT 10;
+ """
+
+ // Test 6: Reversed predicate form (literal < score())
+ qt_score_reversed """
+ SELECT id, content, score() as s
+ FROM ${tableName}
+ WHERE content match_any 'apple' AND 0.5 < score()
+ ORDER BY s DESC
+ LIMIT 10;
+ """
+
+ // Test 7: Score range with match_all
+ qt_score_match_all """
+ SELECT id, content, score() as s
+ FROM ${tableName}
+ WHERE content match_all 'apple banana' AND score() > 0.3
+ ORDER BY s DESC
+ LIMIT 10;
+ """
+
+ // Test 8: Verify explain shows score range push down
+ def explain_result = sql """
+ EXPLAIN VERBOSE
+ SELECT id, content, score() as s
+ FROM ${tableName}
+ WHERE content match_any 'apple' AND score() > 0.5
+ ORDER BY s DESC
+ LIMIT 10;
+ """
+ log.info("Explain result: ${explain_result}")
+
+ // Test 9: Error case - multiple score predicates not supported
+ test {
+ sql """
+ SELECT id, content, score() as s
+ FROM ${tableName}
+ WHERE content match_any 'apple' AND score() > 0.3 AND score() < 0.9
+ ORDER BY s DESC
+ LIMIT 10;
+ """
+ exception "Only one score() range predicate is supported"
+ }
+
+ // Test 10: Error case - score() < X not supported (max_score semantics)
+ test {
+ sql """
+ SELECT id, content, score() as s
+ FROM ${tableName}
+ WHERE content match_any 'apple' AND score() < 0.9
+ ORDER BY s DESC
+ LIMIT 10;
+ """
+ exception "score() predicate in WHERE clause must be in the form of
'score() > literal'"
+ }
+
+ // Test 11: Error case - score() compared with string type not supported
+ test {
+ sql """
+ SELECT id, content, score() as s
+ FROM ${tableName}
+ WHERE content match_any 'apple' AND score() > 'invalid_string'
+ ORDER BY s DESC
+ LIMIT 10;
+ """
+ exception "score() function can only be compared with numeric types"
+ }
+
+ sql "DROP TABLE IF EXISTS ${tableName}"
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]