HappenLee commented on code in PR #54276:
URL: https://github.com/apache/doris/pull/54276#discussion_r2287754696
##########
be/src/vec/exprs/vectorized_fn_call.cpp:
##########
@@ -301,5 +324,262 @@ bool VectorizedFnCall::equals(const VExpr& other) {
return true;
}
+/*
+ FuncationCall(LE/LT/GE/GT)
+ |----------------
+ | |
+ | |
+ VirtualSlotRef Float64Literal
+ |
+ |
+ FuncationCall
+ |----------------
+ | |
+ | |
+ CastToArray ArrayLiteral
+ |
+ |
+ SlotRef
+*/
+
+Status VectorizedFnCall::prepare_ann_range_search(
+ const doris::VectorSearchUserParams& user_params,
+ segment_v2::AnnRangeSearchRuntime& range_search_runtime, bool&
suitable_for_ann_index) {
+ if (!suitable_for_ann_index) {
+ return Status::OK();
+ }
+ std::set<TExprOpcode::type> ops = {TExprOpcode::GE, TExprOpcode::LE,
TExprOpcode::LE,
+ TExprOpcode::GT, TExprOpcode::LT};
+ if (ops.find(this->op()) == ops.end()) {
+ suitable_for_ann_index = false;
+ // Not a range search function.
+ return Status::OK();
+ }
+
+ range_search_runtime.is_le_or_lt =
+ (this->op() == TExprOpcode::LE || this->op() == TExprOpcode::LT);
+
+ DCHECK(_children.size() == 2);
+
+ auto left_child = get_child(0);
+ auto right_child = get_child(1);
+
+ // Return type of L2Distance is always double.
+ auto right_literal = std::dynamic_pointer_cast<VLiteral>(right_child);
+ if (right_literal == nullptr) {
+ suitable_for_ann_index = false;
+ // Right child is not a literal.
+ return Status::OK();
+ }
+
+ auto right_col =
right_literal->get_column_ptr()->convert_to_full_column_if_const();
+ auto right_type = right_literal->get_data_type();
+ if (right_type->get_primitive_type() != PrimitiveType::TYPE_DOUBLE) {
+ suitable_for_ann_index = false;
+ // Right child is not a Float64Literal.
+ return Status::OK();
+ }
+
+ const ColumnFloat64* cf64_right = assert_cast<const
ColumnFloat64*>(right_col.get());
+ range_search_runtime.radius = cf64_right->get_data()[0];
+
+ std::shared_ptr<VectorizedFnCall> function_call;
+ auto vir_slot_ref = std::dynamic_pointer_cast<VirtualSlotRef>(left_child);
+ if (vir_slot_ref != nullptr) {
+ DCHECK(vir_slot_ref->get_virtual_column_expr() != nullptr);
+ function_call = std::dynamic_pointer_cast<VectorizedFnCall>(
+ vir_slot_ref->get_virtual_column_expr());
+ } else {
+ function_call =
std::dynamic_pointer_cast<VectorizedFnCall>(left_child);
+ }
+
+ if (function_call == nullptr) {
+ suitable_for_ann_index = false;
+ // Left child is not a function call.
+ return Status::OK();
+ }
+
+ // Now left child is a function call, we need to check if it is a distance
function
+ std::set<std::string> distance_functions = {L2DistanceApproximate::name,
+ InnerProductApproximate::name};
+ if (distance_functions.find(function_call->_function_name) ==
distance_functions.end()) {
+ // Left child is not a approximate distance function. Got
function_call->_function_name
+ suitable_for_ann_index = false;
+ return Status::OK();
+ } else {
+ // Strip the _approximate suffix.
+ std::string metric_name = function_call->_function_name;
+ metric_name = metric_name.substr(0, metric_name.size() - 12);
+ range_search_runtime.metric_type =
segment_v2::string_to_metric(metric_name);
+ }
+
+ if (function_call->get_num_children() != 2) {
Review Comment:
suitable_for_ann_index = false;
return Status::OK(); better be a #DEFINE
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]