This is an automated email from the ASF dual-hosted git repository.

zclllyybb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 99691f6895d [refine](function) use typed ANN query vector (#63834)
99691f6895d is described below

commit 99691f6895d6d3f527ced114fb225a742888e4b4
Author: Mryange <[email protected]>
AuthorDate: Fri May 29 10:39:25 2026 +0800

    [refine](function) use typed ANN query vector (#63834)
    
    ANN query vector extraction returned a generic `IColumn::Ptr`, so the
    TopN and range search paths had to downcast the column again before
    reading float data. This made the code more indirect and delayed type
    validation. This PR changes the helper and runtime state to keep the
    query vector as `ColumnFloat32::Ptr`, validates the concrete type at
    extraction time, and removes redundant casts from the ANN execution
    path.
---
 be/src/storage/index/ann/ann_range_search_runtime.cpp |  3 +--
 be/src/storage/index/ann/ann_range_search_runtime.h   |  2 +-
 be/src/storage/index/ann/ann_topn_runtime.cpp         | 19 +++++++++++++------
 be/src/storage/index/ann/ann_topn_runtime.h           |  5 +++--
 be/test/storage/index/ann/ann_range_search_test.cpp   |  3 ++-
 .../storage/index/ann/ann_topn_descriptor_test.cpp    |  3 +--
 .../storage/index/ann/extract_query_vector_test.cpp   | 18 +++++++++++++++++-
 7 files changed, 38 insertions(+), 15 deletions(-)

diff --git a/be/src/storage/index/ann/ann_range_search_runtime.cpp 
b/be/src/storage/index/ann/ann_range_search_runtime.cpp
index d66c660eea3..c887ca6fc33 100644
--- a/be/src/storage/index/ann/ann_range_search_runtime.cpp
+++ b/be/src/storage/index/ann/ann_range_search_runtime.cpp
@@ -34,8 +34,7 @@ namespace doris::segment_v2 {
  */
 AnnRangeSearchParams AnnRangeSearchRuntime::to_range_search_params() const {
     AnnRangeSearchParams params;
-    const auto* query = assert_cast<const ColumnFloat32*>(query_value.get());
-    params.query_value = query->get_data().data();
+    params.query_value = query_value->get_data().data();
     params.radius = static_cast<float>(radius);
     params.roaring = nullptr;
     params.is_le_or_lt = is_le_or_lt;
diff --git a/be/src/storage/index/ann/ann_range_search_runtime.h 
b/be/src/storage/index/ann/ann_range_search_runtime.h
index f5b884bd649..b1fd9dd9e72 100644
--- a/be/src/storage/index/ann/ann_range_search_runtime.h
+++ b/be/src/storage/index/ann/ann_range_search_runtime.h
@@ -132,6 +132,6 @@ struct AnnRangeSearchRuntime {
     double radius = 0.0;                       ///< Search radius/distance 
threshold
     AnnIndexMetric metric_type;                ///< Distance metric (L2, Inner 
Product, etc.)
     doris::VectorSearchUserParams user_params; ///< User-defined search 
parameters
-    IColumn::Ptr query_value;                  ///< Query vector data (deep 
copied)
+    ColumnFloat32::Ptr query_value;            ///< Query vector data
 };
 } // namespace doris::segment_v2
\ No newline at end of file
diff --git a/be/src/storage/index/ann/ann_topn_runtime.cpp 
b/be/src/storage/index/ann/ann_topn_runtime.cpp
index b80289be90d..a0e1382e6be 100644
--- a/be/src/storage/index/ann/ann_topn_runtime.cpp
+++ b/be/src/storage/index/ann/ann_topn_runtime.cpp
@@ -29,6 +29,7 @@
 #include "core/column/column_array.h"
 #include "core/column/column_const.h"
 #include "core/column/column_nullable.h"
+#include "core/column/column_vector.h"
 #include "core/data_type/primitive_type.h"
 #include "exprs/function/array/function_array_distance.h"
 #include "exprs/vexpr_context.h"
@@ -42,7 +43,7 @@
 
 namespace doris::segment_v2 {
 
-Result<IColumn::Ptr> extract_query_vector(std::shared_ptr<VExpr> arg_expr) {
+Result<ColumnFloat32::Ptr> extract_query_vector(std::shared_ptr<VExpr> 
arg_expr) {
     if (arg_expr->is_constant() == false) {
         return ResultError(Status::InvalidArgument("Ann topn expr must be 
constant, got\n{}",
                                                    arg_expr->debug_string()));
@@ -98,7 +99,14 @@ Result<IColumn::Ptr> 
extract_query_vector(std::shared_ptr<VExpr> arg_expr) {
         values_holder_col = value_nullable_col->get_nested_column_ptr();
     }
 
-    return values_holder_col;
+    auto float_col = 
check_and_get_column_ptr<ColumnFloat32>(values_holder_col);
+    if (float_col.get() == nullptr) {
+        return ResultError(Status::InvalidArgument(
+                "Ann topn query vector elements must be Float32, got column: 
{}",
+                values_holder_col->get_name()));
+    }
+
+    return float_col;
 }
 
 Status AnnTopNRuntime::prepare(RuntimeState* state, const RowDescriptor& 
row_desc) {
@@ -188,10 +196,10 @@ Status 
AnnTopNRuntime::evaluate_vector_ann_search(segment_v2::AnnIndexIterator*
     DCHECK(ann_index_iterator != nullptr);
     DCHECK(_order_by_expr_ctx != nullptr);
     DCHECK(_order_by_expr_ctx->root() != nullptr);
-    size_t query_array_size = _query_array->size();
-    if (_query_array.get() == nullptr || query_array_size == 0) {
+    if (_query_array.get() == nullptr || _query_array->size() == 0) {
         return Status::InternalError("Ann topn query vector is not 
initialized");
     }
+    size_t query_array_size = _query_array->size();
 
     // TODO:(zhiqiang) Maybe we can move this dimension check to prepare phase.
 
@@ -203,9 +211,8 @@ Status 
AnnTopNRuntime::evaluate_vector_ann_search(segment_v2::AnnIndexIterator*
                 "Ann topn query vector dimension {} does not match index 
dimension {}",
                 query_array_size, ann_index_reader->get_dimension());
     }
-    const ColumnFloat32* query = assert_cast<const 
ColumnFloat32*>(_query_array.get());
     segment_v2::AnnTopNParam ann_query_params {
-            .query_value = query->get_data().data(),
+            .query_value = _query_array->get_data().data(),
             .query_value_size = query_array_size,
             .limit = _limit,
             ._user_params = _user_params,
diff --git a/be/src/storage/index/ann/ann_topn_runtime.h 
b/be/src/storage/index/ann/ann_topn_runtime.h
index 0c76c07d4ad..49f48fa7388 100644
--- a/be/src/storage/index/ann/ann_topn_runtime.h
+++ b/be/src/storage/index/ann/ann_topn_runtime.h
@@ -39,6 +39,7 @@
 #include <vector>
 
 #include "core/column/column.h"
+#include "core/column/column_vector.h"
 #include "core/data_type/primitive_type.h"
 #include "exprs/vectorized_fn_call.h"
 #include "exprs/vexpr.h"
@@ -51,7 +52,7 @@ namespace doris::segment_v2 {
 struct AnnIndexStats;
 class AnnIndexIterator;
 
-Result<IColumn::Ptr> extract_query_vector(std::shared_ptr<VExpr> arg_expr);
+Result<ColumnFloat32::Ptr> extract_query_vector(std::shared_ptr<VExpr> 
arg_expr);
 
 /**
  * @brief Runtime execution engine for ANN (Approximate Nearest Neighbor) 
Top-N queries.
@@ -164,7 +165,7 @@ private:
     size_t _src_column_idx = -1;                ///< Source vector column index
     size_t _dest_column_idx = -1;               ///< Destination distance 
column index
     segment_v2::AnnIndexMetric _metric_type;    ///< Distance metric type
-    IColumn::Ptr _query_array;                  ///< Query vector data 
(contiguous float buffer)
+    ColumnFloat32::Ptr _query_array;            ///< Query vector data 
(contiguous float buffer)
     doris::VectorSearchUserParams _user_params; ///< User-defined search 
parameters
 };
 } // namespace doris::segment_v2
diff --git a/be/test/storage/index/ann/ann_range_search_test.cpp 
b/be/test/storage/index/ann/ann_range_search_test.cpp
index aa9868350fb..572783ebecb 100644
--- a/be/test/storage/index/ann/ann_range_search_test.cpp
+++ b/be/test/storage/index/ann/ann_range_search_test.cpp
@@ -86,8 +86,9 @@ TEST_F(VectorSearchTest, TestPrepareAnnRangeSearch) {
     EXPECT_EQ(ann_range_search_runtime.radius, 10.0f);
     std::vector<int> query_array_groud_truth = {1, 2, 3, 4, 5, 6, 7, 20};
     std::vector<int> query_array_f32;
+    const auto& query_value = 
range_search_ctx->_ann_range_search_runtime.query_value;
     for (int i = 0; i < query_array_groud_truth.size(); ++i) {
-        
query_array_f32.push_back(static_cast<int>(ann_range_search_runtime.query_value[i]));
+        
query_array_f32.push_back(static_cast<int>(query_value->get_data()[i]));
     }
     for (int i = 0; i < query_array_f32.size(); ++i) {
         EXPECT_EQ(query_array_f32[i], query_array_groud_truth[i]);
diff --git a/be/test/storage/index/ann/ann_topn_descriptor_test.cpp 
b/be/test/storage/index/ann/ann_topn_descriptor_test.cpp
index fa64807bece..3e6043a52ad 100644
--- a/be/test/storage/index/ann/ann_topn_descriptor_test.cpp
+++ b/be/test/storage/index/ann/ann_topn_descriptor_test.cpp
@@ -116,8 +116,7 @@ TEST_F(VectorSearchTest, AnnTopNRuntimeEvaluateTopN) {
     ASSERT_TRUE(st.ok()) << fmt::format("st: {}, expr {}", st.to_string(),
                                         
predicate->get_order_by_expr_ctx()->root()->debug_string());
 
-    const ColumnFloat32* query_column =
-            assert_cast<const ColumnFloat32*>(predicate->_query_array.get());
+    const auto& query_column = predicate->_query_array;
     const float* query_value = query_column->get_data().data();
     const size_t query_value_size = predicate->_query_array->size();
     ASSERT_EQ(query_value_size, 8);
diff --git a/be/test/storage/index/ann/extract_query_vector_test.cpp 
b/be/test/storage/index/ann/extract_query_vector_test.cpp
index 13e059825b3..3b72e4f4ca6 100644
--- a/be/test/storage/index/ann/extract_query_vector_test.cpp
+++ b/be/test/storage/index/ann/extract_query_vector_test.cpp
@@ -183,7 +183,7 @@ TEST_F(ExtractQueryVectorTest, ValuesMatchInput) {
 
     auto result = extract_query_vector(mock);
     ASSERT_TRUE(result.has_value());
-    auto* float_col = assert_cast<const ColumnFloat32*>(result.value().get());
+    const auto& float_col = result.value();
     ASSERT_EQ(float_col->size(), 4u);
     for (size_t i = 0; i < input.size(); ++i) {
         EXPECT_FLOAT_EQ(float_col->get_data()[i], input[i]);
@@ -245,4 +245,20 @@ TEST_F(ExtractQueryVectorTest, NonArrayColumnFails) {
     EXPECT_TRUE(result.error().to_string().find("Array literal") != 
std::string::npos);
 }
 
+TEST_F(ExtractQueryVectorTest, NonFloatArrayFails) {
+    auto int_col = ColumnInt32::create();
+    int_col->insert_value(1);
+    int_col->insert_value(2);
+    auto offsets = ColumnArray::ColumnOffsets::create();
+    offsets->insert_value(2);
+    auto array_col = ColumnArray::create(std::move(int_col), 
std::move(offsets));
+
+    auto mock = std::make_shared<MockConstVExpr>();
+    mock->set_column(std::move(array_col));
+
+    auto result = extract_query_vector(mock);
+    ASSERT_FALSE(result.has_value());
+    EXPECT_TRUE(result.error().to_string().find("must be Float32") != 
std::string::npos);
+}
+
 } // namespace doris::segment_v2


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to