This is an automated email from the ASF dual-hosted git repository.
yiguolei pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-4.0 by this push:
new 9bffa860b7d branch-4.0: [fix](range search) Fix ann range search
prepare failed #56621 (#56656)
9bffa860b7d is described below
commit 9bffa860b7dc3d1b12cfd87d6426039ddef417d4
Author: github-actions[bot]
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Tue Sep 30 15:46:13 2025 +0800
branch-4.0: [fix](range search) Fix ann range search prepare failed #56621
(#56656)
Cherry-picked from #56621
Co-authored-by: zhiqiang <[email protected]>
---
be/src/vec/exprs/vectorized_fn_call.cpp | 145 +++++++++++++++++++++-----------
1 file changed, 96 insertions(+), 49 deletions(-)
diff --git a/be/src/vec/exprs/vectorized_fn_call.cpp
b/be/src/vec/exprs/vectorized_fn_call.cpp
index e43accefab8..cc61a6976ed 100644
--- a/be/src/vec/exprs/vectorized_fn_call.cpp
+++ b/be/src/vec/exprs/vectorized_fn_call.cpp
@@ -332,19 +332,31 @@ bool VectorizedFnCall::equals(const VExpr& other) {
}
/*
- FuncationCall(LE/LT/GE/GT)
- |----------------
- | |
- | |
- VirtualSlotRef Float32Literal
- |
- |
- FuncationCall
- |----------------
- | |
- | |
- SlotRef ArrayLiteral/Cast(String as Array<FLOAT>)
-*/
+ * For ANN range search we expect a comparison expression (LE/LT/GE/GT) whose
left side is either:
+ * 1) a vector distance function call, or
+ * 2) a cast/virtual slot that unwraps to the function call when the planner
promotes float to
+ * double literals.
+ *
+ * Visually the logical tree looks like:
+ *
+ * FunctionCall(LE/LT/GE/GT)
+ * |----------------
+ * | |
+ * | |
+ * VirtualSlotRef* Float32Literal/Float64Literal
+ * |
+ * |
+ * Cast(Float -> Double)*
+ * |
+ * FunctionCall(distance)
+ * |----------------
+ * | |
+ * | |
+ * SlotRef ArrayLiteral/Cast(String as Array<FLOAT>)
+ *
+ * Items marked with * are optional and depend on literal types/virtual column
usage. The helper
+ * below normalizes the shape and validates distance function, slot, and
constant vector inputs.
+ */
void VectorizedFnCall::prepare_ann_range_search(
const doris::VectorSearchUserParams& user_params,
@@ -355,10 +367,14 @@ void VectorizedFnCall::prepare_ann_range_search(
if (OPS_FOR_ANN_RANGE_SEARCH.find(this->op()) ==
OPS_FOR_ANN_RANGE_SEARCH.end()) {
suitable_for_ann_index = false;
- // Not a range search function.
return;
}
+ auto mark_unsuitable = [&](const std::string& reason) {
+ suitable_for_ann_index = false;
+ VLOG_DEBUG << "ANN range search skipped: " << reason;
+ };
+
range_search_runtime.is_le_or_lt =
(this->op() == TExprOpcode::LE || this->op() == TExprOpcode::LT);
@@ -367,46 +383,71 @@ void VectorizedFnCall::prepare_ann_range_search(
auto left_child = get_child(0);
auto right_child = get_child(1);
- // right side
auto right_literal = std::dynamic_pointer_cast<VLiteral>(right_child);
if (right_literal == nullptr) {
suitable_for_ann_index = false;
- // Right child is not a literal.
return;
}
auto right_col =
right_literal->get_column_ptr()->convert_to_full_column_if_const();
auto right_type = right_literal->get_data_type();
- if (right_type->get_primitive_type() != PrimitiveType::TYPE_FLOAT) {
- suitable_for_ann_index = false;
- // Right child is not a Float32Literal.
+
+ PrimitiveType right_primitive = right_type->get_primitive_type();
+ const bool float32_literal = right_primitive == PrimitiveType::TYPE_FLOAT;
+ const bool float64_literal = right_primitive == PrimitiveType::TYPE_DOUBLE;
+ if (!float32_literal && !float64_literal) {
+ mark_unsuitable("Right child is not a Float32Literal or
Float64Literal.");
return;
}
- const ColumnFloat32* cf32_right = assert_cast<const
ColumnFloat32*>(right_col.get());
- range_search_runtime.radius = cf32_right->get_data()[0];
+ if (float32_literal) {
+ const ColumnFloat32* cf32_right = assert_cast<const
ColumnFloat32*>(right_col.get());
+ range_search_runtime.radius = cf32_right->get_data()[0];
+ } else if (float64_literal) {
+ const ColumnFloat64* cf64_right = assert_cast<const
ColumnFloat64*>(right_col.get());
+ range_search_runtime.radius =
static_cast<float>(cf64_right->get_data()[0]);
+ }
+
+ auto get_virtual_expr = [&](const VExprSPtr& expr,
+ std::shared_ptr<VirtualSlotRef>& slot_ref) ->
VExprSPtr {
+ auto virtual_ref = std::dynamic_pointer_cast<VirtualSlotRef>(expr);
+ if (virtual_ref != nullptr) {
+ DCHECK(virtual_ref->get_virtual_column_expr() != nullptr);
+ slot_ref = virtual_ref;
+ return virtual_ref->get_virtual_column_expr();
+ }
+ return expr;
+ };
+
+ std::shared_ptr<VirtualSlotRef> vir_slot_ref;
+ auto normalized_left = get_virtual_expr(left_child, vir_slot_ref);
- // left side
std::shared_ptr<VectorizedFnCall> function_call;
- auto vir_slot_ref = std::dynamic_pointer_cast<VirtualSlotRef>(left_child);
- // Return type of L2Distance is always float.
- if (vir_slot_ref != nullptr) {
- DCHECK(vir_slot_ref->get_virtual_column_expr() != nullptr);
- function_call = std::dynamic_pointer_cast<VectorizedFnCall>(
- vir_slot_ref->get_virtual_column_expr());
+ if (float32_literal) {
+ function_call =
std::dynamic_pointer_cast<VectorizedFnCall>(normalized_left);
+ if (function_call == nullptr) {
+ mark_unsuitable("Left child is not a function call.");
+ return;
+ }
} else {
- function_call =
std::dynamic_pointer_cast<VectorizedFnCall>(left_child);
- }
+ auto cast_float_to_double =
std::dynamic_pointer_cast<VCastExpr>(normalized_left);
+ if (cast_float_to_double == nullptr) {
+ mark_unsuitable("Left child is not a cast expression.");
+ return;
+ }
- if (function_call == nullptr) {
- suitable_for_ann_index = false;
- // Left child is not a function call.
- return;
+ auto normalized_cast_child =
+ get_virtual_expr(cast_float_to_double->get_child(0),
vir_slot_ref);
+ function_call =
std::dynamic_pointer_cast<VectorizedFnCall>(normalized_cast_child);
+ if (function_call == nullptr) {
+ mark_unsuitable("Left child of cast is not a function call.");
+ return;
+ }
}
if (DISTANCE_FUNCS.find(function_call->_function_name) ==
DISTANCE_FUNCS.end()) {
- // Left child is not a approximate distance function. Got
function_call->_function_name
- suitable_for_ann_index = false;
+ mark_unsuitable(fmt::format("Left child is not a supported distance
function: {}",
+ function_call->_function_name));
return;
} else {
// Strip the _approximate suffix.
@@ -418,23 +459,24 @@ void VectorizedFnCall::prepare_ann_range_search(
// Identify the slot ref child and the constant query array child
(ArrayLiteral or CAST to array)
Int32 idx_of_slot_ref = -1;
Int32 idx_of_array_expr = -1;
- for (UInt16 i = 0; i < function_call->get_num_children(); ++i) {
- auto child = function_call->get_child(i);
+ auto classify_child = [&](const VExprSPtr& child, UInt16 index) {
if (idx_of_slot_ref == -1 &&
std::dynamic_pointer_cast<VSlotRef>(child) != nullptr) {
- idx_of_slot_ref = i;
- continue;
+ idx_of_slot_ref = index;
+ return;
}
- // Accept either ArrayLiteral or Cast-to-array constant
if (idx_of_array_expr == -1 &&
(std::dynamic_pointer_cast<VArrayLiteral>(child) != nullptr ||
std::dynamic_pointer_cast<VCastExpr>(child) != nullptr)) {
- idx_of_array_expr = i;
+ idx_of_array_expr = index;
}
+ };
+
+ for (UInt16 i = 0; i < function_call->get_num_children(); ++i) {
+ classify_child(function_call->get_child(i), i);
}
if (idx_of_slot_ref == -1 || idx_of_array_expr == -1) {
- suitable_for_ann_index = false;
- // slot ref or array literal/cast is missing.
+ mark_unsuitable("slot ref or array literal/cast is missing.");
return;
}
@@ -444,11 +486,10 @@ void VectorizedFnCall::prepare_ann_range_search(
range_search_runtime.dst_col_idx = vir_slot_ref == nullptr ? -1 :
vir_slot_ref->column_id();
// Materialize the constant array expression and validate its shape and
types
- std::shared_ptr<ColumnPtrWrapper> column_wrapper;
auto array_expr =
function_call->get_child(static_cast<UInt16>(idx_of_array_expr));
auto extract_result = extract_query_vector(array_expr);
if (!extract_result.has_value()) {
- suitable_for_ann_index = false;
+ mark_unsuitable("Failed to extract query vector from constant array
expression.");
return;
}
range_search_runtime.query_value = extract_result.value();
@@ -481,14 +522,17 @@ Status VectorizedFnCall::evaluate_ann_range_search(
DCHECK(src_col_cid < cid_to_index_iterators.size());
segment_v2::IndexIterator* index_iterator =
cid_to_index_iterators[src_col_cid].get();
if (index_iterator == nullptr) {
- // No index iterator for column cid
+ VLOG_DEBUG << "ANN range search skipped: "
+ << fmt::format("No index iterator for column cid {}",
src_col_cid);
+ ;
return Status::OK();
}
segment_v2::AnnIndexIterator* ann_index_iterator =
dynamic_cast<segment_v2::AnnIndexIterator*>(index_iterator);
if (ann_index_iterator == nullptr) {
- // No ann index iterator for column cid
+ VLOG_DEBUG << "ANN range search skipped: "
+ << fmt::format("Column cid {} has no ANN index iterator",
src_col_cid);
return Status::OK();
}
DCHECK(ann_index_iterator->get_reader(AnnIndexReaderType::ANN) != nullptr)
@@ -499,7 +543,10 @@ Status VectorizedFnCall::evaluate_ann_range_search(
<< "Ann index reader should not be null. Column cid: " <<
src_col_cid;
// Check if metrics type is match.
if (ann_index_reader->get_metric_type() !=
range_search_runtime.metric_type) {
- // Metric type not match, can not execute range search by index.
+ VLOG_DEBUG << "ANN range search skipped: "
+ << fmt::format("Metric type mismatch. Index={} Query={}",
+
segment_v2::metric_to_string(ann_index_reader->get_metric_type()),
+
segment_v2::metric_to_string(range_search_runtime.metric_type));
return Status::OK();
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]