This is an automated email from the ASF dual-hosted git repository.
zclllyybb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 99691f6895d [refine](function) use typed ANN query vector (#63834)
99691f6895d is described below
commit 99691f6895d6d3f527ced114fb225a742888e4b4
Author: Mryange <[email protected]>
AuthorDate: Fri May 29 10:39:25 2026 +0800
[refine](function) use typed ANN query vector (#63834)
ANN query vector extraction returned a generic `IColumn::Ptr`, so the
TopN and range search paths had to downcast the column again before
reading float data. This made the code more indirect and delayed type
validation. This PR changes the helper and runtime state to keep the
query vector as `ColumnFloat32::Ptr`, validates the concrete type at
extraction time, and removes redundant casts from the ANN execution
path.
---
be/src/storage/index/ann/ann_range_search_runtime.cpp | 3 +--
be/src/storage/index/ann/ann_range_search_runtime.h | 2 +-
be/src/storage/index/ann/ann_topn_runtime.cpp | 19 +++++++++++++------
be/src/storage/index/ann/ann_topn_runtime.h | 5 +++--
be/test/storage/index/ann/ann_range_search_test.cpp | 3 ++-
.../storage/index/ann/ann_topn_descriptor_test.cpp | 3 +--
.../storage/index/ann/extract_query_vector_test.cpp | 18 +++++++++++++++++-
7 files changed, 38 insertions(+), 15 deletions(-)
diff --git a/be/src/storage/index/ann/ann_range_search_runtime.cpp
b/be/src/storage/index/ann/ann_range_search_runtime.cpp
index d66c660eea3..c887ca6fc33 100644
--- a/be/src/storage/index/ann/ann_range_search_runtime.cpp
+++ b/be/src/storage/index/ann/ann_range_search_runtime.cpp
@@ -34,8 +34,7 @@ namespace doris::segment_v2 {
*/
AnnRangeSearchParams AnnRangeSearchRuntime::to_range_search_params() const {
AnnRangeSearchParams params;
- const auto* query = assert_cast<const ColumnFloat32*>(query_value.get());
- params.query_value = query->get_data().data();
+ params.query_value = query_value->get_data().data();
params.radius = static_cast<float>(radius);
params.roaring = nullptr;
params.is_le_or_lt = is_le_or_lt;
diff --git a/be/src/storage/index/ann/ann_range_search_runtime.h
b/be/src/storage/index/ann/ann_range_search_runtime.h
index f5b884bd649..b1fd9dd9e72 100644
--- a/be/src/storage/index/ann/ann_range_search_runtime.h
+++ b/be/src/storage/index/ann/ann_range_search_runtime.h
@@ -132,6 +132,6 @@ struct AnnRangeSearchRuntime {
double radius = 0.0; ///< Search radius/distance
threshold
AnnIndexMetric metric_type; ///< Distance metric (L2, Inner
Product, etc.)
doris::VectorSearchUserParams user_params; ///< User-defined search
parameters
- IColumn::Ptr query_value; ///< Query vector data (deep
copied)
+ ColumnFloat32::Ptr query_value; ///< Query vector data
};
} // namespace doris::segment_v2
\ No newline at end of file
diff --git a/be/src/storage/index/ann/ann_topn_runtime.cpp
b/be/src/storage/index/ann/ann_topn_runtime.cpp
index b80289be90d..a0e1382e6be 100644
--- a/be/src/storage/index/ann/ann_topn_runtime.cpp
+++ b/be/src/storage/index/ann/ann_topn_runtime.cpp
@@ -29,6 +29,7 @@
#include "core/column/column_array.h"
#include "core/column/column_const.h"
#include "core/column/column_nullable.h"
+#include "core/column/column_vector.h"
#include "core/data_type/primitive_type.h"
#include "exprs/function/array/function_array_distance.h"
#include "exprs/vexpr_context.h"
@@ -42,7 +43,7 @@
namespace doris::segment_v2 {
-Result<IColumn::Ptr> extract_query_vector(std::shared_ptr<VExpr> arg_expr) {
+Result<ColumnFloat32::Ptr> extract_query_vector(std::shared_ptr<VExpr>
arg_expr) {
if (arg_expr->is_constant() == false) {
return ResultError(Status::InvalidArgument("Ann topn expr must be
constant, got\n{}",
arg_expr->debug_string()));
@@ -98,7 +99,14 @@ Result<IColumn::Ptr>
extract_query_vector(std::shared_ptr<VExpr> arg_expr) {
values_holder_col = value_nullable_col->get_nested_column_ptr();
}
- return values_holder_col;
+ auto float_col =
check_and_get_column_ptr<ColumnFloat32>(values_holder_col);
+ if (float_col.get() == nullptr) {
+ return ResultError(Status::InvalidArgument(
+ "Ann topn query vector elements must be Float32, got column:
{}",
+ values_holder_col->get_name()));
+ }
+
+ return float_col;
}
Status AnnTopNRuntime::prepare(RuntimeState* state, const RowDescriptor&
row_desc) {
@@ -188,10 +196,10 @@ Status
AnnTopNRuntime::evaluate_vector_ann_search(segment_v2::AnnIndexIterator*
DCHECK(ann_index_iterator != nullptr);
DCHECK(_order_by_expr_ctx != nullptr);
DCHECK(_order_by_expr_ctx->root() != nullptr);
- size_t query_array_size = _query_array->size();
- if (_query_array.get() == nullptr || query_array_size == 0) {
+ if (_query_array.get() == nullptr || _query_array->size() == 0) {
return Status::InternalError("Ann topn query vector is not
initialized");
}
+ size_t query_array_size = _query_array->size();
// TODO:(zhiqiang) Maybe we can move this dimension check to prepare phase.
@@ -203,9 +211,8 @@ Status
AnnTopNRuntime::evaluate_vector_ann_search(segment_v2::AnnIndexIterator*
"Ann topn query vector dimension {} does not match index
dimension {}",
query_array_size, ann_index_reader->get_dimension());
}
- const ColumnFloat32* query = assert_cast<const
ColumnFloat32*>(_query_array.get());
segment_v2::AnnTopNParam ann_query_params {
- .query_value = query->get_data().data(),
+ .query_value = _query_array->get_data().data(),
.query_value_size = query_array_size,
.limit = _limit,
._user_params = _user_params,
diff --git a/be/src/storage/index/ann/ann_topn_runtime.h
b/be/src/storage/index/ann/ann_topn_runtime.h
index 0c76c07d4ad..49f48fa7388 100644
--- a/be/src/storage/index/ann/ann_topn_runtime.h
+++ b/be/src/storage/index/ann/ann_topn_runtime.h
@@ -39,6 +39,7 @@
#include <vector>
#include "core/column/column.h"
+#include "core/column/column_vector.h"
#include "core/data_type/primitive_type.h"
#include "exprs/vectorized_fn_call.h"
#include "exprs/vexpr.h"
@@ -51,7 +52,7 @@ namespace doris::segment_v2 {
struct AnnIndexStats;
class AnnIndexIterator;
-Result<IColumn::Ptr> extract_query_vector(std::shared_ptr<VExpr> arg_expr);
+Result<ColumnFloat32::Ptr> extract_query_vector(std::shared_ptr<VExpr>
arg_expr);
/**
* @brief Runtime execution engine for ANN (Approximate Nearest Neighbor)
Top-N queries.
@@ -164,7 +165,7 @@ private:
size_t _src_column_idx = -1; ///< Source vector column index
size_t _dest_column_idx = -1; ///< Destination distance
column index
segment_v2::AnnIndexMetric _metric_type; ///< Distance metric type
- IColumn::Ptr _query_array; ///< Query vector data
(contiguous float buffer)
+ ColumnFloat32::Ptr _query_array; ///< Query vector data
(contiguous float buffer)
doris::VectorSearchUserParams _user_params; ///< User-defined search
parameters
};
} // namespace doris::segment_v2
diff --git a/be/test/storage/index/ann/ann_range_search_test.cpp
b/be/test/storage/index/ann/ann_range_search_test.cpp
index aa9868350fb..572783ebecb 100644
--- a/be/test/storage/index/ann/ann_range_search_test.cpp
+++ b/be/test/storage/index/ann/ann_range_search_test.cpp
@@ -86,8 +86,9 @@ TEST_F(VectorSearchTest, TestPrepareAnnRangeSearch) {
EXPECT_EQ(ann_range_search_runtime.radius, 10.0f);
std::vector<int> query_array_groud_truth = {1, 2, 3, 4, 5, 6, 7, 20};
std::vector<int> query_array_f32;
+ const auto& query_value =
range_search_ctx->_ann_range_search_runtime.query_value;
for (int i = 0; i < query_array_groud_truth.size(); ++i) {
-
query_array_f32.push_back(static_cast<int>(ann_range_search_runtime.query_value[i]));
+
query_array_f32.push_back(static_cast<int>(query_value->get_data()[i]));
}
for (int i = 0; i < query_array_f32.size(); ++i) {
EXPECT_EQ(query_array_f32[i], query_array_groud_truth[i]);
diff --git a/be/test/storage/index/ann/ann_topn_descriptor_test.cpp
b/be/test/storage/index/ann/ann_topn_descriptor_test.cpp
index fa64807bece..3e6043a52ad 100644
--- a/be/test/storage/index/ann/ann_topn_descriptor_test.cpp
+++ b/be/test/storage/index/ann/ann_topn_descriptor_test.cpp
@@ -116,8 +116,7 @@ TEST_F(VectorSearchTest, AnnTopNRuntimeEvaluateTopN) {
ASSERT_TRUE(st.ok()) << fmt::format("st: {}, expr {}", st.to_string(),
predicate->get_order_by_expr_ctx()->root()->debug_string());
- const ColumnFloat32* query_column =
- assert_cast<const ColumnFloat32*>(predicate->_query_array.get());
+ const auto& query_column = predicate->_query_array;
const float* query_value = query_column->get_data().data();
const size_t query_value_size = predicate->_query_array->size();
ASSERT_EQ(query_value_size, 8);
diff --git a/be/test/storage/index/ann/extract_query_vector_test.cpp
b/be/test/storage/index/ann/extract_query_vector_test.cpp
index 13e059825b3..3b72e4f4ca6 100644
--- a/be/test/storage/index/ann/extract_query_vector_test.cpp
+++ b/be/test/storage/index/ann/extract_query_vector_test.cpp
@@ -183,7 +183,7 @@ TEST_F(ExtractQueryVectorTest, ValuesMatchInput) {
auto result = extract_query_vector(mock);
ASSERT_TRUE(result.has_value());
- auto* float_col = assert_cast<const ColumnFloat32*>(result.value().get());
+ const auto& float_col = result.value();
ASSERT_EQ(float_col->size(), 4u);
for (size_t i = 0; i < input.size(); ++i) {
EXPECT_FLOAT_EQ(float_col->get_data()[i], input[i]);
@@ -245,4 +245,20 @@ TEST_F(ExtractQueryVectorTest, NonArrayColumnFails) {
EXPECT_TRUE(result.error().to_string().find("Array literal") !=
std::string::npos);
}
+TEST_F(ExtractQueryVectorTest, NonFloatArrayFails) {
+ auto int_col = ColumnInt32::create();
+ int_col->insert_value(1);
+ int_col->insert_value(2);
+ auto offsets = ColumnArray::ColumnOffsets::create();
+ offsets->insert_value(2);
+ auto array_col = ColumnArray::create(std::move(int_col),
std::move(offsets));
+
+ auto mock = std::make_shared<MockConstVExpr>();
+ mock->set_column(std::move(array_col));
+
+ auto result = extract_query_vector(mock);
+ ASSERT_FALSE(result.has_value());
+ EXPECT_TRUE(result.error().to_string().find("must be Float32") !=
std::string::npos);
+}
+
} // namespace doris::segment_v2
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]