This is an automated email from the ASF dual-hosted git repository.
airborne pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 26c1950074f [feature](score) support BM25 scoring in inverted index
query_v2 (#59847)
26c1950074f is described below
commit 26c1950074f1049411b3455cff59a6b47a089fdf
Author: zzzxl <[email protected]>
AuthorDate: Tue Mar 17 15:17:56 2026 +0800
[feature](score) support BM25 scoring in inverted index query_v2 (#59847)
---
be/src/exprs/function/function_search.cpp | 69 +--
be/src/exprs/function/function_search.h | 3 +-
be/src/exprs/vexpr_context.h | 9 +
be/src/exprs/vsearch.cpp | 5 +-
.../storage/compaction/collection_statistics.cpp | 146 ++----
be/src/storage/compaction/collection_statistics.h | 19 +-
be/src/storage/index/index_iterator.h | 1 +
be/src/storage/index/index_query_context.h | 3 +
.../boolean_query/occur_boolean_weight.cpp | 23 +-
.../query_v2/boolean_query/occur_boolean_weight.h | 43 +-
.../query_v2/collect/doc_set_collector.cpp | 47 ++
.../query_v2/collect/doc_set_collector.h} | 24 +-
.../inverted/query_v2/collect/multi_segment_util.h | 97 ++++
.../inverted/query_v2/collect/top_k_collector.cpp | 61 +++
.../inverted/query_v2/collect/top_k_collector.h | 107 +++++
.../index/inverted/query_v2/composite_reader.h | 12 +
.../inverted/query_v2/match_all_docs_scorer.h | 12 +
.../phrase_prefix_query/phrase_prefix_weight.h | 4 +-
.../query_v2/phrase_query/multi_phrase_weight.h | 7 +-
.../inverted/query_v2/phrase_query/phrase_weight.h | 2 +-
.../inverted/query_v2/prefix_query/prefix_weight.h | 15 +-
.../query_v2/regexp_query/regexp_weight.cpp | 24 +-
.../inverted/query_v2/regexp_query/regexp_weight.h | 2 +-
be/src/storage/index/inverted/query_v2/scorer.h | 1 +
.../index/inverted/query_v2/segment_postings.h | 88 +++-
.../inverted/query_v2/term_query/term_scorer.h | 5 +
.../inverted/query_v2/term_query/term_weight.h | 43 +-
.../index/inverted/query_v2/wand/block_wand.h | 286 ++++++++++++
be/src/storage/index/inverted/query_v2/weight.h | 68 ++-
.../query_v2/wildcard_query/wildcard_weight.h | 2 +-
.../index/inverted/similarity/bm25_similarity.cpp | 9 +
.../index/inverted/similarity/bm25_similarity.h | 1 +
.../storage/index/inverted/similarity/similarity.h | 1 +
be/src/storage/predicate_collector.cpp | 263 +++++++++++
be/src/storage/predicate_collector.h | 87 ++++
be/src/storage/segment/segment_iterator.cpp | 3 +
.../compaction/collection_statistics_test.cpp | 29 --
.../index/inverted/query/query_helper_test.cpp | 2 +
.../inverted/query_v2/occur_boolean_query_test.cpp | 3 +
.../inverted/query_v2/segment_postings_test.cpp | 40 +-
.../inverted/query_v2/top_k_collector_test.cpp | 490 +++++++++++++++++++++
.../inverted/query_v2/union_postings_test.cpp | 8 +-
contrib/clucene | 2 +-
.../rewrite/PushDownScoreTopNIntoOlapScan.java | 12 +-
.../java/org/apache/doris/qe/SessionVariable.java | 5 +
gensrc/thrift/PaloInternalService.thrift | 2 +
.../inverted_index_p0/test_bm25_score.groovy | 2 +-
47 files changed, 1907 insertions(+), 280 deletions(-)
diff --git a/be/src/exprs/function/function_search.cpp
b/be/src/exprs/function/function_search.cpp
index 37849dd246d..de09e85c561 100644
--- a/be/src/exprs/function/function_search.cpp
+++ b/be/src/exprs/function/function_search.cpp
@@ -49,6 +49,8 @@
#include "storage/index/inverted/query_v2/bit_set_query/bit_set_query.h"
#include
"storage/index/inverted/query_v2/boolean_query/boolean_query_builder.h"
#include "storage/index/inverted/query_v2/boolean_query/operator.h"
+#include "storage/index/inverted/query_v2/collect/doc_set_collector.h"
+#include "storage/index/inverted/query_v2/collect/top_k_collector.h"
#include "storage/index/inverted/query_v2/phrase_query/multi_phrase_query.h"
#include "storage/index/inverted/query_v2/phrase_query/phrase_query.h"
#include "storage/index/inverted/query_v2/regexp_query/regexp_query.h"
@@ -377,7 +379,8 @@ Status
FunctionSearch::evaluate_inverted_index_with_search_param(
std::unordered_map<std::string, IndexIterator*> iterators, uint32_t
num_rows,
InvertedIndexResultBitmap& bitmap_result, bool enable_cache,
const IndexExecContext* index_exec_ctx,
- const std::unordered_map<std::string, int>& field_name_to_column_id)
const {
+ const std::unordered_map<std::string, int>& field_name_to_column_id,
+ const std::shared_ptr<IndexQueryContext>& index_query_context) const {
const bool is_nested_query = search_param.root.clause_type == "NESTED";
if (is_nested_query && !is_nested_group_search_supported()) {
return Status::NotSupported(
@@ -431,9 +434,14 @@ Status
FunctionSearch::evaluate_inverted_index_with_search_param(
}
}
- auto context = std::make_shared<IndexQueryContext>();
- context->collection_statistics = std::make_shared<CollectionStatistics>();
- context->collection_similarity = std::make_shared<CollectionSimilarity>();
+ std::shared_ptr<IndexQueryContext> context;
+ if (index_query_context) {
+ context = index_query_context;
+ } else {
+ context = std::make_shared<IndexQueryContext>();
+ context->collection_statistics =
std::make_shared<CollectionStatistics>();
+ context->collection_similarity =
std::make_shared<CollectionSimilarity>();
+ }
// NESTED() queries evaluate predicates on the flattened "element space"
of a nested group.
// For VARIANT nested groups, the indexed lucene field (stored_field_name)
uses:
@@ -551,43 +559,52 @@ Status
FunctionSearch::evaluate_inverted_index_with_search_param(
query_v2::QueryExecutionContext exec_ctx =
build_query_execution_context(num_rows, resolver, &null_resolver);
- auto weight = root_query->weight(false);
- if (!weight) {
- LOG(WARNING) << "search: Failed to build query weight";
- bitmap_result =
InvertedIndexResultBitmap(std::make_shared<roaring::Roaring>(),
-
std::make_shared<roaring::Roaring>());
- return Status::OK();
+ bool enable_scoring = false;
+ bool is_asc = false;
+ size_t top_k = 0;
+ if (index_query_context) {
+ enable_scoring = index_query_context->collection_similarity != nullptr;
+ is_asc = index_query_context->is_asc;
+ top_k = index_query_context->query_limit;
}
- auto scorer = weight->scorer(exec_ctx, root_binding_key);
- if (!scorer) {
- LOG(WARNING) << "search: Failed to build scorer";
+ auto weight = root_query->weight(enable_scoring);
+ if (!weight) {
+ LOG(WARNING) << "search: Failed to build query weight";
bitmap_result =
InvertedIndexResultBitmap(std::make_shared<roaring::Roaring>(),
std::make_shared<roaring::Roaring>());
return Status::OK();
}
std::shared_ptr<roaring::Roaring> roaring =
std::make_shared<roaring::Roaring>();
- uint32_t doc = scorer->doc();
- uint32_t matched_docs = 0;
- while (doc != query_v2::TERMINATED) {
- roaring->add(doc);
- ++matched_docs;
- doc = scorer->advance();
+ if (enable_scoring && !is_asc && top_k > 0) {
+ bool use_wand = index_query_context->runtime_state != nullptr &&
+ index_query_context->runtime_state->query_options()
+ .enable_inverted_index_wand_query;
+ query_v2::collect_multi_segment_top_k(weight, exec_ctx,
root_binding_key, top_k, roaring,
+
index_query_context->collection_similarity, use_wand);
+ } else {
+ query_v2::collect_multi_segment_doc_set(
+ weight, exec_ctx, root_binding_key, roaring,
+ index_query_context ?
index_query_context->collection_similarity : nullptr,
+ enable_scoring);
}
- VLOG_DEBUG << "search: Query completed, matched " << matched_docs << "
documents";
+ VLOG_DEBUG << "search: Query completed, matched " <<
roaring->cardinality() << " documents";
// Extract NULL bitmap from three-valued logic scorer
// The scorer correctly computes which documents evaluate to NULL based on
query logic
// For example: TRUE OR NULL = TRUE (not NULL), FALSE OR NULL = NULL
std::shared_ptr<roaring::Roaring> null_bitmap =
std::make_shared<roaring::Roaring>();
- if (scorer->has_null_bitmap(exec_ctx.null_resolver)) {
- const auto* bitmap = scorer->get_null_bitmap(exec_ctx.null_resolver);
- if (bitmap != nullptr) {
- *null_bitmap = *bitmap;
- VLOG_TRACE << "search: Extracted NULL bitmap with " <<
null_bitmap->cardinality()
- << " NULL documents";
+ if (exec_ctx.null_resolver) {
+ auto scorer = weight->scorer(exec_ctx, root_binding_key);
+ if (scorer && scorer->has_null_bitmap(exec_ctx.null_resolver)) {
+ const auto* bitmap =
scorer->get_null_bitmap(exec_ctx.null_resolver);
+ if (bitmap != nullptr) {
+ *null_bitmap = *bitmap;
+ VLOG_TRACE << "search: Extracted NULL bitmap with " <<
null_bitmap->cardinality()
+ << " NULL documents";
+ }
}
}
diff --git a/be/src/exprs/function/function_search.h
b/be/src/exprs/function/function_search.h
index 216d84fb3da..376e1aa0728 100644
--- a/be/src/exprs/function/function_search.h
+++ b/be/src/exprs/function/function_search.h
@@ -174,7 +174,8 @@ public:
std::unordered_map<std::string, IndexIterator*> iterators,
uint32_t num_rows,
InvertedIndexResultBitmap& bitmap_result, bool enable_cache,
const IndexExecContext* index_exec_ctx,
- const std::unordered_map<std::string, int>&
field_name_to_column_id) const;
+ const std::unordered_map<std::string, int>&
field_name_to_column_id,
+ const std::shared_ptr<IndexQueryContext>& index_query_context =
nullptr) const;
Status evaluate_nested_query(
const TSearchParam& search_param, const TSearchClause&
nested_clause,
diff --git a/be/src/exprs/vexpr_context.h b/be/src/exprs/vexpr_context.h
index 98eed2a7604..5b6aaafdddd 100644
--- a/be/src/exprs/vexpr_context.h
+++ b/be/src/exprs/vexpr_context.h
@@ -198,6 +198,14 @@ public:
return iter->second.get();
}
+ void set_index_query_context(segment_v2::IndexQueryContextPtr
index_query_context) {
+ _index_query_context = index_query_context;
+ }
+
+ const segment_v2::IndexQueryContextPtr& get_index_query_context() const {
+ return _index_query_context;
+ }
+
private:
// A reference to a vector of column IDs for the current expression's
output columns.
const std::vector<ColumnId>& _col_ids;
@@ -224,6 +232,7 @@ private:
segment_v2::Segment* _segment = nullptr; // Ref
segment_v2::ColumnIteratorOptions _column_iter_opts;
+ segment_v2::IndexQueryContextPtr _index_query_context;
};
class VExprContext {
diff --git a/be/src/exprs/vsearch.cpp b/be/src/exprs/vsearch.cpp
index 45a0e3717ca..f4ed11e95fc 100644
--- a/be/src/exprs/vsearch.cpp
+++ b/be/src/exprs/vsearch.cpp
@@ -258,11 +258,14 @@ Status VSearchExpr::evaluate_inverted_index(VExprContext*
context, uint32_t segm
return Status::OK();
}
+ auto index_query_context = index_context->get_index_query_context();
+
auto function = std::make_shared<FunctionSearch>();
auto result_bitmap = InvertedIndexResultBitmap();
auto status = function->evaluate_inverted_index_with_search_param(
_search_param, bundle.field_types, bundle.iterators,
segment_num_rows, result_bitmap,
- _enable_cache, index_context.get(),
bundle.field_name_to_column_id);
+ _enable_cache, index_context.get(), bundle.field_name_to_column_id,
+ index_query_context);
if (!status.ok()) {
LOG(WARNING) << "VSearchExpr: Function evaluation failed: " <<
status.to_string();
diff --git a/be/src/storage/compaction/collection_statistics.cpp
b/be/src/storage/compaction/collection_statistics.cpp
index 4752237e764..16decc2b15f 100644
--- a/be/src/storage/compaction/collection_statistics.cpp
+++ b/be/src/storage/compaction/collection_statistics.cpp
@@ -29,6 +29,7 @@
#include "storage/index/index_reader_helper.h"
#include "storage/index/inverted/analyzer/analyzer.h"
#include "storage/index/inverted/util/string_helper.h"
+#include "storage/index/inverted/util/term_iterator.h"
#include "storage/rowset/rowset.h"
#include "storage/rowset/rowset_reader.h"
#include "util/uid_util.h"
@@ -109,94 +110,15 @@ Status CollectionStatistics::collect(RuntimeState* state,
return Status::OK();
}
-VSlotRef* find_slot_ref(const VExprSPtr& expr) {
- if (!expr) return nullptr;
- auto cur = VExpr::expr_without_cast(expr);
- if (cur->node_type() == TExprNodeType::SLOT_REF) {
- return static_cast<VSlotRef*>(cur.get());
- }
- for (auto& ch : cur->children()) {
- if (auto* s = find_slot_ref(ch)) return s;
- }
- return nullptr;
-}
-
-Status handle_match_pred(RuntimeState* state, const TabletSchemaSPtr&
tablet_schema,
- const VExprSPtr& expr,
- std::unordered_map<std::wstring, CollectInfo>*
collect_infos) {
- auto* left_slot_ref = find_slot_ref(expr->children()[0]);
- if (left_slot_ref == nullptr) {
- return Status::Error<ErrorCode::INVERTED_INDEX_NOT_SUPPORTED>(
- "Index statistics collection failed: Cannot find slot
reference in match predicate "
- "left expression");
- }
- auto* right_literal = static_cast<VLiteral*>(expr->children()[1].get());
- DCHECK(right_literal != nullptr);
-
- const auto* sd =
state->desc_tbl().get_slot_descriptor(left_slot_ref->slot_id());
- if (sd == nullptr) {
- return Status::Error<ErrorCode::INVERTED_INDEX_NOT_SUPPORTED>(
- "Index statistics collection failed: Cannot find slot
descriptor for slot_id={}",
- left_slot_ref->slot_id());
- }
- int32_t col_idx = tablet_schema->field_index(left_slot_ref->column_name());
- if (col_idx == -1) {
- return Status::Error<ErrorCode::INVERTED_INDEX_NOT_SUPPORTED>(
- "Index statistics collection failed: Cannot find column index
for column={}",
- left_slot_ref->column_name());
- }
-
- const auto& column = tablet_schema->column(col_idx);
- auto index_metas = tablet_schema->inverted_indexs(sd->col_unique_id(),
column.suffix_path());
-#ifndef BE_TEST
- if (index_metas.empty()) {
- return Status::Error<ErrorCode::INVERTED_INDEX_NOT_SUPPORTED>(
- "Index statistics collection failed: Score query is not
supported without inverted "
- "index for column={}",
- left_slot_ref->column_name());
- }
-#endif
-
- auto format_options = DataTypeSerDe::get_default_format_options();
- format_options.timezone = &state->timezone_obj();
- for (const auto* index_meta : index_metas) {
- if (!InvertedIndexAnalyzer::should_analyzer(index_meta->properties()))
{
- continue;
- }
- if
(!segment_v2::IndexReaderHelper::is_need_similarity_score(expr->op(),
index_meta)) {
- continue;
- }
-
- auto term_infos = InvertedIndexAnalyzer::get_analyse_result(
- right_literal->value(format_options),
index_meta->properties());
- if (term_infos.empty()) {
- LOG(WARNING) << "Index statistics collection: no terms extracted
from literal value, "
- << "col_unique_id=" <<
index_meta->col_unique_ids()[0];
- continue;
- }
-
- std::string field_name =
std::to_string(index_meta->col_unique_ids()[0]);
- if (!column.suffix_path().empty()) {
- field_name += "." + column.suffix_path();
- }
- std::wstring ws_field_name = StringHelper::to_wstring(field_name);
- auto iter = collect_infos->find(ws_field_name);
- if (iter == collect_infos->end()) {
- CollectInfo collect_info;
- collect_info.term_infos.insert(term_infos.begin(),
term_infos.end());
- collect_info.index_meta = index_meta;
- (*collect_infos)[ws_field_name] = std::move(collect_info);
- } else {
- iter->second.term_infos.insert(term_infos.begin(),
term_infos.end());
- }
- }
- return Status::OK();
-}
-
Status CollectionStatistics::extract_collect_info(
RuntimeState* state, const VExprContextSPtrs&
common_expr_ctxs_push_down,
- const TabletSchemaSPtr& tablet_schema,
- std::unordered_map<std::wstring, CollectInfo>* collect_infos) {
+ const TabletSchemaSPtr& tablet_schema, CollectInfoMap* collect_infos) {
+ DCHECK(collect_infos != nullptr);
+
+ std::unordered_map<TExprNodeType::type, PredicateCollectorPtr> collectors;
+ collectors[TExprNodeType::MATCH_PRED] =
std::make_unique<MatchPredicateCollector>();
+ collectors[TExprNodeType::SEARCH_EXPR] =
std::make_unique<SearchPredicateCollector>();
+
for (const auto& root_expr_ctx : common_expr_ctxs_push_down) {
const auto& root_expr = root_expr_ctx->root();
if (root_expr == nullptr) {
@@ -207,27 +129,35 @@ Status CollectionStatistics::extract_collect_info(
stack.emplace(root_expr);
while (!stack.empty()) {
- const auto& expr = stack.top();
+ auto expr = stack.top();
stack.pop();
- if (expr->node_type() == TExprNodeType::MATCH_PRED) {
- RETURN_IF_ERROR(handle_match_pred(state, tablet_schema, expr,
collect_infos));
+ if (!expr) {
+ continue;
+ }
+
+ auto collector_it = collectors.find(expr->node_type());
+ if (collector_it != collectors.end()) {
+ RETURN_IF_ERROR(
+ collector_it->second->collect(state, tablet_schema,
expr, collect_infos));
}
const auto& children = expr->children();
- for (int32_t i = static_cast<int32_t>(children.size()) - 1; i >=
0; --i) {
- if (!children[i]->children().empty()) {
- stack.emplace(children[i]);
- }
+ for (const auto& child : children) {
+ stack.push(child);
}
}
}
+
+ LOG(INFO) << "Extracted collect info for " << collect_infos->size() << "
fields";
+
return Status::OK();
}
-Status CollectionStatistics::process_segment(
- const RowsetSharedPtr& rowset, int32_t seg_id, const TabletSchema*
tablet_schema,
- const std::unordered_map<std::wstring, CollectInfo>& collect_infos,
io::IOContext* io_ctx) {
+Status CollectionStatistics::process_segment(const RowsetSharedPtr& rowset,
int32_t seg_id,
+ const TabletSchema* tablet_schema,
+ const CollectInfoMap&
collect_infos,
+ io::IOContext* io_ctx) {
auto seg_path = DORIS_TRY(rowset->segment_path(seg_id));
auto rowset_meta = rowset->rowset_meta();
@@ -239,36 +169,42 @@ Status CollectionStatistics::process_segment(
RETURN_IF_ERROR(idx_file_reader->init(config::inverted_index_read_buffer_size,
io_ctx));
int32_t total_seg_num_docs = 0;
+
for (const auto& [ws_field_name, collect_info] : collect_infos) {
+ lucene::search::IndexSearcher* index_searcher = nullptr;
+ lucene::index::IndexReader* index_reader = nullptr;
+
#ifdef BE_TEST
auto compound_reader =
DORIS_TRY(idx_file_reader->open(collect_info.index_meta, io_ctx));
auto* reader = lucene::index::IndexReader::open(compound_reader.get());
- auto index_searcher =
std::make_shared<lucene::search::IndexSearcher>(reader, true);
-
- auto* index_reader = index_searcher->getReader();
+ auto searcher_ptr =
std::make_shared<lucene::search::IndexSearcher>(reader, true);
+ index_searcher = searcher_ptr.get();
+ index_reader = index_searcher->getReader();
#else
InvertedIndexCacheHandle inverted_index_cache_handle;
auto index_file_key =
idx_file_reader->get_index_file_cache_key(collect_info.index_meta);
InvertedIndexSearcherCache::CacheKey
searcher_cache_key(index_file_key);
+
if (!InvertedIndexSearcherCache::instance()->lookup(searcher_cache_key,
&inverted_index_cache_handle)) {
auto compound_reader =
DORIS_TRY(idx_file_reader->open(collect_info.index_meta,
io_ctx));
auto* reader =
lucene::index::IndexReader::open(compound_reader.get());
size_t reader_size = reader->getTermInfosRAMUsed();
- auto index_searcher =
std::make_shared<lucene::search::IndexSearcher>(reader, true);
+ auto searcher_ptr =
std::make_shared<lucene::search::IndexSearcher>(reader, true);
auto* cache_value = new InvertedIndexSearcherCache::CacheValue(
- std::move(index_searcher), reader_size, UnixMillis());
+ std::move(searcher_ptr), reader_size, UnixMillis());
InvertedIndexSearcherCache::instance()->insert(searcher_cache_key,
cache_value,
&inverted_index_cache_handle);
}
auto searcher_variant =
inverted_index_cache_handle.get_index_searcher();
- auto index_searcher =
std::get<FulltextIndexSearcherPtr>(searcher_variant);
- auto* index_reader = index_searcher->getReader();
+ auto index_searcher_ptr =
std::get<FulltextIndexSearcherPtr>(searcher_variant);
+ index_searcher = index_searcher_ptr.get();
+ index_reader = index_searcher->getReader();
#endif
-
total_seg_num_docs = std::max(total_seg_num_docs,
index_reader->maxDoc());
+
_total_num_tokens[ws_field_name] +=
index_reader->sumTotalTermFreq(ws_field_name.c_str()).value_or(0);
@@ -278,7 +214,9 @@ Status CollectionStatistics::process_segment(
_term_doc_freqs[ws_field_name][iter->term()] += iter->doc_freq();
}
}
+
_total_num_docs += total_seg_num_docs;
+
return Status::OK();
}
diff --git a/be/src/storage/compaction/collection_statistics.h
b/be/src/storage/compaction/collection_statistics.h
index 0b4bcc18d8d..e51d5750db6 100644
--- a/be/src/storage/compaction/collection_statistics.h
+++ b/be/src/storage/compaction/collection_statistics.h
@@ -25,6 +25,7 @@
#include "runtime/runtime_state.h"
#include "storage/index/inverted/query/query_info.h"
#include "storage/olap_common.h"
+#include "storage/predicate_collector.h"
namespace doris {
#include "common/compile_check_begin.h"
@@ -44,18 +45,6 @@ class TabletIndex;
class TabletSchema;
using TabletSchemaSPtr = std::shared_ptr<TabletSchema>;
-struct TermInfoComparer {
- bool operator()(const segment_v2::TermInfo& lhs, const
segment_v2::TermInfo& rhs) const {
- return lhs.term < rhs.term;
- }
-};
-
-class CollectInfo {
-public:
- std::set<segment_v2::TermInfo, TermInfoComparer> term_infos;
- const TabletIndex* index_meta = nullptr;
-};
-
class CollectionStatistics {
public:
CollectionStatistics() = default;
@@ -73,10 +62,9 @@ private:
Status extract_collect_info(RuntimeState* state,
const VExprContextSPtrs&
common_expr_ctxs_push_down,
const TabletSchemaSPtr& tablet_schema,
- std::unordered_map<std::wstring, CollectInfo>*
collect_infos);
+ CollectInfoMap* collect_infos);
Status process_segment(const RowsetSharedPtr& rowset, int32_t seg_id,
- const TabletSchema* tablet_schema,
- const std::unordered_map<std::wstring,
CollectInfo>& collect_infos,
+ const TabletSchema* tablet_schema, const
CollectInfoMap& collect_infos,
io::IOContext* io_ctx);
uint64_t get_term_doc_freq_by_col(const std::wstring& lucene_col_name,
@@ -94,6 +82,7 @@ private:
MOCK_DEFINE(friend class BM25SimilarityTest;)
MOCK_DEFINE(friend class CollectionStatisticsTest;)
MOCK_DEFINE(friend class BooleanQueryTest;)
+ MOCK_DEFINE(friend class OccurBooleanQueryTest;)
};
using CollectionStatisticsPtr = std::shared_ptr<CollectionStatistics>;
diff --git a/be/src/storage/index/index_iterator.h
b/be/src/storage/index/index_iterator.h
index b4d15c23187..becb5b56ea5 100644
--- a/be/src/storage/index/index_iterator.h
+++ b/be/src/storage/index/index_iterator.h
@@ -57,6 +57,7 @@ public:
virtual Result<bool> has_null() = 0;
void set_context(const IndexQueryContextPtr& context) { _context =
context; }
+ IndexQueryContextPtr get_context() const { return _context; }
protected:
IndexQueryContextPtr _context = nullptr;
diff --git a/be/src/storage/index/index_query_context.h
b/be/src/storage/index/index_query_context.h
index fdd48d7c9f9..9afa6ea504e 100644
--- a/be/src/storage/index/index_query_context.h
+++ b/be/src/storage/index/index_query_context.h
@@ -30,6 +30,9 @@ struct IndexQueryContext {
CollectionStatisticsPtr collection_statistics;
CollectionSimilarityPtr collection_similarity;
+
+ size_t query_limit = 0;
+ bool is_asc = false;
};
using IndexQueryContextPtr = std::shared_ptr<IndexQueryContext>;
diff --git
a/be/src/storage/index/inverted/query_v2/boolean_query/occur_boolean_weight.cpp
b/be/src/storage/index/inverted/query_v2/boolean_query/occur_boolean_weight.cpp
index ff0594673af..2789ac9474c 100644
---
a/be/src/storage/index/inverted/query_v2/boolean_query/occur_boolean_weight.cpp
+++
b/be/src/storage/index/inverted/query_v2/boolean_query/occur_boolean_weight.cpp
@@ -45,6 +45,12 @@ OccurBooleanWeight<ScoreCombinerPtrT>::OccurBooleanWeight(
template <typename ScoreCombinerPtrT>
ScorerPtr OccurBooleanWeight<ScoreCombinerPtrT>::scorer(const
QueryExecutionContext& context) {
+ return scorer(context, {});
+}
+
+template <typename ScoreCombinerPtrT>
+ScorerPtr OccurBooleanWeight<ScoreCombinerPtrT>::scorer(const
QueryExecutionContext& context,
+ const std::string&
binding_key) {
if (_sub_weights.empty()) {
return std::make_shared<EmptyScorer>();
}
@@ -53,27 +59,28 @@ ScorerPtr
OccurBooleanWeight<ScoreCombinerPtrT>::scorer(const QueryExecutionCont
if (occur == Occur::MUST_NOT) {
return std::make_shared<EmptyScorer>();
}
- return weight->scorer(context);
+ return weight->scorer(context, binding_key);
}
_max_doc = context.segment_num_rows;
if (_enable_scoring) {
- auto specialized = complex_scorer(context, _score_combiner);
+ auto specialized = complex_scorer(context, _score_combiner,
binding_key);
return into_box_scorer(std::move(specialized), _score_combiner);
} else {
auto combiner = std::make_shared<DoNothingCombiner>();
- auto specialized = complex_scorer(context, combiner);
+ auto specialized = complex_scorer(context, combiner, binding_key);
return into_box_scorer(std::move(specialized), combiner);
}
}
template <typename ScoreCombinerPtrT>
std::unordered_map<Occur, std::vector<ScorerPtr>>
-OccurBooleanWeight<ScoreCombinerPtrT>::per_occur_scorers(const
QueryExecutionContext& context) {
+OccurBooleanWeight<ScoreCombinerPtrT>::per_occur_scorers(const
QueryExecutionContext& context,
+ const std::string&
binding_key) {
std::unordered_map<Occur, std::vector<ScorerPtr>> result;
for (size_t i = 0; i < _sub_weights.size(); ++i) {
const auto& [occur, weight] = _sub_weights[i];
- const auto& binding_key = _binding_keys[i];
- auto sub_scorer = weight->scorer(context, binding_key);
+ const auto& key = _binding_keys[i].empty() ? binding_key :
_binding_keys[i];
+ auto sub_scorer = weight->scorer(context, key);
if (sub_scorer) {
result[occur].push_back(std::move(sub_scorer));
}
@@ -217,8 +224,8 @@ SpecializedScorer
OccurBooleanWeight<ScoreCombinerPtrT>::build_positive_opt(
template <typename ScoreCombinerPtrT>
template <typename CombinerT>
SpecializedScorer OccurBooleanWeight<ScoreCombinerPtrT>::complex_scorer(
- const QueryExecutionContext& context, CombinerT combiner) {
- auto scorers_by_occur = per_occur_scorers(context);
+ const QueryExecutionContext& context, CombinerT combiner, const
std::string& binding_key) {
+ auto scorers_by_occur = per_occur_scorers(context, binding_key);
auto must_scorers = std::move(scorers_by_occur[Occur::MUST]);
auto should_scorers = std::move(scorers_by_occur[Occur::SHOULD]);
auto must_not_scorers = std::move(scorers_by_occur[Occur::MUST_NOT]);
diff --git
a/be/src/storage/index/inverted/query_v2/boolean_query/occur_boolean_weight.h
b/be/src/storage/index/inverted/query_v2/boolean_query/occur_boolean_weight.h
index e9f7708991e..d3157c81473 100644
---
a/be/src/storage/index/inverted/query_v2/boolean_query/occur_boolean_weight.h
+++
b/be/src/storage/index/inverted/query_v2/boolean_query/occur_boolean_weight.h
@@ -22,6 +22,7 @@
#include "storage/index/inverted/query_v2/boolean_query/occur.h"
#include "storage/index/inverted/query_v2/scorer.h"
#include "storage/index/inverted/query_v2/term_query/term_scorer.h"
+#include "storage/index/inverted/query_v2/wand/block_wand.h"
#include "storage/index/inverted/query_v2/weight.h"
namespace doris::segment_v2::inverted_index::query_v2 {
@@ -51,14 +52,21 @@ public:
~OccurBooleanWeight() override = default;
ScorerPtr scorer(const QueryExecutionContext& context) override;
+ ScorerPtr scorer(const QueryExecutionContext& context, const std::string&
binding_key) override;
+
+ void for_each_pruning(const QueryExecutionContext& context, float
threshold,
+ PruningCallback callback) override;
+ void for_each_pruning(const QueryExecutionContext& context, const
std::string& binding_key,
+ float threshold, PruningCallback callback) override;
private:
std::unordered_map<Occur, std::vector<ScorerPtr>> per_occur_scorers(
- const QueryExecutionContext& context);
+ const QueryExecutionContext& context, const std::string&
binding_key = {});
AllAndEmptyScorerCounts
remove_and_count_all_and_empty_scorers(std::vector<ScorerPtr>& scorers);
template <typename CombinerT>
- SpecializedScorer complex_scorer(const QueryExecutionContext& context,
CombinerT combiner);
+ SpecializedScorer complex_scorer(const QueryExecutionContext& context,
CombinerT combiner,
+ const std::string& binding_key = {});
template <typename CombinerT>
std::optional<CombinationMethod> build_should_opt(std::vector<ScorerPtr>&
must_scorers,
@@ -101,4 +109,35 @@ private:
uint32_t _max_doc = 0;
};
+template <typename ScoreCombinerPtrT>
+void OccurBooleanWeight<ScoreCombinerPtrT>::for_each_pruning(const
QueryExecutionContext& context,
+ float threshold,
+ PruningCallback
callback) {
+ for_each_pruning(context, {}, threshold, std::move(callback));
+}
+
+template <typename ScoreCombinerPtrT>
+void OccurBooleanWeight<ScoreCombinerPtrT>::for_each_pruning(const
QueryExecutionContext& context,
+ const
std::string& binding_key,
+ float threshold,
+ PruningCallback
callback) {
+ if (_sub_weights.empty()) {
+ return;
+ }
+
+ _max_doc = context.segment_num_rows;
+ auto specialized = complex_scorer(context, _score_combiner, binding_key);
+
+ std::visit(
+ [&](auto&& arg) {
+ using T = std::decay_t<decltype(arg)>;
+ if constexpr (std::is_same_v<T, std::vector<TermScorerPtr>>) {
+ block_wand(std::move(arg), threshold, std::move(callback));
+ } else {
+ for_each_pruning_scorer(std::move(arg), threshold,
std::move(callback));
+ }
+ },
+ std::move(specialized));
+}
+
} // namespace doris::segment_v2::inverted_index::query_v2
\ No newline at end of file
diff --git
a/be/src/storage/index/inverted/query_v2/collect/doc_set_collector.cpp
b/be/src/storage/index/inverted/query_v2/collect/doc_set_collector.cpp
new file mode 100644
index 00000000000..08cdf1e5a2a
--- /dev/null
+++ b/be/src/storage/index/inverted/query_v2/collect/doc_set_collector.cpp
@@ -0,0 +1,47 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "storage/index/inverted/query_v2/collect/doc_set_collector.h"
+
+#include "storage/index/inverted/query_v2/collect/multi_segment_util.h"
+
+namespace doris::segment_v2::inverted_index::query_v2 {
+
+void collect_multi_segment_doc_set(const WeightPtr& weight, const
QueryExecutionContext& context,
+ const std::string& binding_key,
+ const std::shared_ptr<roaring::Roaring>&
roaring,
+ const CollectionSimilarityPtr& similarity,
bool enable_scoring) {
+ for_each_index_segment(context, binding_key,
+ [&](const QueryExecutionContext& seg_ctx, uint32_t
doc_base) {
+ auto scorer = weight->scorer(seg_ctx,
binding_key);
+ if (!scorer) {
+ return;
+ }
+
+ uint32_t doc = scorer->doc();
+ while (doc != TERMINATED) {
+ uint32_t global_doc = doc + doc_base;
+ roaring->add(global_doc);
+ if (enable_scoring && similarity) {
+ similarity->collect(global_doc,
scorer->score());
+ }
+ doc = scorer->advance();
+ }
+ });
+}
+
+} // namespace doris::segment_v2::inverted_index::query_v2
diff --git a/be/src/storage/index/index_query_context.h
b/be/src/storage/index/inverted/query_v2/collect/doc_set_collector.h
similarity index 60%
copy from be/src/storage/index/index_query_context.h
copy to be/src/storage/index/inverted/query_v2/collect/doc_set_collector.h
index fdd48d7c9f9..bad94d724bc 100644
--- a/be/src/storage/index/index_query_context.h
+++ b/be/src/storage/index/inverted/query_v2/collect/doc_set_collector.h
@@ -17,21 +17,17 @@
#pragma once
-#include "storage/compaction/collection_similarity.h"
-#include "storage/compaction/collection_statistics.h"
+#include <memory>
+#include <roaring/roaring.hh>
-namespace doris::segment_v2 {
-#include "common/compile_check_begin.h"
+#include "storage/compaction/collection_similarity.h"
+#include "storage/index/inverted/query_v2/weight.h"
-struct IndexQueryContext {
- io::IOContext* io_ctx = nullptr;
- OlapReaderStatistics* stats = nullptr;
- RuntimeState* runtime_state = nullptr;
+namespace doris::segment_v2::inverted_index::query_v2 {
- CollectionStatisticsPtr collection_statistics;
- CollectionSimilarityPtr collection_similarity;
-};
-using IndexQueryContextPtr = std::shared_ptr<IndexQueryContext>;
+void collect_multi_segment_doc_set(const WeightPtr& weight, const
QueryExecutionContext& context,
+ const std::string& binding_key,
+ const std::shared_ptr<roaring::Roaring>&
roaring,
+ const CollectionSimilarityPtr& similarity,
bool enable_scoring);
-#include "common/compile_check_end.h"
-} // namespace doris::segment_v2
\ No newline at end of file
+} // namespace doris::segment_v2::inverted_index::query_v2
diff --git
a/be/src/storage/index/inverted/query_v2/collect/multi_segment_util.h
b/be/src/storage/index/inverted/query_v2/collect/multi_segment_util.h
new file mode 100644
index 00000000000..0519fe172c9
--- /dev/null
+++ b/be/src/storage/index/inverted/query_v2/collect/multi_segment_util.h
@@ -0,0 +1,97 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "storage/index/inverted/query_v2/weight.h"
+
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wshadow-field"
+#pragma clang diagnostic ignored "-Woverloaded-virtual"
+#pragma clang diagnostic ignored "-Winconsistent-missing-override"
+#pragma clang diagnostic ignored "-Wreorder-ctor"
+#pragma clang diagnostic ignored "-Wshorten-64-to-32"
+#elif defined(__GNUC__)
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Woverloaded-virtual"
+#endif
+#include "CLucene.h"
+#include "CLucene/index/_MultiSegmentReader.h"
+#ifdef __clang__
+#pragma clang diagnostic pop
+#elif defined(__GNUC__)
+#pragma GCC diagnostic pop
+#endif
+
+namespace doris::segment_v2::inverted_index::query_v2 {
+
+inline QueryExecutionContext
create_segment_context(lucene::index::IndexReader* seg_reader,
+ const
QueryExecutionContext& original_ctx,
+ const std::string&
binding_key) {
+ QueryExecutionContext seg_ctx;
+ seg_ctx.segment_num_rows = seg_reader->numDocs();
+
+ auto reader_ptr = std::shared_ptr<lucene::index::IndexReader>(
+ seg_reader, [](lucene::index::IndexReader*) {});
+ seg_ctx.readers.push_back(reader_ptr);
+
+ if (!binding_key.empty()) {
+ seg_ctx.reader_bindings[binding_key] = reader_ptr;
+ }
+
+ seg_ctx.binding_fields = original_ctx.binding_fields;
+ seg_ctx.null_resolver = original_ctx.null_resolver;
+
+ return seg_ctx;
+}
+
+template <typename SegmentCallback>
+void for_each_index_segment(const QueryExecutionContext& context, const
std::string& binding_key,
+ SegmentCallback&& callback) {
+ auto* reader = context.readers.empty() ? nullptr :
context.readers.front().get();
+ if (!reader) {
+ // No reader available (e.g., AllQuery/MatchAllDocsQuery which doesn't
resolve fields).
+ // Fall back to using the original context directly, as AllScorer only
needs segment_num_rows.
+ if (context.segment_num_rows > 0) {
+ callback(context, 0);
+ }
+ return;
+ }
+
+ auto* multi_reader =
dynamic_cast<lucene::index::MultiSegmentReader*>(reader);
+ if (multi_reader == nullptr) {
+ callback(context, 0);
+ return;
+ }
+
+ const auto* sub_readers = multi_reader->getSubReaders();
+ const auto* starts = multi_reader->getStarts();
+
+ if (!sub_readers || sub_readers->length == 0) {
+ return;
+ }
+
+ for (size_t i = 0; i < sub_readers->length; ++i) {
+ auto* seg_reader = (*sub_readers)[i];
+ auto seg_base = static_cast<uint32_t>(starts[i]);
+ QueryExecutionContext seg_ctx = create_segment_context(seg_reader,
context, binding_key);
+ callback(seg_ctx, seg_base);
+ }
+}
+
+} // namespace doris::segment_v2::inverted_index::query_v2
diff --git a/be/src/storage/index/inverted/query_v2/collect/top_k_collector.cpp
b/be/src/storage/index/inverted/query_v2/collect/top_k_collector.cpp
new file mode 100644
index 00000000000..8ac21495dae
--- /dev/null
+++ b/be/src/storage/index/inverted/query_v2/collect/top_k_collector.cpp
@@ -0,0 +1,61 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "storage/index/inverted/query_v2/collect/top_k_collector.h"
+
+#include "storage/index/inverted/query_v2/collect/multi_segment_util.h"
+
+namespace doris::segment_v2::inverted_index::query_v2 {
+
+void collect_multi_segment_top_k(const WeightPtr& weight, const
QueryExecutionContext& context,
+ const std::string& binding_key, size_t k,
+ const std::shared_ptr<roaring::Roaring>&
roaring,
+ const CollectionSimilarityPtr& similarity,
bool use_wand) {
+ TopKCollector final_collector(k);
+
+ for_each_index_segment(
+ context, binding_key, [&](const QueryExecutionContext& seg_ctx,
uint32_t seg_base) {
+ float initial_threshold = final_collector.threshold();
+
+ TopKCollector seg_collector(k);
+ auto callback = [&seg_collector](uint32_t doc_id, float score)
-> float {
+ return seg_collector.collect(doc_id, score);
+ };
+
+ if (use_wand) {
+ weight->for_each_pruning(seg_ctx, binding_key,
initial_threshold, callback);
+ } else {
+ auto scorer = weight->scorer(seg_ctx, binding_key);
+ if (scorer) {
+ Weight::for_each_pruning_scorer(scorer,
initial_threshold, callback);
+ }
+ }
+
+ for (const auto& doc : seg_collector.into_sorted_vec()) {
+ final_collector.collect(doc.doc_id + seg_base, doc.score);
+ }
+ });
+
+ for (const auto& doc : final_collector.into_sorted_vec()) {
+ roaring->add(doc.doc_id);
+ if (similarity) {
+ similarity->collect(doc.doc_id, doc.score);
+ }
+ }
+}
+
+} // namespace doris::segment_v2::inverted_index::query_v2
diff --git a/be/src/storage/index/inverted/query_v2/collect/top_k_collector.h
b/be/src/storage/index/inverted/query_v2/collect/top_k_collector.h
new file mode 100644
index 00000000000..889aff9fa55
--- /dev/null
+++ b/be/src/storage/index/inverted/query_v2/collect/top_k_collector.h
@@ -0,0 +1,107 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <glog/logging.h>
+
+#include <algorithm>
+#include <cstdint>
+#include <limits>
+#include <ranges>
+#include <roaring/roaring.hh>
+#include <string>
+#include <vector>
+
+#include "storage/compaction/collection_similarity.h"
+#include "storage/index/inverted/query_v2/weight.h"
+
+namespace doris::segment_v2::inverted_index::query_v2 {
+
+struct ScoredDoc {
+ ScoredDoc() = default;
+ ScoredDoc(uint32_t doc, float s) : doc_id(doc), score(s) {}
+
+ uint32_t doc_id = 0;
+ float score = 0.0F;
+};
+
+struct ScoredDocByScoreDesc {
+ bool operator()(const ScoredDoc& a, const ScoredDoc& b) const {
+ return a.score > b.score || (a.score == b.score && a.doc_id <
b.doc_id);
+ }
+};
+
+class TopKCollector {
+public:
+ static constexpr size_t kMaxK = 10000;
+
+ explicit TopKCollector(size_t k) : _k(std::clamp(k, size_t(1), kMaxK)) {
+ if (k > kMaxK) {
+ LOG(WARNING) << "TopKCollector: requested k=" << k << " exceeds
maximum " << kMaxK
+ << ", truncated to " << kMaxK;
+ }
+ _buffer.reserve(_k * 2);
+ }
+
+ float collect(uint32_t doc_id, float score) {
+ if (score < _threshold) {
+ return _threshold;
+ }
+ _buffer.emplace_back(doc_id, score);
+ if (_buffer.size() == _buffer.capacity()) {
+ _truncate();
+ } else if (_buffer.size() == _k) {
+ _update_threshold_at_capacity();
+ }
+ return _threshold;
+ }
+
+ float threshold() const { return _threshold; }
+ size_t size() const { return std::min(_buffer.size(), _k); }
+
+ [[nodiscard]] std::vector<ScoredDoc> into_sorted_vec() {
+ if (_buffer.size() > _k) {
+ _truncate();
+ }
+ std::ranges::sort(_buffer, ScoredDocByScoreDesc {});
+ return std::move(_buffer);
+ }
+
+private:
+ void _truncate() {
+ std::ranges::nth_element(_buffer, _buffer.begin() + _k,
ScoredDocByScoreDesc {});
+ _buffer.resize(_k);
+ _update_threshold_at_capacity();
+ }
+
+ void _update_threshold_at_capacity() {
+ auto it = std::ranges::max_element(_buffer, ScoredDocByScoreDesc {});
+ _threshold = it->score;
+ }
+
+ size_t _k;
+ float _threshold = -std::numeric_limits<float>::infinity();
+ std::vector<ScoredDoc> _buffer;
+};
+
+void collect_multi_segment_top_k(const WeightPtr& weight, const
QueryExecutionContext& context,
+ const std::string& binding_key, size_t k,
+ const std::shared_ptr<roaring::Roaring>&
roaring,
+ const CollectionSimilarityPtr& similarity,
bool use_wand = true);
+
+} // namespace doris::segment_v2::inverted_index::query_v2
diff --git a/be/src/storage/index/inverted/query_v2/composite_reader.h
b/be/src/storage/index/inverted/query_v2/composite_reader.h
index dda04bfd0fa..73a74b8653d 100644
--- a/be/src/storage/index/inverted/query_v2/composite_reader.h
+++ b/be/src/storage/index/inverted/query_v2/composite_reader.h
@@ -17,8 +17,20 @@
#pragma once
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Woverloaded-virtual"
+#elif defined(__GNUC__)
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Woverloaded-virtual"
+#endif
#include <CLucene.h>
#include <CLucene/index/IndexReader.h>
+#ifdef __clang__
+#pragma clang diagnostic pop
+#elif defined(__GNUC__)
+#pragma GCC diagnostic pop
+#endif
#include <algorithm>
#include <ranges>
diff --git a/be/src/storage/index/inverted/query_v2/match_all_docs_scorer.h
b/be/src/storage/index/inverted/query_v2/match_all_docs_scorer.h
index 085a5281f0c..63a705660dc 100644
--- a/be/src/storage/index/inverted/query_v2/match_all_docs_scorer.h
+++ b/be/src/storage/index/inverted/query_v2/match_all_docs_scorer.h
@@ -21,7 +21,19 @@
#include <utility>
#include <vector>
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Woverloaded-virtual"
+#elif defined(__GNUC__)
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Woverloaded-virtual"
+#endif
#include "CLucene.h" // IWYU pragma: keep
+#ifdef __clang__
+#pragma clang diagnostic pop
+#elif defined(__GNUC__)
+#pragma GCC diagnostic pop
+#endif
#include "storage/index/inverted/query_v2/scorer.h"
namespace doris::segment_v2::inverted_index::query_v2 {
diff --git
a/be/src/storage/index/inverted/query_v2/phrase_prefix_query/phrase_prefix_weight.h
b/be/src/storage/index/inverted/query_v2/phrase_prefix_query/phrase_prefix_weight.h
index 9651cf629b4..efd7e62b5b3 100644
---
a/be/src/storage/index/inverted/query_v2/phrase_prefix_query/phrase_prefix_weight.h
+++
b/be/src/storage/index/inverted/query_v2/phrase_prefix_query/phrase_prefix_weight.h
@@ -65,7 +65,7 @@ private:
std::vector<std::pair<size_t, PostingsPtr>> all_postings;
for (const auto& [offset, term] : _phrase_terms) {
auto posting = create_position_posting(reader.get(), _field, term,
_enable_scoring,
- _context->io_ctx);
+ _similarity,
_context->io_ctx);
if (!posting) {
return std::make_shared<EmptyScorer>();
}
@@ -81,7 +81,7 @@ private:
std::vector<SegmentPostingsPtr> suffix_postings;
for (const auto& term : expanded_terms) {
auto posting = create_position_posting(reader.get(), _field, term,
_enable_scoring,
- _context->io_ctx);
+ _similarity,
_context->io_ctx);
if (posting) {
suffix_postings.emplace_back(std::move(posting));
}
diff --git
a/be/src/storage/index/inverted/query_v2/phrase_query/multi_phrase_weight.h
b/be/src/storage/index/inverted/query_v2/phrase_query/multi_phrase_weight.h
index 2399ebf42b4..e7eea460507 100644
--- a/be/src/storage/index/inverted/query_v2/phrase_query/multi_phrase_weight.h
+++ b/be/src/storage/index/inverted/query_v2/phrase_query/multi_phrase_weight.h
@@ -66,7 +66,7 @@ private:
if (term_info.is_single_term()) {
auto posting =
create_position_posting(reader.get(), _field,
term_info.get_single_term(),
- _enable_scoring,
_context->io_ctx);
+ _enable_scoring, _similarity,
_context->io_ctx);
if (posting) {
if (posting->size_hint() > SPARSE_TERM_DOC_THRESHOLD) {
auto loaded_posting = LoadedPostings::load(*posting);
@@ -81,8 +81,9 @@ private:
const auto& terms = term_info.get_multi_terms();
std::vector<PostingsPtr> postings;
for (const auto& term : terms) {
- auto posting = create_position_posting(reader.get(),
_field, term,
- _enable_scoring,
_context->io_ctx);
+ auto posting =
+ create_position_posting(reader.get(), _field,
term, _enable_scoring,
+ _similarity,
_context->io_ctx);
if (posting) {
if (posting->size_hint() <= SPARSE_TERM_DOC_THRESHOLD)
{
postings.push_back(LoadedPostings::load(*posting));
diff --git
a/be/src/storage/index/inverted/query_v2/phrase_query/phrase_weight.h
b/be/src/storage/index/inverted/query_v2/phrase_query/phrase_weight.h
index 5e06d80c058..2600308a9b7 100644
--- a/be/src/storage/index/inverted/query_v2/phrase_query/phrase_weight.h
+++ b/be/src/storage/index/inverted/query_v2/phrase_query/phrase_weight.h
@@ -62,7 +62,7 @@ private:
size_t offset = term_info.position;
auto posting =
create_position_posting(reader.get(), _field,
term_info.get_single_term(),
- _enable_scoring, _context->io_ctx);
+ _enable_scoring, _similarity,
_context->io_ctx);
if (posting) {
term_postings_list.emplace_back(offset, std::move(posting));
} else {
diff --git
a/be/src/storage/index/inverted/query_v2/prefix_query/prefix_weight.h
b/be/src/storage/index/inverted/query_v2/prefix_query/prefix_weight.h
index 62ee3a28260..5e863fd49ce 100644
--- a/be/src/storage/index/inverted/query_v2/prefix_query/prefix_weight.h
+++ b/be/src/storage/index/inverted/query_v2/prefix_query/prefix_weight.h
@@ -17,9 +17,21 @@
#pragma once
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Woverloaded-virtual"
+#elif defined(__GNUC__)
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Woverloaded-virtual"
+#endif
#include <CLucene/config/repl_wchar.h>
#include <CLucene/index/IndexReader.h>
#include <CLucene/index/Term.h>
+#ifdef __clang__
+#pragma clang diagnostic pop
+#elif defined(__GNUC__)
+#pragma GCC diagnostic pop
+#endif
#include "storage/index/index_query_context.h"
#include "storage/index/inverted/query_v2/bit_set_query/bit_set_scorer.h"
@@ -135,7 +147,8 @@ private:
auto term_wstr = StringHelper::to_wstring(term);
auto t = make_term_ptr(_field.c_str(), term_wstr.c_str());
auto iter = make_term_doc_ptr(reader.get(), t.get(),
_enable_scoring, _context->io_ctx);
- auto segment_postings = make_segment_postings(std::move(iter),
_enable_scoring);
+ auto segment_postings =
+ make_segment_postings(std::move(iter), _enable_scoring,
nullptr);
uint32_t doc = segment_postings->doc();
while (doc != TERMINATED) {
diff --git
a/be/src/storage/index/inverted/query_v2/regexp_query/regexp_weight.cpp
b/be/src/storage/index/inverted/query_v2/regexp_query/regexp_weight.cpp
index 0dcabbe54d9..ac6a905ba47 100644
--- a/be/src/storage/index/inverted/query_v2/regexp_query/regexp_weight.cpp
+++ b/be/src/storage/index/inverted/query_v2/regexp_query/regexp_weight.cpp
@@ -17,8 +17,20 @@
#include "storage/index/inverted/query_v2/regexp_query/regexp_weight.h"
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Woverloaded-virtual"
+#elif defined(__GNUC__)
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Woverloaded-virtual"
+#endif
#include <CLucene/index/IndexReader.h>
#include <CLucene/index/Term.h>
+#ifdef __clang__
+#pragma clang diagnostic pop
+#elif defined(__GNUC__)
+#pragma GCC diagnostic pop
+#endif
#include <gen_cpp/PaloBrokerService_types.h>
#include <algorithm>
@@ -41,7 +53,9 @@ RegexpWeight::RegexpWeight(IndexQueryContextPtr context,
std::wstring field, std
_pattern(std::move(pattern)),
_enable_scoring(enable_scoring),
_nullable(nullable) {
- // _max_expansions =
_context->runtime_state->query_options().inverted_index_max_expansions;
+ if (_context->runtime_state) {
+ _max_expansions =
_context->runtime_state->query_options().inverted_index_max_expansions;
+ }
}
ScorerPtr RegexpWeight::scorer(const QueryExecutionContext& context,
@@ -91,13 +105,11 @@ ScorerPtr RegexpWeight::regexp_scorer(const
QueryExecutionContext& context,
return std::make_shared<EmptyScorer>();
}
+ auto reader = lookup_reader(_field, context, binding_key);
auto doc_bitset = std::make_shared<roaring::Roaring>();
for (const auto& term : matching_terms) {
- auto t = make_term_ptr(_field.c_str(), term.c_str());
- auto reader = lookup_reader(_field, context, binding_key);
- auto iter = make_term_doc_ptr(reader.get(), t.get(), _enable_scoring,
_context->io_ctx);
- auto segment_postings = make_segment_postings(std::move(iter),
_enable_scoring);
-
+ auto segment_postings =
+ create_term_posting(reader.get(), _field, term, false,
nullptr, _context->io_ctx);
uint32_t doc = segment_postings->doc();
while (doc != TERMINATED) {
doc_bitset->add(doc);
diff --git
a/be/src/storage/index/inverted/query_v2/regexp_query/regexp_weight.h
b/be/src/storage/index/inverted/query_v2/regexp_query/regexp_weight.h
index c3607585698..f8b1a0ba9fa 100644
--- a/be/src/storage/index/inverted/query_v2/regexp_query/regexp_weight.h
+++ b/be/src/storage/index/inverted/query_v2/regexp_query/regexp_weight.h
@@ -46,7 +46,7 @@ private:
std::wstring _field;
std::string _pattern;
- bool _enable_scoring = false;
+ [[maybe_unused]] bool _enable_scoring = false;
bool _nullable = true;
// Set to 0 to disable limit (ES has no default limit for prefix queries)
// The limit prevents collecting too many terms, but can cause incorrect
results
diff --git a/be/src/storage/index/inverted/query_v2/scorer.h
b/be/src/storage/index/inverted/query_v2/scorer.h
index 41b2ce23987..2b43c475f99 100644
--- a/be/src/storage/index/inverted/query_v2/scorer.h
+++ b/be/src/storage/index/inverted/query_v2/scorer.h
@@ -71,5 +71,6 @@ public:
float score() override { return 0.0F; }
};
+using EmptyScorerPtr = std::shared_ptr<EmptyScorer>;
} // namespace doris::segment_v2::inverted_index::query_v2
diff --git a/be/src/storage/index/inverted/query_v2/segment_postings.h
b/be/src/storage/index/inverted/query_v2/segment_postings.h
index 8112da4772d..316d81d66d4 100644
--- a/be/src/storage/index/inverted/query_v2/segment_postings.h
+++ b/be/src/storage/index/inverted/query_v2/segment_postings.h
@@ -17,14 +17,19 @@
#pragma once
+#include <limits>
#include <variant>
#include "CLucene/index/DocRange.h"
#include "storage/index/inverted/inverted_index_common.h"
#include "storage/index/inverted/query_v2/doc_set.h"
+#include "storage/index/inverted/similarity/similarity.h"
namespace doris::segment_v2::inverted_index::query_v2 {
+using doris::segment_v2::Similarity;
+using doris::segment_v2::SimilarityPtr;
+
class Postings : public DocSet {
public:
Postings() = default;
@@ -40,20 +45,25 @@ public:
using PostingsPtr = std::shared_ptr<Postings>;
-class SegmentPostings final : public Postings {
+class SegmentPostings : public Postings {
public:
using IterVariant = std::variant<std::monostate, TermDocsPtr,
TermPositionsPtr>;
- explicit SegmentPostings(TermDocsPtr iter, bool enable_scoring = false)
- : _iter(std::move(iter)), _enable_scoring(enable_scoring) {
+ explicit SegmentPostings(TermDocsPtr iter, bool enable_scoring,
SimilarityPtr similarity)
+ : _iter(std::move(iter)),
+ _enable_scoring(enable_scoring),
+ _similarity(std::move(similarity)) {
if (auto* p = std::get_if<TermDocsPtr>(&_iter)) {
_raw_iter = p->get();
}
_init_doc();
}
- explicit SegmentPostings(TermPositionsPtr iter, bool enable_scoring =
false)
- : _iter(std::move(iter)), _enable_scoring(enable_scoring),
_has_positions(true) {
+ explicit SegmentPostings(TermPositionsPtr iter, bool enable_scoring,
SimilarityPtr similarity)
+ : _iter(std::move(iter)),
+ _enable_scoring(enable_scoring),
+ _has_positions(true),
+ _similarity(std::move(similarity)) {
if (auto* p = std::get_if<TermPositionsPtr>(&_iter)) {
_raw_iter = p->get();
}
@@ -155,14 +165,63 @@ public:
bool scoring_enabled() const { return _enable_scoring; }
+ int64_t block_id() const { return _block_id; }
+
+ void seek_block(uint32_t target_doc) {
+ if (target_doc <= _doc) {
+ return;
+ }
+ if (_raw_iter->skipToBlock(target_doc)) {
+ _block_max_score_cache = -1.0F;
+ _cursor = 0;
+ _block.doc_many_size_ = 0;
+ }
+ }
+
+ uint32_t last_doc_in_block() const {
+ int32_t last_doc = _raw_iter->getLastDocInBlock();
+ if (last_doc == -1 || last_doc == 0x7FFFFFFFL) {
+ return TERMINATED;
+ }
+ return static_cast<uint32_t>(last_doc);
+ }
+
+ float block_max_score() {
+ if (!_enable_scoring || !_similarity) {
+ return std::numeric_limits<float>::max();
+ }
+ if (_block_max_score_cache >= 0.0F) {
+ return _block_max_score_cache;
+ }
+ int32_t max_block_freq = _raw_iter->getMaxBlockFreq();
+ int32_t max_block_norm = _raw_iter->getMaxBlockNorm();
+ if (max_block_freq >= 0 && max_block_norm >= 0) {
+ _block_max_score_cache =
_similarity->score(static_cast<float>(max_block_freq),
+
static_cast<int64_t>(max_block_norm));
+ return _block_max_score_cache;
+ }
+ return _similarity->max_score();
+ }
+
+ float max_score() const {
+ if (!_enable_scoring || !_similarity) {
+ return std::numeric_limits<float>::max();
+ }
+ return _similarity->max_score();
+ }
+
+ int32_t max_block_freq() const { return _raw_iter->getMaxBlockFreq(); }
+ int32_t max_block_norm() const { return _raw_iter->getMaxBlockNorm(); }
+
private:
bool _refill() {
- _block.need_positions = _has_positions;
- if (!_raw_iter->readRange(&_block)) {
+ if (!_raw_iter->readBlock(&_block)) {
return false;
}
_cursor = 0;
_prox_cursor = 0;
+ _block_max_score_cache = -1.0F;
+ _block_id++;
return _block.doc_many_size_ > 0;
}
@@ -187,17 +246,20 @@ private:
DocRange _block;
uint32_t _cursor = 0;
uint32_t _prox_cursor = 0;
+ mutable float _block_max_score_cache = -1.0F;
+ mutable int64_t _block_id = 0;
+ SimilarityPtr _similarity;
};
-
using SegmentPostingsPtr = std::shared_ptr<SegmentPostings>;
-inline SegmentPostingsPtr make_segment_postings(TermDocsPtr iter, bool
enable_scoring = false) {
- return std::make_shared<SegmentPostings>(std::move(iter), enable_scoring);
+inline SegmentPostingsPtr make_segment_postings(TermDocsPtr iter, bool
enable_scoring,
+ SimilarityPtr similarity) {
+ return std::make_shared<SegmentPostings>(std::move(iter), enable_scoring,
similarity);
}
-inline SegmentPostingsPtr make_segment_postings(TermPositionsPtr iter,
- bool enable_scoring = false) {
- return std::make_shared<SegmentPostings>(std::move(iter), enable_scoring);
+inline SegmentPostingsPtr make_segment_postings(TermPositionsPtr iter, bool
enable_scoring,
+ SimilarityPtr similarity) {
+ return std::make_shared<SegmentPostings>(std::move(iter), enable_scoring,
similarity);
}
} // namespace doris::segment_v2::inverted_index::query_v2
\ No newline at end of file
diff --git a/be/src/storage/index/inverted/query_v2/term_query/term_scorer.h
b/be/src/storage/index/inverted/query_v2/term_query/term_scorer.h
index d03d5fbb9bd..b6c5882ef21 100644
--- a/be/src/storage/index/inverted/query_v2/term_query/term_scorer.h
+++ b/be/src/storage/index/inverted/query_v2/term_query/term_scorer.h
@@ -45,6 +45,11 @@ public:
uint32_t freq() const override { return _segment_postings->freq(); }
uint32_t norm() const override { return _segment_postings->norm(); }
+ void seek_block(uint32_t target) { _segment_postings->seek_block(target); }
+ uint32_t last_doc_in_block() const { return
_segment_postings->last_doc_in_block(); }
+ float block_max_score() const { return
_segment_postings->block_max_score(); }
+ float max_score() const { return _segment_postings->max_score(); }
+
float score() override { return _similarity->score(freq(), norm()); }
bool has_null_bitmap(const NullBitmapResolver* resolver = nullptr)
override {
diff --git a/be/src/storage/index/inverted/query_v2/term_query/term_weight.h
b/be/src/storage/index/inverted/query_v2/term_query/term_weight.h
index fbecd4c3700..544569df0ad 100644
--- a/be/src/storage/index/inverted/query_v2/term_query/term_weight.h
+++ b/be/src/storage/index/inverted/query_v2/term_query/term_weight.h
@@ -17,15 +17,22 @@
#pragma once
+#include <variant>
+
#include "storage/index/inverted/query_v2/segment_postings.h"
#include "storage/index/inverted/query_v2/term_query/term_scorer.h"
+#include "storage/index/inverted/query_v2/wand/block_wand.h"
#include "storage/index/inverted/query_v2/weight.h"
#include "storage/index/inverted/similarity/similarity.h"
namespace doris::segment_v2::inverted_index::query_v2 {
+using TermOrEmptyScorer = std::variant<EmptyScorerPtr, TermScorerPtr>;
+
class TermWeight : public Weight {
public:
+ using Weight::for_each_pruning;
+
TermWeight(IndexQueryContextPtr context, std::wstring field, std::wstring
term,
SimilarityPtr similarity, bool enable_scoring)
: _context(std::move(context)),
@@ -36,25 +43,43 @@ public:
~TermWeight() override = default;
ScorerPtr scorer(const QueryExecutionContext& ctx, const std::string&
binding_key) override {
+ auto result = specialized_scorer(ctx, binding_key);
+ return std::visit([](auto&& sc) -> ScorerPtr { return sc; }, result);
+ }
+
+ template <typename Callback>
+ void for_each_pruning(const QueryExecutionContext& context, const
std::string& binding_key,
+ float threshold, Callback&& callback) {
+ auto result = specialized_scorer(context, binding_key);
+ std::visit(
+ [&](auto&& sc) {
+ using T = std::decay_t<decltype(sc)>;
+ if constexpr (std::is_same_v<T, TermScorerPtr>) {
+ block_wand_single_scorer(std::move(sc), threshold,
+
std::forward<Callback>(callback));
+ }
+ },
+ std::move(result));
+ }
+
+private:
+ TermOrEmptyScorer specialized_scorer(const QueryExecutionContext& ctx,
+ const std::string& binding_key) {
auto reader = lookup_reader(_field, ctx, binding_key);
auto logical_field = logical_field_or_fallback(ctx, binding_key,
_field);
-
if (!reader) {
return std::make_shared<EmptyScorer>();
}
- auto t = make_term_ptr(_field.c_str(), _term.c_str());
- auto iter = make_term_doc_ptr(reader.get(), t.get(), _enable_scoring,
_context->io_ctx);
- if (iter) {
- return std::make_shared<TermScorer>(
- make_segment_postings(std::move(iter), _enable_scoring),
_similarity,
- logical_field);
+ SegmentPostingsPtr segment_postings;
+ segment_postings = create_term_posting(reader.get(), _field, _term,
_enable_scoring,
+ _similarity, _context->io_ctx);
+ if (segment_postings) {
+ return std::make_shared<TermScorer>(segment_postings, _similarity,
logical_field);
}
-
return std::make_shared<EmptyScorer>();
}
-private:
IndexQueryContextPtr _context;
std::wstring _field;
diff --git a/be/src/storage/index/inverted/query_v2/wand/block_wand.h
b/be/src/storage/index/inverted/query_v2/wand/block_wand.h
new file mode 100644
index 00000000000..e0138012778
--- /dev/null
+++ b/be/src/storage/index/inverted/query_v2/wand/block_wand.h
@@ -0,0 +1,286 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <algorithm>
+#include <cassert>
+#include <ranges>
+#include <vector>
+
+#include "storage/index/inverted/query_v2/term_query/term_scorer.h"
+
+namespace doris::segment_v2::inverted_index::query_v2 {
+
+class BlockWand {
+public:
+ template <typename Callback>
+ static void execute(TermScorerPtr scorer, float threshold, Callback&&
callback) {
+ uint32_t doc = scorer->doc();
+ while (doc != TERMINATED) {
+ while (scorer->block_max_score() < threshold) {
+ uint32_t last_doc_in_block = scorer->last_doc_in_block();
+ if (last_doc_in_block == TERMINATED) {
+ return;
+ }
+ doc = last_doc_in_block + 1;
+ scorer->seek_block(doc);
+ }
+
+ doc = scorer->seek(doc);
+ if (doc == TERMINATED) {
+ break;
+ }
+
+ while (true) {
+ float score = scorer->score();
+ if (score > threshold) {
+ threshold = callback(doc, score);
+ }
+ if (doc == scorer->last_doc_in_block()) {
+ break;
+ }
+ doc = scorer->advance();
+ if (doc == TERMINATED) {
+ return;
+ }
+ }
+ doc++;
+ scorer->seek_block(doc);
+ }
+ }
+
+ template <typename Callback>
+ static void execute(std::vector<TermScorerPtr> scorers, float threshold,
Callback&& callback) {
+ if (scorers.empty()) {
+ return;
+ }
+
+ if (scorers.size() == 1) {
+ execute(std::move(scorers[0]), threshold,
std::forward<Callback>(callback));
+ return;
+ }
+
+ std::vector<ScorerWrapper> wrappers;
+ wrappers.reserve(scorers.size());
+ for (auto& s : scorers) {
+ if (s->doc() != TERMINATED) {
+ wrappers.emplace_back(std::move(s));
+ }
+ }
+
+ std::sort(wrappers.begin(), wrappers.end(),
+ [](const ScorerWrapper& a, const ScorerWrapper& b) { return
a.doc() < b.doc(); });
+
+ while (true) {
+ auto result = find_pivot_doc(wrappers, threshold);
+ if (result.pivot_doc == TERMINATED) {
+ break;
+ }
+ auto [before_pivot_len, pivot_len, pivot_doc] = result;
+
+ assert(std::ranges::is_sorted(wrappers,
+ [](const ScorerWrapper& a, const
ScorerWrapper& b) {
+ return a.doc() < b.doc();
+ }));
+ assert(pivot_doc != TERMINATED);
+ assert(before_pivot_len < pivot_len);
+
+ float block_max_score_upperbound = 0.0F;
+ for (size_t i = 0; i < pivot_len; ++i) {
+ wrappers[i].seek_block(pivot_doc);
+ block_max_score_upperbound += wrappers[i].block_max_score();
+ }
+
+ if (block_max_score_upperbound <= threshold) {
+ block_max_was_too_low_advance_one_scorer(wrappers, pivot_len);
+ continue;
+ }
+
+ if (!align_scorers(wrappers, pivot_doc, before_pivot_len)) {
+ continue;
+ }
+
+ float score = 0.0F;
+ for (size_t i = 0; i < pivot_len; ++i) {
+ score += wrappers[i].score();
+ }
+
+ if (score > threshold) {
+ threshold = callback(pivot_doc, score);
+ }
+
+ advance_all_scorers_on_pivot(wrappers, pivot_len);
+ }
+ }
+
+private:
+ class ScorerWrapper {
+ public:
+ explicit ScorerWrapper(TermScorerPtr scorer)
+ : _scorer(std::move(scorer)), _max_score(_scorer->max_score())
{}
+
+ uint32_t doc() const { return _scorer->doc(); }
+ uint32_t advance() { return _scorer->advance(); }
+ uint32_t seek(uint32_t target) { return _scorer->seek(target); }
+ float score() { return _scorer->score(); }
+
+ void seek_block(uint32_t target) { _scorer->seek_block(target); }
+ uint32_t last_doc_in_block() const { return
_scorer->last_doc_in_block(); }
+ float block_max_score() const { return _scorer->block_max_score(); }
+ float max_score() const { return _max_score; }
+
+ private:
+ TermScorerPtr _scorer;
+ float _max_score;
+ };
+
+ struct PivotResult {
+ size_t before_pivot_len;
+ size_t pivot_len;
+ uint32_t pivot_doc;
+ };
+
+ static PivotResult find_pivot_doc(std::vector<ScorerWrapper>& scorers,
float threshold) {
+ float max_score = 0.0F;
+ size_t before_pivot_len = 0;
+ uint32_t pivot_doc = TERMINATED;
+
+ while (before_pivot_len < scorers.size()) {
+ max_score += scorers[before_pivot_len].max_score();
+ if (max_score > threshold) {
+ pivot_doc = scorers[before_pivot_len].doc();
+ break;
+ }
+ before_pivot_len++;
+ }
+
+ if (pivot_doc == TERMINATED) {
+ return PivotResult {.before_pivot_len = 0, .pivot_len = 0,
.pivot_doc = TERMINATED};
+ }
+
+ size_t pivot_len = before_pivot_len + 1;
+ while (pivot_len < scorers.size() && scorers[pivot_len].doc() ==
pivot_doc) {
+ pivot_len++;
+ }
+
+ return PivotResult {.before_pivot_len = before_pivot_len,
+ .pivot_len = pivot_len,
+ .pivot_doc = pivot_doc};
+ }
+
+ static void restore_ordering(std::vector<ScorerWrapper>& scorers, size_t
ord) {
+ uint32_t doc = scorers[ord].doc();
+ while (ord + 1 < scorers.size() && doc > scorers[ord + 1].doc()) {
+ std::swap(scorers[ord], scorers[ord + 1]);
+ ord++;
+ }
+ assert(std::ranges::is_sorted(scorers, [](const ScorerWrapper& a,
const ScorerWrapper& b) {
+ return a.doc() < b.doc();
+ }));
+ }
+
+ static void
block_max_was_too_low_advance_one_scorer(std::vector<ScorerWrapper>& scorers,
+ size_t pivot_len) {
+ assert(std::ranges::is_sorted(scorers, [](const ScorerWrapper& a,
const ScorerWrapper& b) {
+ return a.doc() < b.doc();
+ }));
+
+ size_t scorer_to_seek = pivot_len - 1;
+ float global_max_score = scorers[scorer_to_seek].max_score();
+ uint32_t doc_to_seek_after =
scorers[scorer_to_seek].last_doc_in_block();
+ for (size_t i = pivot_len - 1; i > 0; --i) {
+ size_t scorer_ord = i - 1;
+ const auto& scorer = scorers[scorer_ord];
+ doc_to_seek_after = std::min(doc_to_seek_after,
scorer.last_doc_in_block());
+ if (scorer.max_score() > global_max_score) {
+ global_max_score = scorer.max_score();
+ scorer_to_seek = scorer_ord;
+ }
+ }
+ if (doc_to_seek_after != TERMINATED) {
+ doc_to_seek_after++;
+ }
+ for (size_t i = pivot_len; i < scorers.size(); ++i) {
+ const auto& scorer = scorers[i];
+ doc_to_seek_after = std::min(doc_to_seek_after, scorer.doc());
+ }
+ scorers[scorer_to_seek].seek(doc_to_seek_after);
+ restore_ordering(scorers, scorer_to_seek);
+
+ assert(std::ranges::is_sorted(scorers, [](const ScorerWrapper& a,
const ScorerWrapper& b) {
+ return a.doc() < b.doc();
+ }));
+ }
+
+ static bool align_scorers(std::vector<ScorerWrapper>& scorers, uint32_t
pivot_doc,
+ size_t before_pivot_len) {
+ for (size_t i = before_pivot_len; i > 0; --i) {
+ size_t idx = i - 1;
+ uint32_t new_doc = scorers[idx].seek(pivot_doc);
+ if (new_doc != pivot_doc) {
+ if (new_doc == TERMINATED) {
+ std::swap(scorers[idx], scorers.back());
+ scorers.pop_back();
+ if (scorers.empty()) {
+ return false;
+ }
+ }
+ // Full re-sort to guarantee invariant after swap-with-back,
+ // consistent with advance_all_scorers_on_pivot approach.
+ std::ranges::sort(scorers, [](const ScorerWrapper& a, const
ScorerWrapper& b) {
+ return a.doc() < b.doc();
+ });
+ return false;
+ }
+ }
+ return true;
+ }
+
+ static void advance_all_scorers_on_pivot(std::vector<ScorerWrapper>&
scorers,
+ size_t pivot_len) {
+ for (size_t i = 0; i < pivot_len; ++i) {
+ scorers[i].advance();
+ }
+
+ size_t i = 0;
+ while (i < scorers.size()) {
+ if (scorers[i].doc() == TERMINATED) {
+ std::swap(scorers[i], scorers.back());
+ scorers.pop_back();
+ } else {
+ i++;
+ }
+ }
+
+ std::ranges::sort(scorers, [](const ScorerWrapper& a, const
ScorerWrapper& b) {
+ return a.doc() < b.doc();
+ });
+ }
+};
+
+template <typename Callback>
+inline void block_wand_single_scorer(TermScorerPtr scorer, float threshold,
Callback&& callback) {
+ BlockWand::execute(std::move(scorer), threshold,
std::forward<Callback>(callback));
+}
+
+template <typename Callback>
+inline void block_wand(std::vector<TermScorerPtr> scorers, float threshold,
Callback&& callback) {
+ BlockWand::execute(std::move(scorers), threshold,
std::forward<Callback>(callback));
+}
+
+} // namespace doris::segment_v2::inverted_index::query_v2
diff --git a/be/src/storage/index/inverted/query_v2/weight.h
b/be/src/storage/index/inverted/query_v2/weight.h
index 6ee08c24a54..22d7b513030 100644
--- a/be/src/storage/index/inverted/query_v2/weight.h
+++ b/be/src/storage/index/inverted/query_v2/weight.h
@@ -17,6 +17,7 @@
#pragma once
+#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
@@ -54,6 +55,8 @@ struct QueryExecutionContext {
class Weight {
public:
+ using PruningCallback = std::function<float(uint32_t doc_id, float score)>;
+
Weight() = default;
virtual ~Weight() = default;
@@ -63,6 +66,34 @@ public:
return scorer(context);
}
+ virtual void for_each_pruning(const QueryExecutionContext& context, float
threshold,
+ PruningCallback callback) {
+ auto sc = scorer(context);
+ if (!sc) {
+ return;
+ }
+ for_each_pruning_scorer(sc, threshold, std::move(callback));
+ }
+
+ virtual void for_each_pruning(const QueryExecutionContext& context,
+ const std::string& binding_key, float
threshold,
+ PruningCallback callback) {
+ (void)binding_key;
+ for_each_pruning(context, threshold, std::move(callback));
+ }
+
+ static void for_each_pruning_scorer(const ScorerPtr& scorer, float
threshold,
+ PruningCallback callback) {
+ int32_t doc = scorer->doc();
+ while (doc != TERMINATED) {
+ float score = scorer->score();
+ if (score > threshold) {
+ threshold = callback(doc, score);
+ }
+ doc = scorer->advance();
+ }
+ }
+
protected:
const FieldBindingContext* get_field_binding(const QueryExecutionContext&
ctx,
const std::string&
binding_key) const {
@@ -108,27 +139,36 @@ protected:
SegmentPostingsPtr create_term_posting(lucene::index::IndexReader* reader,
const std::wstring& field, const
std::string& term,
- bool enable_scoring, const
io::IOContext* io_ctx) const {
- auto term_wstr = StringHelper::to_wstring(term);
- auto t = make_term_ptr(field.c_str(), term_wstr.c_str());
+ bool enable_scoring, const
SimilarityPtr& similarity,
+ const io::IOContext* io_ctx) const {
+ return create_term_posting(reader, field,
StringHelper::to_wstring(term), enable_scoring,
+ similarity, io_ctx);
+ }
+
+ SegmentPostingsPtr create_term_posting(lucene::index::IndexReader* reader,
+ const std::wstring& field, const
std::wstring& term,
+ bool enable_scoring, const
SimilarityPtr& similarity,
+ const io::IOContext* io_ctx) const {
+ auto t = make_term_ptr(field.c_str(), term.c_str());
auto iter = make_term_doc_ptr(reader, t.get(), enable_scoring, io_ctx);
- if (iter) {
- return make_segment_postings(std::move(iter), enable_scoring);
- }
- return nullptr;
+ return iter ? make_segment_postings(std::move(iter), enable_scoring,
similarity) : nullptr;
}
SegmentPostingsPtr create_position_posting(lucene::index::IndexReader*
reader,
const std::wstring& field,
const std::string& term,
- bool enable_scoring,
+ bool enable_scoring, const
SimilarityPtr& similarity,
const io::IOContext* io_ctx)
const {
- auto term_wstr = StringHelper::to_wstring(term);
- auto t = make_term_ptr(field.c_str(), term_wstr.c_str());
+ return create_position_posting(reader, field,
StringHelper::to_wstring(term),
+ enable_scoring, similarity, io_ctx);
+ }
+
+ SegmentPostingsPtr create_position_posting(lucene::index::IndexReader*
reader,
+ const std::wstring& field,
const std::wstring& term,
+ bool enable_scoring, const
SimilarityPtr& similarity,
+ const io::IOContext* io_ctx)
const {
+ auto t = make_term_ptr(field.c_str(), term.c_str());
auto iter = make_term_positions_ptr(reader, t.get(), enable_scoring,
io_ctx);
- if (iter) {
- return make_segment_postings(std::move(iter), enable_scoring);
- }
- return nullptr;
+ return iter ? make_segment_postings(std::move(iter), enable_scoring,
similarity) : nullptr;
}
};
diff --git
a/be/src/storage/index/inverted/query_v2/wildcard_query/wildcard_weight.h
b/be/src/storage/index/inverted/query_v2/wildcard_query/wildcard_weight.h
index 566cd5d8dd1..c84d049d7c6 100644
--- a/be/src/storage/index/inverted/query_v2/wildcard_query/wildcard_weight.h
+++ b/be/src/storage/index/inverted/query_v2/wildcard_query/wildcard_weight.h
@@ -41,7 +41,7 @@ public:
ScorerPtr scorer(const QueryExecutionContext& ctx, const std::string&
binding_key) override {
std::string regex_pattern = wildcard_to_regex(_pattern);
auto regexp_weight = std::make_shared<RegexpWeight>(
- _context, std::move(_field), std::move(regex_pattern),
_enable_scoring, _nullable);
+ _context, _field, std::move(regex_pattern), _enable_scoring,
_nullable);
return regexp_weight->scorer(ctx, binding_key);
}
diff --git a/be/src/storage/index/inverted/similarity/bm25_similarity.cpp
b/be/src/storage/index/inverted/similarity/bm25_similarity.cpp
index 095b6ada6be..d3e1ba5ac63 100644
--- a/be/src/storage/index/inverted/similarity/bm25_similarity.cpp
+++ b/be/src/storage/index/inverted/similarity/bm25_similarity.cpp
@@ -17,6 +17,8 @@
#include "storage/index/inverted/similarity/bm25_similarity.h"
+#include <cmath>
+
namespace doris::segment_v2 {
#include "common/compile_check_begin.h"
@@ -83,6 +85,13 @@ float BM25Similarity::score(float freq, int64_t
encoded_norm) {
return _weight - _weight / (1.0F + freq * norm_inverse);
}
+float BM25Similarity::max_score() {
+ // 2013265944 = byte4_to_int(int_to_byte4(MAX_INT32)) from Lucene's
SmallFloat encoding,
+ // representing the maximum possible term frequency. Combined with
norm=255 (shortest
+ // document length), this yields the theoretical upper-bound BM25 score
for this term.
+ return score(static_cast<float>(2013265944), 255);
+}
+
int32_t BM25Similarity::number_of_leading_zeros(uint64_t value) {
if (value == 0) {
return 64;
diff --git a/be/src/storage/index/inverted/similarity/bm25_similarity.h
b/be/src/storage/index/inverted/similarity/bm25_similarity.h
index 06bfcd55a62..e3bfcf17802 100644
--- a/be/src/storage/index/inverted/similarity/bm25_similarity.h
+++ b/be/src/storage/index/inverted/similarity/bm25_similarity.h
@@ -42,6 +42,7 @@ public:
const std::vector<std::wstring>& terms) override;
float score(float freq, int64_t encoded_norm) override;
+ float max_score() override;
static uint8_t int_to_byte4(int32_t i);
static int32_t byte4_to_int(uint8_t b);
diff --git a/be/src/storage/index/inverted/similarity/similarity.h
b/be/src/storage/index/inverted/similarity/similarity.h
index be43534c835..7b4ad195a79 100644
--- a/be/src/storage/index/inverted/similarity/similarity.h
+++ b/be/src/storage/index/inverted/similarity/similarity.h
@@ -36,6 +36,7 @@ public:
const std::vector<std::wstring>& terms) = 0;
virtual float score(float freq, int64_t encoded_norm) = 0;
+ virtual float max_score() = 0;
};
using SimilarityPtr = std::shared_ptr<Similarity>;
diff --git a/be/src/storage/predicate_collector.cpp
b/be/src/storage/predicate_collector.cpp
new file mode 100644
index 00000000000..8e319ae329f
--- /dev/null
+++ b/be/src/storage/predicate_collector.cpp
@@ -0,0 +1,263 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "storage/predicate_collector.h"
+
+#include <glog/logging.h>
+
+#include "exprs/vexpr.h"
+#include "exprs/vexpr_context.h"
+#include "exprs/vliteral.h"
+#include "exprs/vsearch.h"
+#include "exprs/vslot_ref.h"
+#include "gen_cpp/Exprs_types.h"
+#include "storage/index/index_reader_helper.h"
+#include "storage/index/inverted/analyzer/analyzer.h"
+#include "storage/index/inverted/util/string_helper.h"
+#include "storage/tablet/tablet_schema.h"
+
+namespace doris {
+
+using namespace segment_v2;
+
+VSlotRef* PredicateCollector::find_slot_ref(const VExprSPtr& expr) const {
+ if (!expr) {
+ return nullptr;
+ }
+
+ auto cur = VExpr::expr_without_cast(expr);
+ if (cur->node_type() == TExprNodeType::SLOT_REF) {
+ return static_cast<VSlotRef*>(cur.get());
+ }
+
+ for (const auto& ch : cur->children()) {
+ if (auto* s = find_slot_ref(ch)) {
+ return s;
+ }
+ }
+
+ return nullptr;
+}
+
+std::string PredicateCollector::build_field_name(int32_t col_unique_id,
+ const std::string&
suffix_path) const {
+ std::string field_name = std::to_string(col_unique_id);
+ if (!suffix_path.empty()) {
+ field_name += "." + suffix_path;
+ }
+ return field_name;
+}
+
+Status MatchPredicateCollector::collect(RuntimeState* state, const
TabletSchemaSPtr& tablet_schema,
+ const VExprSPtr& expr, CollectInfoMap*
collect_infos) {
+ DCHECK(collect_infos != nullptr);
+
+ auto* left_slot_ref = find_slot_ref(expr->children()[0]);
+ if (left_slot_ref == nullptr) {
+ return Status::Error<ErrorCode::INVERTED_INDEX_NOT_SUPPORTED>(
+ "Index statistics collection failed: Cannot find slot
reference in match predicate "
+ "left expression");
+ }
+
+ auto* right_literal = static_cast<VLiteral*>(expr->children()[1].get());
+ DCHECK(right_literal != nullptr);
+
+ const auto* sd =
state->desc_tbl().get_slot_descriptor(left_slot_ref->slot_id());
+ if (sd == nullptr) {
+ return Status::Error<ErrorCode::INVERTED_INDEX_NOT_SUPPORTED>(
+ "Index statistics collection failed: Cannot find slot
descriptor for slot_id={}",
+ left_slot_ref->slot_id());
+ }
+
+ int32_t col_idx = tablet_schema->field_index(left_slot_ref->column_name());
+ if (col_idx == -1) {
+ return Status::Error<ErrorCode::INVERTED_INDEX_NOT_SUPPORTED>(
+ "Index statistics collection failed: Cannot find column index
for column={}",
+ left_slot_ref->column_name());
+ }
+
+ const auto& column = tablet_schema->column(col_idx);
+ auto index_metas = tablet_schema->inverted_indexs(sd->col_unique_id(),
column.suffix_path());
+
+#ifndef BE_TEST
+ if (index_metas.empty()) {
+ return Status::Error<ErrorCode::INVERTED_INDEX_NOT_SUPPORTED>(
+ "Index statistics collection failed: Score query is not
supported without inverted "
+ "index for column={}",
+ left_slot_ref->column_name());
+ }
+#endif
+
+ for (const auto* index_meta : index_metas) {
+ if (!InvertedIndexAnalyzer::should_analyzer(index_meta->properties()))
{
+ continue;
+ }
+
+ if (!IndexReaderHelper::is_need_similarity_score(expr->op(),
index_meta)) {
+ continue;
+ }
+
+ auto options = DataTypeSerDe::get_default_format_options();
+ options.timezone = &state->timezone_obj();
+ auto term_infos =
InvertedIndexAnalyzer::get_analyse_result(right_literal->value(options),
+
index_meta->properties());
+
+ std::string field_name =
+ build_field_name(index_meta->col_unique_ids()[0],
column.suffix_path());
+ std::wstring ws_field_name = StringHelper::to_wstring(field_name);
+
+ auto iter = collect_infos->find(ws_field_name);
+ if (iter == collect_infos->end()) {
+ CollectInfo collect_info;
+ collect_info.term_infos.insert(term_infos.begin(),
term_infos.end());
+ collect_info.index_meta = index_meta;
+ (*collect_infos)[ws_field_name] = std::move(collect_info);
+ } else {
+ iter->second.term_infos.insert(term_infos.begin(),
term_infos.end());
+ }
+ }
+
+ return Status::OK();
+}
+
+Status SearchPredicateCollector::collect(RuntimeState* state, const
TabletSchemaSPtr& tablet_schema,
+ const VExprSPtr& expr,
CollectInfoMap* collect_infos) {
+ DCHECK(collect_infos != nullptr);
+
+ auto* search_expr = dynamic_cast<VSearchExpr*>(expr.get());
+ if (search_expr == nullptr) {
+ return Status::InternalError("SearchPredicateCollector: expr is not
VSearchExpr type");
+ }
+
+ const TSearchParam& search_param = search_expr->get_search_param();
+
+ RETURN_IF_ERROR(collect_from_clause(search_param.root, state,
tablet_schema, collect_infos));
+
+ return Status::OK();
+}
+
+Status SearchPredicateCollector::collect_from_clause(const TSearchClause&
clause,
+ RuntimeState* state,
+ const TabletSchemaSPtr&
tablet_schema,
+ CollectInfoMap*
collect_infos) {
+ const std::string& clause_type = clause.clause_type;
+ ClauseTypeCategory category = get_clause_type_category(clause_type);
+
+ if (category == ClauseTypeCategory::COMPOUND) {
+ if (clause.__isset.children) {
+ for (const auto& child_clause : clause.children) {
+ RETURN_IF_ERROR(
+ collect_from_clause(child_clause, state,
tablet_schema, collect_infos));
+ }
+ }
+ return Status::OK();
+ }
+
+ return collect_from_leaf(clause, state, tablet_schema, collect_infos);
+}
+
+Status SearchPredicateCollector::collect_from_leaf(const TSearchClause&
clause, RuntimeState* state,
+ const TabletSchemaSPtr&
tablet_schema,
+ CollectInfoMap*
collect_infos) {
+ if (!clause.__isset.field_name || !clause.__isset.value) {
+ return Status::InvalidArgument("Search clause missing field_name or
value");
+ }
+
+ const std::string& field_name = clause.field_name;
+ const std::string& value = clause.value;
+ const std::string& clause_type = clause.clause_type;
+
+ if (!is_score_query_type(clause_type)) {
+ return Status::OK();
+ }
+
+ int32_t col_idx = tablet_schema->field_index(field_name);
+ if (col_idx == -1) {
+ return Status::OK();
+ }
+
+ const auto& column = tablet_schema->column(col_idx);
+
+ auto index_metas = tablet_schema->inverted_indexs(column.unique_id(),
column.suffix_path());
+ if (index_metas.empty()) {
+ return Status::OK();
+ }
+
+ ClauseTypeCategory category = get_clause_type_category(clause_type);
+ for (const auto* index_meta : index_metas) {
+ std::set<TermInfo, TermInfoComparer> term_infos;
+
+ if (category == ClauseTypeCategory::TOKENIZED) {
+ if
(InvertedIndexAnalyzer::should_analyzer(index_meta->properties())) {
+ auto analyzed_terms =
+ InvertedIndexAnalyzer::get_analyse_result(value,
index_meta->properties());
+ term_infos.insert(analyzed_terms.begin(),
analyzed_terms.end());
+ } else {
+ term_infos.insert(TermInfo(value));
+ }
+ } else if (category == ClauseTypeCategory::NON_TOKENIZED) {
+ if (clause_type == "TERM" &&
+
InvertedIndexAnalyzer::should_analyzer(index_meta->properties())) {
+ auto analyzed_terms =
+ InvertedIndexAnalyzer::get_analyse_result(value,
index_meta->properties());
+ term_infos.insert(analyzed_terms.begin(),
analyzed_terms.end());
+ } else {
+ term_infos.insert(TermInfo(value));
+ }
+ }
+
+ std::string lucene_field_name =
+ build_field_name(index_meta->col_unique_ids()[0],
column.suffix_path());
+ std::wstring ws_field_name =
StringHelper::to_wstring(lucene_field_name);
+
+ auto iter = collect_infos->find(ws_field_name);
+ if (iter == collect_infos->end()) {
+ CollectInfo collect_info;
+ collect_info.term_infos = std::move(term_infos);
+ collect_info.index_meta = index_meta;
+ (*collect_infos)[ws_field_name] = std::move(collect_info);
+ } else {
+ iter->second.term_infos.insert(term_infos.begin(),
term_infos.end());
+ }
+ }
+
+ return Status::OK();
+}
+
+bool SearchPredicateCollector::is_score_query_type(const std::string&
clause_type) const {
+ return clause_type == "TERM" || clause_type == "EXACT" || clause_type ==
"PHRASE" ||
+ clause_type == "MATCH" || clause_type == "ANY" || clause_type ==
"ALL";
+}
+
+SearchPredicateCollector::ClauseTypeCategory
SearchPredicateCollector::get_clause_type_category(
+ const std::string& clause_type) const {
+ if (clause_type == "AND" || clause_type == "OR" || clause_type == "NOT" ||
+ clause_type == "OCCUR_BOOLEAN") {
+ return ClauseTypeCategory::COMPOUND;
+ } else if (clause_type == "TERM" || clause_type == "EXACT") {
+ return ClauseTypeCategory::NON_TOKENIZED;
+ } else if (clause_type == "PHRASE" || clause_type == "MATCH" ||
clause_type == "ANY" ||
+ clause_type == "ALL") {
+ return ClauseTypeCategory::TOKENIZED;
+ } else {
+ LOG(WARNING) << "Unknown clause type '" << clause_type
+ << "', defaulting to NON_TOKENIZED category";
+ return ClauseTypeCategory::NON_TOKENIZED;
+ }
+}
+
+} // namespace doris
\ No newline at end of file
diff --git a/be/src/storage/predicate_collector.h
b/be/src/storage/predicate_collector.h
new file mode 100644
index 00000000000..aa5a49344b9
--- /dev/null
+++ b/be/src/storage/predicate_collector.h
@@ -0,0 +1,87 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <set>
+#include <string>
+#include <unordered_map>
+
+#include "common/status.h"
+#include "exprs/vexpr_fwd.h"
+#include "gen_cpp/Exprs_types.h"
+#include "runtime/runtime_state.h"
+#include "storage/index/inverted/query/query_info.h"
+
+namespace doris {
+
+class VSlotRef;
+class TabletIndex;
+class TabletSchema;
+using TabletSchemaSPtr = std::shared_ptr<TabletSchema>;
+
+struct TermInfoComparer {
+ bool operator()(const segment_v2::TermInfo& lhs, const
segment_v2::TermInfo& rhs) const {
+ return lhs.term < rhs.term;
+ }
+};
+
+struct CollectInfo {
+ std::set<segment_v2::TermInfo, TermInfoComparer> term_infos;
+ const TabletIndex* index_meta = nullptr;
+};
+using CollectInfoMap = std::unordered_map<std::wstring, CollectInfo>;
+
+class PredicateCollector {
+public:
+ virtual ~PredicateCollector() = default;
+
+ virtual Status collect(RuntimeState* state, const TabletSchemaSPtr&
tablet_schema,
+ const VExprSPtr& expr, CollectInfoMap*
collect_infos) = 0;
+
+protected:
+ VSlotRef* find_slot_ref(const VExprSPtr& expr) const;
+ std::string build_field_name(int32_t col_unique_id, const std::string&
suffix_path) const;
+};
+
+class MatchPredicateCollector : public PredicateCollector {
+public:
+ Status collect(RuntimeState* state, const TabletSchemaSPtr& tablet_schema,
+ const VExprSPtr& expr, CollectInfoMap* collect_infos)
override;
+};
+
+class SearchPredicateCollector : public PredicateCollector {
+public:
+ Status collect(RuntimeState* state, const TabletSchemaSPtr& tablet_schema,
+ const VExprSPtr& expr, CollectInfoMap* collect_infos)
override;
+
+private:
+ enum class ClauseTypeCategory { NON_TOKENIZED, TOKENIZED, COMPOUND };
+
+ Status collect_from_clause(const TSearchClause& clause, RuntimeState*
state,
+ const TabletSchemaSPtr& tablet_schema,
+ CollectInfoMap* collect_infos);
+ Status collect_from_leaf(const TSearchClause& clause, RuntimeState* state,
+ const TabletSchemaSPtr& tablet_schema,
CollectInfoMap* collect_infos);
+ bool is_score_query_type(const std::string& clause_type) const;
+ ClauseTypeCategory get_clause_type_category(const std::string&
clause_type) const;
+};
+
+using PredicateCollectorPtr = std::unique_ptr<PredicateCollector>;
+
+} // namespace doris
\ No newline at end of file
diff --git a/be/src/storage/segment/segment_iterator.cpp
b/be/src/storage/segment/segment_iterator.cpp
index 9006c96af4d..d6612071efc 100644
--- a/be/src/storage/segment/segment_iterator.cpp
+++ b/be/src/storage/segment/segment_iterator.cpp
@@ -1488,6 +1488,8 @@ Status SegmentIterator::_init_index_iterators() {
if (_score_runtime) {
_index_query_context->collection_statistics =
_opts.collection_statistics;
_index_query_context->collection_similarity =
std::make_shared<CollectionSimilarity>();
+ _index_query_context->query_limit = _score_runtime->get_limit();
+ _index_query_context->is_asc = _score_runtime->is_asc();
}
// Inverted index iterators
@@ -2981,6 +2983,7 @@ Status
SegmentIterator::_construct_compound_expr_context() {
auto inverted_index_context = std::make_shared<IndexExecContext>(
_schema->column_ids(), _index_iterators, _storage_name_and_type,
_common_expr_index_exec_status, _score_runtime, _segment.get(),
iter_opts);
+ inverted_index_context->set_index_query_context(_index_query_context);
for (const auto& expr_ctx : _opts.common_expr_ctxs_push_down) {
VExprContextSPtr context;
// _ann_range_search_runtime will do deep copy.
diff --git a/be/test/storage/compaction/collection_statistics_test.cpp
b/be/test/storage/compaction/collection_statistics_test.cpp
index da1b37e9bda..c37d533c12d 100644
--- a/be/test/storage/compaction/collection_statistics_test.cpp
+++ b/be/test/storage/compaction/collection_statistics_test.cpp
@@ -614,35 +614,6 @@ TEST_F(CollectionStatisticsTest,
CollectWithDoubleCastWrappedSlotRef) {
EXPECT_TRUE(status.ok()) << status.msg();
}
-TEST_F(CollectionStatisticsTest, FindSlotRefHandlesNullDirectCastAndNested) {
- // null
- VExprSPtr null_expr;
- EXPECT_EQ(find_slot_ref(null_expr), nullptr);
-
- // direct SLOT_REF
- auto slot_ref_direct =
- std::make_shared<collection_statistics::MockVSlotRef>("content",
SlotId(1));
- EXPECT_EQ(find_slot_ref(slot_ref_direct),
static_cast<VSlotRef*>(slot_ref_direct.get()));
-
- // CAST(SLOT_REF)
- auto slot_ref_cast =
- std::make_shared<collection_statistics::MockVSlotRef>("content",
SlotId(1));
- auto cast_expr =
std::make_shared<collection_statistics::MockVExpr>(TExprNodeType::CAST_EXPR);
- cast_expr->_children.push_back(slot_ref_cast);
- EXPECT_EQ(find_slot_ref(cast_expr),
static_cast<VSlotRef*>(slot_ref_cast.get()));
-
- // BINARY_PRED(CAST(SLOT_REF), literal)
- auto slot_ref_nested =
- std::make_shared<collection_statistics::MockVSlotRef>("content",
SlotId(1));
- auto inner_cast =
std::make_shared<collection_statistics::MockVExpr>(TExprNodeType::CAST_EXPR);
- inner_cast->_children.push_back(slot_ref_nested);
- auto lit = std::make_shared<collection_statistics::MockVLiteral>("x");
- auto bin =
std::make_shared<collection_statistics::MockVExpr>(TExprNodeType::BINARY_PRED);
- bin->_children.push_back(inner_cast);
- bin->_children.push_back(lit);
- EXPECT_EQ(find_slot_ref(bin),
static_cast<VSlotRef*>(slot_ref_nested.get()));
-}
-
TEST(TermInfoComparerTest, OrdersByTermAndDedups) {
using doris::TermInfoComparer;
using doris::segment_v2::TermInfo;
diff --git a/be/test/storage/index/inverted/query/query_helper_test.cpp
b/be/test/storage/index/inverted/query/query_helper_test.cpp
index e138f7eb5b5..95c1a3a6fad 100644
--- a/be/test/storage/index/inverted/query/query_helper_test.cpp
+++ b/be/test/storage/index/inverted/query/query_helper_test.cpp
@@ -46,6 +46,8 @@ public:
MOCK_FUNCTION float score(float freq, int64_t encoded_norm) override {
return _score_value; }
+ MOCK_FUNCTION float max_score() override { return
std::numeric_limits<float>::max(); }
+
private:
float _score_value;
};
diff --git
a/be/test/storage/index/inverted/query_v2/occur_boolean_query_test.cpp
b/be/test/storage/index/inverted/query_v2/occur_boolean_query_test.cpp
index 1356d114459..8ea18a4395c 100644
--- a/be/test/storage/index/inverted/query_v2/occur_boolean_query_test.cpp
+++ b/be/test/storage/index/inverted/query_v2/occur_boolean_query_test.cpp
@@ -25,12 +25,15 @@
#include <set>
#include <vector>
+#include "storage/index/inverted/analyzer/custom_analyzer.h"
#include "storage/index/inverted/query_v2/all_query/all_query.h"
#include "storage/index/inverted/query_v2/boolean_query/occur.h"
#include "storage/index/inverted/query_v2/boolean_query/occur_boolean_weight.h"
#include "storage/index/inverted/query_v2/query.h"
#include "storage/index/inverted/query_v2/scorer.h"
+#include "storage/index/inverted/query_v2/segment_postings.h"
#include "storage/index/inverted/query_v2/weight.h"
+#include "storage/index/inverted/similarity/bm25_similarity.h"
namespace doris::segment_v2::inverted_index::query_v2 {
namespace {
diff --git a/be/test/storage/index/inverted/query_v2/segment_postings_test.cpp
b/be/test/storage/index/inverted/query_v2/segment_postings_test.cpp
index 7dd289f3d50..0b8e4649b3a 100644
--- a/be/test/storage/index/inverted/query_v2/segment_postings_test.cpp
+++ b/be/test/storage/index/inverted/query_v2/segment_postings_test.cpp
@@ -46,7 +46,9 @@ public:
int32_t read(int32_t*, int32_t*, int32_t) override { return 0; }
int32_t read(int32_t*, int32_t*, int32_t*, int32_t) override { return 0; }
- bool readRange(DocRange* docRange) override {
+ bool readRange(DocRange* docRange) override { return
_fillDocRange(docRange); }
+ bool readBlock(DocRange* docRange) override { return
_fillDocRange(docRange); }
+ bool _fillDocRange(DocRange* docRange) {
if (_read_done || _docs.empty()) {
return false;
}
@@ -62,7 +64,7 @@ public:
}
bool skipTo(const int32_t target) override { return false; }
- void skipToBlock(const int32_t target) override {}
+ bool skipToBlock(const int32_t target) override { return false; }
void close() override {}
lucene::index::TermPositions* __asTermPositions() override { return
nullptr; }
@@ -105,7 +107,9 @@ public:
int32_t read(int32_t*, int32_t*, int32_t) override { return 0; }
int32_t read(int32_t*, int32_t*, int32_t*, int32_t) override { return 0; }
- bool readRange(DocRange* docRange) override {
+ bool readRange(DocRange* docRange) override { return
_fillDocRange(docRange); }
+ bool readBlock(DocRange* docRange) override { return
_fillDocRange(docRange); }
+ bool _fillDocRange(DocRange* docRange) {
if (_read_done || _docs.empty()) {
return false;
}
@@ -121,7 +125,7 @@ public:
}
bool skipTo(const int32_t target) override { return false; }
- void skipToBlock(const int32_t target) override {}
+ bool skipToBlock(const int32_t target) override { return false; }
void close() override {}
lucene::index::TermPositions* __asTermPositions() override { return this; }
@@ -173,7 +177,7 @@ TEST_F(SegmentPostingsTest,
test_postings_positions_with_offset) {
TEST_F(SegmentPostingsTest, test_segment_postings_base_constructor_next_true) {
TermDocsPtr ptr(new MockTermDocs({1, 3, 5}, {2, 4, 6}, {1, 1, 1}, 3));
- SegmentPostings base(std::move(ptr), true);
+ SegmentPostings base(std::move(ptr), true, nullptr);
EXPECT_EQ(base.doc(), 1);
EXPECT_EQ(base.size_hint(), 3);
@@ -183,21 +187,21 @@ TEST_F(SegmentPostingsTest,
test_segment_postings_base_constructor_next_true) {
TEST_F(SegmentPostingsTest, test_segment_postings_base_constructor_next_false)
{
TermDocsPtr ptr(new MockTermDocs({}, {}, {}, 0));
- SegmentPostings base(std::move(ptr));
+ SegmentPostings base(std::move(ptr), true, nullptr);
EXPECT_EQ(base.doc(), TERMINATED);
}
TEST_F(SegmentPostingsTest,
test_segment_postings_base_constructor_doc_terminate) {
TermDocsPtr ptr(new MockTermDocs({TERMINATED}, {1}, {1}, 1));
- SegmentPostings base(std::move(ptr));
+ SegmentPostings base(std::move(ptr), true, nullptr);
EXPECT_EQ(base.doc(), TERMINATED);
}
TEST_F(SegmentPostingsTest, test_segment_postings_base_advance_success) {
TermDocsPtr ptr(new MockTermDocs({1, 3, 5}, {2, 4, 6}, {1, 1, 1}, 3));
- SegmentPostings base(std::move(ptr));
+ SegmentPostings base(std::move(ptr), true, nullptr);
EXPECT_EQ(base.doc(), 1);
EXPECT_EQ(base.advance(), 3);
@@ -206,14 +210,14 @@ TEST_F(SegmentPostingsTest,
test_segment_postings_base_advance_success) {
TEST_F(SegmentPostingsTest, test_segment_postings_base_advance_end) {
TermDocsPtr ptr(new MockTermDocs({1}, {2}, {1}, 1));
- SegmentPostings base(std::move(ptr));
+ SegmentPostings base(std::move(ptr), true, nullptr);
EXPECT_EQ(base.advance(), TERMINATED);
}
TEST_F(SegmentPostingsTest, test_segment_postings_base_seek_target_le_doc) {
TermDocsPtr ptr(new MockTermDocs({1, 3, 5}, {2, 4, 6}, {1, 1, 1}, 3));
- SegmentPostings base(std::move(ptr));
+ SegmentPostings base(std::move(ptr), true, nullptr);
EXPECT_EQ(base.seek(0), 1);
EXPECT_EQ(base.seek(1), 1);
@@ -221,21 +225,21 @@ TEST_F(SegmentPostingsTest,
test_segment_postings_base_seek_target_le_doc) {
TEST_F(SegmentPostingsTest, test_segment_postings_base_seek_in_block_success) {
TermDocsPtr ptr(new MockTermDocs({1, 3, 5, 7}, {2, 4, 6, 8}, {1, 1, 1, 1},
4));
- SegmentPostings base(std::move(ptr));
+ SegmentPostings base(std::move(ptr), true, nullptr);
EXPECT_EQ(base.seek(5), 5);
}
TEST_F(SegmentPostingsTest, test_segment_postings_base_seek_fail) {
TermDocsPtr ptr(new MockTermDocs({1, 3, 5}, {2, 4, 6}, {1, 1, 1}, 3));
- SegmentPostings base(std::move(ptr));
+ SegmentPostings base(std::move(ptr), true, nullptr);
EXPECT_EQ(base.seek(10), TERMINATED);
}
TEST_F(SegmentPostingsTest,
test_segment_postings_base_append_positions_exception) {
TermDocsPtr ptr(new MockTermDocs({1}, {2}, {1}, 1));
- SegmentPostings base(std::move(ptr));
+ SegmentPostings base(std::move(ptr), true, nullptr);
std::vector<uint32_t> output;
EXPECT_THROW(base.append_positions_with_offset(0, output), Exception);
@@ -243,7 +247,7 @@ TEST_F(SegmentPostingsTest,
test_segment_postings_base_append_positions_exceptio
TEST_F(SegmentPostingsTest, test_segment_postings_termdocs) {
TermDocsPtr ptr(new MockTermDocs({1, 3}, {2, 4}, {1, 1}, 2));
- SegmentPostings postings(std::move(ptr));
+ SegmentPostings postings(std::move(ptr), true, nullptr);
EXPECT_EQ(postings.doc(), 1);
EXPECT_EQ(postings.size_hint(), 2);
@@ -252,16 +256,14 @@ TEST_F(SegmentPostingsTest,
test_segment_postings_termdocs) {
TEST_F(SegmentPostingsTest, test_segment_postings_termpositions) {
TermPositionsPtr ptr(
new MockTermPositions({1, 3}, {2, 3}, {1, 1}, {{10, 20}, {30, 40,
50}}, 2));
- SegmentPostings postings(std::move(ptr), true);
-
- EXPECT_EQ(postings.doc(), 1);
+ SegmentPostings postings(std::move(ptr), true, nullptr);
EXPECT_EQ(postings.freq(), 2);
}
TEST_F(SegmentPostingsTest,
test_segment_postings_termpositions_append_positions) {
TermPositionsPtr ptr(
new MockTermPositions({1, 3}, {2, 3}, {1, 1}, {{10, 20}, {30, 40,
50}}, 2));
- SegmentPostings postings(std::move(ptr), true);
+ SegmentPostings postings(std::move(ptr), true, nullptr);
std::vector<uint32_t> output = {999};
postings.append_positions_with_offset(100, output);
@@ -274,7 +276,7 @@ TEST_F(SegmentPostingsTest,
test_segment_postings_termpositions_append_positions
TEST_F(SegmentPostingsTest, test_no_score_segment_posting) {
TermDocsPtr ptr(new MockTermDocs({1, 3}, {5, 7}, {10, 20}, 2));
- SegmentPostings posting(std::move(ptr));
+ SegmentPostings posting(std::move(ptr), false, nullptr);
EXPECT_EQ(posting.doc(), 1);
EXPECT_EQ(posting.freq(), 1);
diff --git a/be/test/storage/index/inverted/query_v2/top_k_collector_test.cpp
b/be/test/storage/index/inverted/query_v2/top_k_collector_test.cpp
new file mode 100644
index 00000000000..74ed27116f1
--- /dev/null
+++ b/be/test/storage/index/inverted/query_v2/top_k_collector_test.cpp
@@ -0,0 +1,490 @@
+
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "storage/index/inverted/query_v2/collect/top_k_collector.h"
+
+#include <gtest/gtest.h>
+
+#include <algorithm>
+#include <numeric>
+#include <random>
+
+namespace doris::segment_v2::inverted_index::query_v2 {
+
+TEST(TopKCollectorTest, TestTieBreaking) {
+ {
+ TopKCollector collector(1);
+
+ collector.collect(100, 5.0);
+ ASSERT_EQ(collector.size(), 1);
+ ASSERT_EQ(collector.threshold(), 5.0);
+
+ collector.collect(99, 5.0);
+
+ auto result = collector.into_sorted_vec();
+ ASSERT_EQ(result.size(), 1);
+ EXPECT_EQ(result[0].doc_id, 99);
+ EXPECT_EQ(result[0].score, 5.0);
+ }
+
+ {
+ TopKCollector collector(2);
+
+ collector.collect(100, 5.0);
+ collector.collect(101, 5.0);
+
+ collector.collect(99, 5.0);
+
+ auto result = collector.into_sorted_vec();
+ ASSERT_EQ(result.size(), 2);
+ EXPECT_EQ(result[0].doc_id, 99);
+ EXPECT_EQ(result[1].doc_id, 100);
+ }
+}
+
+TEST(TopKCollectorTest, TestBasicCollection) {
+ TopKCollector collector(3);
+
+ collector.collect(1, 1.0);
+ collector.collect(2, 2.0);
+ collector.collect(3, 3.0);
+ collector.collect(4, 4.0);
+
+ auto result = collector.into_sorted_vec();
+ ASSERT_EQ(result.size(), 3);
+
+ EXPECT_EQ(result[0].doc_id, 4);
+ EXPECT_EQ(result[0].score, 4.0);
+
+ EXPECT_EQ(result[1].doc_id, 3);
+ EXPECT_EQ(result[1].score, 3.0);
+
+ EXPECT_EQ(result[2].doc_id, 2);
+ EXPECT_EQ(result[2].score, 2.0);
+}
+
+TEST(TopKCollectorTest, TestThresholdPruning) {
+ TopKCollector collector(2);
+
+ collector.collect(1, 5.0);
+ collector.collect(2, 6.0);
+ EXPECT_EQ(collector.threshold(), 5.0);
+
+ float new_threshold = collector.collect(3, 4.0);
+ EXPECT_EQ(new_threshold, 5.0);
+
+ new_threshold = collector.collect(4, 7.0);
+ EXPECT_EQ(new_threshold, 5.0);
+
+ auto result = collector.into_sorted_vec();
+ ASSERT_EQ(result.size(), 2);
+ EXPECT_EQ(result[0].doc_id, 4);
+ EXPECT_EQ(result[1].doc_id, 2);
+}
+
+TEST(TopKCollectorTest, TestK1) {
+ TopKCollector collector(1);
+
+ collector.collect(10, 1.0);
+ EXPECT_EQ(collector.threshold(), 1.0);
+
+ collector.collect(20, 0.5);
+ collector.collect(30, 2.0);
+ EXPECT_EQ(collector.threshold(), 2.0);
+
+ auto result = collector.into_sorted_vec();
+ ASSERT_EQ(result.size(), 1);
+ EXPECT_EQ(result[0].doc_id, 30);
+}
+
+TEST(TopKCollectorTest, TestLargeK) {
+ TopKCollector collector(100);
+
+ for (uint32_t i = 0; i < 50; i++) {
+ collector.collect(i, static_cast<float>(i));
+ }
+
+ EXPECT_EQ(collector.size(), 50);
+ EXPECT_EQ(collector.threshold(), -std::numeric_limits<float>::infinity());
+
+ for (uint32_t i = 50; i < 100; i++) {
+ collector.collect(i, static_cast<float>(i));
+ }
+ EXPECT_EQ(collector.threshold(), 0.0);
+
+ for (uint32_t i = 100; i < 150; i++) {
+ collector.collect(i, static_cast<float>(i));
+ }
+
+ auto result = collector.into_sorted_vec();
+ ASSERT_EQ(result.size(), 100);
+ EXPECT_EQ(result[0].doc_id, 149);
+ EXPECT_EQ(result[99].doc_id, 50);
+}
+
+TEST(TopKCollectorTest, TestBufferTruncation) {
+ TopKCollector collector(3);
+
+ collector.collect(1, 1.0);
+ collector.collect(2, 2.0);
+ collector.collect(3, 3.0);
+ collector.collect(4, 4.0);
+ collector.collect(5, 5.0);
+ collector.collect(6, 6.0);
+
+ auto result = collector.into_sorted_vec();
+ ASSERT_EQ(result.size(), 3);
+ EXPECT_EQ(result[0].score, 6.0);
+ EXPECT_EQ(result[1].score, 5.0);
+ EXPECT_EQ(result[2].score, 4.0);
+}
+
+TEST(TopKCollectorTest, TestEmptyCollector) {
+ TopKCollector collector(5);
+
+ auto result = collector.into_sorted_vec();
+ EXPECT_TRUE(result.empty());
+}
+
+TEST(TopKCollectorTest, TestFewerThanK) {
+ TopKCollector collector(10);
+
+ collector.collect(1, 3.0);
+ collector.collect(2, 1.0);
+ collector.collect(3, 2.0);
+
+ auto result = collector.into_sorted_vec();
+ ASSERT_EQ(result.size(), 3);
+ EXPECT_EQ(result[0].doc_id, 1);
+ EXPECT_EQ(result[1].doc_id, 3);
+ EXPECT_EQ(result[2].doc_id, 2);
+}
+
+TEST(TopKCollectorTest, TestNegativeScores) {
+ TopKCollector collector(3);
+
+ collector.collect(1, -1.0);
+ collector.collect(2, -2.0);
+ collector.collect(3, -0.5);
+ collector.collect(4, -3.0);
+
+ auto result = collector.into_sorted_vec();
+ ASSERT_EQ(result.size(), 3);
+ EXPECT_EQ(result[0].doc_id, 3);
+ EXPECT_EQ(result[1].doc_id, 1);
+ EXPECT_EQ(result[2].doc_id, 2);
+}
+
+TEST(TopKCollectorTest, TestAllSameScore) {
+ TopKCollector collector(3);
+
+ collector.collect(5, 1.0);
+ collector.collect(3, 1.0);
+ collector.collect(7, 1.0);
+ collector.collect(1, 1.0);
+
+ auto result = collector.into_sorted_vec();
+ ASSERT_EQ(result.size(), 3);
+ EXPECT_EQ(result[0].doc_id, 1);
+ EXPECT_EQ(result[1].doc_id, 3);
+ EXPECT_EQ(result[2].doc_id, 5);
+}
+
+std::vector<ScoredDoc> compute_expected_topk(std::vector<ScoredDoc>& docs,
size_t k) {
+ std::sort(docs.begin(), docs.end(), ScoredDocByScoreDesc {});
+ docs.resize(std::min(docs.size(), k));
+ return docs;
+}
+
+TEST(TopKCollectorTest, StressRandomScores1M) {
+ constexpr size_t N = 1000000;
+ constexpr size_t K = 100;
+
+ std::mt19937 rng(42);
+ std::uniform_real_distribution<float> dist(0.0f, 1000.0f);
+
+ TopKCollector collector(K);
+ std::vector<ScoredDoc> all_docs;
+ all_docs.reserve(N);
+
+ for (uint32_t i = 0; i < N; i++) {
+ float score = dist(rng);
+ collector.collect(i, score);
+ all_docs.emplace_back(i, score);
+ }
+
+ auto result = collector.into_sorted_vec();
+ auto expected = compute_expected_topk(all_docs, K);
+
+ ASSERT_EQ(result.size(), K);
+ for (size_t i = 0; i < K; i++) {
+ EXPECT_EQ(result[i].doc_id, expected[i].doc_id) << "Mismatch at
position " << i;
+ EXPECT_FLOAT_EQ(result[i].score, expected[i].score) << "Mismatch at
position " << i;
+ }
+}
+
+TEST(TopKCollectorTest, StressAscendingOrder500K) {
+ constexpr size_t N = 500000;
+ constexpr size_t K = 1000;
+
+ TopKCollector collector(K);
+ std::vector<ScoredDoc> all_docs;
+ all_docs.reserve(N);
+
+ for (uint32_t i = 0; i < N; i++) {
+ float score = static_cast<float>(i);
+ collector.collect(i, score);
+ all_docs.emplace_back(i, score);
+ }
+
+ auto result = collector.into_sorted_vec();
+ auto expected = compute_expected_topk(all_docs, K);
+
+ ASSERT_EQ(result.size(), K);
+ EXPECT_EQ(result[0].doc_id, N - 1);
+ EXPECT_EQ(result[K - 1].doc_id, N - K);
+
+ for (size_t i = 0; i < K; i++) {
+ EXPECT_EQ(result[i].doc_id, expected[i].doc_id);
+ }
+}
+
+TEST(TopKCollectorTest, StressDescendingOrder500K) {
+ constexpr size_t N = 500000;
+ constexpr size_t K = 1000;
+
+ TopKCollector collector(K);
+ std::vector<ScoredDoc> all_docs;
+ all_docs.reserve(N);
+
+ for (uint32_t i = 0; i < N; i++) {
+ float score = static_cast<float>(N - i);
+ collector.collect(i, score);
+ all_docs.emplace_back(i, score);
+ }
+
+ auto result = collector.into_sorted_vec();
+ auto expected = compute_expected_topk(all_docs, K);
+
+ ASSERT_EQ(result.size(), K);
+ EXPECT_EQ(result[0].doc_id, 0);
+ EXPECT_EQ(result[K - 1].doc_id, K - 1);
+
+ for (size_t i = 0; i < K; i++) {
+ EXPECT_EQ(result[i].doc_id, expected[i].doc_id);
+ }
+}
+
+TEST(TopKCollectorTest, StressManyDuplicateScores) {
+ constexpr size_t N = 100000;
+ constexpr size_t K = 500;
+ constexpr int NUM_DISTINCT_SCORES = 100;
+
+ std::mt19937 rng(123);
+ std::uniform_int_distribution<int> score_dist(0, NUM_DISTINCT_SCORES - 1);
+
+ TopKCollector collector(K);
+ std::vector<ScoredDoc> all_docs;
+ all_docs.reserve(N);
+
+ for (uint32_t i = 0; i < N; i++) {
+ float score = static_cast<float>(score_dist(rng));
+ collector.collect(i, score);
+ all_docs.emplace_back(i, score);
+ }
+
+ auto result = collector.into_sorted_vec();
+ auto expected = compute_expected_topk(all_docs, K);
+
+ ASSERT_EQ(result.size(), K);
+ for (size_t i = 0; i < K; i++) {
+ EXPECT_EQ(result[i].doc_id, expected[i].doc_id) << "Mismatch at
position " << i;
+ EXPECT_FLOAT_EQ(result[i].score, expected[i].score);
+ }
+}
+
+TEST(TopKCollectorTest, StressAllSameScore) {
+ constexpr size_t N = 50000;
+ constexpr size_t K = 1000;
+ constexpr float SCORE = 42.0f;
+
+ std::mt19937 rng(456);
+ std::vector<uint32_t> doc_ids(N);
+ std::iota(doc_ids.begin(), doc_ids.end(), 0);
+ std::shuffle(doc_ids.begin(), doc_ids.end(), rng);
+
+ TopKCollector collector(K);
+ for (uint32_t doc_id : doc_ids) {
+ collector.collect(doc_id, SCORE);
+ }
+
+ auto result = collector.into_sorted_vec();
+ ASSERT_EQ(result.size(), K);
+
+ for (size_t i = 0; i < K; i++) {
+ EXPECT_EQ(result[i].doc_id, i) << "Expected doc_id " << i << " at
position " << i;
+ EXPECT_FLOAT_EQ(result[i].score, SCORE);
+ }
+}
+
+TEST(TopKCollectorTest, StressMultipleTruncations) {
+ constexpr size_t K = 100;
+ constexpr size_t N = K * 50;
+
+ std::mt19937 rng(789);
+ std::uniform_real_distribution<float> dist(0.0f, 10000.0f);
+
+ TopKCollector collector(K);
+ std::vector<ScoredDoc> all_docs;
+ all_docs.reserve(N);
+
+ for (uint32_t i = 0; i < N; i++) {
+ float score = dist(rng);
+ collector.collect(i, score);
+ all_docs.emplace_back(i, score);
+ }
+
+ auto result = collector.into_sorted_vec();
+ auto expected = compute_expected_topk(all_docs, K);
+
+ ASSERT_EQ(result.size(), K);
+ for (size_t i = 0; i < K; i++) {
+ EXPECT_EQ(result[i].doc_id, expected[i].doc_id);
+ EXPECT_FLOAT_EQ(result[i].score, expected[i].score);
+ }
+}
+
+TEST(TopKCollectorTest, StressZipfDistribution) {
+ constexpr size_t N = 500000;
+ constexpr size_t K = 100;
+
+ std::mt19937 rng(999);
+
+ TopKCollector collector(K);
+ std::vector<ScoredDoc> all_docs;
+ all_docs.reserve(N);
+
+ for (uint32_t i = 0; i < N; i++) {
+ float base_score = 1.0f / (static_cast<float>(i % 10000) + 1.0f);
+ float noise = static_cast<float>(rng() % 1000) / 1000000.0f;
+ float score = base_score + noise;
+
+ collector.collect(i, score);
+ all_docs.emplace_back(i, score);
+ }
+
+ auto result = collector.into_sorted_vec();
+ auto expected = compute_expected_topk(all_docs, K);
+
+ ASSERT_EQ(result.size(), K);
+ for (size_t i = 0; i < K; i++) {
+ EXPECT_EQ(result[i].doc_id, expected[i].doc_id) << "Mismatch at
position " << i;
+ }
+}
+
+TEST(TopKCollectorTest, StressSmallKLargeN) {
+ constexpr size_t N = 1000000;
+ constexpr size_t K = 10;
+
+ std::mt19937 rng(111);
+ std::uniform_real_distribution<float> dist(0.0f, 1.0f);
+
+ TopKCollector collector(K);
+ std::vector<ScoredDoc> all_docs;
+ all_docs.reserve(N);
+
+ for (uint32_t i = 0; i < N; i++) {
+ float score = dist(rng);
+ collector.collect(i, score);
+ all_docs.emplace_back(i, score);
+ }
+
+ auto result = collector.into_sorted_vec();
+ auto expected = compute_expected_topk(all_docs, K);
+
+ ASSERT_EQ(result.size(), K);
+ for (size_t i = 0; i < K; i++) {
+ EXPECT_EQ(result[i].doc_id, expected[i].doc_id);
+ EXPECT_FLOAT_EQ(result[i].score, expected[i].score);
+ }
+}
+
+TEST(TopKCollectorTest, StressBimodalDistribution) {
+ constexpr size_t N = 200000;
+ constexpr size_t K = 500;
+
+ std::mt19937 rng(222);
+
+ TopKCollector collector(K);
+ std::vector<ScoredDoc> all_docs;
+ all_docs.reserve(N);
+
+ for (uint32_t i = 0; i < N; i++) {
+ float score;
+ if (i % 2 == 0) {
+ score = static_cast<float>(rng() % 1000) / 100.0f;
+ } else {
+ score = 90.0f + static_cast<float>(rng() % 1000) / 100.0f;
+ }
+ collector.collect(i, score);
+ all_docs.emplace_back(i, score);
+ }
+
+ auto result = collector.into_sorted_vec();
+ auto expected = compute_expected_topk(all_docs, K);
+
+ ASSERT_EQ(result.size(), K);
+ for (size_t i = 0; i < K; i++) {
+ EXPECT_EQ(result[i].doc_id, expected[i].doc_id);
+ }
+
+ for (size_t i = 0; i < K; i++) {
+ EXPECT_EQ(result[i].doc_id % 2, 1) << "Expected odd doc_id at position
" << i;
+ }
+}
+
+TEST(TopKCollectorTest, StressThresholdBoundary) {
+ constexpr size_t K = 100;
+ constexpr size_t N = 10000;
+ constexpr float BASE_SCORE = 50.0f;
+
+ TopKCollector collector(K);
+ std::vector<ScoredDoc> all_docs;
+ all_docs.reserve(N);
+
+ for (uint32_t i = 0; i < K; i++) {
+ collector.collect(i, BASE_SCORE);
+ all_docs.emplace_back(i, BASE_SCORE);
+ }
+
+ for (uint32_t i = K; i < N; i++) {
+ float score = (i % 2 == 0) ? BASE_SCORE : BASE_SCORE + 0.001f;
+ collector.collect(i, score);
+ all_docs.emplace_back(i, score);
+ }
+
+ auto result = collector.into_sorted_vec();
+ auto expected = compute_expected_topk(all_docs, K);
+
+ ASSERT_EQ(result.size(), K);
+ for (size_t i = 0; i < K; i++) {
+ EXPECT_EQ(result[i].doc_id, expected[i].doc_id) << "Mismatch at
position " << i;
+ }
+}
+
+} // namespace doris::segment_v2::inverted_index::query_v2
diff --git a/be/test/storage/index/inverted/query_v2/union_postings_test.cpp
b/be/test/storage/index/inverted/query_v2/union_postings_test.cpp
index 3ddd310ea37..d728bd7cda0 100644
--- a/be/test/storage/index/inverted/query_v2/union_postings_test.cpp
+++ b/be/test/storage/index/inverted/query_v2/union_postings_test.cpp
@@ -56,7 +56,9 @@ public:
int32_t read(int32_t*, int32_t*, int32_t) override { return 0; }
int32_t read(int32_t*, int32_t*, int32_t*, int32_t) override { return 0; }
- bool readRange(DocRange* docRange) override {
+ bool readRange(DocRange* docRange) override { return
_fillDocRange(docRange); }
+ bool readBlock(DocRange* docRange) override { return
_fillDocRange(docRange); }
+ bool _fillDocRange(DocRange* docRange) {
if (_read_done || _docs.empty()) {
return false;
}
@@ -72,7 +74,7 @@ public:
}
bool skipTo(const int32_t target) override { return false; }
- void skipToBlock(const int32_t target) override {}
+ bool skipToBlock(const int32_t target) override { return false; }
void close() override {}
lucene::index::TermPositions* __asTermPositions() override { return this; }
lucene::index::TermDocs* __asTermDocs() override { return this; }
@@ -105,7 +107,7 @@ static SegmentPostingsPtr
make_pos_postings(std::vector<uint32_t> docs, std::vec
int32_t df = static_cast<int32_t>(docs.size());
TermPositionsPtr ptr(new MockTermPositionsForUnion(std::move(docs),
std::move(freqs),
std::move(norms),
std::move(positions), df));
- return std::make_shared<SegmentPostings>(std::move(ptr), true);
+ return std::make_shared<SegmentPostings>(std::move(ptr), true, nullptr);
}
class UnionPostingsTest : public testing::Test {};
diff --git a/contrib/clucene b/contrib/clucene
index 8b57674e9d7..c51b5cc9adc 160000
--- a/contrib/clucene
+++ b/contrib/clucene
@@ -1 +1 @@
-Subproject commit 8b57674e9d78769b10aa0c1441cd12671a394745
+Subproject commit c51b5cc9adc63817ad8322f617c75737ece7288d
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownScoreTopNIntoOlapScan.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownScoreTopNIntoOlapScan.java
index 7073febac4d..d24a6018438 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownScoreTopNIntoOlapScan.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownScoreTopNIntoOlapScan.java
@@ -30,6 +30,7 @@ import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.Match;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.SearchExpression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Score;
import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
@@ -116,12 +117,13 @@ public class PushDownScoreTopNIntoOlapScan implements
RewriteRuleFactory {
return null;
}
- // 2. Requirement: WHERE clause must contain a MATCH function.
- boolean hasMatchPredicate = filter.getConjuncts().stream()
- .anyMatch(conjunct -> !conjunct.collect(e -> e instanceof
Match).isEmpty());
- if (!hasMatchPredicate) {
+ // 2. Requirement: WHERE clause must contain a MATCH or SEARCH
function.
+ boolean hasMatchOrSearchPredicate = filter.getConjuncts().stream()
+ .anyMatch(conjunct -> !conjunct.collect(
+ e -> e instanceof Match || e instanceof
SearchExpression).isEmpty());
+ if (!hasMatchOrSearchPredicate) {
throw new AnalysisException(
- "WHERE clause must contain at least one MATCH function"
+ "WHERE clause must contain at least one MATCH or SEARCH
function"
+ " for score() push down optimization");
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
index b513ba12345..c6266623ec6 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
@@ -620,6 +620,7 @@ public class SessionVariable implements Serializable,
Writable {
// used for cross-platform (x86/arm) inverted index compatibility
// may removed in the future
public static final String INVERTED_INDEX_COMPATIBLE_READ =
"inverted_index_compatible_read";
+ public static final String ENABLE_INVERTED_INDEX_WAND_QUERY =
"enable_inverted_index_wand_query";
public static final String AUTO_ANALYZE_START_TIME =
"auto_analyze_start_time";
@@ -873,6 +874,9 @@ public class SessionVariable implements Serializable,
Writable {
+ "proportion as hot values, up to
HOT_VALUE_COLLECT_COUNT."})
public int hotValueCollectCount = 10; // Select the values that account
for at least 10% of the column
+ @VariableMgr.VarAttr(name = ENABLE_INVERTED_INDEX_WAND_QUERY,
+ description = {"是否开启倒排索引WAND查询优化", "Whether to enable inverted
index WAND query optimization"})
+ public boolean enableInvertedIndexWandQuery = true;
public void setHotValueCollectCount(int count) {
this.hotValueCollectCount = count;
@@ -5363,6 +5367,7 @@ public class SessionVariable implements Serializable,
Writable {
tResult.setInvertedIndexSkipThreshold(invertedIndexSkipThreshold);
tResult.setInvertedIndexCompatibleRead(invertedIndexCompatibleRead);
+ tResult.setEnableInvertedIndexWandQuery(enableInvertedIndexWandQuery);
tResult.setCteMaxRecursionDepth(cteMaxRecursionDepth);
tResult.setEnableParallelScan(enableParallelScan);
tResult.setEnableLeftSemiDirectReturnOpt(enableLeftSemiDirectReturnOpt);
diff --git a/gensrc/thrift/PaloInternalService.thrift
b/gensrc/thrift/PaloInternalService.thrift
index 495c1477647..a118cea519f 100644
--- a/gensrc/thrift/PaloInternalService.thrift
+++ b/gensrc/thrift/PaloInternalService.thrift
@@ -445,6 +445,8 @@ struct TQueryOptions {
// hash table expansion thresholds since all data is local.
202: optional bool single_backend_query = false;
+ 203: optional bool enable_inverted_index_wand_query = true;
+
// For cloud, to control if the content would be written into file cache
// In write path, to control if the content would be written into file cache.
// In read path, read from file cache or remote storage when execute query.
diff --git a/regression-test/suites/inverted_index_p0/test_bm25_score.groovy
b/regression-test/suites/inverted_index_p0/test_bm25_score.groovy
index cdbec257922..2686011e89e 100644
--- a/regression-test/suites/inverted_index_p0/test_bm25_score.groovy
+++ b/regression-test/suites/inverted_index_p0/test_bm25_score.groovy
@@ -139,7 +139,7 @@ suite("test_bm25_score", "p0") {
test {
sql """ select score() as score from test_bm25_score where request
= 'button.03.gif' order by score() limit 10; """
- exception "WHERE clause must contain at least one MATCH function
for score() push down optimization"
+ exception "WHERE clause must contain at least one MATCH or SEARCH
function for score() push down optimization"
}
test {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]