This is an automated email from the ASF dual-hosted git repository.
airborne pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 9e6da1ee7cd [opt](ann index) Support cast expression as rhs of
approximate top n (#55458)
9e6da1ee7cd is described below
commit 9e6da1ee7cd7323212c287ef42b5e272a8178507
Author: zhiqiang <[email protected]>
AuthorDate: Tue Sep 16 15:15:30 2025 +0800
[opt](ann index) Support cast expression as rhs of approximate top n
(#55458)
### What problem does this PR solve?
Support expression like this to be pushed down to ann index
```sql
select id from ann_cast_rhs_ip order by
inner_product_approximate(embedding, cast('[0.1,0.2,0.3,0.4]' as array<float>))
desc limit 3;
```
Related PR: https://github.com/apache/doris/pull/54276
---
.../segment_v2/ann_index/ann_index_reader.cpp | 8 ++
.../rowset/segment_v2/ann_index/ann_index_reader.h | 4 +-
.../ann_index/ann_range_search_runtime.cpp | 6 +-
.../ann_index/ann_range_search_runtime.h | 29 ++--
.../segment_v2/ann_index/ann_search_params.h | 2 +-
.../segment_v2/ann_index/ann_topn_runtime.cpp | 149 +++++++++++++++-----
.../rowset/segment_v2/ann_index/ann_topn_runtime.h | 5 +-
be/src/vec/exec/scan/olap_scanner.cpp | 1 +
be/src/vec/exprs/vectorized_fn_call.cpp | 61 ++++----
.../olap/vector_search/ann_index_reader_test.cpp | 4 +-
.../olap/vector_search/ann_range_search_test.cpp | 63 ++++++++-
.../vector_search/ann_topn_descriptor_test.cpp | 26 ++--
.../ann_topn_runtime_negative_test.cpp | 148 ++++++++++++++++++++
be/test/olap/vector_search/vector_search_utils.cpp | 10 ++
.../data/ann_index_p0/cast_string_as_array.out | Bin 0 -> 363 bytes
.../ann_index_p0/cast_string_as_array.groovy | 155 +++++++++++++++++++++
16 files changed, 571 insertions(+), 100 deletions(-)
diff --git a/be/src/olap/rowset/segment_v2/ann_index/ann_index_reader.cpp
b/be/src/olap/rowset/segment_v2/ann_index/ann_index_reader.cpp
index c9ff38102b5..35880336580 100644
--- a/be/src/olap/rowset/segment_v2/ann_index/ann_index_reader.cpp
+++ b/be/src/olap/rowset/segment_v2/ann_index/ann_index_reader.cpp
@@ -24,6 +24,7 @@
#include "common/config.h"
#include "io/io_common.h"
#include "olap/rowset/segment_v2/ann_index/ann_index.h"
+#include "olap/rowset/segment_v2/ann_index/ann_index_writer.h"
#include "olap/rowset/segment_v2/ann_index/ann_search_params.h"
#include "olap/rowset/segment_v2/ann_index/faiss_ann_index.h"
#include "olap/rowset/segment_v2/index_file_reader.h"
@@ -58,6 +59,9 @@ AnnIndexReader::AnnIndexReader(const TabletIndex* index_meta,
it = index_properties.find("metric_type");
DCHECK(it != index_properties.end());
_metric_type = string_to_metric(it->second);
+ it = index_properties.find(AnnIndexColumnWriter::DIM);
+ DCHECK(it != index_properties.end());
+ _dim = std::stoi(it->second);
}
Status AnnIndexReader::new_iterator(std::unique_ptr<IndexIterator>* iterator) {
@@ -225,4 +229,8 @@ Status AnnIndexReader::range_search(const
AnnRangeSearchParams& params,
return Status::OK();
}
+size_t AnnIndexReader::get_dimension() const {
+ return _dim;
+}
+
} // namespace doris::segment_v2
\ No newline at end of file
diff --git a/be/src/olap/rowset/segment_v2/ann_index/ann_index_reader.h
b/be/src/olap/rowset/segment_v2/ann_index/ann_index_reader.h
index 8d9aba5239c..68e62f7f347 100644
--- a/be/src/olap/rowset/segment_v2/ann_index/ann_index_reader.h
+++ b/be/src/olap/rowset/segment_v2/ann_index/ann_index_reader.h
@@ -59,13 +59,15 @@ public:
AnnIndexMetric get_metric_type() const { return _metric_type; }
+ size_t get_dimension() const;
+
private:
TabletIndex _index_meta;
std::shared_ptr<IndexFileReader> _index_file_reader;
std::unique_ptr<VectorIndex> _vector_index;
AnnIndexType _index_type;
AnnIndexMetric _metric_type;
-
+ size_t _dim;
DorisCallOnce<Status> _load_index_once;
};
diff --git
a/be/src/olap/rowset/segment_v2/ann_index/ann_range_search_runtime.cpp
b/be/src/olap/rowset/segment_v2/ann_index/ann_range_search_runtime.cpp
index 416af7826d2..f8e0e4af8c3 100644
--- a/be/src/olap/rowset/segment_v2/ann_index/ann_range_search_runtime.cpp
+++ b/be/src/olap/rowset/segment_v2/ann_index/ann_range_search_runtime.cpp
@@ -35,7 +35,8 @@ namespace doris::segment_v2 {
*/
AnnRangeSearchParams AnnRangeSearchRuntime::to_range_search_params() const {
AnnRangeSearchParams params;
- params.query_value = query_value.get();
+ const auto* query = assert_cast<const
vectorized::ColumnFloat32*>(query_value.get());
+ params.query_value = query->get_data().data();
params.radius = static_cast<float>(radius);
params.roaring = nullptr;
params.is_le_or_lt = is_le_or_lt;
@@ -58,6 +59,7 @@ std::string AnnRangeSearchRuntime::to_string() const {
"dst_col_idx: {}, metric_type {}, radius: {}, user params: {},
query_vector is null: "
"{}",
is_ann_range_search, is_le_or_lt, src_col_idx, dst_col_idx,
- metric_to_string(metric_type), radius, user_params.to_string(),
query_value == nullptr);
+ metric_to_string(metric_type), radius, user_params.to_string(),
+ query_value.get() == nullptr);
}
} // namespace doris::segment_v2
\ No newline at end of file
diff --git a/be/src/olap/rowset/segment_v2/ann_index/ann_range_search_runtime.h
b/be/src/olap/rowset/segment_v2/ann_index/ann_range_search_runtime.h
index c3d112c9bf2..113bdf8786f 100644
--- a/be/src/olap/rowset/segment_v2/ann_index/ann_range_search_runtime.h
+++ b/be/src/olap/rowset/segment_v2/ann_index/ann_range_search_runtime.h
@@ -22,6 +22,11 @@
#include <string>
#include "olap/rowset/segment_v2/ann_index/ann_index.h"
+#include "runtime/define_primitive_type.h"
+#include "runtime/primitive_type.h"
+#include "vec/columns/column.h"
+#include "vec/columns/column_vector.h"
+#include "vec/common/assert_cast.h"
#include "vec/runtime/vector_search_user_params.h"
namespace doris::segment_v2 {
@@ -81,16 +86,8 @@ struct AnnRangeSearchRuntime {
dst_col_idx(other.dst_col_idx),
radius(other.radius),
metric_type(other.metric_type),
- user_params(other.user_params) {
- // Do deep copy to query_value.
- if (other.query_value) {
- query_value = std::make_unique<float[]>(other.dim);
- std::copy(other.query_value.get(), other.query_value.get() +
other.dim,
- query_value.get());
- } else {
- query_value = nullptr;
- }
- }
+ user_params(other.user_params),
+ query_value(other.query_value) {}
/**
* @brief Assignment operator with deep copy semantics.
@@ -110,14 +107,8 @@ struct AnnRangeSearchRuntime {
metric_type = other.metric_type;
user_params = other.user_params;
dim = other.dim;
- // Do deep copy to query_value.
- if (other.query_value) {
- query_value = std::make_unique<float[]>(other.dim);
- std::copy(other.query_value.get(), other.query_value.get() +
other.dim,
- query_value.get());
- } else {
- query_value = nullptr;
- }
+ query_value = other.query_value;
+
return *this;
}
@@ -142,7 +133,7 @@ struct AnnRangeSearchRuntime {
double radius = 0.0; ///< Search radius/distance
threshold
AnnIndexMetric metric_type; ///< Distance metric (L2, Inner
Product, etc.)
doris::VectorSearchUserParams user_params; ///< User-defined search
parameters
- std::unique_ptr<float[]> query_value; ///< Query vector data (deep
copied)
+ vectorized::IColumn::Ptr query_value; ///< Query vector data (deep
copied)
};
#include "common/compile_check_end.h"
} // namespace doris::segment_v2
\ No newline at end of file
diff --git a/be/src/olap/rowset/segment_v2/ann_index/ann_search_params.h
b/be/src/olap/rowset/segment_v2/ann_index/ann_search_params.h
index c38c8d2138d..b2d9c758659 100644
--- a/be/src/olap/rowset/segment_v2/ann_index/ann_search_params.h
+++ b/be/src/olap/rowset/segment_v2/ann_index/ann_search_params.h
@@ -94,7 +94,7 @@ struct AnnTopNParam {
struct AnnRangeSearchParams {
bool is_le_or_lt = true;
- float* query_value = nullptr;
+ const float* query_value = nullptr;
float radius = -1;
roaring::Roaring* roaring; // roaring from segment_iterator
std::string to_string() const {
diff --git a/be/src/olap/rowset/segment_v2/ann_index/ann_topn_runtime.cpp
b/be/src/olap/rowset/segment_v2/ann_index/ann_topn_runtime.cpp
index 62dac6e7095..dc99192d7f1 100644
--- a/be/src/olap/rowset/segment_v2/ann_index/ann_topn_runtime.cpp
+++ b/be/src/olap/rowset/segment_v2/ann_index/ann_topn_runtime.cpp
@@ -22,16 +22,19 @@
#include <string>
#include <utility>
+#include "common/exception.h"
#include "common/logging.h"
+#include "common/status.h"
#include "olap/rowset/segment_v2/ann_index/ann_index_iterator.h"
#include "olap/rowset/segment_v2/ann_index/ann_search_params.h"
+#include "olap/rowset/segment_v2/inverted_index_query_type.h"
+#include "runtime/primitive_type.h"
#include "runtime/runtime_state.h"
#include "vec/columns/column.h"
#include "vec/columns/column_array.h"
-#include "vec/columns/column_const.h"
#include "vec/columns/column_nullable.h"
-#include "vec/common/assert_cast.h"
#include "vec/exprs/varray_literal.h"
+#include "vec/exprs/vcast_expr.h"
#include "vec/exprs/vexpr_context.h"
#include "vec/exprs/vexpr_fwd.h"
#include "vec/exprs/virtual_slot_ref.h"
@@ -40,6 +43,81 @@
namespace doris::segment_v2 {
#include "common/compile_check_begin.h"
+
+Result<vectorized::IColumn::Ptr>
extract_query_vector(std::shared_ptr<vectorized::VExpr> arg_expr) {
+ if (arg_expr->is_constant() == false) {
+ return ResultError(Status::InvalidArgument("Ann topn expr must be
constant, got\n{}",
+ arg_expr->debug_string()));
+ }
+
+ // Accept either ArrayLiteral([..]) or CAST('..' AS
Nullable(Array(Nullable(Float32))))
+ // First, check the expr node type for clarity.
+
+ bool is_array_literal =
+ std::dynamic_pointer_cast<vectorized::VArrayLiteral>(arg_expr) !=
nullptr;
+ bool is_cast_expr =
std::dynamic_pointer_cast<vectorized::VCastExpr>(arg_expr) != nullptr;
+ if (!is_array_literal && !is_cast_expr) {
+ return ResultError(
+ Status::InvalidArgument("Constant must be ArrayLiteral or CAST
to array, got\n{}",
+ arg_expr->debug_string()));
+ }
+
+ // We'll validate shape by inspecting the materialized constant column
below.
+
+ std::shared_ptr<ColumnPtrWrapper> column_wrapper;
+ auto st = arg_expr->get_const_col(nullptr, &column_wrapper);
+ if (!st.ok()) {
+ return ResultError(Status::InvalidArgument("Failed to get constant
column, error: {}",
+ st.to_string()));
+ }
+
+ // Execute the constant array literal and extract its float elements into
_query_array
+ vectorized::IColumn::Ptr col_ptr =
+ column_wrapper->column_ptr->convert_to_full_column_if_const();
+
+ // The expected runtime column layout for the literal is:
+ // Nullable(ColumnArray(Nullable(ColumnFloat32))) with exactly 1 row (one
array literal)
+ const vectorized::IColumn* top_col = col_ptr.get();
+ const vectorized::IColumn* array_holder_col = top_col;
+ // Handle outer Nullable and remember result nullability preference
+ if (auto* nullable_col =
+
vectorized::check_and_get_column<vectorized::ColumnNullable>(*top_col)) {
+ if (nullable_col->has_null()) {
+ return ResultError(Status::InvalidArgument("Ann query vector
cannot be NULL"));
+ }
+ array_holder_col = &nullable_col->get_nested_column();
+ }
+
+ // Must be an array column with single row
+ const auto* array_col =
+
vectorized::check_and_get_column<vectorized::ColumnArray>(*array_holder_col);
+ if (array_col == nullptr || array_col->size() != 1) {
+ return ResultError(Status::InvalidArgument(
+ "Ann topn expr constant should be an Array literal, got
column: {}",
+ array_holder_col->get_name()));
+ }
+
+ // Fetch nested data column: Nullable(ColumnFloat32) or ColumnFloat32
+ const vectorized::IColumn& nested_data_any = array_col->get_data();
+ vectorized::IColumn::Ptr values_holder_col = array_col->get_data_ptr();
+ size_t value_count = array_col->get_offsets()[0];
+
+ if (value_count == 0) {
+ return ResultError(Status::InvalidArgument("Ann topn query vector
cannot be empty"));
+ }
+
+ if (auto* value_nullable_col =
+
vectorized::check_and_get_column<vectorized::ColumnNullable>(nested_data_any)) {
+ if (value_nullable_col->has_null(0, value_count)) {
+ return ResultError(Status::InvalidArgument(
+ "Ann topn query vector elements cannot contain NULL
values"));
+ }
+ values_holder_col = value_nullable_col->get_nested_column_ptr();
+ }
+
+ return values_holder_col;
+}
+
Status AnnTopNRuntime::prepare(RuntimeState* state, const RowDescriptor&
row_desc) {
RETURN_IF_ERROR(_order_by_expr_ctx->prepare(state, row_desc));
RETURN_IF_ERROR(_order_by_expr_ctx->open(state));
@@ -54,13 +132,13 @@ Status AnnTopNRuntime::prepare(RuntimeState* state, const
RowDescriptor& row_des
|----------------
| |
| |
- SlotRef ArrayLiteral
+ SlotRef CAST(String as Nullable<ArrayFloat>) OR ArrayLiteral
*/
std::shared_ptr<vectorized::VirtualSlotRef> vir_slot_ref =
std::dynamic_pointer_cast<vectorized::VirtualSlotRef>(_order_by_expr_ctx->root());
DCHECK(vir_slot_ref != nullptr);
if (vir_slot_ref == nullptr) {
- return Status::InternalError(
+ return Status::InvalidArgument(
"root of order by expr of ann topn must be a
vectorized::VirtualSlotRef, got\n{}",
_order_by_expr_ctx->root()->debug_string());
}
@@ -71,27 +149,35 @@ Status AnnTopNRuntime::prepare(RuntimeState* state, const
RowDescriptor& row_des
std::dynamic_pointer_cast<vectorized::VectorizedFnCall>(vir_col_expr);
if (distance_fn_call == nullptr) {
- return Status::InternalError("Ann topn expr expect FuncationCall,
got\n{}",
- vir_col_expr->debug_string());
+ return Status::InvalidArgument("Ann topn expr expect FuncationCall,
got\n{}",
+ vir_col_expr->debug_string());
}
std::shared_ptr<vectorized::VSlotRef> slot_ref =
std::dynamic_pointer_cast<vectorized::VSlotRef>(distance_fn_call->children()[0]);
if (slot_ref == nullptr) {
- return Status::InternalError("Ann topn expr expect SlotRef, got\n{}",
-
distance_fn_call->children()[0]->debug_string());
+ return Status::InvalidArgument("Ann topn expr expect SlotRef, got\n{}",
+
distance_fn_call->children()[0]->debug_string());
}
// slot_ref->column_id() is acutually the columnd idx in block.
_src_column_idx = slot_ref->column_id();
- std::shared_ptr<vectorized::VArrayLiteral> array_literal =
-
std::dynamic_pointer_cast<vectorized::VArrayLiteral>(distance_fn_call->children()[1]);
- if (array_literal == nullptr) {
- return Status::InternalError("Ann topn expr expect ArrayLiteral,
got\n{}",
-
distance_fn_call->children()[1]->debug_string());
+ if (distance_fn_call->children()[1]->is_constant() == false) {
+ return Status::InvalidArgument("Ann topn expr expect constant
ArrayLiteral, got\n{}",
+
distance_fn_call->children()[1]->debug_string());
}
- _query_array = array_literal->get_column_ptr();
+
+ // Accept either ArrayLiteral([..]) or CAST('..' AS
Nullable(Array(Nullable(Float32))))
+ // First, check the expr node type for clarity.
+ auto arg_expr = distance_fn_call->children()[1];
+
+ auto query_array_result = extract_query_vector(arg_expr);
+ if (!query_array_result.has_value()) {
+ return query_array_result.error();
+ }
+ _query_array = query_array_result.value();
+
_user_params = state->get_vector_search_params();
std::set<std::string> distance_func_names =
{vectorized::L2DistanceApproximate::name,
@@ -121,27 +207,26 @@ Status
AnnTopNRuntime::evaluate_vector_ann_search(segment_v2::IndexIterator* ann
DCHECK(ann_index_iterator_casted != nullptr);
DCHECK(_order_by_expr_ctx != nullptr);
DCHECK(_order_by_expr_ctx->root() != nullptr);
+ size_t query_array_size = _query_array->size();
+ if (_query_array.get() == nullptr || query_array_size == 0) {
+ return Status::InternalError("Ann topn query vector is not
initialized");
+ }
- const vectorized::ColumnConst* const_column =
- assert_cast<const vectorized::ColumnConst*>(_query_array.get());
- const vectorized::ColumnArray* column_array =
- assert_cast<const
vectorized::ColumnArray*>(const_column->get_data_column_ptr().get());
- const vectorized::ColumnNullable* column_nullable =
- assert_cast<const
vectorized::ColumnNullable*>(column_array->get_data_ptr().get());
- const vectorized::ColumnFloat32* cf32 = assert_cast<const
vectorized::ColumnFloat32*>(
- column_nullable->get_nested_column_ptr().get());
-
- const float* query_value = cf32->get_data().data();
- const size_t query_value_size = cf32->get_data().size();
+ // TODO:(zhiqiang) Maybe we can move this dimension check to prepare phase.
- std::unique_ptr<float[]> query_value_f32 =
std::make_unique<float[]>(query_value_size);
- for (size_t i = 0; i < query_value_size; ++i) {
- query_value_f32[i] = static_cast<float>(query_value[i]);
+ auto index_reader =
ann_index_iterator_casted->get_reader(AnnIndexReaderType::ANN);
+ auto ann_index_reader =
std::dynamic_pointer_cast<AnnIndexReader>(index_reader);
+ DCHECK(ann_index_reader != nullptr);
+ if (ann_index_reader->get_dimension() != query_array_size) {
+ return Status::InvalidArgument(
+ "Ann topn query vector dimension {} does not match index
dimension {}",
+ query_array_size, ann_index_reader->get_dimension());
}
-
+ const vectorized::ColumnFloat32* query =
+ assert_cast<const vectorized::ColumnFloat32*>(_query_array.get());
segment_v2::AnnTopNParam ann_query_params {
- .query_value = query_value_f32.get(),
- .query_value_size = query_value_size,
+ .query_value = query->get_data().data(),
+ .query_value_size = query_array_size,
.limit = _limit,
._user_params = _user_params,
.roaring = roaring,
@@ -157,11 +242,9 @@ Status
AnnTopNRuntime::evaluate_vector_ann_search(segment_v2::IndexIterator* ann
size_t num_results = ann_query_params.distance->size();
auto result_column_float = vectorized::ColumnFloat32::create(num_results);
-
for (size_t i = 0; i < num_results; ++i) {
result_column_float->get_data()[i] = (*ann_query_params.distance)[i];
}
-
result_column = std::move(result_column_float);
row_ids = std::move(ann_query_params.row_ids);
ann_index_stats = *ann_query_params.stats;
diff --git a/be/src/olap/rowset/segment_v2/ann_index/ann_topn_runtime.h
b/be/src/olap/rowset/segment_v2/ann_index/ann_topn_runtime.h
index 8fd4dcee8a6..121901ff918 100644
--- a/be/src/olap/rowset/segment_v2/ann_index/ann_topn_runtime.h
+++ b/be/src/olap/rowset/segment_v2/ann_index/ann_topn_runtime.h
@@ -35,6 +35,7 @@
#pragma once
+#include "runtime/primitive_type.h"
#include "runtime/runtime_state.h"
#include "vec/columns/column.h"
#include "vec/exprs/varray_literal.h"
@@ -49,6 +50,8 @@ namespace doris::segment_v2 {
#include "common/compile_check_begin.h"
struct AnnIndexStats;
+Result<vectorized::IColumn::Ptr>
extract_query_vector(std::shared_ptr<vectorized::VExpr> arg_expr);
+
/**
* @brief Runtime execution engine for ANN (Approximate Nearest Neighbor)
Top-N queries.
*
@@ -161,7 +164,7 @@ private:
size_t _src_column_idx = -1; ///< Source vector column index
size_t _dest_column_idx = -1; ///< Destination distance
column index
segment_v2::AnnIndexMetric _metric_type; ///< Distance metric type
- vectorized::IColumn::Ptr _query_array; ///< Query vector data
+ vectorized::IColumn::Ptr _query_array; ///< Query vector data
(contiguous float buffer)
doris::VectorSearchUserParams _user_params; ///< User-defined search
parameters
};
#include "common/compile_check_end.h"
diff --git a/be/src/vec/exec/scan/olap_scanner.cpp
b/be/src/vec/exec/scan/olap_scanner.cpp
index 606d1757b1b..1e8a7f93212 100644
--- a/be/src/vec/exec/scan/olap_scanner.cpp
+++ b/be/src/vec/exec/scan/olap_scanner.cpp
@@ -155,6 +155,7 @@ Status OlapScanner::prepare() {
_score_runtime = local_state->_score_runtime;
_score_runtime = local_state->_score_runtime;
+ // All scanners share the same ann_topn_runtime.
_ann_topn_runtime = local_state->_ann_topn_runtime;
// set limit to reduce end of rowset and segment mem use
diff --git a/be/src/vec/exprs/vectorized_fn_call.cpp
b/be/src/vec/exprs/vectorized_fn_call.cpp
index 786674eadf3..6ee7a12a251 100644
--- a/be/src/vec/exprs/vectorized_fn_call.cpp
+++ b/be/src/vec/exprs/vectorized_fn_call.cpp
@@ -42,6 +42,7 @@
#include "vec/columns/column_array.h"
#include "vec/columns/column_nullable.h"
#include "vec/columns/column_vector.h"
+#include "vec/common/assert_cast.h"
#include "vec/core/block.h"
#include "vec/core/column_numbers.h"
#include "vec/core/types.h"
@@ -345,7 +346,7 @@ bool VectorizedFnCall::equals(const VExpr& other) {
|----------------
| |
| |
- SlotRef ArrayLiteral
+ SlotRef ArrayLiteral/Cast(String as Array<FLOAT>)
*/
void VectorizedFnCall::prepare_ann_range_search(
@@ -425,44 +426,44 @@ void VectorizedFnCall::prepare_ann_range_search(
range_search_runtime.metric_type =
segment_v2::string_to_metric(metric_name);
}
- UInt16 idx_of_slot_ref = 0;
- UInt16 idx_of_array_literal = 0;
+ // Identify the slot ref child and the constant query array child
(ArrayLiteral or CAST to array)
+ Int32 idx_of_slot_ref = -1;
+ Int32 idx_of_array_expr = -1;
for (UInt16 i = 0; i < function_call->get_num_children(); ++i) {
auto child = function_call->get_child(i);
- if (std::dynamic_pointer_cast<VSlotRef>(child) != nullptr) {
+ if (idx_of_slot_ref == -1 &&
std::dynamic_pointer_cast<VSlotRef>(child) != nullptr) {
idx_of_slot_ref = i;
- } else if (std::dynamic_pointer_cast<VArrayLiteral>(child) != nullptr)
{
- idx_of_array_literal = i;
+ continue;
+ }
+ // Accept either ArrayLiteral or Cast-to-array constant
+ if (idx_of_array_expr == -1 &&
+ (std::dynamic_pointer_cast<VArrayLiteral>(child) != nullptr ||
+ std::dynamic_pointer_cast<VCastExpr>(child) != nullptr)) {
+ idx_of_array_expr = i;
}
}
- std::shared_ptr<VSlotRef> slot_ref =
-
std::dynamic_pointer_cast<VSlotRef>(function_call->get_child(idx_of_slot_ref));
- std::shared_ptr<VArrayLiteral> array_literal =
std::dynamic_pointer_cast<VArrayLiteral>(
- function_call->get_child(idx_of_array_literal));
-
- if (slot_ref == nullptr || array_literal == nullptr) {
+ if (idx_of_slot_ref == -1 || idx_of_array_expr == -1) {
suitable_for_ann_index = false;
- // slot ref or array literal is null.
+ // slot ref or array literal/cast is missing.
return;
}
+ auto slot_ref = std::dynamic_pointer_cast<VSlotRef>(
+ function_call->get_child(static_cast<UInt16>(idx_of_slot_ref)));
range_search_runtime.src_col_idx = slot_ref->column_id();
range_search_runtime.dst_col_idx = vir_slot_ref == nullptr ? -1 :
vir_slot_ref->column_id();
- auto col_const = array_literal->get_column_ptr();
- auto col_array = col_const->convert_to_full_column_if_const();
- const ColumnArray* array_col = assert_cast<const
ColumnArray*>(col_array.get());
- DCHECK(array_col->size() == 1);
- size_t dim = array_col->get_offsets()[0];
- range_search_runtime.dim = dim;
- range_search_runtime.query_value = std::make_unique<float[]>(dim);
-
- const ColumnNullable* cn = assert_cast<const
ColumnNullable*>(array_col->get_data_ptr().get());
- const ColumnFloat32* cf32 =
- assert_cast<const
ColumnFloat32*>(cn->get_nested_column_ptr().get());
- for (size_t i = 0; i < dim; ++i) {
- range_search_runtime.query_value[i] = cf32->get_data()[i];
+
+ // Materialize the constant array expression and validate its shape and
types
+ std::shared_ptr<ColumnPtrWrapper> column_wrapper;
+ auto array_expr =
function_call->get_child(static_cast<UInt16>(idx_of_array_expr));
+ auto extract_result = extract_query_vector(array_expr);
+ if (!extract_result.has_value()) {
+ suitable_for_ann_index = false;
+ return;
}
+ range_search_runtime.query_value = extract_result.value();
+ range_search_runtime.dim = range_search_runtime.query_value->size();
range_search_runtime.is_ann_range_search = true;
range_search_runtime.user_params = user_params;
VLOG_DEBUG << fmt::format("Ann range search params: {}",
range_search_runtime.to_string());
@@ -513,6 +514,14 @@ Status VectorizedFnCall::evaluate_ann_range_search(
return Status::OK();
}
+ // Check dimension if available (>0)
+ const size_t index_dim = ann_index_reader->get_dimension();
+ if (index_dim > 0 && index_dim != range_search_runtime.dim) {
+ return Status::InvalidArgument(
+ "Ann range search query dimension {} does not match index
dimension {}",
+ range_search_runtime.dim, index_dim);
+ }
+
AnnRangeSearchParams params =
range_search_runtime.to_range_search_params();
params.roaring = &row_bitmap;
diff --git a/be/test/olap/vector_search/ann_index_reader_test.cpp
b/be/test/olap/vector_search/ann_index_reader_test.cpp
index 7f914038982..8af387ab230 100644
--- a/be/test/olap/vector_search/ann_index_reader_test.cpp
+++ b/be/test/olap/vector_search/ann_index_reader_test.cpp
@@ -423,7 +423,9 @@ TEST_F(AnnIndexReaderTest, AnnIndexReaderRangeSearch) {
for (size_t i = 0; i < iterations; ++i) {
std::map<std::string, std::string> index_properties;
index_properties["index_type"] = "hnsw";
- index_properties["metric_type"] = "l2";
+ // Use canonical metric name and include required dimension property
+ index_properties["metric_type"] = "l2_distance";
+ index_properties["dim"] = "128";
std::unique_ptr<doris::TabletIndex> index_meta =
std::make_unique<doris::TabletIndex>();
index_meta->_properties = index_properties;
auto mock_index_file_reader = std::make_shared<MockIndexFileReader>();
diff --git a/be/test/olap/vector_search/ann_range_search_test.cpp
b/be/test/olap/vector_search/ann_range_search_test.cpp
index d9541d6cbb8..080e3c76f9e 100644
--- a/be/test/olap/vector_search/ann_range_search_test.cpp
+++ b/be/test/olap/vector_search/ann_range_search_test.cpp
@@ -140,6 +140,7 @@ TEST_F(VectorSearchTest, TestEvaluateAnnRangeSearch) {
std::map<std::string, std::string> properties;
properties["index_type"] = "hnsw";
properties["metric_type"] = "l2_distance";
+ properties["dim"] = "8"; // match query vector size from thrift
auto pair = vector_search_utils::create_tmp_ann_index_reader(properties);
mock_ann_index_iter->_ann_reader = pair.second;
@@ -230,6 +231,7 @@ TEST_F(VectorSearchTest, TestEvaluateAnnRangeSearch2) {
std::map<std::string, std::string> properties;
properties["index_type"] = "hnsw";
properties["metric_type"] = "l2_distance";
+ properties["dim"] = "8"; // match query vector size from thrift
auto pair = vector_search_utils::create_tmp_ann_index_reader(properties);
mock_ann_index_iter->_ann_reader = pair.second;
@@ -303,11 +305,12 @@ TEST_F(VectorSearchTest,
TestRangeSearchRuntimeInfoToString) {
runtime_info2.radius = 15.5;
runtime_info2.metric_type = doris::segment_v2::AnnIndexMetric::L2;
runtime_info2.dim = 4;
- runtime_info2.query_value = std::make_unique<float[]>(4);
- runtime_info2.query_value[0] = 1.0f;
- runtime_info2.query_value[1] = 2.0f;
- runtime_info2.query_value[2] = 3.0f;
- runtime_info2.query_value[3] = 4.0f;
+ auto f32 = ColumnFloat32::create(4);
+ f32->get_data()[0] = 1.0f;
+ f32->get_data()[1] = 2.0f;
+ f32->get_data()[2] = 3.0f;
+ f32->get_data()[3] = 4.0f;
+ runtime_info2.query_value = std::move(f32);
doris::VectorSearchUserParams user_params;
user_params.hnsw_ef_search = 100;
@@ -692,6 +695,56 @@ TEST_F(VectorSearchTest, TestAnnIndexReader_NewIterator) {
EXPECT_NE(ann_iterator, nullptr);
}
+TEST_F(VectorSearchTest, TestEvaluateAnnRangeSearch_DimensionMismatch) {
+ // Prepare a valid range search expr from thrift
+ TExpr texpr = read_from_json<TExpr>(ann_range_search_thrift);
+ TDescriptorTable table1 =
read_from_json<TDescriptorTable>(thrift_table_desc);
+ std::unique_ptr<doris::ObjectPool> pool =
std::make_unique<doris::ObjectPool>();
+ auto desc_tbl = std::make_unique<DescriptorTbl>();
+ DescriptorTbl* desc_tbl_ptr = desc_tbl.get();
+ ASSERT_TRUE(DescriptorTbl::create(pool.get(), table1,
&(desc_tbl_ptr)).ok());
+ RowDescriptor row_desc = RowDescriptor(*desc_tbl_ptr, {0}, {false});
+ std::unique_ptr<doris::RuntimeState> state =
std::make_unique<doris::RuntimeState>();
+ state->set_desc_tbl(desc_tbl_ptr);
+
+ VExprContextSPtr range_search_ctx;
+ ASSERT_TRUE(vectorized::VExpr::create_expr_tree(texpr,
range_search_ctx).ok());
+ ASSERT_TRUE(range_search_ctx->prepare(state.get(), row_desc).ok());
+ ASSERT_TRUE(range_search_ctx->open(state.get()).ok());
+ doris::VectorSearchUserParams user_params;
+ range_search_ctx->prepare_ann_range_search(user_params);
+
ASSERT_TRUE(range_search_ctx->_ann_range_search_runtime.is_ann_range_search);
+ // Force a dimension mismatch: query dim is 8 in thrift; set index dim to 4
+ std::vector<ColumnId> idx_to_cid(4);
+ idx_to_cid[0] = 0;
+ idx_to_cid[1] = 1; // embedding
+ idx_to_cid[2] = 2;
+ idx_to_cid[3] = 3; // virtual dist
+
+ std::vector<std::unique_ptr<segment_v2::IndexIterator>>
cid_to_index_iterators(4);
+ auto mock_iter =
std::make_unique<doris::vector_search_utils::MockAnnIndexIterator>();
+
+ // Back its reader with a real AnnIndexReader but with dim=4
+ std::map<std::string, std::string> properties;
+ properties["index_type"] = "hnsw";
+ properties["metric_type"] = "l2_distance";
+ properties["dim"] = "4"; // mismatch
+ auto pair = vector_search_utils::create_tmp_ann_index_reader(properties);
+ mock_iter->_ann_reader = pair.second;
+ cid_to_index_iterators[1] = std::move(mock_iter);
+
+ std::vector<std::unique_ptr<segment_v2::ColumnIterator>>
column_iterators(4);
+ column_iterators[3] =
std::make_unique<doris::segment_v2::VirtualColumnIterator>();
+
+ roaring::Roaring row_bitmap;
+ segment_v2::AnnIndexStats stats;
+
+ auto st =
range_search_ctx->evaluate_ann_range_search(cid_to_index_iterators, idx_to_cid,
+ column_iterators,
row_bitmap, stats);
+ EXPECT_FALSE(st.ok());
+ EXPECT_TRUE(st.is<doris::ErrorCode::INVALID_ARGUMENT>());
+}
+
TEST_F(VectorSearchTest, TestAnnIndexIterator_ReadFromIndex_NullParam) {
// Test AnnIndexIterator::read_from_index with null parameter
std::map<std::string, std::string> properties;
diff --git a/be/test/olap/vector_search/ann_topn_descriptor_test.cpp
b/be/test/olap/vector_search/ann_topn_descriptor_test.cpp
index 03d6c8800cb..530f1f1a038 100644
--- a/be/test/olap/vector_search/ann_topn_descriptor_test.cpp
+++ b/be/test/olap/vector_search/ann_topn_descriptor_test.cpp
@@ -30,6 +30,7 @@
#include "olap/rowset/segment_v2/ann_index/ann_topn_runtime.h"
#include "runtime/primitive_type.h"
#include "vec/columns/column_nullable.h"
+#include "vec/columns/column_vector.h"
#include "vec/exprs/virtual_slot_ref.h"
#include "vector_search_utils.h"
@@ -119,17 +120,10 @@ TEST_F(VectorSearchTest, AnnTopNRuntimeEvaluateTopN) {
ASSERT_TRUE(st.ok()) << fmt::format("st: {}, expr {}", st.to_string(),
predicate->get_order_by_expr_ctx()->root()->debug_string());
- const ColumnConst* const_column =
- assert_cast<const ColumnConst*>(predicate->_query_array.get());
- const ColumnArray* column_array =
- assert_cast<const
ColumnArray*>(const_column->get_data_column_ptr().get());
- const ColumnNullable* column_nullable =
- assert_cast<const
ColumnNullable*>(column_array->get_data_ptr().get());
- const ColumnFloat32* cf32 =
- assert_cast<const
ColumnFloat32*>(column_nullable->get_nested_column_ptr().get());
-
- const float* query_value = cf32->get_data().data();
- const size_t query_value_size = cf32->get_data().size();
+ const vectorized::ColumnFloat32* query_column =
+ assert_cast<const
vectorized::ColumnFloat32*>(predicate->_query_array.get());
+ const float* query_value = query_column->get_data().data();
+ const size_t query_value_size = predicate->_query_array->size();
ASSERT_EQ(query_value_size, 8);
std::vector<float> query_value_f32;
for (size_t i = 0; i < query_value_size; ++i) {
@@ -153,6 +147,16 @@ TEST_F(VectorSearchTest, AnnTopNRuntimeEvaluateTopN) {
std::cout << "query_vector: " << fmt::format("[{}]",
fmt::join(*query_vector, ","))
<< std::endl;
+ // Attach a valid ANN reader to the mock iterator so runtime can fetch
reader and check dim
+ {
+ std::map<std::string, std::string> properties;
+ properties["index_type"] = "hnsw";
+ properties["metric_type"] = "l2_distance";
+ properties["dim"] = "8"; // match the query vector dimension
+ auto pair =
vector_search_utils::create_tmp_ann_index_reader(properties);
+ _ann_index_iterator->_ann_reader = pair.second;
+ }
+
EXPECT_CALL(*_ann_index_iterator, read_from_index(testing::_))
.Times(1)
.WillOnce(testing::Invoke([](const segment_v2::IndexParam& value) {
diff --git a/be/test/olap/vector_search/ann_topn_runtime_negative_test.cpp
b/be/test/olap/vector_search/ann_topn_runtime_negative_test.cpp
new file mode 100644
index 00000000000..a84d99ec783
--- /dev/null
+++ b/be/test/olap/vector_search/ann_topn_runtime_negative_test.cpp
@@ -0,0 +1,148 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include <memory>
+
+#include "common/status.h"
+#include "olap/rowset/segment_v2/ann_index/ann_topn_runtime.h"
+#include "vec/exprs/vectorized_fn_call.h"
+#include "vec/exprs/vexpr.h"
+#include "vec/exprs/virtual_slot_ref.h"
+#include "vector_search_utils.h"
+
+using ::testing::HasSubstr;
+
+namespace doris::vectorized {
+
+// These tests target uncovered error branches in AnnTopNRuntime::prepare and
+// evaluate_vector_ann_search using the existing VectorSearchTest fixture
setup.
+
+TEST_F(VectorSearchTest, AnnTopNRuntimePrepare_NoFunctionCall) {
+ // Build a DescriptorTbl where the virtual column expr is a constant
literal (not a function call)
+ doris::ObjectPool obj_pool_local;
+ TDescriptorTable thrift_tbl;
+ {
+ TTableDescriptor table_desc;
+ table_desc.id = 1000;
+ thrift_tbl.tableDescriptors.push_back(table_desc);
+
+ TTupleDescriptor tuple_desc;
+ tuple_desc.__isset.tableId = true;
+ tuple_desc.id = 2000;
+ tuple_desc.tableId = 1000;
+ thrift_tbl.tupleDescriptors.push_back(tuple_desc);
+
+ // Slot 0: materialized, with a virtual column expr set to
FLOAT_LITERAL (not a function)
+ TSlotDescriptor slot0;
+ slot0.id = 3000;
+ slot0.parent = 2000;
+ slot0.isMaterialized = true;
+ slot0.need_materialize = true;
+ slot0.__isset.need_materialize = true;
+ // type: DOUBLE (matches fixture)
+ TTypeNode type_node;
+ type_node.type = TTypeNodeType::type::SCALAR;
+ TScalarType scalar_type;
+ scalar_type.__set_type(TPrimitiveType::DOUBLE);
+ type_node.__set_scalar_type(scalar_type);
+ slot0.slotType.types.push_back(type_node);
+ // Provide a simple FLOAT_LITERAL as the virtual column expr
+ doris::TExpr vexpr;
+ doris::TExprNode node;
+ node.node_type = TExprNodeType::FLOAT_LITERAL;
+ node.type = TTypeDesc();
+ node.type.types.push_back(type_node);
+ doris::TFloatLiteral flit;
+ flit.value = 1.0;
+ node.__set_float_literal(flit);
+ node.__isset.float_literal = true;
+ vexpr.nodes.push_back(node);
+ slot0.virtual_column_expr = vexpr;
+ slot0.__isset.virtual_column_expr = true;
+ thrift_tbl.slotDescriptors.push_back(slot0);
+
+ // Slot 1: a normal slot to satisfy references
+ TSlotDescriptor slot1 = slot0;
+ slot1.id = 3001;
+ slot1.__isset.virtual_column_expr = false;
+ thrift_tbl.slotDescriptors.push_back(slot1);
+ thrift_tbl.__isset.slotDescriptors = true;
+ }
+
+ doris::DescriptorTbl* desc_tbl_local = nullptr;
+ ASSERT_TRUE(DescriptorTbl::create(&obj_pool_local, thrift_tbl,
&desc_tbl_local).ok());
+ RowDescriptor row_desc_local(*desc_tbl_local, {2000}, {false});
+
+ // Create a VirtualSlotRef root expr that points to the local descriptor's
slot id (3000)
+ doris::TExpr local_virtual_slot_ref_expr = _virtual_slot_ref_expr;
+ ASSERT_TRUE(local_virtual_slot_ref_expr.nodes.size() == 1);
+ local_virtual_slot_ref_expr.nodes[0].slot_ref.slot_id = 3000;
+ std::shared_ptr<VExprContext> vslot_ctx;
+ ASSERT_TRUE(VExpr::create_expr_tree(local_virtual_slot_ref_expr,
vslot_ctx).ok());
+
+ doris::RuntimeState state_local;
+ state_local.set_desc_tbl(desc_tbl_local);
+
+ auto runtime = segment_v2::AnnTopNRuntime::create_shared(true, 10,
vslot_ctx);
+ Status st = runtime->prepare(&state_local, row_desc_local);
+ ASSERT_FALSE(st.ok());
+ EXPECT_THAT(st.to_string(), HasSubstr("expect FuncationCall"));
+}
+
+// Note: We intentionally avoid testing a non-VirtualSlotRef root since it
triggers DCHECK.
+
+// Removed additional negative prepare tests that rely on internal descriptor
mutations.
+
+TEST_F(VectorSearchTest, AnnTopNRuntimeEvaluate_DimensionMismatch) {
+ // Prepare a valid runtime first.
+ std::shared_ptr<VExprContext> dist_ctx;
+ auto fn_thrift = read_from_json<TExpr>(_distance_function_call_thrift);
+ ASSERT_TRUE(VExpr::create_expr_tree(fn_thrift, dist_ctx).ok());
+
+ std::shared_ptr<VExprContext> vslot_ctx;
+ ASSERT_TRUE(VExpr::create_expr_tree(_virtual_slot_ref_expr,
vslot_ctx).ok());
+ auto vir_slot =
std::dynamic_pointer_cast<VirtualSlotRef>(vslot_ctx->root());
+ ASSERT_TRUE(vir_slot != nullptr);
+ vir_slot->set_virtual_column_expr(dist_ctx->root());
+
+ auto runtime = segment_v2::AnnTopNRuntime::create_shared(true, 10,
vslot_ctx);
+ ASSERT_TRUE(runtime->prepare(&_runtime_state, _row_desc).ok());
+
+ // Attach an ANN reader with a different dimension to trigger the mismatch
branch.
+ {
+ std::map<std::string, std::string> props;
+ props["index_type"] = "hnsw";
+ props["metric_type"] = "l2_distance";
+ props["dim"] = "4"; // runtime query vector dimension is 8 from
fixture JSON
+ auto pair = vector_search_utils::create_tmp_ann_index_reader(props);
+ _ann_index_iterator->_ann_reader = pair.second;
+ }
+
+ roaring::Roaring bitmap;
+ vectorized::IColumn::MutablePtr result_col = ColumnFloat32::create(0);
+ std::unique_ptr<std::vector<uint64_t>> row_ids;
+ doris::segment_v2::AnnIndexStats stats;
+ Status st = runtime->evaluate_vector_ann_search(_ann_index_iterator.get(),
&bitmap, 10,
+ result_col, row_ids,
stats);
+ ASSERT_FALSE(st.ok());
+ EXPECT_THAT(st.to_string(), HasSubstr("dimension"));
+}
+
+} // namespace doris::vectorized
diff --git a/be/test/olap/vector_search/vector_search_utils.cpp
b/be/test/olap/vector_search/vector_search_utils.cpp
index 506d0e4ea8d..cb02b464d6a 100644
--- a/be/test/olap/vector_search/vector_search_utils.cpp
+++ b/be/test/olap/vector_search/vector_search_utils.cpp
@@ -265,6 +265,16 @@ float get_radius_from_matrix(const float* vector, int dim,
std::pair<std::unique_ptr<MockTabletIndex>,
std::shared_ptr<segment_v2::AnnIndexReader>>
create_tmp_ann_index_reader(std::map<std::string, std::string> properties) {
+ // Ensure required properties exist for tests
+ if (properties.find("index_type") == properties.end()) {
+ properties["index_type"] = "hnsw";
+ }
+ if (properties.find("metric_type") == properties.end()) {
+ properties["metric_type"] = "l2_distance";
+ }
+ if (properties.find("dim") == properties.end()) {
+ properties["dim"] = "4"; // default small dimension for tests
+ }
auto mock_tablet_index = std::make_unique<MockTabletIndex>();
mock_tablet_index->_properties = properties;
auto mock_index_file_reader = std::make_shared<MockIndexFileReader>();
diff --git a/regression-test/data/ann_index_p0/cast_string_as_array.out
b/regression-test/data/ann_index_p0/cast_string_as_array.out
new file mode 100644
index 00000000000..dbbe421d76a
Binary files /dev/null and
b/regression-test/data/ann_index_p0/cast_string_as_array.out differ
diff --git a/regression-test/suites/ann_index_p0/cast_string_as_array.groovy
b/regression-test/suites/ann_index_p0/cast_string_as_array.groovy
new file mode 100644
index 00000000000..9d0ea331ef2
--- /dev/null
+++ b/regression-test/suites/ann_index_p0/cast_string_as_array.groovy
@@ -0,0 +1,155 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("cast_string_as_array") {
+ sql "unset variable all;"
+ sql "set enable_common_expr_pushdown=true;"
+
+ // L2 table: dim=3
+ sql "drop table if exists ann_cast_rhs_l2"
+ sql """
+ CREATE TABLE ann_cast_rhs_l2 (
+ id INT NOT NULL,
+ embedding ARRAY<FLOAT> NOT NULL,
+ INDEX idx_emb (`embedding`) USING ANN PROPERTIES(
+ "index_type"="hnsw",
+ "metric_type"="l2_distance",
+ "dim"="3"
+ )
+ ) ENGINE=OLAP
+ DUPLICATE KEY(id)
+ DISTRIBUTED BY HASH(id) BUCKETS AUTO
+ PROPERTIES ("replication_num" = "1");
+ """
+ sql """
+ INSERT INTO ann_cast_rhs_l2 VALUES
+ (1, [1.0, 2.0, 3.0]),
+ (2, [0.5, 2.1, 2.9]),
+ (3, [10.0, 10.0, 10.0]),
+ (4, [2.0, 3.0, 4.0]);
+ """
+
+ // Success: CAST(string AS array<float>) on RHS
+ qt_sql_0 "select id from ann_cast_rhs_l2 order by
l2_distance_approximate(embedding, cast('[1.0,2.0,3.0]' as array<float>)) limit
3;"
+
+ // Success: extra spaces in the string and integer literals should parse
fine
+ test {
+ sql "select id from ann_cast_rhs_l2 order by
l2_distance_approximate(embedding, cast(' [1, 2 , 3 ] ' as array<float>)) limit
3;"
+
+ exception "Ann query vector cannot be NULL"
+ }
+
+ // Success: nested cast(string->string->array<float>) should also work
+ qt_sql_1 "select id from ann_cast_rhs_l2 order by
l2_distance_approximate(embedding, cast(cast('[1.0,2.0,3.0]' as string) as
array<float>)) limit 3;"
+
+ // Failure: empty array is not allowed for ANN query vector
+ test {
+ sql "select id from ann_cast_rhs_l2 order by
l2_distance_approximate(embedding, cast('[]' as array<float>)) limit 1;"
+ exception "Ann topn query vector cannot be empty"
+ }
+
+ // A special case.
+ // Constant propagation may optimize l2_distance_approximate(embedding,
NULL) to NULL before reaching the
+ // runtime of ANN topn. So here we will get null directly...
+ test {
+ sql "select id from ann_cast_rhs_l2 order by
l2_distance_approximate(embedding, cast(NULL as array<float>)) limit 1;"
+ exception "Constant must be ArrayLiteral or CAST to array"
+ }
+
+
+ // Failure: dim mismatch (2 vs table dim=3)
+ test {
+ sql "select id from ann_cast_rhs_l2 order by
l2_distance_approximate(embedding, cast('[1.0,2.0]' as array<float>)) limit 1;"
+ exception "[INVALID_ARGUMENT]"
+ }
+
+ // Inner product table: dim=4
+ sql "drop table if exists ann_cast_rhs_ip"
+ sql """
+ CREATE TABLE ann_cast_rhs_ip (
+ id INT NOT NULL,
+ embedding ARRAY<FLOAT> NOT NULL,
+ INDEX idx_emb (`embedding`) USING ANN PROPERTIES(
+ "index_type"="hnsw",
+ "metric_type"="inner_product",
+ "dim"="4"
+ )
+ ) ENGINE=OLAP
+ DUPLICATE KEY(id)
+ DISTRIBUTED BY HASH(id) BUCKETS AUTO
+ PROPERTIES ("replication_num" = "1");
+ """
+
+ sql "truncate table ann_cast_rhs_ip"
+ sql """
+ INSERT INTO ann_cast_rhs_ip VALUES
+ (1, [0.1, 0.2, 0.3, 0.4]),
+ (2, [0.5, 0.6, 0.7, 0.8]),
+ (3, [1.0, 1.0, 1.0, 1.0]);
+ """
+
+ // Success: DESC for inner_product
+ qt_sql_3 "select id from ann_cast_rhs_ip order by
inner_product_approximate(embedding, cast('[0.1,0.2,0.3,0.4]' as array<float>))
desc limit 3;"
+
+ // Failure: dim mismatch (3 vs table dim=4)
+ test {
+ sql "select id from ann_cast_rhs_ip order by
inner_product_approximate(embedding, cast('[0.1,0.2,0.3]' as array<float>))
desc limit 1;"
+ exception "[INVALID_ARGUMENT]"
+ }
+
+ // ----------------------
+ // Range search cases (CAST string -> array<float>)
+ // ----------------------
+
+ // L2 range search with <= radius: expect ids 1 and 2 (distance to [1,2,3]
is <= 1.0)
+ qt_sql_rs_l2_le "select id from ann_cast_rhs_l2 where
l2_distance_approximate(embedding, cast('[1,2,3]' as array<float>)) <= 1.0
order by id;"
+
+ // L2 range search with >= radius: expect ids 3 and 4 (distance to [1,2,3]
is >= 1.0)
+ qt_sql_rs_l2_ge "select id from ann_cast_rhs_l2 where
l2_distance_approximate(embedding, cast('[1,2,3]' as array<float>)) >= 1.0
order by id;"
+
+ // L2 range search: dim mismatch should error
+ test {
+ sql "select id from ann_cast_rhs_l2 where
l2_distance_approximate(embedding, cast('[1,2]' as array<float>)) <= 1.0 order
by id;"
+ exception "[INVALID_ARGUMENT]"
+ }
+
+ // Inner product range search with >= threshold: expect ids 2 and 3
+ qt_sql_rs_ip_ge "select id from ann_cast_rhs_ip where
inner_product_approximate(embedding, cast('[0.1,0.2,0.3,0.4]' as array<float>))
>= 0.6 order by id;"
+
+ // Inner product range search with < threshold: expect id 1 only
+ qt_sql_rs_ip_lt "select id from ann_cast_rhs_ip where
inner_product_approximate(embedding, cast('[0.1,0.2,0.3,0.4]' as array<float>))
< 0.6 order by id;"
+
+ // Inner product range search: dim mismatch should error
+ test {
+ sql "select id from ann_cast_rhs_ip where
inner_product_approximate(embedding, cast('[0.1,0.2,0.3]' as array<float>)) >=
0.6 order by id;"
+ exception "[INVALID_ARGUMENT]"
+ }
+
+ // ----------------------
+ // Non-constant RHS behavior
+ // ----------------------
+
+ // Fall back to full scan if RHS is not constant
+ qt_sql_fall_back "select l2_distance_approximate(embedding, embedding)
from ann_cast_rhs_l2 order by l2_distance_approximate(embedding, embedding)
limit 10;"
+
+ // Range search with non-constant RHS should execute without index pushdown
+ // L2: distance(embedding, embedding) == 0, so <= 0 selects all rows
+ qt_sql_rs_l2_nonconst_le "select id from ann_cast_rhs_l2 where
l2_distance_approximate(embedding, embedding) <= 0.0 order by id;"
+
+ // IP: inner_product(embedding, embedding) is sum of squares; with
threshold 1.5 expect ids 2 and 3
+ qt_sql_rs_ip_nonconst_ge "select id from ann_cast_rhs_ip where
inner_product_approximate(embedding, embedding) >= 1.5 order by id;"
+}
\ No newline at end of file
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]