This is an automated email from the ASF dual-hosted git repository. yiguolei pushed a commit to branch branch-2.1 in repository https://gitbox.apache.org/repos/asf/doris.git
commit d36ad56dce2ba3c15330cd3e7faa9b238bf37786 Author: Pxl <[email protected]> AuthorDate: Wed Feb 28 21:40:12 2024 +0800 [Opt](Exec) Support runtime update topn filter (#31250) --- be/src/olap/accept_null_predicate.h | 17 +- be/src/olap/column_predicate.h | 5 - be/src/olap/comparison_predicate.h | 31 +--- be/src/olap/rowset/segment_v2/segment.cpp | 34 ++-- be/src/olap/rowset/segment_v2/segment_iterator.cpp | 39 ++--- be/src/olap/rowset/segment_v2/segment_iterator.h | 2 - be/src/olap/shared_predicate.h | 179 +++++++++++++++++++++ be/src/olap/tablet_schema.cpp | 4 +- be/src/pipeline/exec/scan_operator.cpp | 13 -- be/src/pipeline/exec/sort_sink_operator.cpp | 18 +-- be/src/runtime/runtime_predicate.cpp | 104 +++++------- be/src/runtime/runtime_predicate.h | 120 ++++---------- be/src/vec/exec/vsort_node.cpp | 23 ++- be/src/vec/olap/vcollect_iterator.cpp | 9 +- 14 files changed, 298 insertions(+), 300 deletions(-) diff --git a/be/src/olap/accept_null_predicate.h b/be/src/olap/accept_null_predicate.h index 3d6103e81cd..35d2546582b 100644 --- a/be/src/olap/accept_null_predicate.h +++ b/be/src/olap/accept_null_predicate.h @@ -44,6 +44,8 @@ public: PredicateType type() const override { return _nested->type(); } + void set_nested(ColumnPredicate* nested) { _nested.reset(nested); } + Status evaluate(BitmapIndexIterator* iterator, uint32_t num_rows, roaring::Roaring* roaring) const override { return _nested->evaluate(iterator, num_rows, roaring); @@ -149,21 +151,6 @@ public: std::string get_search_str() const override { return _nested->get_search_str(); } - std::string debug_string() const override { - return "passnull predicate for " + _nested->debug_string(); - } - - /// Some predicates need to be cloned for each segment. - bool need_to_clone() const override { return _nested->need_to_clone(); } - - void clone(ColumnPredicate** to) const override { - if (need_to_clone()) { - ColumnPredicate* clone_nested; - _nested->clone(&clone_nested); - *to = new AcceptNullPredicate(clone_nested); - } - } - private: uint16_t _evaluate_inner(const vectorized::IColumn& column, uint16_t* sel, uint16_t size) const override { diff --git a/be/src/olap/column_predicate.h b/be/src/olap/column_predicate.h index adbb9a695ed..d549540e5c7 100644 --- a/be/src/olap/column_predicate.h +++ b/be/src/olap/column_predicate.h @@ -267,11 +267,6 @@ public: ", opposite=" + (_opposite ? "true" : "false"); } - /// Some predicates need to be cloned for each segment. - virtual bool need_to_clone() const { return false; } - - virtual void clone(ColumnPredicate** to) const { LOG(FATAL) << "clone not supported"; } - virtual int get_filter_id() const { return -1; } // now InListPredicateBase BloomFilterColumnPredicate BitmapFilterColumnPredicate = true virtual bool is_filter() const { return false; } diff --git a/be/src/olap/comparison_predicate.h b/be/src/olap/comparison_predicate.h index 17b334d7b8d..826b1414b2a 100644 --- a/be/src/olap/comparison_predicate.h +++ b/be/src/olap/comparison_predicate.h @@ -34,25 +34,12 @@ class ComparisonPredicateBase : public ColumnPredicate { public: using T = typename PrimitiveTypeTraits<Type>::CppType; ComparisonPredicateBase(uint32_t column_id, const T& value, bool opposite = false) - : ColumnPredicate(column_id, opposite), - _cached_code(_InvalidateCodeValue), - _value(value) {} - - void clone(ColumnPredicate** to) const override { - auto* cloned = new ComparisonPredicateBase(_column_id, _value, _opposite); - cloned->predicate_params()->value = _predicate_params->value; - cloned->_cache_code_enabled = true; - cloned->predicate_params()->marked_by_runtime_filter = - _predicate_params->marked_by_runtime_filter; - *to = cloned; - } + : ColumnPredicate(column_id, opposite), _value(value) {} bool can_do_apply_safely(PrimitiveType input_type, bool is_null) const override { return input_type == Type || (is_string_type(input_type) && is_string_type(Type)); } - bool need_to_clone() const override { return true; } - PredicateType type() const override { return PT; } Status evaluate(BitmapIndexIterator* iterator, uint32_t num_rows, @@ -591,16 +578,10 @@ private: __attribute__((flatten)) int32_t _find_code_from_dictionary_column( const vectorized::ColumnDictI32& column) const { - /// if _cache_code_enabled is false, always find the code from dict. - if (UNLIKELY(!_cache_code_enabled || _cached_code == _InvalidateCodeValue)) { + if (!_segment_id_to_cached_code.contains(column.get_rowset_segment_id())) { int32_t code = _is_range() ? column.find_code_by_bound(_value, _is_greater(), _is_eq()) : column.find_code(_value); - // Protect the invalid code logic, to avoid data error. - if (code == _InvalidateCodeValue) { - LOG(FATAL) << "column dictionary should not return the code " << code - << ", because it is assumed as an invalid code in comparison predicate"; - } // Sometimes the dict is not initialized when run comparison predicate here, for example, // the full page is null, then the reader will skip read, so that the dictionary is not // inited. The cached code is wrong during this case, because the following page maybe not @@ -612,9 +593,9 @@ private: return code; } // If the dict is not empty, then the dict is inited and we could cache the value. - _cached_code = code; + _segment_id_to_cached_code[column.get_rowset_segment_id()] = code; } - return _cached_code; + return _segment_id_to_cached_code[column.get_rowset_segment_id()]; } std::string _debug_string() const override { @@ -623,9 +604,7 @@ private: return info; } - static constexpr int32_t _InvalidateCodeValue = std::numeric_limits<int32_t>::max(); - mutable int32_t _cached_code; - bool _cache_code_enabled = false; + mutable std::map<std::pair<RowsetId, uint32_t>, int32_t> _segment_id_to_cached_code; T _value; }; diff --git a/be/src/olap/rowset/segment_v2/segment.cpp b/be/src/olap/rowset/segment_v2/segment.cpp index f63af78358b..29fc93e9243 100644 --- a/be/src/olap/rowset/segment_v2/segment.cpp +++ b/be/src/olap/rowset/segment_v2/segment.cpp @@ -150,24 +150,22 @@ Status Segment::new_iterator(SchemaSPtr schema, const StorageReadOptions& read_o } } if (read_options.use_topn_opt) { - auto query_ctx = read_options.runtime_state->get_query_ctx(); - auto runtime_predicate = query_ctx->get_runtime_predicate().get_predictate(); - if (runtime_predicate) { - // TODO handle var path - int32_t uid = - read_options.tablet_schema->column(runtime_predicate->column_id()).unique_id(); - AndBlockColumnPredicate and_predicate; - auto single_predicate = new SingleColumnBlockPredicate(runtime_predicate.get()); - and_predicate.add_column_predicate(single_predicate); - if (_column_readers.count(uid) >= 1 && - can_apply_predicate_safely(runtime_predicate->column_id(), runtime_predicate.get(), - *schema, read_options.io_ctx.reader_type) && - !_column_readers.at(uid)->match_condition(&and_predicate)) { - // any condition not satisfied, return. - iter->reset(new EmptySegmentIterator(*schema)); - read_options.stats->filtered_segment_number++; - return Status::OK(); - } + auto* query_ctx = read_options.runtime_state->get_query_ctx(); + auto runtime_predicate = query_ctx->get_runtime_predicate().get_predicate(); + + int32_t uid = + read_options.tablet_schema->column(runtime_predicate->column_id()).unique_id(); + AndBlockColumnPredicate and_predicate; + auto* single_predicate = new SingleColumnBlockPredicate(runtime_predicate.get()); + and_predicate.add_column_predicate(single_predicate); + if (_column_readers.contains(uid) && + can_apply_predicate_safely(runtime_predicate->column_id(), runtime_predicate.get(), + *schema, read_options.io_ctx.reader_type) && + !_column_readers.at(uid)->match_condition(&and_predicate)) { + // any condition not satisfied, return. + *iter = std::make_unique<EmptySegmentIterator>(*schema); + read_options.stats->filtered_segment_number++; + return Status::OK(); } } diff --git a/be/src/olap/rowset/segment_v2/segment_iterator.cpp b/be/src/olap/rowset/segment_v2/segment_iterator.cpp index f1cb1f2c61c..360e4cac2d6 100644 --- a/be/src/olap/rowset/segment_v2/segment_iterator.cpp +++ b/be/src/olap/rowset/segment_v2/segment_iterator.cpp @@ -283,19 +283,12 @@ Status SegmentIterator::_init_impl(const StorageReadOptions& opts) { _opts = opts; _col_predicates.clear(); - for (auto& predicate : opts.column_predicates) { + for (const auto& predicate : opts.column_predicates) { if (!_segment->can_apply_predicate_safely(predicate->column_id(), predicate, *_schema, _opts.io_ctx.reader_type)) { continue; } - if (predicate->need_to_clone()) { - ColumnPredicate* cloned; - predicate->clone(&cloned); - _pool->add(cloned); - _col_predicates.emplace_back(cloned); - } else { - _col_predicates.emplace_back(predicate); - } + _col_predicates.emplace_back(predicate); } _tablet_id = opts.tablet_id; // Read options will not change, so that just resize here @@ -303,7 +296,7 @@ Status SegmentIterator::_init_impl(const StorageReadOptions& opts) { // compound predicates _col_preds_except_leafnode_of_andnode.clear(); - for (auto& predicate : opts.column_predicates_except_leafnode_of_andnode) { + for (const auto& predicate : opts.column_predicates_except_leafnode_of_andnode) { if (!_segment->can_apply_predicate_safely(predicate->column_id(), predicate, *_schema, _opts.io_ctx.reader_type)) { continue; @@ -520,14 +513,8 @@ Status SegmentIterator::_get_row_ranges_by_column_conditions() { RETURN_IF_ERROR(_apply_bitmap_index()); RETURN_IF_ERROR(_apply_inverted_index()); - std::shared_ptr<doris::ColumnPredicate> runtime_predicate = nullptr; - if (_opts.use_topn_opt) { - auto* query_ctx = _opts.runtime_state->get_query_ctx(); - runtime_predicate = query_ctx->get_runtime_predicate().get_predictate(); - } - if (!_row_bitmap.isEmpty() && - (runtime_predicate || !_opts.col_id_to_predicates.empty() || + (_opts.use_topn_opt || !_opts.col_id_to_predicates.empty() || _opts.delete_condition_predicates->num_of_column_predicate() > 0)) { RowRanges condition_row_ranges = RowRanges::create_single(_segment->num_rows()); RETURN_IF_ERROR(_get_row_ranges_from_conditions(&condition_row_ranges)); @@ -604,17 +591,16 @@ Status SegmentIterator::_get_row_ranges_from_conditions(RowRanges* condition_row RowRanges::ranges_intersection(*condition_row_ranges, zone_map_row_ranges, condition_row_ranges); - std::shared_ptr<doris::ColumnPredicate> runtime_predicate = nullptr; if (_opts.use_topn_opt) { SCOPED_RAW_TIMER(&_opts.stats->block_conditions_filtered_zonemap_ns); - auto query_ctx = _opts.runtime_state->get_query_ctx(); - runtime_predicate = query_ctx->get_runtime_predicate().get_predictate(); - if (runtime_predicate && - _segment->can_apply_predicate_safely(runtime_predicate->column_id(), + auto* query_ctx = _opts.runtime_state->get_query_ctx(); + std::shared_ptr<doris::ColumnPredicate> runtime_predicate = + query_ctx->get_runtime_predicate().get_predicate(); + if (_segment->can_apply_predicate_safely(runtime_predicate->column_id(), runtime_predicate.get(), *_schema, _opts.io_ctx.reader_type)) { AndBlockColumnPredicate and_predicate; - auto single_predicate = new SingleColumnBlockPredicate(runtime_predicate.get()); + auto* single_predicate = new SingleColumnBlockPredicate(runtime_predicate.get()); and_predicate.add_column_predicate(single_predicate); RowRanges column_rp_row_ranges = RowRanges::create_single(num_rows()); @@ -1522,12 +1508,9 @@ Status SegmentIterator::_vec_init_lazy_materialization() { // should add add for order by none-key column, since none-key column is not sorted and // all rows should be read, so runtime predicate will reduce rows for topn node if (_opts.use_topn_opt && - !(_opts.read_orderby_key_columns != nullptr && !_opts.read_orderby_key_columns->empty())) { + (_opts.read_orderby_key_columns == nullptr || _opts.read_orderby_key_columns->empty())) { auto& runtime_predicate = _opts.runtime_state->get_query_ctx()->get_runtime_predicate(); - _runtime_predicate = runtime_predicate.get_predictate(); - if (_runtime_predicate) { - _col_predicates.push_back(_runtime_predicate.get()); - } + _col_predicates.push_back(runtime_predicate.get_predicate().get()); } // Step1: extract columns that can be lazy materialization diff --git a/be/src/olap/rowset/segment_v2/segment_iterator.h b/be/src/olap/rowset/segment_v2/segment_iterator.h index 01d34b8cddf..fb039246384 100644 --- a/be/src/olap/rowset/segment_v2/segment_iterator.h +++ b/be/src/olap/rowset/segment_v2/segment_iterator.h @@ -451,8 +451,6 @@ private: _column_pred_in_remaining_vconjunct; std::set<ColumnId> _not_apply_index_pred; - std::shared_ptr<ColumnPredicate> _runtime_predicate; - // row schema of the key to seek // only used in `_get_row_ranges_by_keys` std::unique_ptr<Schema> _seek_schema; diff --git a/be/src/olap/shared_predicate.h b/be/src/olap/shared_predicate.h new file mode 100644 index 00000000000..83c4ae62515 --- /dev/null +++ b/be/src/olap/shared_predicate.h @@ -0,0 +1,179 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cstdint> +#include <memory> + +#include "common/factory_creator.h" +#include "olap/column_predicate.h" +#include "olap/rowset/segment_v2/bloom_filter.h" +#include "olap/rowset/segment_v2/inverted_index_reader.h" +#include "olap/wrapper_field.h" +#include "vec/columns/column_dictionary.h" + +namespace doris { + +// SharedPredicate only used on topn runtime predicate. +// Runtime predicate globally share one predicate, to ensure that updates can be real-time. +// At the beginning nested predicate may be nullptr, in which case predicate always returns true. +class SharedPredicate : public ColumnPredicate { + ENABLE_FACTORY_CREATOR(SharedPredicate); + +public: + SharedPredicate(uint32_t column_id) : ColumnPredicate(column_id) {} + + PredicateType type() const override { + std::shared_lock<std::shared_mutex> lock(_mtx); + if (!_nested) { + // topn filter is le or ge + return PredicateType::LE; + } + return _nested->type(); + } + + void set_nested(ColumnPredicate* nested) { + std::unique_lock<std::shared_mutex> lock(_mtx); + _nested.reset(nested); + } + + Status evaluate(BitmapIndexIterator* iterator, uint32_t num_rows, + roaring::Roaring* roaring) const override { + std::shared_lock<std::shared_mutex> lock(_mtx); + if (!_nested) { + return Status::OK(); + } + return _nested->evaluate(iterator, num_rows, roaring); + } + + Status evaluate(const vectorized::NameAndTypePair& name_with_type, + InvertedIndexIterator* iterator, uint32_t num_rows, + roaring::Roaring* bitmap) const override { + std::shared_lock<std::shared_mutex> lock(_mtx); + if (!_nested) { + return Status::OK(); + } + return _nested->evaluate(name_with_type, iterator, num_rows, bitmap); + } + + bool can_do_apply_safely(PrimitiveType input_type, bool is_null) const override { + std::shared_lock<std::shared_mutex> lock(_mtx); + if (!_nested) { + return true; + } + return _nested->can_do_apply_safely(input_type, is_null); + } + + void evaluate_and(const vectorized::IColumn& column, const uint16_t* sel, uint16_t size, + bool* flags) const override { + std::shared_lock<std::shared_mutex> lock(_mtx); + if (!_nested) { + return; + } + return _nested->evaluate_and(column, sel, size, flags); + } + + void evaluate_or(const vectorized::IColumn& column, const uint16_t* sel, uint16_t size, + bool* flags) const override { + DCHECK(false) << "should not reach here"; + } + + bool evaluate_and(const std::pair<WrapperField*, WrapperField*>& statistic) const override { + std::shared_lock<std::shared_mutex> lock(_mtx); + if (!_nested) { + return ColumnPredicate::evaluate_and(statistic); + } + return _nested->evaluate_and(statistic); + } + + bool evaluate_del(const std::pair<WrapperField*, WrapperField*>& statistic) const override { + std::shared_lock<std::shared_mutex> lock(_mtx); + if (!_nested) { + return ColumnPredicate::evaluate_del(statistic); + } + return _nested->evaluate_del(statistic); + } + + bool evaluate_and(const BloomFilter* bf) const override { + std::shared_lock<std::shared_mutex> lock(_mtx); + if (!_nested) { + return ColumnPredicate::evaluate_and(bf); + } + return _nested->evaluate_and(bf); + } + + bool can_do_bloom_filter(bool ngram) const override { + std::shared_lock<std::shared_mutex> lock(_mtx); + if (!_nested) { + return ColumnPredicate::can_do_bloom_filter(ngram); + } + return _nested->can_do_bloom_filter(ngram); + } + + void evaluate_vec(const vectorized::IColumn& column, uint16_t size, + bool* flags) const override { + std::shared_lock<std::shared_mutex> lock(_mtx); + if (!_nested) { + for (uint16_t i = 0; i < size; ++i) { + flags[i] = true; + } + return; + } + _nested->evaluate_vec(column, size, flags); + } + + void evaluate_and_vec(const vectorized::IColumn& column, uint16_t size, + bool* flags) const override { + std::shared_lock<std::shared_mutex> lock(_mtx); + if (!_nested) { + return; + } + _nested->evaluate_and_vec(column, size, flags); + } + + std::string get_search_str() const override { + std::shared_lock<std::shared_mutex> lock(_mtx); + if (!_nested) { + DCHECK(false) << "should not reach here"; + } + return _nested->get_search_str(); + } + +private: + uint16_t _evaluate_inner(const vectorized::IColumn& column, uint16_t* sel, + uint16_t size) const override { + std::shared_lock<std::shared_mutex> lock(_mtx); + if (!_nested) { + return size; + } + return _nested->evaluate(column, sel, size); + } + + std::string _debug_string() const override { + std::shared_lock<std::shared_mutex> lock(_mtx); + if (!_nested) { + return "shared_predicate<unknow>"; + } + return "shared_predicate<" + _nested->debug_string() + ">"; + } + + mutable std::shared_mutex _mtx; + std::shared_ptr<ColumnPredicate> _nested; +}; + +} //namespace doris diff --git a/be/src/olap/tablet_schema.cpp b/be/src/olap/tablet_schema.cpp index 79806c4703e..78c7e694c6f 100644 --- a/be/src/olap/tablet_schema.cpp +++ b/be/src/olap/tablet_schema.cpp @@ -1198,7 +1198,7 @@ void TabletSchema::update_indexes_from_thrift(const std::vector<doris::TOlapTabl } Status TabletSchema::have_column(const std::string& field_name) const { - if (!_field_name_to_index.count(field_name)) { + if (!_field_name_to_index.contains(field_name)) { return Status::Error<ErrorCode::INTERNAL_ERROR>( "Not found field_name, field_name:{}, schema:{}", field_name, get_all_field_names()); @@ -1207,7 +1207,7 @@ Status TabletSchema::have_column(const std::string& field_name) const { } const TabletColumn& TabletSchema::column(const std::string& field_name) const { - DCHECK(_field_name_to_index.count(field_name) != 0) + DCHECK(_field_name_to_index.contains(field_name)) << ", field_name=" << field_name << ", field_name_to_index=" << get_all_field_names(); const auto& found = _field_name_to_index.find(field_name); return _cols[found->second]; diff --git a/be/src/pipeline/exec/scan_operator.cpp b/be/src/pipeline/exec/scan_operator.cpp index 08c58d4180d..e3b0fe38a1d 100644 --- a/be/src/pipeline/exec/scan_operator.cpp +++ b/be/src/pipeline/exec/scan_operator.cpp @@ -1219,19 +1219,6 @@ Status ScanLocalState<Derived>::_start_scanners( _scanner_ctx = PipXScannerContext::create_shared( state(), this, p._output_tuple_desc, p.output_row_descriptor(), scanners, p.limit(), state()->scan_queue_mem_limit(), _dependency->shared_from_this()); - if constexpr (std::is_same_v<OlapScanLocalState, Derived>) { - /** - * If `use_topn_opt` is true, - * we let 1/4 scanners run first to update the value of runtime predicate, - * and the other 3/4 scanners could then read fewer rows. - */ - if (static_cast<OlapScanLocalState*>(this)->olap_scan_node().use_topn_opt) { - int32_t max_thread_num = std::max<int32_t>(4, scanners.size() / 4); - if (max_thread_num < _scanner_ctx->get_max_thread_num()) { - _scanner_ctx->set_max_thread_num(max_thread_num); - } - } - } return Status::OK(); } diff --git a/be/src/pipeline/exec/sort_sink_operator.cpp b/be/src/pipeline/exec/sort_sink_operator.cpp index 16d7e6fba06..854b6f165ca 100644 --- a/be/src/pipeline/exec/sort_sink_operator.cpp +++ b/be/src/pipeline/exec/sort_sink_operator.cpp @@ -94,18 +94,19 @@ Status SortSinkOperatorX::init(const TPlanNode& tnode, RuntimeState* state) { // init runtime predicate if (_use_topn_opt) { - auto query_ctx = state->get_query_ctx(); + auto* query_ctx = state->get_query_ctx(); auto first_sort_expr_node = tnode.sort_node.sort_info.ordering_exprs[0].nodes[0]; if (first_sort_expr_node.node_type == TExprNodeType::SLOT_REF) { auto first_sort_slot = first_sort_expr_node.slot_ref; - for (auto tuple_desc : _row_descriptor.tuple_descriptors()) { + for (auto* tuple_desc : _row_descriptor.tuple_descriptors()) { if (tuple_desc->id() != first_sort_slot.tuple_id) { continue; } - for (auto slot : tuple_desc->slots()) { + for (auto* slot : tuple_desc->slots()) { if (slot->id() == first_sort_slot.slot_id) { - RETURN_IF_ERROR(query_ctx->get_runtime_predicate().init(slot->type().type, - _nulls_first[0])); + RETURN_IF_ERROR(query_ctx->get_runtime_predicate().init( + slot->type().type, _nulls_first[0], _is_asc_order[0], + slot->col_name())); break; } } @@ -158,13 +159,8 @@ Status SortSinkOperatorX::sink(doris::RuntimeState* state, vectorized::Block* in if (_use_topn_opt) { vectorized::Field new_top = local_state._shared_state->sorter->get_top_value(); if (!new_top.is_null() && new_top != local_state.old_top) { - const auto& sort_description = - local_state._shared_state->sorter->get_sort_description(); - auto col = in_block->get_by_position(sort_description[0].column_number); - bool is_reverse = sort_description[0].direction < 0; auto* query_ctx = state->get_query_ctx(); - RETURN_IF_ERROR( - query_ctx->get_runtime_predicate().update(new_top, col.name, is_reverse)); + RETURN_IF_ERROR(query_ctx->get_runtime_predicate().update(new_top)); local_state.old_top = std::move(new_top); } } diff --git a/be/src/runtime/runtime_predicate.cpp b/be/src/runtime/runtime_predicate.cpp index 9e300dab4a0..032a5d505c9 100644 --- a/be/src/runtime/runtime_predicate.cpp +++ b/be/src/runtime/runtime_predicate.cpp @@ -26,11 +26,10 @@ #include "olap/column_predicate.h" #include "olap/predicate_creator.h" -namespace doris { +namespace doris::vectorized { -namespace vectorized { - -Status RuntimePredicate::init(const PrimitiveType type, const bool nulls_first) { +Status RuntimePredicate::init(PrimitiveType type, bool nulls_first, bool is_asc, + const std::string& col_name) { std::unique_lock<std::shared_mutex> wlock(_rwlock); if (_inited) { @@ -38,55 +37,53 @@ Status RuntimePredicate::init(const PrimitiveType type, const bool nulls_first) } _nulls_first = nulls_first; - - _predicate_arena.reset(new Arena()); + _is_asc = is_asc; + // For ASC sort, create runtime predicate col_name <= max_top_value + // since values that > min_top_value are large than any value in current topn values + // For DESC sort, create runtime predicate col_name >= min_top_value + // since values that < min_top_value are less than any value in current topn values + _pred_constructor = is_asc ? create_comparison_predicate<PredicateType::LE> + : create_comparison_predicate<PredicateType::GE>; + _col_name = col_name; // set get value function switch (type) { case PrimitiveType::TYPE_BOOLEAN: { - _get_value_fn = get_bool_value; + _get_value_fn = get_normal_value<TYPE_BOOLEAN>; break; } case PrimitiveType::TYPE_TINYINT: { - _get_value_fn = get_tinyint_value; + _get_value_fn = get_normal_value<TYPE_TINYINT>; break; } case PrimitiveType::TYPE_SMALLINT: { - _get_value_fn = get_smallint_value; + _get_value_fn = get_normal_value<TYPE_SMALLINT>; break; } case PrimitiveType::TYPE_INT: { - _get_value_fn = get_int_value; + _get_value_fn = get_normal_value<TYPE_INT>; break; } case PrimitiveType::TYPE_BIGINT: { - _get_value_fn = get_bigint_value; + _get_value_fn = get_normal_value<TYPE_BIGINT>; break; } case PrimitiveType::TYPE_LARGEINT: { - _get_value_fn = get_largeint_value; - break; - } - case PrimitiveType::TYPE_FLOAT: { - _get_value_fn = get_float_value; - break; - } - case PrimitiveType::TYPE_DOUBLE: { - _get_value_fn = get_double_value; + _get_value_fn = get_normal_value<TYPE_LARGEINT>; break; } case PrimitiveType::TYPE_CHAR: case PrimitiveType::TYPE_VARCHAR: case PrimitiveType::TYPE_STRING: { - _get_value_fn = get_string_value; + _get_value_fn = [](const Field& field) { return field.get<String>(); }; break; } case PrimitiveType::TYPE_DATEV2: { - _get_value_fn = get_datev2_value; + _get_value_fn = get_normal_value<TYPE_DATEV2>; break; } case PrimitiveType::TYPE_DATETIMEV2: { - _get_value_fn = get_datetimev2_value; + _get_value_fn = get_normal_value<TYPE_DATETIMEV2>; break; } case PrimitiveType::TYPE_DATE: { @@ -98,11 +95,11 @@ Status RuntimePredicate::init(const PrimitiveType type, const bool nulls_first) break; } case PrimitiveType::TYPE_DECIMAL32: { - _get_value_fn = get_decimal32_value; + _get_value_fn = get_decimal_value<TYPE_DECIMAL32>; break; } case PrimitiveType::TYPE_DECIMAL64: { - _get_value_fn = get_decimal64_value; + _get_value_fn = get_decimal_value<TYPE_DECIMAL64>; break; } case PrimitiveType::TYPE_DECIMALV2: { @@ -110,19 +107,19 @@ Status RuntimePredicate::init(const PrimitiveType type, const bool nulls_first) break; } case PrimitiveType::TYPE_DECIMAL128I: { - _get_value_fn = get_decimal128_value; + _get_value_fn = get_decimal_value<TYPE_DECIMAL128I>; break; } case PrimitiveType::TYPE_DECIMAL256: { - _get_value_fn = get_decimal256_value; + _get_value_fn = get_decimal_value<TYPE_DECIMAL256>; break; } case PrimitiveType::TYPE_IPV4: { - _get_value_fn = get_ipv4_value; + _get_value_fn = get_normal_value<TYPE_IPV4>; break; } case PrimitiveType::TYPE_IPV6: { - _get_value_fn = get_ipv6_value; + _get_value_fn = get_normal_value<TYPE_IPV6>; break; } default: @@ -133,30 +130,20 @@ Status RuntimePredicate::init(const PrimitiveType type, const bool nulls_first) return Status::OK(); } -Status RuntimePredicate::update(const Field& value, const String& col_name, bool is_reverse) { +Status RuntimePredicate::update(const Field& value) { + std::unique_lock<std::shared_mutex> wlock(_rwlock); // skip null value - if (value.is_null()) { + if (value.is_null() || !_inited || !_tablet_schema) { return Status::OK(); } - if (!_inited) { - return Status::OK(); - } - - std::unique_lock<std::shared_mutex> wlock(_rwlock); - bool updated = false; if (UNLIKELY(_orderby_extrem.is_null())) { _orderby_extrem = value; updated = true; - } else if (is_reverse) { - if (value > _orderby_extrem) { - _orderby_extrem = value; - updated = true; - } } else { - if (value < _orderby_extrem) { + if ((_is_asc && value < _orderby_extrem) || (!_is_asc && value > _orderby_extrem)) { _orderby_extrem = value; updated = true; } @@ -166,38 +153,19 @@ Status RuntimePredicate::update(const Field& value, const String& col_name, bool return Status::OK(); } - // TODO defensive code - if (!_tablet_schema || !_tablet_schema->have_column(col_name)) { - return Status::OK(); - } - // update _predictate - int32_t col_unique_id = _tablet_schema->column(col_name).unique_id(); - const TabletColumn& column = _tablet_schema->column_by_uid(col_unique_id); - uint32_t index = _tablet_schema->field_index(col_unique_id); - auto val = _get_value_fn(_orderby_extrem); - std::unique_ptr<ColumnPredicate> pred {nullptr}; - if (is_reverse) { - // For DESC sort, create runtime predicate col_name >= min_top_value - // since values that < min_top_value are less than any value in current topn values - pred.reset(create_comparison_predicate<PredicateType::GE>(column, index, val, false, - _predicate_arena.get())); - } else { - // For ASC sort, create runtime predicate col_name <= max_top_value - // since values that > min_top_value are large than any value in current topn values - pred.reset(create_comparison_predicate<PredicateType::LE>(column, index, val, false, - _predicate_arena.get())); - } - + std::unique_ptr<ColumnPredicate> pred { + _pred_constructor(_tablet_schema->column(_col_name), _predicate->column_id(), + _get_value_fn(_orderby_extrem), false, &_predicate_arena)}; // For NULLS FIRST, wrap a AcceptNullPredicate to return true for NULL // since ORDER BY ASC/DESC should get NULL first but pred returns NULL // and NULL in where predicate will be treated as FALSE if (_nulls_first) { pred = AcceptNullPredicate::create_unique(pred.release()); } - _predictate.reset(pred.release()); + + ((SharedPredicate*)_predicate.get())->set_nested(pred.release()); return Status::OK(); } -} // namespace vectorized -} // namespace doris +} // namespace doris::vectorized diff --git a/be/src/runtime/runtime_predicate.h b/be/src/runtime/runtime_predicate.h index 297d90979ec..ec5b9612cbc 100644 --- a/be/src/runtime/runtime_predicate.h +++ b/be/src/runtime/runtime_predicate.h @@ -25,6 +25,7 @@ #include "common/status.h" #include "exec/olap_common.h" +#include "olap/shared_predicate.h" #include "olap/tablet_schema.h" #include "runtime/define_primitive_type.h" #include "runtime/primitive_type.h" @@ -43,7 +44,7 @@ class RuntimePredicate { public: RuntimePredicate() = default; - Status init(const PrimitiveType type, const bool nulls_first); + Status init(PrimitiveType type, bool nulls_first, bool is_asc, const std::string& col_name); bool inited() { std::unique_lock<std::shared_mutex> wlock(_rwlock); @@ -52,73 +53,47 @@ public: void set_tablet_schema(TabletSchemaSPtr tablet_schema) { std::unique_lock<std::shared_mutex> wlock(_rwlock); + if (_tablet_schema) { + return; + } _tablet_schema = tablet_schema; + _predicate = SharedPredicate::create_shared( + tablet_schema->field_index(_tablet_schema->column(_col_name).unique_id())); } - std::shared_ptr<ColumnPredicate> get_predictate() { + std::shared_ptr<ColumnPredicate> get_predicate() { std::shared_lock<std::shared_mutex> rlock(_rwlock); - return _predictate; + return _predicate; } - Status update(const Field& value, const String& col_name, bool is_reverse); + Status update(const Field& value); private: mutable std::shared_mutex _rwlock; Field _orderby_extrem {Field::Types::Null}; - std::shared_ptr<ColumnPredicate> _predictate; + std::shared_ptr<ColumnPredicate> _predicate; TabletSchemaSPtr _tablet_schema = nullptr; - std::unique_ptr<Arena> _predicate_arena; + Arena _predicate_arena; std::function<std::string(const Field&)> _get_value_fn; bool _nulls_first = true; + bool _is_asc; + std::function<ColumnPredicate*(const TabletColumn&, int, const std::string&, bool, + vectorized::Arena*)> + _pred_constructor; bool _inited = false; + std::string _col_name; - static std::string get_bool_value(const Field& field) { - using ValueType = typename PrimitiveTypeTraits<TYPE_BOOLEAN>::CppType; - return cast_to_string<TYPE_BOOLEAN, ValueType>(field.get<ValueType>(), 0); + template <PrimitiveType type> + static std::string get_normal_value(const Field& field) { + using ValueType = typename PrimitiveTypeTraits<type>::CppType; + return cast_to_string<type, ValueType>(field.get<ValueType>(), 0); } - static std::string get_tinyint_value(const Field& field) { - using ValueType = typename PrimitiveTypeTraits<TYPE_TINYINT>::CppType; - return cast_to_string<TYPE_TINYINT, ValueType>(field.get<ValueType>(), 0); - } - - static std::string get_smallint_value(const Field& field) { - using ValueType = typename PrimitiveTypeTraits<TYPE_SMALLINT>::CppType; - return cast_to_string<TYPE_SMALLINT, ValueType>(field.get<ValueType>(), 0); - } - - static std::string get_int_value(const Field& field) { - using ValueType = typename PrimitiveTypeTraits<TYPE_INT>::CppType; - return cast_to_string<TYPE_INT, ValueType>(field.get<ValueType>(), 0); - } - - static std::string get_bigint_value(const Field& field) { - using ValueType = typename PrimitiveTypeTraits<TYPE_BIGINT>::CppType; - return cast_to_string<TYPE_BIGINT, ValueType>(field.get<ValueType>(), 0); - } - - static std::string get_largeint_value(const Field& field) { - using ValueType = typename PrimitiveTypeTraits<TYPE_LARGEINT>::CppType; - return cast_to_string<TYPE_LARGEINT, ValueType>(field.get<ValueType>(), 0); - } - - static std::string get_float_value(const Field& field) { - using ValueType = typename PrimitiveTypeTraits<TYPE_FLOAT>::CppType; - return cast_to_string<TYPE_FLOAT, ValueType>(field.get<ValueType>(), 0); - } - - static std::string get_double_value(const Field& field) { - using ValueType = typename PrimitiveTypeTraits<TYPE_DOUBLE>::CppType; - return cast_to_string<TYPE_DOUBLE, ValueType>(field.get<ValueType>(), 0); - } - - static std::string get_string_value(const Field& field) { return field.get<String>(); } - static std::string get_date_value(const Field& field) { using ValueType = typename PrimitiveTypeTraits<TYPE_DATE>::CppType; ValueType value; Int64 v = field.get<Int64>(); - VecDateTimeValue* p = (VecDateTimeValue*)&v; + auto* p = (VecDateTimeValue*)&v; value.from_olap_date(p->to_olap_date()); value.cast_to_date(); return cast_to_string<TYPE_DATE, ValueType>(value, 0); @@ -128,24 +103,12 @@ private: using ValueType = typename PrimitiveTypeTraits<TYPE_DATETIME>::CppType; ValueType value; Int64 v = field.get<Int64>(); - VecDateTimeValue* p = (VecDateTimeValue*)&v; + auto* p = (VecDateTimeValue*)&v; value.from_olap_datetime(p->to_olap_datetime()); value.to_datetime(); return cast_to_string<TYPE_DATETIME, ValueType>(value, 0); } - static std::string get_datev2_value(const Field& field) { - using ValueType = typename PrimitiveTypeTraits<TYPE_DATEV2>::CppType; - return cast_to_string<TYPE_DATEV2, ValueType>( - binary_cast<UInt32, ValueType>(field.get<UInt32>()), 0); - } - - static std::string get_datetimev2_value(const Field& field) { - using ValueType = typename PrimitiveTypeTraits<TYPE_DATETIMEV2>::CppType; - return cast_to_string<TYPE_DATETIMEV2, ValueType>( - binary_cast<UInt64, ValueType>(field.get<UInt64>()), 0); - } - static std::string get_decimalv2_value(const Field& field) { // can NOT use PrimitiveTypeTraits<TYPE_DECIMALV2>::CppType since // it is DecimalV2Value and Decimal128V2 can not convert to it implicitly @@ -156,38 +119,11 @@ private: return cast_to_string<TYPE_DECIMAL128I, ValueType>(v.get_value(), v.get_scale()); } - static std::string get_decimal32_value(const Field& field) { - using ValueType = typename PrimitiveTypeTraits<TYPE_DECIMAL32>::CppType; - auto v = field.get<DecimalField<Decimal32>>(); - return cast_to_string<TYPE_DECIMAL32, ValueType>(v.get_value(), v.get_scale()); - } - - static std::string get_decimal64_value(const Field& field) { - using ValueType = typename PrimitiveTypeTraits<TYPE_DECIMAL64>::CppType; - auto v = field.get<DecimalField<Decimal64>>(); - return cast_to_string<TYPE_DECIMAL64, ValueType>(v.get_value(), v.get_scale()); - } - - static std::string get_decimal128_value(const Field& field) { - using ValueType = typename PrimitiveTypeTraits<TYPE_DECIMAL128I>::CppType; - auto v = field.get<DecimalField<Decimal128V3>>(); - return cast_to_string<TYPE_DECIMAL128I, ValueType>(v.get_value(), v.get_scale()); - } - - static std::string get_decimal256_value(const Field& field) { - using ValueType = typename PrimitiveTypeTraits<TYPE_DECIMAL256>::CppType; - auto v = field.get<DecimalField<Decimal256>>(); - return cast_to_string<TYPE_DECIMAL256, ValueType>(v.get_value(), v.get_scale()); - } - - static std::string get_ipv4_value(const Field& field) { - using ValueType = typename PrimitiveTypeTraits<TYPE_IPV4>::CppType; - return cast_to_string<TYPE_IPV4, ValueType>(field.get<ValueType>(), 0); - } - - static std::string get_ipv6_value(const Field& field) { - using ValueType = typename PrimitiveTypeTraits<TYPE_IPV6>::CppType; - return cast_to_string<TYPE_IPV6, ValueType>(field.get<ValueType>(), 0); + template <PrimitiveType type> + static std::string get_decimal_value(const Field& field) { + using ValueType = typename PrimitiveTypeTraits<type>::CppType; + auto v = field.get<DecimalField<ValueType>>(); + return cast_to_string<type, ValueType>(v.get_value(), v.get_scale()); } }; diff --git a/be/src/vec/exec/vsort_node.cpp b/be/src/vec/exec/vsort_node.cpp index e313e3f74ac..4206d9b6a2f 100644 --- a/be/src/vec/exec/vsort_node.cpp +++ b/be/src/vec/exec/vsort_node.cpp @@ -82,18 +82,19 @@ Status VSortNode::init(const TPlanNode& tnode, RuntimeState* state) { // init runtime predicate _use_topn_opt = tnode.sort_node.use_topn_opt; if (_use_topn_opt) { - auto query_ctx = state->get_query_ctx(); + auto* query_ctx = state->get_query_ctx(); auto first_sort_expr_node = tnode.sort_node.sort_info.ordering_exprs[0].nodes[0]; if (first_sort_expr_node.node_type == TExprNodeType::SLOT_REF) { auto first_sort_slot = first_sort_expr_node.slot_ref; - for (auto tuple_desc : this->intermediate_row_desc().tuple_descriptors()) { + for (auto* tuple_desc : this->intermediate_row_desc().tuple_descriptors()) { if (tuple_desc->id() != first_sort_slot.tuple_id) { continue; } - for (auto slot : tuple_desc->slots()) { + for (auto* slot : tuple_desc->slots()) { if (slot->id() == first_sort_slot.slot_id) { - RETURN_IF_ERROR(query_ctx->get_runtime_predicate().init(slot->type().type, - _nulls_first[0])); + RETURN_IF_ERROR(query_ctx->get_runtime_predicate().init( + slot->type().type, _nulls_first[0], _is_asc_order[0], + slot->col_name())); break; } } @@ -144,15 +145,9 @@ Status VSortNode::sink(RuntimeState* state, vectorized::Block* input_block, bool if (_use_topn_opt) { Field new_top = _sorter->get_top_value(); if (!new_top.is_null() && new_top != old_top) { - auto& sort_description = _sorter->get_sort_description(); - auto col = input_block->get_by_position(sort_description[0].column_number); - if (!col.name.empty()) { - bool is_reverse = sort_description[0].direction < 0; - auto* query_ctx = state->get_query_ctx(); - RETURN_IF_ERROR(query_ctx->get_runtime_predicate().update(new_top, col.name, - is_reverse)); - old_top = std::move(new_top); - } + auto* query_ctx = state->get_query_ctx(); + RETURN_IF_ERROR(query_ctx->get_runtime_predicate().update(new_top)); + old_top = std::move(new_top); } } if (!_reuse_mem) { diff --git a/be/src/vec/olap/vcollect_iterator.cpp b/be/src/vec/olap/vcollect_iterator.cpp index b13f5333d5e..4eca53146ed 100644 --- a/be/src/vec/olap/vcollect_iterator.cpp +++ b/be/src/vec/olap/vcollect_iterator.cpp @@ -305,8 +305,6 @@ Status VCollectIterator::_topn_next(Block* block) { } } - auto col_name = block->get_names()[first_sort_column_idx]; - // filter block RETURN_IF_ERROR(VExprContext::filter_block( _reader->_reader_context.filter_block_conjuncts, block, block->columns())); @@ -417,14 +415,13 @@ Status VCollectIterator::_topn_next(Block* block) { sorted_row_pos.size() >= _topn_limit) { // get field value from column size_t last_sorted_row = *sorted_row_pos.rbegin(); - auto col_ptr = mutable_block.get_column_by_position(first_sort_column_idx).get(); + auto* col_ptr = mutable_block.get_column_by_position(first_sort_column_idx).get(); Field new_top; col_ptr->get(last_sorted_row, new_top); // update orderby_extrems in query global context - auto query_ctx = _reader->_reader_context.runtime_state->get_query_ctx(); - RETURN_IF_ERROR( - query_ctx->get_runtime_predicate().update(new_top, col_name, _is_reverse)); + auto* query_ctx = _reader->_reader_context.runtime_state->get_query_ctx(); + RETURN_IF_ERROR(query_ctx->get_runtime_predicate().update(new_top)); } } // end of while (read_rows < _topn_limit && !eof) VLOG_DEBUG << "topn debug rowset " << i << " read_rows=" << read_rows << " eof=" << eof --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
