This is an automated email from the ASF dual-hosted git repository.
yiguolei pushed a commit to branch vector-index-dev
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/vector-index-dev by this push:
new abff03a60d1 fix ann index compaction & inner product plan (#52150)
abff03a60d1 is described below
commit abff03a60d1b07800ee17826ef73e232d09cf388
Author: zhiqiang <[email protected]>
AuthorDate: Tue Jun 24 09:06:28 2025 +0800
fix ann index compaction & inner product plan (#52150)
---
be/src/apache-orc | 2 +-
be/src/olap/rowset/segment_v2/ann_index_reader.cpp | 5 +-
be/src/olap/rowset/segment_v2/column_writer.h | 1 +
.../olap/rowset/segment_v2/index_file_writer.cpp | 4 +
be/src/olap/rowset/segment_v2/segment_iterator.cpp | 68 ++++-------
be/src/olap/rowset/segment_v2/segment_iterator.h | 1 +
be/src/olap/rowset/segment_v2/segment_writer.cpp | 2 +-
be/src/pipeline/exec/olap_scan_operator.cpp | 6 -
be/src/vec/exec/scan/olap_scanner.cpp | 2 -
be/src/vec/exprs/ann_topn_runtime.cpp | 8 +-
be/src/vec/exprs/vectorized_fn_call.cpp | 2 +-
be/src/vec/exprs/vexpr_context.cpp | 5 +-
be/src/vec/olap/block_reader.cpp | 1 -
be/src/vector/faiss_vector_index.cpp | 126 ++++++++++++++-------
be/src/vector/vector_index.h | 2 -
.../vector_search/ann_topn_descriptor_test.cpp | 8 +-
.../olap/vector_search/faiss_vector_index_test.cpp | 48 +++++---
be/test/olap/vector_search/vector_search_utils.h | 3 +-
.../rewrite/PushDownVectorTopNIntoOlapScan.java | 32 +++++-
.../functions/scalar/InnerProductApproximate.java | 3 +-
20 files changed, 199 insertions(+), 130 deletions(-)
diff --git a/be/src/apache-orc b/be/src/apache-orc
index 13ada78b494..72787269f5f 160000
--- a/be/src/apache-orc
+++ b/be/src/apache-orc
@@ -1 +1 @@
-Subproject commit 13ada78b494133cacc0ccb5120e3f4611828fdbb
+Subproject commit 72787269f5f52ab0174bac1dbf54050bb7b60242
diff --git a/be/src/olap/rowset/segment_v2/ann_index_reader.cpp
b/be/src/olap/rowset/segment_v2/ann_index_reader.cpp
index d83c4fc23c9..8e28a2e1ae5 100644
--- a/be/src/olap/rowset/segment_v2/ann_index_reader.cpp
+++ b/be/src/olap/rowset/segment_v2/ann_index_reader.cpp
@@ -78,8 +78,11 @@ Status AnnIndexReader::load_index(io::IOContext* io_ctx) {
}
_vector_index = std::make_unique<FaissVectorIndex>();
{
- // SCOPED_TIMER()
+ RuntimeProfile::Counter load_counter {TUnit::TIME_NS};
+ SCOPED_TIMER(&load_counter);
RETURN_IF_ERROR(_vector_index->load(compound_dir->get()));
+ LOG_INFO("Ann index load costs {} ms",
+ load_counter.value() / 1e6); // Convert to milliseconds
}
return Status::OK();
diff --git a/be/src/olap/rowset/segment_v2/column_writer.h
b/be/src/olap/rowset/segment_v2/column_writer.h
index 6e5bec9f59b..f10c6781e9f 100644
--- a/be/src/olap/rowset/segment_v2/column_writer.h
+++ b/be/src/olap/rowset/segment_v2/column_writer.h
@@ -159,6 +159,7 @@ public:
virtual Status write_bitmap_index() = 0;
virtual Status write_inverted_index() = 0;
+
virtual Status write_ann_index() { return Status::OK(); }
virtual Status write_bloom_filter_index() = 0;
diff --git a/be/src/olap/rowset/segment_v2/index_file_writer.cpp
b/be/src/olap/rowset/segment_v2/index_file_writer.cpp
index 10468ac3794..91dd777c3ff 100644
--- a/be/src/olap/rowset/segment_v2/index_file_writer.cpp
+++ b/be/src/olap/rowset/segment_v2/index_file_writer.cpp
@@ -50,6 +50,10 @@ IndexFileWriter::IndexFileWriter(io::FileSystemSPtr fs,
std::string index_path_p
_can_use_ram_dir(can_use_ram_dir) {
auto tmp_file_dir =
ExecEnv::GetInstance()->get_tmp_file_dirs()->get_tmp_file_dir();
_tmp_dir = tmp_file_dir.native();
+ LOG_INFO(
+ "IndexFileWriter created with index_path_prefix: {}, rowset_id:
{}, seg_id: {}, "
+ "tmp_file_dir: {}",
+ _index_path_prefix, _rowset_id, _seg_id, _tmp_dir);
if (_storage_format == InvertedIndexStorageFormatPB::V1) {
_index_storage_format = std::make_unique<IndexStorageFormatV1>(this);
} else {
diff --git a/be/src/olap/rowset/segment_v2/segment_iterator.cpp
b/be/src/olap/rowset/segment_v2/segment_iterator.cpp
index ca191769af0..790c9362ffe 100644
--- a/be/src/olap/rowset/segment_v2/segment_iterator.cpp
+++ b/be/src/olap/rowset/segment_v2/segment_iterator.cpp
@@ -268,6 +268,7 @@ private:
SegmentIterator::SegmentIterator(std::shared_ptr<Segment> segment, SchemaSPtr
schema)
: _segment(std::move(segment)),
_schema(schema),
+ _column_iterators(_schema->num_columns()),
_bitmap_index_iterators(_schema->num_columns()),
_index_iterators(_schema->num_columns()),
_cur_rowid(0),
@@ -302,7 +303,7 @@ Status SegmentIterator::_init_impl(const
StorageReadOptions& opts) {
}
_col_predicates.emplace_back(predicate);
}
- LOG_INFO("Segment iterator init, column predicates size: {}",
_col_predicates.size());
+
_tablet_id = opts.tablet_id;
// Read options will not change, so that just resize here
_block_rowids.resize(_opts.block_row_max);
@@ -312,12 +313,6 @@ Status SegmentIterator::_init_impl(const
StorageReadOptions& opts) {
if (_schema->rowid_col_idx() > 0) {
_record_rowids = true;
}
- // The final return block contains normal_column & virtual_column
- // _schema->num_columns is the size of normal columns
- // opts.vir_cid_to_idx_in_block.size is the the is virtual columns
- _column_iterators.resize(_schema->num_columns() +
opts.vir_cid_to_idx_in_block.size());
- LOG_INFO("Read tablet schema num_columns: {}, virtual columns num: {}",
_schema->num_columns(),
- opts.vir_cid_to_idx_in_block.size());
_virtual_column_exprs = _opts.virtual_column_exprs;
_ann_topn_runtime = _opts.ann_topn_runtime;
@@ -372,7 +367,7 @@ Status SegmentIterator::_init_impl(const
StorageReadOptions& opts) {
RETURN_IF_ERROR(_construct_compound_expr_context());
_enable_common_expr_pushdown = !_common_expr_ctxs_push_down.empty();
- LOG_INFO(
+ VLOG_DEBUG << fmt::format(
"Segment iterator init, virtual_column_exprs size: {}, has
ann_topn_runtime: {}, "
"_vir_cid_to_idx_in_block size: {}, common_expr_pushdown size: {}",
_opts.virtual_column_exprs.size(), _opts.ann_topn_runtime !=
nullptr,
@@ -570,8 +565,6 @@ Status
SegmentIterator::_get_row_ranges_by_column_conditions() {
for (auto cid : _schema->column_ids()) {
bool result_true =
_check_all_conditions_passed_inverted_index_for_column(cid);
- LOG_INFO("Check all conditions passed in inverted index for
column {}, result: {}",
- cid, result_true);
if (result_true) {
_need_read_data_indices[cid] = false;
}
@@ -623,14 +616,14 @@ Status SegmentIterator::_apply_ann_topn_predicate() {
return Status::OK();
}
- LOG_INFO("Try apply ann topn: {}", _ann_topn_runtime->debug_string());
+ VLOG_DEBUG << fmt::format("Try apply ann topn: {}",
_ann_topn_runtime->debug_string());
size_t src_col_idx = _ann_topn_runtime->get_src_column_idx();
ColumnId src_cid = _schema->column_id(src_col_idx);
IndexIterator* ann_index_iterator = _index_iterators[src_cid].get();
if (ann_index_iterator == nullptr || !_common_expr_ctxs_push_down.empty()
||
!_col_predicates.empty()) {
- LOG_INFO(
+ VLOG_DEBUG << fmt::format(
"Can not apply ann topn, has index iterators: {}, has common
expr ctxs "
"push down: {}, has column predicates: {}",
_index_iterators.size(), !_common_expr_ctxs_push_down.empty(),
@@ -644,18 +637,19 @@ Status SegmentIterator::_apply_ann_topn_predicate() {
DCHECK(ann_index_reader != nullptr);
if (ann_index_reader->get_metric_type() == Metric::IP) {
if (_ann_topn_runtime->is_asc()) {
- LOG_INFO("Asc topn for inner product can not be evaluated by ann
index");
+ VLOG_DEBUG << fmt::format(
+ "Asc topn for inner product can not be evaluated by ann
index");
return Status::OK();
}
} else {
if (!_ann_topn_runtime->is_asc()) {
- LOG_INFO("Desc topn for l2/cosine can not be evaluated by ann
index");
+ VLOG_DEBUG << fmt::format("Desc topn for l2/cosine can not be
evaluated by ann index");
return Status::OK();
}
}
if (ann_index_reader->get_metric_type() !=
_ann_topn_runtime->get_metric_type()) {
- LOG_INFO(
+ VLOG_DEBUG << fmt::format(
"Ann topn metric type {} not match index metric type {}, can
not be evaluated by "
"ann index",
metric_to_string(_ann_topn_runtime->get_metric_type()),
@@ -666,7 +660,7 @@ Status SegmentIterator::_apply_ann_topn_predicate() {
size_t pre_size = _row_bitmap.cardinality();
size_t rows_of_semgnet = _segment->num_rows();
if (pre_size < rows_of_semgnet * 0.3) {
- LOG_INFO(
+ VLOG_DEBUG << fmt::format(
"Ann topn predicate input rows {} < 30% of segment rows {},
will not use ann index "
"to "
"filter",
@@ -677,16 +671,17 @@ Status SegmentIterator::_apply_ann_topn_predicate() {
std::unique_ptr<std::vector<uint64_t>> result_row_ids;
RETURN_IF_ERROR(_ann_topn_runtime->evaluate_vector_ann_search(ann_index_iterator,
_row_bitmap,
result_column, result_row_ids));
- LOG_INFO("Ann topn filtered {} - {} = {} rows", pre_size,
_row_bitmap.cardinality(),
- pre_size - _row_bitmap.cardinality());
+ VLOG_DEBUG << fmt::format("Ann topn filtered {} - {} = {} rows", pre_size,
+ _row_bitmap.cardinality(), pre_size -
_row_bitmap.cardinality());
_opts.stats->rows_ann_index_topn_filtered += (pre_size -
_row_bitmap.cardinality());
const size_t dst_col_idx = _ann_topn_runtime->get_dest_column_idx();
ColumnIterator* column_iter =
_column_iterators[_schema->column_id(dst_col_idx)].get();
DCHECK(column_iter != nullptr);
VirtualColumnIterator* virtual_column_iter =
dynamic_cast<VirtualColumnIterator*>(column_iter);
DCHECK(virtual_column_iter != nullptr);
- LOG_INFO("Virtual column iterator, column_idx {}, is materialized with {}
rows", dst_col_idx,
- result_row_ids->size());
+ VLOG_DEBUG << fmt::format(
+ "Virtual column iterator, column_idx {}, is materialized with {}
rows", dst_col_idx,
+ result_row_ids->size());
// reference count of result_column should be 1, so move will not issue
any data copy.
virtual_column_iter->prepare_materialization(std::move(result_column),
std::move(result_row_ids));
@@ -1206,8 +1201,6 @@ Status SegmentIterator::_init_return_column_iterators() {
}
if
(_schema->column(cid)->name().starts_with(BeConsts::VIRTUAL_COLUMN_PREFIX)) {
- LOG_INFO("Virtual column iterator for column {}, cid: {}",
_schema->column(cid)->name(),
- cid);
_column_iterators[cid] = std::make_unique<VirtualColumnIterator>();
continue;
}
@@ -1823,7 +1816,7 @@ Status SegmentIterator::_read_columns(const
std::vector<ColumnId>& column_ids,
}
Status SegmentIterator::_init_return_columns(vectorized::Block* block,
uint32_t nrows_read_limit) {
- block->clear_column_data(_schema->num_column_ids() +
_virtual_column_exprs.size());
+ block->clear_column_data(_schema->num_column_ids());
for (size_t i = 0; i < _schema->num_column_ids(); i++) {
auto cid = _schema->column_id(i);
@@ -2151,6 +2144,11 @@ Status
SegmentIterator::_read_columns_by_rowids(std::vector<ColumnId>& read_colu
}
})
+ if (_current_return_columns[cid].get() == nullptr) {
+ return Status::InternalError(
+ "SegmentIterator meet invalid column, return columns size
{}, cid {}",
+ _current_return_columns.size(), cid);
+ }
RETURN_IF_ERROR(_column_iterators[cid]->read_by_rowids(rowids.data(),
select_size,
_current_return_columns[cid]));
}
@@ -2554,7 +2552,7 @@ Status
SegmentIterator::_next_batch_internal(vectorized::Block* block) {
// shrink char_type suffix zero data
block->shrink_char_type_column_suffix_zero(_char_type_idx);
- // #ifndef NDEBUG
+#ifndef NDEBUG
size_t rows = block->rows();
size_t idx = 0;
for (const auto& entry : *block) {
@@ -2589,9 +2587,7 @@ Status
SegmentIterator::_next_batch_internal(vectorized::Block* block) {
}
idx++;
}
- // #endif
- VLOG_DEBUG << "dump block " << block->dump_data(0, block->rows());
-
+#endif
return Status::OK();
}
@@ -2873,30 +2869,12 @@ bool SegmentIterator::_can_opt_topn_reads() {
}
void SegmentIterator::_init_virtual_columns(vectorized::Block* block) {
- // const size_t num_virtual_columns = _virtual_column_exprs.size();
- // if (block->columns() < _schema->num_column_ids() + num_virtual_columns)
{
- // std::vector<size_t> vir_col_idx;
- // for (const auto& pair : _vir_cid_to_idx_in_block) {
- // vir_col_idx.push_back(pair.second);
- // }
- // std::sort(vir_col_idx.begin(), vir_col_idx.end());
- // for (size_t i = 0; i < num_virtual_columns; ++i) {
- // auto iter = _opts.vir_col_idx_to_type.find(vir_col_idx[i]);
- // DCHECK(iter != _opts.vir_col_idx_to_type.end());
- // // Name of virtual currently is not used, so we just use a
dummy name.
- // block->insert({vectorized::ColumnNothing::create(0),
iter->second,
- // fmt::format("VIRTUAL_COLUMN_{}", i)});
- // }
- // } else {
// Before get next batch. make sure all virtual columns has type
ColumnNothing.
for (const auto& pair : _vir_cid_to_idx_in_block) {
auto& col_with_type_and_name = block->get_by_position(pair.second);
col_with_type_and_name.column = vectorized::ColumnNothing::create(0);
col_with_type_and_name.type = _opts.vir_col_idx_to_type[pair.second];
- LOG_INFO("Virtual column is reset, cid {}, idx_in_block {}, type {}",
pair.first,
- pair.second, col_with_type_and_name.type->get_name());
}
- // }
}
Status SegmentIterator::_materialization_of_virtual_column(vectorized::Block*
block) {
diff --git a/be/src/olap/rowset/segment_v2/segment_iterator.h
b/be/src/olap/rowset/segment_v2/segment_iterator.h
index 31bfb1f75f2..b2d0cc47b2a 100644
--- a/be/src/olap/rowset/segment_v2/segment_iterator.h
+++ b/be/src/olap/rowset/segment_v2/segment_iterator.h
@@ -489,6 +489,7 @@ private:
std::shared_ptr<vectorized::AnnTopNRuntime> _ann_topn_runtime;
+ // cid to virtual column expr
std::map<ColumnId, vectorized::VExprContextSPtr> _virtual_column_exprs;
std::map<ColumnId, size_t> _vir_cid_to_idx_in_block;
};
diff --git a/be/src/olap/rowset/segment_v2/segment_writer.cpp
b/be/src/olap/rowset/segment_v2/segment_writer.cpp
index 40f0b570ba2..c25f32d1b5e 100644
--- a/be/src/olap/rowset/segment_v2/segment_writer.cpp
+++ b/be/src/olap/rowset/segment_v2/segment_writer.cpp
@@ -1189,7 +1189,7 @@ Status SegmentWriter::_write_inverted_index() {
Status SegmentWriter::_write_ann_index() {
for (auto& column_writer : _column_writers) {
- RETURN_IF_ERROR(column_writer->write_inverted_index());
+ RETURN_IF_ERROR(column_writer->write_ann_index());
}
return Status::OK();
}
diff --git a/be/src/pipeline/exec/olap_scan_operator.cpp
b/be/src/pipeline/exec/olap_scan_operator.cpp
index b42ef0c9fc1..171e21a4b6b 100644
--- a/be/src/pipeline/exec/olap_scan_operator.cpp
+++ b/be/src/pipeline/exec/olap_scan_operator.cpp
@@ -633,12 +633,6 @@ Status OlapScanLocalState::open(RuntimeState* state) {
} else {
_slot_id_to_index_in_block[slot_desc->id()] = col_pos;
}
-
- LOG_INFO(
- "OlapScanLocalState opening, virtual column expr slot id:
{}, col_pos: {}, "
- "expr: "
- "{}",
- slot_desc->id(), col_pos,
virtual_column_expr_ctx->root()->debug_string());
}
}
diff --git a/be/src/vec/exec/scan/olap_scanner.cpp
b/be/src/vec/exec/scan/olap_scanner.cpp
index 328b8a7026f..76c75b90d9a 100644
--- a/be/src/vec/exec/scan/olap_scanner.cpp
+++ b/be/src/vec/exec/scan/olap_scanner.cpp
@@ -135,9 +135,7 @@ Status OlapScanner::init() {
VExprContextSPtr context;
RETURN_IF_ERROR(ctx->clone(_state, context));
_common_expr_ctxs_push_down.emplace_back(context);
- LOG_INFO("Prepare ann range search.");
RETURN_IF_ERROR(context->prepare_ann_range_search(_vector_search_params));
- LOG_INFO("Finish prepare ann range search, query_id={}",
print_id(_state->query_id()));
}
for (auto pair : local_state->_slot_id_to_virtual_column_expr) {
diff --git a/be/src/vec/exprs/ann_topn_runtime.cpp
b/be/src/vec/exprs/ann_topn_runtime.cpp
index 3a550c0d6f2..64fdf324bff 100644
--- a/be/src/vec/exprs/ann_topn_runtime.cpp
+++ b/be/src/vec/exprs/ann_topn_runtime.cpp
@@ -157,8 +157,14 @@ Status
AnnTopNRuntime::evaluate_vector_ann_search(segment_v2::IndexIterator* ann
.distance = nullptr,
.row_ids = nullptr,
};
+ {
+ RuntimeProfile::Counter search_counter {TUnit::TIME_NS};
+ SCOPED_TIMER(&search_counter);
+
RETURN_IF_ERROR(ann_index_iterator->read_from_index(&ann_query_params));
+ LOG_INFO("Ann index search costs {} ms",
+ search_counter.value() / 1e6); // Convert to milliseconds
+ }
- RETURN_IF_ERROR(ann_index_iterator->read_from_index(&ann_query_params));
DCHECK(ann_query_params.distance != nullptr);
DCHECK(ann_query_params.row_ids != nullptr);
diff --git a/be/src/vec/exprs/vectorized_fn_call.cpp
b/be/src/vec/exprs/vectorized_fn_call.cpp
index 34f47aafb2b..db1b1059e7b 100644
--- a/be/src/vec/exprs/vectorized_fn_call.cpp
+++ b/be/src/vec/exprs/vectorized_fn_call.cpp
@@ -459,7 +459,7 @@ Status VectorizedFnCall::prepare_ann_range_search(const
doris::VectorSearchUserP
}
range_search_runtime.is_ann_range_search = true;
range_search_runtime.user_params = user_params;
- LOG_INFO("Ann range search params: {}", range_search_runtime.to_string());
+ VLOG_DEBUG << fmt::format("Ann range search params: {}",
range_search_runtime.to_string());
return Status::OK();
}
diff --git a/be/src/vec/exprs/vexpr_context.cpp
b/be/src/vec/exprs/vexpr_context.cpp
index fc027538494..f35f9a58d1d 100644
--- a/be/src/vec/exprs/vexpr_context.cpp
+++ b/be/src/vec/exprs/vexpr_context.cpp
@@ -442,8 +442,9 @@ Status VExprContext::prepare_ann_range_search(const
doris::VectorSearchUserParam
RETURN_IF_ERROR(_root->prepare_ann_range_search(params,
_ann_range_search_runtime,
_suitable_for_ann_index));
- LOG_INFO("Prepare ann range search result {}, _suitable_for_ann_index {}",
- this->_ann_range_search_runtime.to_string(),
this->_suitable_for_ann_index);
+ VLOG_DEBUG << fmt::format("Prepare ann range search result {},
_suitable_for_ann_index {}",
+ this->_ann_range_search_runtime.to_string(),
+ this->_suitable_for_ann_index);
return Status::OK();
}
diff --git a/be/src/vec/olap/block_reader.cpp b/be/src/vec/olap/block_reader.cpp
index f0eef58371d..70ba4e86e94 100644
--- a/be/src/vec/olap/block_reader.cpp
+++ b/be/src/vec/olap/block_reader.cpp
@@ -227,7 +227,6 @@ Status BlockReader::init(const ReaderParams& read_params) {
/*
where abs()
*/
- LOG_INFO("Direct_mode: {}, key type: {}", _direct_mode,
tablet()->keys_type());
auto status = _init_collect_iter(read_params);
if (!status.ok()) [[unlikely]] {
if (!config::is_cloud_mode()) {
diff --git a/be/src/vector/faiss_vector_index.cpp
b/be/src/vector/faiss_vector_index.cpp
index 7cb8b49faef..9b8eb49fcff 100644
--- a/be/src/vector/faiss_vector_index.cpp
+++ b/be/src/vector/faiss_vector_index.cpp
@@ -23,6 +23,7 @@
#include <cstddef>
#include <cstdint>
#include <memory>
+#include <string>
#include "CLucene/store/IndexInput.h"
#include "CLucene/store/IndexOutput.h"
@@ -32,6 +33,7 @@
#include "faiss/IndexHNSW.h"
#include "faiss/impl/io.h"
#include "olap/rowset/segment_v2/ann_index/ann_search_params.h"
+#include "util/metrics.h"
#include "vector/vector_index.h"
namespace doris::segment_v2 {
@@ -231,19 +233,18 @@ doris::Status FaissVectorIndex::ann_topn_search(const
float* query_vec, int k,
result.row_ids = std::make_unique<std::vector<uint64_t>>();
if (_metric == Metric::L2) {
- // For inner product, we need to convert the distance to the actual
distance.
+ // For l2_distance, we need to convert the distance to the actual
distance.
// The distance returned by Faiss is actually the squared distance.
// So we need to take the square root of the squared distance.
for (size_t i = 0; i < roaring_cardinality; ++i) {
result.row_ids->push_back(labels[i]);
- result.distances[i] = distances[i]; // Convert squared distance to
actual distance
+ result.distances[i] = std::sqrt(distances[i]);
}
} else if (_metric == Metric::IP) {
- // For L2, we can use the distance directly.
+ // For inner product, we can use the distance directly.
for (size_t i = 0; i < roaring_cardinality; ++i) {
result.row_ids->push_back(labels[i]);
- result.distances[i] =
- std::sqrt(distances[i]); // Convert squared distance to
actual distance
+ result.distances[i] = distances[i]; // Convert squared distance to
actual distance
}
} else {
throw doris::Exception(doris::ErrorCode::INVALID_ARGUMENT,
"Unsupported metric type: {}",
@@ -256,29 +257,34 @@ doris::Status FaissVectorIndex::ann_topn_search(const
float* query_vec, int k,
return doris::Status::OK();
}
+// For l2 distance, range search radius is the squared distance.
+// For inner product, range search radius is the actual distance.
+// range search on inner product returns all vectors with inner product
greater than or equal to the radius.
+// For l2 distance, range search returns all vectors with squared distance
less than or equal to the radius.
doris::Status FaissVectorIndex::range_search(const float* query_vec, const
float& radius,
const
vectorized::IndexSearchParameters& params,
vectorized::IndexSearchResult&
result) {
DCHECK(_index != nullptr);
DCHECK(query_vec != nullptr);
- std::unique_ptr<faiss::IDSelector> sel = nullptr;
- if (params.roaring != nullptr) {
- sel = roaring_to_faiss_selector(*params.roaring);
- }
+ DCHECK(params.roaring != nullptr)
+ << "Roaring should not be null for range search, please set
roaring in params";
+ std::unique_ptr<faiss::IDSelector> sel =
roaring_to_faiss_selector(*params.roaring);
+
faiss::RangeSearchResult native_search_result(1, true);
const vectorized::HNSWSearchParameters* hnsw_params =
dynamic_cast<const vectorized::HNSWSearchParameters*>(¶ms);
- if (hnsw_params != nullptr) {
- faiss::SearchParametersHNSW param;
- param.efSearch = hnsw_params->ef_search;
- param.check_relative_distance = hnsw_params->check_relative_distance;
- param.bounded_queue = hnsw_params->bounded_queue;
- param.sel = sel ? sel.get() : nullptr;
- _index->range_search(1, query_vec, radius * radius,
&native_search_result, ¶m);
- } else {
- faiss::SearchParameters param;
- param.sel = sel ? sel.get() : nullptr;
+ // Currently only support HNSW index for range search.
+ DCHECK(hnsw_params != nullptr) << "HNSW search parameters should not be
null for HNSW index";
+
+ faiss::SearchParametersHNSW param;
+ param.efSearch = hnsw_params->ef_search;
+ param.check_relative_distance = hnsw_params->check_relative_distance;
+ param.bounded_queue = hnsw_params->bounded_queue;
+ param.sel = sel.get();
+ if (_metric == Metric::L2) {
_index->range_search(1, query_vec, radius * radius,
&native_search_result, ¶m);
+ } else if (_metric == Metric::IP) {
+ _index->range_search(1, query_vec, radius, &native_search_result,
¶m);
}
size_t begin = native_search_result.lims[0];
@@ -287,11 +293,10 @@ doris::Status FaissVectorIndex::range_search(const float*
query_vec, const float
row_ids->resize(end - begin);
LOG_INFO("Range search result: begin {}, end {}", begin, end);
if (params.is_le_or_lt) {
- std::unique_ptr<float[]> distances_ptr = std::make_unique<float[]>(end
- begin);
- float* distances = distances_ptr.get();
- auto roaring = std::make_shared<roaring::Roaring>();
if (_metric == Metric::L2) {
- // For inner product, we need to convert the distance to the
actual distance.
+ std::unique_ptr<float[]> distances_ptr =
std::make_unique<float[]>(end - begin);
+ float* distances = distances_ptr.get();
+ auto roaring = std::make_shared<roaring::Roaring>();
// The distance returned by Faiss is actually the squared distance.
// So we need to take the square root of the squared distance.
for (size_t i = begin; i < end; ++i) {
@@ -299,32 +304,34 @@ doris::Status FaissVectorIndex::range_search(const float*
query_vec, const float
roaring->add(native_search_result.labels[i]);
distances[i - begin] = sqrt(native_search_result.distances[i]);
}
+ result.distances = std::move(distances_ptr);
+ result.row_ids = std::move(row_ids);
+ result.roaring = roaring;
+
+ DCHECK(result.row_ids->size() == result.roaring->cardinality())
+ << "row_ids size: " << result.row_ids->size()
+ << ", roaring size: " << result.roaring->cardinality();
} else if (_metric == Metric::IP) {
- // For L2, we can use the distance directly.
+ // For IP, we can use the distance directly.
+ // range search on ip gets all vectors with inner product greater
than or equal to the radius.
+ // so we need to do a convertion.
+ const roaring::Roaring& origin_row_ids = *params.roaring;
+ std::shared_ptr<roaring::Roaring> roaring =
std::make_shared<roaring::Roaring>();
for (size_t i = begin; i < end; ++i) {
- (*row_ids)[i] = native_search_result.labels[i];
roaring->add(native_search_result.labels[i]);
- distances[i - begin] = native_search_result.distances[i];
}
+ result.roaring = std::make_shared<roaring::Roaring>();
+ // remove all rows that should not be included.
+ *(result.roaring) = origin_row_ids - *roaring;
+ // Just update the roaring. distance can not be used.
} else {
throw doris::Exception(doris::ErrorCode::INVALID_ARGUMENT,
"Unsupported metric type: {}",
static_cast<int>(_metric));
}
-
- result.distances = std::move(distances_ptr);
- result.row_ids = std::move(row_ids);
- result.roaring = roaring;
-
- DCHECK(result.row_ids->size() == result.roaring->cardinality())
- << "row_ids size: " << result.row_ids->size()
- << ", roaring size: " << result.roaring->cardinality();
} else {
- // Faiss can only return labels in the range of radius.
- // If the precidate is not less than, we need to to a convertion.
- DCHECK(params.roaring != nullptr);
- if (params.roaring == nullptr) {
- return doris::Status::InvalidArgument("Row ids should not be
null");
- } else {
+ if (_metric == Metric::L2) {
+ // Faiss can only return labels in the range of radius.
+ // If the precidate is not less than, we need to to a convertion.
const roaring::Roaring& origin_row_ids = *params.roaring;
std::shared_ptr<roaring::Roaring> roaring =
std::make_shared<roaring::Roaring>();
for (size_t i = begin; i < end; ++i) {
@@ -334,6 +341,31 @@ doris::Status FaissVectorIndex::range_search(const float*
query_vec, const float
*(result.roaring) = origin_row_ids - *roaring;
result.distances = nullptr;
result.row_ids = nullptr;
+ } else if (_metric == Metric::IP) {
+ // For inner product, we can use the distance directly.
+ // range search on ip gets all vectors with inner product greater
than or equal to the radius.
+ // when query condition is not le_or_lt, we can use the roaring
and distance directly.
+ std::unique_ptr<float[]> distances_ptr =
std::make_unique<float[]>(end - begin);
+ float* distances = distances_ptr.get();
+ auto roaring = std::make_shared<roaring::Roaring>();
+ // The distance returned by Faiss is actually the squared distance.
+ // So we need to take the square root of the squared distance.
+ for (size_t i = begin; i < end; ++i) {
+ (*row_ids)[i] = native_search_result.labels[i];
+ roaring->add(native_search_result.labels[i]);
+ distances[i - begin] = native_search_result.distances[i];
+ }
+ result.distances = std::move(distances_ptr);
+ result.row_ids = std::move(row_ids);
+ result.roaring = roaring;
+
+ DCHECK(result.row_ids->size() == result.roaring->cardinality())
+ << "row_ids size: " << result.row_ids->size()
+ << ", roaring size: " << result.roaring->cardinality();
+
+ } else {
+ throw doris::Exception(doris::ErrorCode::INVALID_ARGUMENT,
+ "Unsupported metric type: {}",
static_cast<int>(_metric));
}
}
@@ -341,17 +373,29 @@ doris::Status FaissVectorIndex::range_search(const float*
query_vec, const float
}
doris::Status FaissVectorIndex::save(lucene::store::Directory* dir) {
+ auto start_time = std::chrono::high_resolution_clock::now();
+
lucene::store::IndexOutput* idx_output = dir->createOutput("faiss.idx");
auto writer = std::make_unique<FaissIndexWriter>(idx_output);
faiss::write_index(_index.get(), writer.get());
- VLOG_DEBUG << fmt::format("Faiss index saved to faiss.idx, rows {}",
_index->ntotal);
+
+ auto end_time = std::chrono::high_resolution_clock::now();
+ auto duration =
std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time);
+ LOG_INFO(fmt::format("Faiss index saved to {}, {}, rows {}, cost {} ms",
dir->toString(),
+ "faiss.idx", _index->ntotal, duration.count()));
return doris::Status::OK();
}
doris::Status FaissVectorIndex::load(lucene::store::Directory* dir) {
+ LOG_INFO("Loading Faiss index from: {}", dir->getObjectName());
+ auto start_time = std::chrono::high_resolution_clock::now();
lucene::store::IndexInput* idx_input = dir->openInput("faiss.idx");
auto reader = std::make_unique<FaissIndexReader>(idx_input);
faiss::Index* idx = faiss::read_index(reader.get());
+ auto end_time = std::chrono::high_resolution_clock::now();
+ auto duration =
std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time);
+ LOG_INFO("Load index from {} costs {} ms, rows {}", dir->getObjectName(),
duration.count(),
+ idx->ntotal);
_index.reset(idx);
return doris::Status::OK();
}
diff --git a/be/src/vector/vector_index.h b/be/src/vector/vector_index.h
index 63f9efa97ad..2f49a5c631e 100644
--- a/be/src/vector/vector_index.h
+++ b/be/src/vector/vector_index.h
@@ -17,11 +17,9 @@
#pragma once
-#include <memory>
#include <roaring/roaring.hh>
#include "common/status.h"
-#include "vec/functions/array/function_array_distance.h"
#include "vector/metric.h"
namespace lucene::store {
diff --git a/be/test/olap/vector_search/ann_topn_descriptor_test.cpp
b/be/test/olap/vector_search/ann_topn_descriptor_test.cpp
index 376f10d7a79..edefb4e7391 100644
--- a/be/test/olap/vector_search/ann_topn_descriptor_test.cpp
+++ b/be/test/olap/vector_search/ann_topn_descriptor_test.cpp
@@ -27,6 +27,8 @@
#include <memory>
#include "olap/rowset/segment_v2/ann_index/ann_search_params.h"
+#include "runtime/primitive_type.h"
+#include "vec/columns/column_nullable.h"
#include "vec/exprs/ann_topn_runtime.h"
#include "vec/exprs/virtual_slot_ref.h"
#include "vector_search_utils.h"
@@ -164,13 +166,15 @@ TEST_F(VectorSearchTest, AnnTopNRuntimeEvaluateTopN) {
return Status::OK();
}));
- _result_column = ColumnFloat64::create(0, 0);
+ _result_column = ColumnNullable::create(ColumnFloat64::create(0, 0),
ColumnUInt8::create(0, 0));
std::unique_ptr<std::vector<uint64_t>> row_ids =
std::make_unique<std::vector<uint64_t>>();
roaring::Roaring roaring;
st = predicate->evaluate_vector_ann_search(_ann_index_iterator.get(),
roaring, _result_column,
row_ids);
- ColumnFloat64* result_column_float =
assert_cast<ColumnFloat64*>(_result_column.get());
+ ColumnNullable* result_column_null =
assert_cast<ColumnNullable*>(_result_column.get());
+ ColumnFloat64* result_column_float =
+
assert_cast<ColumnFloat64*>(result_column_null->get_nested_column_ptr().get());
for (size_t i = 0; i < query_vector->size(); ++i) {
EXPECT_EQ(result_column_float->get_data()[i], (*query_vector)[i]);
}
diff --git a/be/test/olap/vector_search/faiss_vector_index_test.cpp
b/be/test/olap/vector_search/faiss_vector_index_test.cpp
index 1526e1800ce..dd34656cc7e 100644
--- a/be/test/olap/vector_search/faiss_vector_index_test.cpp
+++ b/be/test/olap/vector_search/faiss_vector_index_test.cpp
@@ -92,6 +92,11 @@ TEST_F(VectorSearchTest, TestSaveAndLoad) {
}
HNSWSearchParameters hnsw_params;
+ auto roaring_bitmap = std::make_unique<roaring::Roaring>();
+ hnsw_params.roaring = roaring_bitmap.get();
+ for (size_t i = 0; i < num_vectors; ++i) {
+ hnsw_params.roaring->add(i);
+ }
IndexSearchResult range_search_result1;
std::ignore = index1->range_search(vectors.data(), 10, hnsw_params,
range_search_result1);
IndexSearchResult range_search_result2;
@@ -298,12 +303,9 @@ TEST_F(VectorSearchTest, SearchAllVectors) {
}
TEST_F(VectorSearchTest, CompRangeSearch) {
- size_t iterations = 25;
- // 支持的metric类型集合
- std::vector<faiss::MetricType> metrics = {
- faiss::METRIC_L2, faiss::METRIC_INNER_PRODUCT
- // 如有更多metric可继续添加
- };
+ size_t iterations = 10;
+ // std::vector<faiss::MetricType> metrics = {faiss::METRIC_L2,
faiss::METRIC_INNER_PRODUCT};
+ std::vector<faiss::MetricType> metrics = {faiss::METRIC_INNER_PRODUCT};
for (size_t i = 0; i < iterations; ++i) {
for (auto metric : metrics) {
// Random parameters for each test iteration
@@ -334,7 +336,7 @@ TEST_F(VectorSearchTest, CompRangeSearch) {
auto vec =
vector_search_utils::generate_random_vector(params.d);
vectors.push_back(vec);
}
- // 创建native index时指定metric
+
std::unique_ptr<faiss::Index> native_index = nullptr;
if (metric == faiss::METRIC_L2) {
native_index =
std::make_unique<faiss::IndexHNSWFlat>(params.d, params.m,
@@ -345,19 +347,24 @@ TEST_F(VectorSearchTest, CompRangeSearch) {
} else {
throw std::runtime_error(fmt::format("Unsupported metric type:
{}", metric));
}
+
doris::vector_search_utils::add_vectors_to_indexes_serial_mode(
doris_index.get(), native_index.get(), vectors);
std::vector<float> query_vec = vectors.front();
float radius = 0;
-
radius =
doris::vector_search_utils::get_radius_from_matrix(query_vec.data(), params.d,
vectors, 0.4f, metric);
HNSWSearchParameters hnsw_params;
hnsw_params.ef_search = 16;
- hnsw_params.roaring = nullptr;
- hnsw_params.is_le_or_lt = true;
+ // Search on all rows;
+ auto roaring = std::make_unique<roaring::Roaring>();
+ hnsw_params.roaring = roaring.get();
+ for (size_t i = 0; i < vectors.size(); i++) {
+ hnsw_params.roaring->add(i);
+ }
+ hnsw_params.is_le_or_lt = metric == faiss::METRIC_L2;
IndexSearchResult doris_result;
std::ignore =
doris_index->range_search(query_vec.data(), radius,
hnsw_params, doris_result);
@@ -402,7 +409,7 @@ TEST_F(VectorSearchTest, CompRangeSearch) {
}
}
-TEST_F(VectorSearchTest, RangeSearchNoSelector1) {
+TEST_F(VectorSearchTest, RangeSearchAllRowsAsCandidates) {
size_t iterations = 5;
// Random parameters for each test iteration
@@ -461,6 +468,12 @@ TEST_F(VectorSearchTest, RangeSearchNoSelector1) {
// Use a vector we know is in the index
faiss::SearchParametersHNSW search_params;
+ std::unique_ptr<roaring::Roaring> all_rows =
std::make_unique<roaring::Roaring>();
+ for (size_t i = 0; i < num_vectors; ++i) {
+ all_rows->add(i);
+ }
+ auto sel = FaissVectorIndex::roaring_to_faiss_selector(*all_rows);
+ search_params.sel = sel.get();
search_params.efSearch = 16; // Set efSearch for better accuracy
faiss::RangeSearchResult native_search_result(1, true);
native_index->range_search(1, query_vec.data(), radius * radius,
&native_search_result,
@@ -476,6 +489,7 @@ TEST_F(VectorSearchTest, RangeSearchNoSelector1) {
HNSWSearchParameters doris_search_params;
doris_search_params.ef_search = 16; // Set efSearch for better accuracy
+ doris_search_params.roaring = all_rows.get();
IndexSearchResult search_result1;
IndexSearchResult search_result2;
@@ -507,7 +521,6 @@ TEST_F(VectorSearchTest, RangeSearchNoSelector1) {
}
doris_search_params.is_le_or_lt = false;
- std::unique_ptr<roaring::Roaring> all_rows =
std::make_unique<roaring::Roaring>();
doris_search_params.roaring = all_rows.get();
for (size_t i = 0; i < num_vectors; ++i) {
doris_search_params.roaring->add(i);
@@ -637,7 +650,7 @@ TEST_F(VectorSearchTest, RangeSearchWithSelector1) {
}
TEST_F(VectorSearchTest, RangeSearchEmptyResult) {
- for (size_t i = 0; i < 10; ++i) {
+ for (size_t i = 0; i < 5; ++i) {
const size_t d = 10;
const size_t m = 32;
const int num_vectors = 1000;
@@ -677,6 +690,11 @@ TEST_F(VectorSearchTest, RangeSearchEmptyResult) {
float radius = 5.0f;
doris::vectorized::HNSWSearchParameters search_params;
search_params.ef_search = 1000; // Set efSearch for better accuracy
+ std::unique_ptr<roaring::Roaring> sel_rows =
std::make_unique<roaring::Roaring>();
+ for (size_t i = 0; i < num_vectors; ++i) {
+ sel_rows->add(i);
+ }
+ search_params.roaring = sel_rows.get();
auto doris_search_result =
vector_search_utils::perform_doris_index_range_search(
index1.get(), query_vec.data(), radius, search_params);
auto native_search_result =
vector_search_utils::perform_native_index_range_search(
@@ -689,10 +707,6 @@ TEST_F(VectorSearchTest, RangeSearchEmptyResult) {
doris::vectorized::HNSWSearchParameters search_params_all_rows;
search_params_all_rows.ef_search = 1000; // Set efSearch for better
accuracy
search_params_all_rows.is_le_or_lt = true;
- std::unique_ptr<roaring::Roaring> sel_rows =
std::make_unique<roaring::Roaring>();
- for (size_t i = 0; i < num_vectors; ++i) {
- sel_rows->add(i);
- }
search_params_all_rows.roaring = sel_rows.get();
doris_search_result =
vector_search_utils::perform_doris_index_range_search(
index1.get(), query_vec.data(), radius,
search_params_all_rows);
diff --git a/be/test/olap/vector_search/vector_search_utils.h
b/be/test/olap/vector_search/vector_search_utils.h
index e28a79ed3a0..8c7444c9846 100644
--- a/be/test/olap/vector_search/vector_search_utils.h
+++ b/be/test/olap/vector_search/vector_search_utils.h
@@ -38,6 +38,7 @@
#include "olap/tablet_schema.h"
#include "runtime/descriptors.h"
#include "vec/exprs/vexpr_context.h"
+#include "vec/utils/util.hpp"
#include "vector_index.h"
// Add CLucene RAM Directory header
#include <CLucene/store/RAMDirectory.h>
@@ -90,7 +91,7 @@ float get_radius_from_flatten(const float* vector, int dim,
const std::vector<float>& flatten_vectors, float
percentile);
float get_radius_from_matrix(const float* vector, int dim,
const std::vector<std::vector<float>>&
matrix_vectors,
- float percentile);
+ float percentile, faiss::MetricType metric_type =
faiss::METRIC_L2);
// Helper function to compare search results between Doris and native Faiss
void compare_search_results(const vectorized::IndexSearchResult& doris_results,
const std::vector<float>& native_distances,
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownVectorTopNIntoOlapScan.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownVectorTopNIntoOlapScan.java
index d8b7ba55daf..b8c8bfd8f46 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownVectorTopNIntoOlapScan.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownVectorTopNIntoOlapScan.java
@@ -28,6 +28,7 @@ import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
+import
org.apache.doris.nereids.trees.expressions.functions.scalar.InnerProductApproximate;
import
org.apache.doris.nereids.trees.expressions.functions.scalar.L2DistanceApproximate;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
@@ -91,17 +92,38 @@ public class PushDownVectorTopNIntoOlapScan implements
RewriteRuleFactory {
if (orderKeyExpr == null) {
return null;
}
- if (!(orderKeyExpr instanceof L2DistanceApproximate)) {
+
+ boolean l2Dist;
+ boolean innerProduct;
+ l2Dist = orderKeyExpr instanceof L2DistanceApproximate;
+ innerProduct = orderKeyExpr instanceof InnerProductApproximate;
+ if (!(l2Dist) && !(innerProduct)) {
return null;
}
- L2DistanceApproximate l2DistanceApproximate = (L2DistanceApproximate)
orderKeyExpr;
- Expression left = l2DistanceApproximate.left();
+
+ Expression left = null;
+ if (l2Dist) {
+ L2DistanceApproximate l2DistanceApproximate =
(L2DistanceApproximate) orderKeyExpr;
+ left = l2DistanceApproximate.left();
+ } else {
+ InnerProductApproximate innerProductApproximate =
(InnerProductApproximate) orderKeyExpr;
+ left = innerProductApproximate.left();
+ }
+
while (left instanceof Cast) {
left = ((Cast) left).child();
}
- if (!(left instanceof SlotReference && ((L2DistanceApproximate)
orderKeyExpr).right().isConstant())) {
- return null;
+
+ if (l2Dist) {
+ if (!(left instanceof SlotReference && ((L2DistanceApproximate)
orderKeyExpr).right().isConstant())) {
+ return null;
+ }
+ } else {
+ if (!(left instanceof SlotReference && ((InnerProductApproximate)
orderKeyExpr).right().isConstant())) {
+ return null;
+ }
}
+
SlotReference leftInput = (SlotReference) left;
if (!leftInput.getOriginalColumn().isPresent() ||
!leftInput.getOriginalTable().isPresent()) {
return null;
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/InnerProductApproximate.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/InnerProductApproximate.java
index bce8e038e78..8c66e396538 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/InnerProductApproximate.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/InnerProductApproximate.java
@@ -22,6 +22,7 @@ import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
import
org.apache.doris.nereids.trees.expressions.functions.ComputePrecisionForArrayItemAgg;
import
org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
+import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
@@ -36,7 +37,7 @@ import java.util.List;
* inner_product function
*/
public class InnerProductApproximate extends ScalarFunction implements
ExplicitlyCastableSignature,
- ComputePrecisionForArrayItemAgg, UnaryExpression, AlwaysNullable {
+ ComputePrecisionForArrayItemAgg, BinaryExpression, AlwaysNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]