This is an automated email from the ASF dual-hosted git repository.
gabriellee pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new b964ab76b33 [refactor](shuffle) Simplify hash partitioning strategy
(#25596)
b964ab76b33 is described below
commit b964ab76b338f12ebd8d6536b4eab6f1add471ac
Author: Gabriel <[email protected]>
AuthorDate: Thu Oct 19 19:28:22 2023 +0800
[refactor](shuffle) Simplify hash partitioning strategy (#25596)
---
be/src/pipeline/exec/exchange_sink_operator.cpp | 123 ++++++++----------------
be/src/pipeline/exec/exchange_sink_operator.h | 24 ++---
be/src/vec/columns/column.h | 5 +-
be/src/vec/columns/column_array.cpp | 15 +--
be/src/vec/columns/column_array.h | 5 +-
be/src/vec/columns/column_const.cpp | 9 +-
be/src/vec/columns/column_const.h | 7 +-
be/src/vec/columns/column_decimal.cpp | 7 +-
be/src/vec/columns/column_decimal.h | 9 +-
be/src/vec/columns/column_map.cpp | 8 +-
be/src/vec/columns/column_map.h | 5 +-
be/src/vec/columns/column_nullable.cpp | 12 +--
be/src/vec/columns/column_nullable.h | 5 +-
be/src/vec/columns/column_string.cpp | 5 +-
be/src/vec/columns/column_string.h | 5 +-
be/src/vec/columns/column_struct.cpp | 7 +-
be/src/vec/columns/column_struct.h | 5 +-
be/src/vec/columns/column_vector.cpp | 5 +-
be/src/vec/columns/column_vector.h | 7 +-
be/src/vec/common/hash_table/hash_map_context.h | 8 +-
be/src/vec/exec/scan/pip_scanner_context.h | 10 +-
be/src/vec/runtime/partitioner.cpp | 70 ++++++++++++++
be/src/vec/runtime/partitioner.h | 105 ++++++++++++++++++++
be/src/vec/sink/vdata_stream_sender.cpp | 89 +++++------------
be/src/vec/sink/vdata_stream_sender.h | 25 ++---
be/test/vec/columns/column_hash_func_test.cpp | 26 ++---
26 files changed, 343 insertions(+), 258 deletions(-)
diff --git a/be/src/pipeline/exec/exchange_sink_operator.cpp
b/be/src/pipeline/exec/exchange_sink_operator.cpp
index c778e397b83..89679ceff89 100644
--- a/be/src/pipeline/exec/exchange_sink_operator.cpp
+++ b/be/src/pipeline/exec/exchange_sink_operator.cpp
@@ -161,11 +161,6 @@ Status ExchangeSinkLocalState::init(RuntimeState* state,
LocalSinkStateInfo& inf
std::random_device rd;
std::mt19937 g(rd());
shuffle(channels.begin(), channels.end(), g);
- } else {
- partition_expr_ctxs.resize(p._partition_expr_ctxs.size());
- for (size_t i = 0; i < p._partition_expr_ctxs.size(); i++) {
- RETURN_IF_ERROR(p._partition_expr_ctxs[i]->clone(state,
partition_expr_ctxs[i]));
- }
}
only_local_exchange = local_size == channels.size();
@@ -211,6 +206,28 @@ Status ExchangeSinkLocalState::init(RuntimeState* state,
LocalSinkStateInfo& inf
}
_exchange_sink_dependency->add_child(deps_for_channels);
}
+ if (p._part_type == TPartitionType::HASH_PARTITIONED) {
+ _partition_count = channels.size();
+ _partitioner.reset(new vectorized::HashPartitioner(channels.size()));
+ RETURN_IF_ERROR(_partitioner->init(p._texprs));
+ RETURN_IF_ERROR(_partitioner->prepare(state, p._row_desc));
+ } else if (p._part_type ==
TPartitionType::BUCKET_SHFFULE_HASH_PARTITIONED) {
+ _partition_count = channel_shared_ptrs.size();
+ _partitioner.reset(new
vectorized::BucketHashPartitioner(channel_shared_ptrs.size()));
+ RETURN_IF_ERROR(_partitioner->init(p._texprs));
+ RETURN_IF_ERROR(_partitioner->prepare(state, p._row_desc));
+ }
+
+ return Status::OK();
+}
+
+Status ExchangeSinkLocalState::open(RuntimeState* state) {
+ RETURN_IF_ERROR(PipelineXSinkLocalState<>::open(state));
+ auto& p = _parent->cast<ExchangeSinkOperatorX>();
+ if (p._part_type == TPartitionType::HASH_PARTITIONED ||
+ p._part_type == TPartitionType::BUCKET_SHFFULE_HASH_PARTITIONED) {
+ RETURN_IF_ERROR(_partitioner->open(state));
+ }
return Status::OK();
}
@@ -223,6 +240,7 @@ ExchangeSinkOperatorX::ExchangeSinkOperatorX(
const std::vector<TPlanFragmentDestination>& destinations,
bool send_query_statistics_with_every_batch)
: DataSinkOperatorX(sink.dest_node_id),
+ _texprs(sink.output_partition.partition_exprs),
_row_desc(row_desc),
_part_type(sink.output_partition.type),
_dests(destinations),
@@ -240,37 +258,20 @@ ExchangeSinkOperatorX::ExchangeSinkOperatorX(
Status ExchangeSinkOperatorX::init(const TDataSink& tsink) {
RETURN_IF_ERROR(DataSinkOperatorX::init(tsink));
- const TDataStreamSink& t_stream_sink = tsink.stream_sink;
- if (_part_type == TPartitionType::HASH_PARTITIONED ||
- _part_type == TPartitionType::BUCKET_SHFFULE_HASH_PARTITIONED) {
- RETURN_IF_ERROR(vectorized::VExpr::create_expr_trees(
- t_stream_sink.output_partition.partition_exprs,
_partition_expr_ctxs));
- } else if (_part_type == TPartitionType::RANGE_PARTITIONED) {
+ if (_part_type == TPartitionType::RANGE_PARTITIONED) {
return Status::InternalError("TPartitionType::RANGE_PARTITIONED should
not be used");
- } else {
- // UNPARTITIONED
}
return Status::OK();
}
Status ExchangeSinkOperatorX::prepare(RuntimeState* state) {
_state = state;
-
_mem_tracker = std::make_unique<MemTracker>("ExchangeSinkOperatorX:");
- SCOPED_CONSUME_MEM_TRACKER(_mem_tracker.get());
-
- if (!(_part_type == TPartitionType::UNPARTITIONED) && !(_part_type ==
TPartitionType::RANDOM)) {
- RETURN_IF_ERROR(vectorized::VExpr::prepare(_partition_expr_ctxs,
state, _row_desc));
- }
return Status::OK();
}
Status ExchangeSinkOperatorX::open(RuntimeState* state) {
DCHECK(state != nullptr);
-
- SCOPED_CONSUME_MEM_TRACKER(_mem_tracker.get());
- RETURN_IF_ERROR(vectorized::VExpr::open(_partition_expr_ctxs, state));
-
_compression_type = state->fragement_transmission_compression_type();
return Status::OK();
}
@@ -378,68 +379,20 @@ Status ExchangeSinkOperatorX::sink(RuntimeState* state,
vectorized::Block* block
(local_state.current_channel_idx + 1) %
local_state.channels.size();
} else if (_part_type == TPartitionType::HASH_PARTITIONED ||
_part_type == TPartitionType::BUCKET_SHFFULE_HASH_PARTITIONED) {
- // will only copy schema
- // we don't want send temp columns
- auto column_to_keep = block->columns();
-
- int result_size = _partition_expr_ctxs.size();
- int result[result_size];
-
- // vectorized calculate hash
- int rows = block->rows();
- auto element_size = _part_type == TPartitionType::HASH_PARTITIONED
- ? local_state.channels.size()
- : local_state.channel_shared_ptrs.size();
- std::vector<uint64_t> hash_vals(rows);
- auto* __restrict hashes = hash_vals.data();
-
- if (rows > 0) {
- {
- SCOPED_CONSUME_MEM_TRACKER(_mem_tracker.get());
- RETURN_IF_ERROR(get_partition_column_result(block, result));
- }
- // TODO: after we support new shuffle hash method, should simple
the code
- if (_part_type == TPartitionType::HASH_PARTITIONED) {
- SCOPED_TIMER(local_state._split_block_hash_compute_timer);
- // result[j] means column index, i means rows index, here to
calculate the xxhash value
- for (int j = 0; j < result_size; ++j) {
- // complex type most not implement get_data_at() method
which column_const will call
- unpack_if_const(block->get_by_position(result[j]).column)
- .first->update_hashes_with_value(hashes);
- }
-
- for (int i = 0; i < rows; i++) {
- hashes[i] = hashes[i] % element_size;
- }
-
- {
- SCOPED_CONSUME_MEM_TRACKER(_mem_tracker.get());
- vectorized::Block::erase_useless_column(block,
column_to_keep);
- }
- } else {
- for (int j = 0; j < result_size; ++j) {
- // complex type most not implement get_data_at() method
which column_const will call
- unpack_if_const(block->get_by_position(result[j]).column)
- .first->update_crcs_with_value(
- hash_vals,
_partition_expr_ctxs[j]->root()->type().type);
- }
- for (int i = 0; i < rows; i++) {
- hashes[i] = hashes[i] % element_size;
- }
-
- {
- SCOPED_CONSUME_MEM_TRACKER(_mem_tracker.get());
- vectorized::Block::erase_useless_column(block,
column_to_keep);
- }
- }
- }
+ auto rows = block->rows();
+ SCOPED_TIMER(local_state._split_block_hash_compute_timer);
+ RETURN_IF_ERROR(
+ local_state._partitioner->do_partitioning(state, block,
_mem_tracker.get()));
if (_part_type == TPartitionType::HASH_PARTITIONED) {
- RETURN_IF_ERROR(channel_add_rows(state, local_state.channels,
element_size, hashes,
+ RETURN_IF_ERROR(channel_add_rows(state, local_state.channels,
+ local_state._partition_count,
+
(uint64_t*)local_state._partitioner->get_hash_values(),
rows, block, source_state ==
SourceState::FINISHED));
} else {
- RETURN_IF_ERROR(channel_add_rows(state,
local_state.channel_shared_ptrs, element_size,
- hashes, rows, block,
- source_state ==
SourceState::FINISHED));
+ RETURN_IF_ERROR(channel_add_rows(state,
local_state.channel_shared_ptrs,
+ local_state._partition_count,
+
(uint32_t*)local_state._partitioner->get_hash_values(),
+ rows, block, source_state ==
SourceState::FINISHED));
}
} else {
// Range partition
@@ -487,11 +440,11 @@ Status ExchangeSinkLocalState::get_next_available_buffer(
return Status::InternalError("No broadcast buffer left!");
}
-template <typename Channels>
+template <typename Channels, typename HashValueType>
Status ExchangeSinkOperatorX::channel_add_rows(RuntimeState* state, Channels&
channels,
int num_channels,
- const uint64_t* __restrict
channel_ids, int rows,
- vectorized::Block* block, bool
eos) {
+ const HashValueType* __restrict
channel_ids,
+ int rows, vectorized::Block*
block, bool eos) {
std::vector<int> channel2rows[num_channels];
for (int i = 0; i < rows; i++) {
diff --git a/be/src/pipeline/exec/exchange_sink_operator.h
b/be/src/pipeline/exec/exchange_sink_operator.h
index f76f24479e9..9575df01a2a 100644
--- a/be/src/pipeline/exec/exchange_sink_operator.h
+++ b/be/src/pipeline/exec/exchange_sink_operator.h
@@ -133,11 +133,11 @@ public:
_serializer(this) {}
Status init(RuntimeState* state, LocalSinkStateInfo& info) override;
+ Status open(RuntimeState* state) override;
Status close(RuntimeState* state, Status exec_status) override;
Status serialize_block(vectorized::Block* src, PBlock* dest, int
num_receivers = 1);
void
register_channels(pipeline::ExchangeSinkBuffer<ExchangeSinkLocalState>* buffer);
- bool channel_all_can_write();
Status get_next_available_buffer(vectorized::BroadcastPBlockHolder**
holder);
RuntimeProfile::Counter* brpc_wait_timer() { return _brpc_wait_timer; }
@@ -163,8 +163,6 @@ public:
segment_v2::CompressionTypePB& compression_type();
- vectorized::VExprContextSPtrs partition_expr_ctxs;
-
std::vector<vectorized::PipChannel<ExchangeSinkLocalState>*> channels;
std::vector<std::shared_ptr<vectorized::PipChannel<ExchangeSinkLocalState>>>
channel_shared_ptrs;
@@ -212,6 +210,8 @@ private:
std::shared_ptr<AndDependency> _exchange_sink_dependency = nullptr;
std::shared_ptr<BroadcastDependency> _broadcast_dependency = nullptr;
std::vector<std::shared_ptr<ChannelDependency>> _channels_dependency;
+ std::unique_ptr<vectorized::PartitionerBase> _partitioner;
+ int _partition_count;
};
class ExchangeSinkOperatorX final : public
DataSinkOperatorX<ExchangeSinkLocalState> {
@@ -243,20 +243,14 @@ private:
template <typename ChannelPtrType>
void _handle_eof_channel(RuntimeState* state, ChannelPtrType channel,
Status st);
- Status get_partition_column_result(vectorized::Block* block, int* result)
const {
- int counter = 0;
- for (auto ctx : _partition_expr_ctxs) {
- RETURN_IF_ERROR(ctx->execute(block, &result[counter++]));
- }
- return Status::OK();
- }
-
- template <typename Channels>
+ template <typename Channels, typename HashValueType>
Status channel_add_rows(RuntimeState* state, Channels& channels, int
num_channels,
- const uint64_t* channel_ids, int rows,
vectorized::Block* block,
+ const HashValueType* channel_ids, int rows,
vectorized::Block* block,
bool eos);
RuntimeState* _state = nullptr;
+ const std::vector<TExpr>& _texprs;
+
const RowDescriptor& _row_desc;
TPartitionType::type _part_type;
@@ -265,10 +259,6 @@ private:
// one while the other one is still being sent
PBlock _pb_block1;
PBlock _pb_block2;
- PBlock* _cur_pb_block = nullptr;
-
- // compute per-row partition values
- vectorized::VExprContextSPtrs _partition_expr_ctxs;
const std::vector<TPlanFragmentDestination> _dests;
const bool _send_query_statistics_with_every_batch;
diff --git a/be/src/vec/columns/column.h b/be/src/vec/columns/column.h
index b21c5d83862..68afe3947d9 100644
--- a/be/src/vec/columns/column.h
+++ b/be/src/vec/columns/column.h
@@ -388,13 +388,14 @@ public:
/// Update state of crc32 hash function with value of n elements to avoid
the virtual function call
/// null_data to mark whether need to do hash compute, null_data == nullptr
/// means all element need to do hash function, else only *null_data != 0
need to do hash func
- virtual void update_crcs_with_value(std::vector<uint64_t>& hash,
PrimitiveType type,
+ virtual void update_crcs_with_value(uint32_t* __restrict hash,
PrimitiveType type,
+ uint32_t rows, uint32_t offset = 0,
const uint8_t* __restrict null_data =
nullptr) const {
LOG(FATAL) << get_name() << "update_crcs_with_value not supported";
}
// use range for one hash value to avoid virtual function call in loop
- virtual void update_crc_with_value(size_t start, size_t end, uint64_t&
hash,
+ virtual void update_crc_with_value(size_t start, size_t end, uint32_t&
hash,
const uint8_t* __restrict null_data)
const {
LOG(FATAL) << get_name() << " update_crc_with_value not supported";
}
diff --git a/be/src/vec/columns/column_array.cpp
b/be/src/vec/columns/column_array.cpp
index 53144d883f8..47949580bea 100644
--- a/be/src/vec/columns/column_array.cpp
+++ b/be/src/vec/columns/column_array.cpp
@@ -288,8 +288,8 @@ void ColumnArray::update_xxHash_with_value(size_t start,
size_t end, uint64_t& h
hash = HashUtil::xxHash64WithSeed(reinterpret_cast<const
char*>(&elem_size),
sizeof(elem_size), hash);
} else {
- get_data().update_crc_with_value(offsets_column[i - 1],
offsets_column[i], hash,
- nullptr);
+ get_data().update_xxHash_with_value(offsets_column[i - 1],
offsets_column[i],
+ hash, nullptr);
}
}
}
@@ -300,15 +300,15 @@ void ColumnArray::update_xxHash_with_value(size_t start,
size_t end, uint64_t& h
hash = HashUtil::xxHash64WithSeed(reinterpret_cast<const
char*>(&elem_size),
sizeof(elem_size), hash);
} else {
- get_data().update_crc_with_value(offsets_column[i - 1],
offsets_column[i], hash,
- nullptr);
+ get_data().update_xxHash_with_value(offsets_column[i - 1],
offsets_column[i], hash,
+ nullptr);
}
}
}
}
// for every array row calculate crcHash
-void ColumnArray::update_crc_with_value(size_t start, size_t end, uint64_t&
hash,
+void ColumnArray::update_crc_with_value(size_t start, size_t end, uint32_t&
hash,
const uint8_t* __restrict null_data)
const {
auto& offsets_column = get_offsets();
if (null_data) {
@@ -354,9 +354,10 @@ void ColumnArray::update_hashes_with_value(uint64_t*
__restrict hashes,
}
}
-void ColumnArray::update_crcs_with_value(std::vector<uint64_t>& hash,
PrimitiveType type,
+void ColumnArray::update_crcs_with_value(uint32_t* __restrict hash,
PrimitiveType type,
+ uint32_t rows, uint32_t offset,
const uint8_t* __restrict null_data)
const {
- auto s = hash.size();
+ auto s = rows;
DCHECK(s == size());
if (null_data) {
diff --git a/be/src/vec/columns/column_array.h
b/be/src/vec/columns/column_array.h
index 668abd1ef62..44391ae8c74 100644
--- a/be/src/vec/columns/column_array.h
+++ b/be/src/vec/columns/column_array.h
@@ -140,7 +140,7 @@ public:
void update_hash_with_value(size_t n, SipHash& hash) const override;
void update_xxHash_with_value(size_t start, size_t end, uint64_t& hash,
const uint8_t* __restrict null_data) const
override;
- void update_crc_with_value(size_t start, size_t end, uint64_t& hash,
+ void update_crc_with_value(size_t start, size_t end, uint32_t& hash,
const uint8_t* __restrict null_data) const
override;
void update_hashes_with_value(std::vector<SipHash>& hashes,
@@ -149,7 +149,8 @@ public:
void update_hashes_with_value(uint64_t* __restrict hashes,
const uint8_t* __restrict null_data =
nullptr) const override;
- void update_crcs_with_value(std::vector<uint64_t>& hash, PrimitiveType
type,
+ void update_crcs_with_value(uint32_t* __restrict hash, PrimitiveType type,
uint32_t rows,
+ uint32_t offset = 0,
const uint8_t* __restrict null_data = nullptr)
const override;
void insert_range_from(const IColumn& src, size_t start, size_t length)
override;
diff --git a/be/src/vec/columns/column_const.cpp
b/be/src/vec/columns/column_const.cpp
index d8dbae40f43..3fb851b2a9c 100644
--- a/be/src/vec/columns/column_const.cpp
+++ b/be/src/vec/columns/column_const.cpp
@@ -115,17 +115,18 @@ void
ColumnConst::update_hashes_with_value(std::vector<SipHash>& hashes,
}
}
-void ColumnConst::update_crcs_with_value(std::vector<uint64_t>& hashes,
doris::PrimitiveType type,
+void ColumnConst::update_crcs_with_value(uint32_t* __restrict hashes,
doris::PrimitiveType type,
+ uint32_t rows, uint32_t offset,
const uint8_t* __restrict null_data)
const {
DCHECK(null_data == nullptr);
- DCHECK(hashes.size() == size());
+ DCHECK(rows == size());
auto real_data = data->get_data_at(0);
if (real_data.data == nullptr) {
- for (int i = 0; i < hashes.size(); ++i) {
+ for (int i = 0; i < rows; ++i) {
hashes[i] = HashUtil::zlib_crc_hash_null(hashes[i]);
}
} else {
- for (int i = 0; i < hashes.size(); ++i) {
+ for (int i = 0; i < rows; ++i) {
hashes[i] = RawValue::zlib_crc32(real_data.data, real_data.size,
type, hashes[i]);
}
}
diff --git a/be/src/vec/columns/column_const.h
b/be/src/vec/columns/column_const.h
index 016a18f216f..307066a7ae9 100644
--- a/be/src/vec/columns/column_const.h
+++ b/be/src/vec/columns/column_const.h
@@ -161,7 +161,7 @@ public:
}
}
- void update_crc_with_value(size_t start, size_t end, uint64_t& hash,
+ void update_crc_with_value(size_t start, size_t end, uint32_t& hash,
const uint8_t* __restrict null_data) const
override {
get_data_column_ptr()->update_crc_with_value(start, end, hash,
nullptr);
}
@@ -179,8 +179,9 @@ public:
const uint8_t* __restrict null_data) const
override;
// (TODO.Amory) here may not use column_const update hash, and
PrimitiveType is not used.
- void update_crcs_with_value(std::vector<uint64_t>& hashes, PrimitiveType
type,
- const uint8_t* __restrict null_data) const
override;
+ void update_crcs_with_value(uint32_t* __restrict hashes, PrimitiveType
type, uint32_t rows,
+ uint32_t offset = 0,
+ const uint8_t* __restrict null_data = nullptr)
const override;
void update_hashes_with_value(uint64_t* __restrict hashes,
const uint8_t* __restrict null_data) const
override;
diff --git a/be/src/vec/columns/column_decimal.cpp
b/be/src/vec/columns/column_decimal.cpp
index edc8a5777ff..b4574fd7b1a 100644
--- a/be/src/vec/columns/column_decimal.cpp
+++ b/be/src/vec/columns/column_decimal.cpp
@@ -137,7 +137,7 @@ void
ColumnDecimal<T>::update_hashes_with_value(std::vector<SipHash>& hashes,
}
template <typename T>
-void ColumnDecimal<T>::update_crc_with_value(size_t start, size_t end,
uint64_t& hash,
+void ColumnDecimal<T>::update_crc_with_value(size_t start, size_t end,
uint32_t& hash,
const uint8_t* __restrict
null_data) const {
if (null_data == nullptr) {
for (size_t i = start; i < end; i++) {
@@ -161,9 +161,10 @@ void ColumnDecimal<T>::update_crc_with_value(size_t start,
size_t end, uint64_t&
}
template <typename T>
-void ColumnDecimal<T>::update_crcs_with_value(std::vector<uint64_t>& hashes,
PrimitiveType type,
+void ColumnDecimal<T>::update_crcs_with_value(uint32_t* __restrict hashes,
PrimitiveType type,
+ uint32_t rows, uint32_t offset,
const uint8_t* __restrict
null_data) const {
- auto s = hashes.size();
+ auto s = rows;
DCHECK(s == size());
if constexpr (!IsDecimalV2<T>) {
diff --git a/be/src/vec/columns/column_decimal.h
b/be/src/vec/columns/column_decimal.h
index 85ce339608d..dcd135d46b9 100644
--- a/be/src/vec/columns/column_decimal.h
+++ b/be/src/vec/columns/column_decimal.h
@@ -176,12 +176,13 @@ public:
const uint8_t* __restrict null_data) const
override;
void update_hashes_with_value(uint64_t* __restrict hashes,
const uint8_t* __restrict null_data) const
override;
- void update_crcs_with_value(std::vector<uint64_t>& hashes, PrimitiveType
type,
+ void update_crcs_with_value(uint32_t* __restrict hashes, PrimitiveType
type, uint32_t rows,
+ uint32_t offset,
const uint8_t* __restrict null_data) const
override;
void update_xxHash_with_value(size_t start, size_t end, uint64_t& hash,
const uint8_t* __restrict null_data) const
override;
- void update_crc_with_value(size_t start, size_t end, uint64_t& hash,
+ void update_crc_with_value(size_t start, size_t end, uint32_t& hash,
const uint8_t* __restrict null_data) const
override;
int compare_at(size_t n, size_t m, const IColumn& rhs_, int
nan_direction_hint) const override;
@@ -295,8 +296,8 @@ protected:
[this](size_t a, size_t b) { return data[a] <
data[b]; });
}
- void ALWAYS_INLINE decimalv2_do_crc(size_t i, uint64_t& hash) const {
- const DecimalV2Value& dec_val = (const DecimalV2Value&)data[i];
+ void ALWAYS_INLINE decimalv2_do_crc(size_t i, uint32_t& hash) const {
+ const auto& dec_val = (const DecimalV2Value&)data[i];
int64_t int_val = dec_val.int_value();
int32_t frac_val = dec_val.frac_value();
hash = HashUtil::zlib_crc_hash(&int_val, sizeof(int_val), hash);
diff --git a/be/src/vec/columns/column_map.cpp
b/be/src/vec/columns/column_map.cpp
index f7c456a19c1..e25cfd52dd8 100644
--- a/be/src/vec/columns/column_map.cpp
+++ b/be/src/vec/columns/column_map.cpp
@@ -282,7 +282,7 @@ void ColumnMap::update_xxHash_with_value(size_t start,
size_t end, uint64_t& has
}
}
-void ColumnMap::update_crc_with_value(size_t start, size_t end, uint64_t& hash,
+void ColumnMap::update_crc_with_value(size_t start, size_t end, uint32_t& hash,
const uint8_t* __restrict null_data)
const {
auto& offsets = get_offsets();
if (null_data) {
@@ -328,9 +328,9 @@ void ColumnMap::update_hashes_with_value(uint64_t* hashes,
const uint8_t* null_d
}
}
-void ColumnMap::update_crcs_with_value(std::vector<uint64_t>& hash,
PrimitiveType type,
- const uint8_t* __restrict null_data)
const {
- auto s = hash.size();
+void ColumnMap::update_crcs_with_value(uint32_t* __restrict hash,
PrimitiveType type, uint32_t rows,
+ uint32_t offset, const uint8_t*
__restrict null_data) const {
+ auto s = rows;
DCHECK(s == size());
if (null_data) {
diff --git a/be/src/vec/columns/column_map.h b/be/src/vec/columns/column_map.h
index 7464aa18946..7da2200fe2d 100644
--- a/be/src/vec/columns/column_map.h
+++ b/be/src/vec/columns/column_map.h
@@ -180,7 +180,7 @@ public:
void update_xxHash_with_value(size_t start, size_t end, uint64_t& hash,
const uint8_t* __restrict null_data) const
override;
- void update_crc_with_value(size_t start, size_t end, uint64_t& hash,
+ void update_crc_with_value(size_t start, size_t end, uint32_t& hash,
const uint8_t* __restrict null_data) const
override;
void update_hashes_with_value(std::vector<SipHash>& hashes,
@@ -189,7 +189,8 @@ public:
void update_hashes_with_value(uint64_t* __restrict hashes,
const uint8_t* __restrict null_data =
nullptr) const override;
- void update_crcs_with_value(std::vector<uint64_t>& hash, PrimitiveType
type,
+ void update_crcs_with_value(uint32_t* __restrict hash, PrimitiveType type,
uint32_t rows,
+ uint32_t offset = 0,
const uint8_t* __restrict null_data = nullptr)
const override;
/******************** keys and values ***************/
diff --git a/be/src/vec/columns/column_nullable.cpp
b/be/src/vec/columns/column_nullable.cpp
index 494a85eabe7..42b88ac7ae9 100644
--- a/be/src/vec/columns/column_nullable.cpp
+++ b/be/src/vec/columns/column_nullable.cpp
@@ -75,7 +75,7 @@ void ColumnNullable::update_xxHash_with_value(size_t start,
size_t end, uint64_t
}
}
-void ColumnNullable::update_crc_with_value(size_t start, size_t end, uint64_t&
hash,
+void ColumnNullable::update_crc_with_value(size_t start, size_t end, uint32_t&
hash,
const uint8_t* __restrict
null_data) const {
if (!has_null()) {
nested_column->update_crc_with_value(start, end, hash, nullptr);
@@ -118,23 +118,23 @@ void
ColumnNullable::update_hashes_with_value(std::vector<SipHash>& hashes,
}
}
-void ColumnNullable::update_crcs_with_value(std::vector<uint64_t>& hashes,
- doris::PrimitiveType type,
+void ColumnNullable::update_crcs_with_value(uint32_t* __restrict hashes,
doris::PrimitiveType type,
+ uint32_t rows, uint32_t offset,
const uint8_t* __restrict
null_data) const {
DCHECK(null_data == nullptr);
- auto s = hashes.size();
+ auto s = rows;
DCHECK(s == size());
const auto* __restrict real_null_data =
assert_cast<const ColumnUInt8&>(*null_map).get_data().data();
if (!has_null()) {
- nested_column->update_crcs_with_value(hashes, type, nullptr);
+ nested_column->update_crcs_with_value(hashes, type, rows, offset,
nullptr);
} else {
for (int i = 0; i < s; ++i) {
if (real_null_data[i] != 0) {
hashes[i] = HashUtil::zlib_crc_hash_null(hashes[i]);
}
}
- nested_column->update_crcs_with_value(hashes, type, real_null_data);
+ nested_column->update_crcs_with_value(hashes, type, rows, offset,
real_null_data);
}
}
diff --git a/be/src/vec/columns/column_nullable.h
b/be/src/vec/columns/column_nullable.h
index e26b5a8cc09..953c66e45bc 100644
--- a/be/src/vec/columns/column_nullable.h
+++ b/be/src/vec/columns/column_nullable.h
@@ -215,13 +215,14 @@ public:
void replicate(const uint32_t* counts, size_t target_size, IColumn&
column) const override;
void update_xxHash_with_value(size_t start, size_t end, uint64_t& hash,
const uint8_t* __restrict null_data) const
override;
- void update_crc_with_value(size_t start, size_t end, uint64_t& hash,
+ void update_crc_with_value(size_t start, size_t end, uint32_t& hash,
const uint8_t* __restrict null_data) const
override;
void update_hash_with_value(size_t n, SipHash& hash) const override;
void update_hashes_with_value(std::vector<SipHash>& hashes,
const uint8_t* __restrict null_data) const
override;
- void update_crcs_with_value(std::vector<uint64_t>& hash, PrimitiveType
type,
+ void update_crcs_with_value(uint32_t* __restrict hash, PrimitiveType type,
uint32_t rows,
+ uint32_t offset,
const uint8_t* __restrict null_data) const
override;
void update_hashes_with_value(uint64_t* __restrict hashes,
const uint8_t* __restrict null_data) const
override;
diff --git a/be/src/vec/columns/column_string.cpp
b/be/src/vec/columns/column_string.cpp
index 2664ea3bafc..5d5abd64349 100644
--- a/be/src/vec/columns/column_string.cpp
+++ b/be/src/vec/columns/column_string.cpp
@@ -161,9 +161,10 @@ void ColumnString::insert_indices_from(const IColumn& src,
const int* indices_be
}
}
-void ColumnString::update_crcs_with_value(std::vector<uint64_t>& hashes,
doris::PrimitiveType type,
+void ColumnString::update_crcs_with_value(uint32_t* __restrict hashes,
doris::PrimitiveType type,
+ uint32_t rows, uint32_t offset,
const uint8_t* __restrict null_data)
const {
- auto s = hashes.size();
+ auto s = rows;
DCHECK(s == size());
if (null_data == nullptr) {
diff --git a/be/src/vec/columns/column_string.h
b/be/src/vec/columns/column_string.h
index 0b7ebe08b21..ae2bb9d25f9 100644
--- a/be/src/vec/columns/column_string.h
+++ b/be/src/vec/columns/column_string.h
@@ -413,7 +413,7 @@ public:
}
}
- void update_crc_with_value(size_t start, size_t end, uint64_t& hash,
+ void update_crc_with_value(size_t start, size_t end, uint32_t& hash,
const uint8_t* __restrict null_data) const
override {
if (null_data) {
for (size_t i = start; i < end; ++i) {
@@ -444,7 +444,8 @@ public:
SIP_HASHES_FUNCTION_COLUMN_IMPL();
}
- void update_crcs_with_value(std::vector<uint64_t>& hashes, PrimitiveType
type,
+ void update_crcs_with_value(uint32_t* __restrict hashes, PrimitiveType
type, uint32_t rows,
+ uint32_t offset,
const uint8_t* __restrict null_data) const
override;
void update_hashes_with_value(uint64_t* __restrict hashes,
diff --git a/be/src/vec/columns/column_struct.cpp
b/be/src/vec/columns/column_struct.cpp
index c5fbc4b4bf4..832bc32189c 100644
--- a/be/src/vec/columns/column_struct.cpp
+++ b/be/src/vec/columns/column_struct.cpp
@@ -203,7 +203,7 @@ void ColumnStruct::update_xxHash_with_value(size_t start,
size_t end, uint64_t&
}
}
-void ColumnStruct::update_crc_with_value(size_t start, size_t end, uint64_t&
hash,
+void ColumnStruct::update_crc_with_value(size_t start, size_t end, uint32_t&
hash,
const uint8_t* __restrict null_data)
const {
for (const auto& column : columns) {
column->update_crc_with_value(start, end, hash, nullptr);
@@ -217,10 +217,11 @@ void ColumnStruct::update_hashes_with_value(uint64_t*
__restrict hashes,
}
}
-void ColumnStruct::update_crcs_with_value(std::vector<uint64_t>& hash,
PrimitiveType type,
+void ColumnStruct::update_crcs_with_value(uint32_t* __restrict hash,
PrimitiveType type,
+ uint32_t rows, uint32_t offset,
const uint8_t* __restrict null_data)
const {
for (const auto& column : columns) {
- column->update_crcs_with_value(hash, type, null_data);
+ column->update_crcs_with_value(hash, type, rows, offset, null_data);
}
}
diff --git a/be/src/vec/columns/column_struct.h
b/be/src/vec/columns/column_struct.h
index 535604f7260..23f50582780 100644
--- a/be/src/vec/columns/column_struct.h
+++ b/be/src/vec/columns/column_struct.h
@@ -108,7 +108,7 @@ public:
void update_hash_with_value(size_t n, SipHash& hash) const override;
void update_xxHash_with_value(size_t start, size_t end, uint64_t& hash,
const uint8_t* __restrict null_data) const
override;
- void update_crc_with_value(size_t start, size_t end, uint64_t& hash,
+ void update_crc_with_value(size_t start, size_t end, uint32_t& hash,
const uint8_t* __restrict null_data) const
override;
void update_hashes_with_value(std::vector<SipHash>& hashes,
@@ -117,7 +117,8 @@ public:
void update_hashes_with_value(uint64_t* __restrict hashes,
const uint8_t* __restrict null_data =
nullptr) const override;
- void update_crcs_with_value(std::vector<uint64_t>& hash, PrimitiveType
type,
+ void update_crcs_with_value(uint32_t* __restrict hash, PrimitiveType type,
uint32_t rows,
+ uint32_t offset = 0,
const uint8_t* __restrict null_data = nullptr)
const override;
void insert_indices_from(const IColumn& src, const int* indices_begin,
diff --git a/be/src/vec/columns/column_vector.cpp
b/be/src/vec/columns/column_vector.cpp
index d61b4a831ae..bae633d1490 100644
--- a/be/src/vec/columns/column_vector.cpp
+++ b/be/src/vec/columns/column_vector.cpp
@@ -168,9 +168,10 @@ void ColumnVector<T>::compare_internal(size_t rhs_row_id,
const IColumn& rhs,
}
template <typename T>
-void ColumnVector<T>::update_crcs_with_value(std::vector<uint64_t>& hashes,
PrimitiveType type,
+void ColumnVector<T>::update_crcs_with_value(uint32_t* __restrict hashes,
PrimitiveType type,
+ uint32_t rows, uint32_t offset,
const uint8_t* __restrict
null_data) const {
- auto s = hashes.size();
+ auto s = rows;
DCHECK(s == size());
if constexpr (!std::is_same_v<T, Int64>) {
diff --git a/be/src/vec/columns/column_vector.h
b/be/src/vec/columns/column_vector.h
index 0cf100ce07b..5f6ff285ab2 100644
--- a/be/src/vec/columns/column_vector.h
+++ b/be/src/vec/columns/column_vector.h
@@ -288,7 +288,7 @@ public:
}
}
- void ALWAYS_INLINE update_crc_with_value_without_null(size_t idx,
uint64_t& hash) const {
+ void ALWAYS_INLINE update_crc_with_value_without_null(size_t idx,
uint32_t& hash) const {
if constexpr (!std::is_same_v<T, Int64>) {
hash = HashUtil::zlib_crc_hash(&data[idx], sizeof(T), hash);
} else {
@@ -303,7 +303,7 @@ public:
}
}
- void update_crc_with_value(size_t start, size_t end, uint64_t& hash,
+ void update_crc_with_value(size_t start, size_t end, uint32_t& hash,
const uint8_t* __restrict null_data) const
override {
if (null_data) {
for (size_t i = start; i < end; i++) {
@@ -322,7 +322,8 @@ public:
void update_hashes_with_value(std::vector<SipHash>& hashes,
const uint8_t* __restrict null_data) const
override;
- void update_crcs_with_value(std::vector<uint64_t>& hashes, PrimitiveType
type,
+ void update_crcs_with_value(uint32_t* __restrict hashes, PrimitiveType
type, uint32_t rows,
+ uint32_t offset,
const uint8_t* __restrict null_data) const
override;
void update_hashes_with_value(uint64_t* __restrict hashes,
diff --git a/be/src/vec/common/hash_table/hash_map_context.h
b/be/src/vec/common/hash_table/hash_map_context.h
index 0d2ad598140..f40a351f9d8 100644
--- a/be/src/vec/common/hash_table/hash_map_context.h
+++ b/be/src/vec/common/hash_table/hash_map_context.h
@@ -93,10 +93,10 @@ struct MethodBase {
}
template <bool read>
- void prefetch(int currrent) {
- if (LIKELY(currrent + HASH_MAP_PREFETCH_DIST < hash_values.size())) {
- hash_table->template prefetch<read>(keys[currrent +
HASH_MAP_PREFETCH_DIST],
- hash_values[currrent +
HASH_MAP_PREFETCH_DIST]);
+ void prefetch(int current) {
+ if (LIKELY(current + HASH_MAP_PREFETCH_DIST < hash_values.size())) {
+ hash_table->template prefetch<read>(keys[current +
HASH_MAP_PREFETCH_DIST],
+ hash_values[current +
HASH_MAP_PREFETCH_DIST]);
}
}
diff --git a/be/src/vec/exec/scan/pip_scanner_context.h
b/be/src/vec/exec/scan/pip_scanner_context.h
index 159cf2ba658..66eaed7f284 100644
--- a/be/src/vec/exec/scan/pip_scanner_context.h
+++ b/be/src/vec/exec/scan/pip_scanner_context.h
@@ -103,7 +103,7 @@ public:
int64_t local_bytes = 0;
if (_need_colocate_distribute) {
- std::vector<uint64_t> hash_vals;
+ std::vector<uint32_t> hash_vals;
for (const auto& block : blocks) {
// vectorized calculate hash
int rows = block->rows();
@@ -115,9 +115,11 @@ public:
for (int j = 0; j < _col_distribute_ids.size(); ++j) {
block->get_by_position(_col_distribute_ids[j])
.column->update_crcs_with_value(
- hash_vals,
_output_tuple_desc->slots()[_col_distribute_ids[j]]
- ->type()
- .type);
+ hash_vals.data(),
+
_output_tuple_desc->slots()[_col_distribute_ids[j]]
+ ->type()
+ .type,
+ rows);
}
for (int i = 0; i < rows; i++) {
hashes[i] = hashes[i] % element_size;
diff --git a/be/src/vec/runtime/partitioner.cpp
b/be/src/vec/runtime/partitioner.cpp
new file mode 100644
index 00000000000..bb95dcbb6b4
--- /dev/null
+++ b/be/src/vec/runtime/partitioner.cpp
@@ -0,0 +1,70 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "partitioner.h"
+
+#include "runtime/thread_context.h"
+#include "vec/columns/column_const.h"
+
+namespace doris::vectorized {
+
+template <typename HashValueType>
+Status Partitioner<HashValueType>::do_partitioning(RuntimeState* state, Block*
block,
+ MemTracker* mem_tracker)
const {
+ int rows = block->rows();
+
+ if (rows > 0) {
+ auto column_to_keep = block->columns();
+
+ int result_size = _partition_expr_ctxs.size();
+ std::vector<int> result(result_size);
+
+ _hash_vals.resize(rows);
+ std::fill(_hash_vals.begin(), _hash_vals.end(), 0);
+ auto* __restrict hashes = _hash_vals.data();
+ {
+ SCOPED_CONSUME_MEM_TRACKER(mem_tracker);
+ RETURN_IF_ERROR(_get_partition_column_result(block, result));
+ }
+ for (int j = 0; j < result_size; ++j) {
+
_do_hash(unpack_if_const(block->get_by_position(result[j]).column).first,
hashes, j);
+ }
+
+ for (int i = 0; i < rows; i++) {
+ hashes[i] = hashes[i] % _partition_count;
+ }
+
+ {
+ SCOPED_CONSUME_MEM_TRACKER(mem_tracker);
+ Block::erase_useless_column(block, column_to_keep);
+ }
+ }
+ return Status::OK();
+}
+
+void BucketHashPartitioner::_do_hash(const ColumnPtr& column, uint32_t*
__restrict result,
+ int idx) const {
+ column->update_crcs_with_value(result,
_partition_expr_ctxs[idx]->root()->type().type,
+ column->size());
+}
+
+void HashPartitioner::_do_hash(const ColumnPtr& column, uint64_t* __restrict
result,
+ int /*idx*/) const {
+ column->update_hashes_with_value(result);
+}
+
+} // namespace doris::vectorized
diff --git a/be/src/vec/runtime/partitioner.h b/be/src/vec/runtime/partitioner.h
new file mode 100644
index 00000000000..c0ee400012d
--- /dev/null
+++ b/be/src/vec/runtime/partitioner.h
@@ -0,0 +1,105 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "util/runtime_profile.h"
+#include "vec/exprs/vexpr.h"
+#include "vec/exprs/vexpr_context.h"
+
+namespace doris {
+class MemTracker;
+
+namespace vectorized {
+
+class PartitionerBase {
+public:
+ PartitionerBase(size_t partition_count) :
_partition_count(partition_count) {}
+ virtual ~PartitionerBase() = default;
+
+ virtual Status init(const std::vector<TExpr>& texprs) = 0;
+
+ virtual Status prepare(RuntimeState* state, const RowDescriptor& row_desc)
= 0;
+
+ virtual Status open(RuntimeState* state) = 0;
+
+ virtual Status do_partitioning(RuntimeState* state, Block* block,
+ MemTracker* mem_tracker) const = 0;
+
+ virtual void* get_hash_values() const = 0;
+
+protected:
+ const size_t _partition_count;
+};
+
+template <typename HashValueType>
+class Partitioner : public PartitionerBase {
+public:
+ Partitioner(int partition_count) : PartitionerBase(partition_count) {}
+ ~Partitioner() override = default;
+
+ Status init(const std::vector<TExpr>& texprs) override {
+ return VExpr::create_expr_trees(texprs, _partition_expr_ctxs);
+ }
+
+ Status prepare(RuntimeState* state, const RowDescriptor& row_desc)
override {
+ return VExpr::prepare(_partition_expr_ctxs, state, row_desc);
+ }
+
+ Status open(RuntimeState* state) override { return
VExpr::open(_partition_expr_ctxs, state); }
+
+ Status do_partitioning(RuntimeState* state, Block* block,
+ MemTracker* mem_tracker) const override;
+
+ void* get_hash_values() const override { return _hash_vals.data(); }
+
+protected:
+ Status _get_partition_column_result(Block* block, std::vector<int>&
result) const {
+ int counter = 0;
+ for (auto ctx : _partition_expr_ctxs) {
+ RETURN_IF_ERROR(ctx->execute(block, &result[counter++]));
+ }
+ return Status::OK();
+ }
+
+ virtual void _do_hash(const ColumnPtr& column, HashValueType* __restrict
result,
+ int idx) const = 0;
+
+ VExprContextSPtrs _partition_expr_ctxs;
+ mutable std::vector<HashValueType> _hash_vals;
+};
+
+class HashPartitioner final : public Partitioner<uint64_t> {
+public:
+ HashPartitioner(int partition_count) :
Partitioner<uint64_t>(partition_count) {}
+ ~HashPartitioner() override = default;
+
+private:
+ void _do_hash(const ColumnPtr& column, uint64_t* __restrict result, int
idx) const override;
+};
+
+class BucketHashPartitioner final : public Partitioner<uint32_t> {
+public:
+ BucketHashPartitioner(int partition_count) :
Partitioner<uint32_t>(partition_count) {}
+ ~BucketHashPartitioner() override = default;
+
+private:
+ void _do_hash(const ColumnPtr& column, uint32_t* __restrict result, int
idx) const override;
+};
+
+} // namespace vectorized
+} // namespace doris
diff --git a/be/src/vec/sink/vdata_stream_sender.cpp
b/be/src/vec/sink/vdata_stream_sender.cpp
index ad19dcd9cdf..3bce57eda98 100644
--- a/be/src/vec/sink/vdata_stream_sender.cpp
+++ b/be/src/vec/sink/vdata_stream_sender.cpp
@@ -416,10 +416,14 @@ VDataStreamSender::~VDataStreamSender() {
Status VDataStreamSender::init(const TDataSink& tsink) {
RETURN_IF_ERROR(DataSink::init(tsink));
const TDataStreamSink& t_stream_sink = tsink.stream_sink;
- if (_part_type == TPartitionType::HASH_PARTITIONED ||
- _part_type == TPartitionType::BUCKET_SHFFULE_HASH_PARTITIONED) {
-
RETURN_IF_ERROR(VExpr::create_expr_trees(t_stream_sink.output_partition.partition_exprs,
- _partition_expr_ctxs));
+ if (_part_type == TPartitionType::HASH_PARTITIONED) {
+ _partition_count = _channels.size();
+ _partitioner.reset(new HashPartitioner(_channels.size()));
+
RETURN_IF_ERROR(_partitioner->init(t_stream_sink.output_partition.partition_exprs));
+ } else if (_part_type == TPartitionType::BUCKET_SHFFULE_HASH_PARTITIONED) {
+ _partition_count = _channel_shared_ptrs.size();
+ _partitioner.reset(new
BucketHashPartitioner(_channel_shared_ptrs.size()));
+
RETURN_IF_ERROR(_partitioner->init(t_stream_sink.output_partition.partition_exprs));
} else if (_part_type == TPartitionType::RANGE_PARTITIONED) {
return Status::InternalError("TPartitionType::RANGE_PARTITIONED should
not be used");
} else {
@@ -449,9 +453,7 @@ Status VDataStreamSender::prepare(RuntimeState* state) {
shuffle(_channels.begin(), _channels.end(), g);
} else if (_part_type == TPartitionType::HASH_PARTITIONED ||
_part_type == TPartitionType::BUCKET_SHFFULE_HASH_PARTITIONED) {
- RETURN_IF_ERROR(VExpr::prepare(_partition_expr_ctxs, state,
_row_desc));
- } else {
- RETURN_IF_ERROR(VExpr::prepare(_partition_expr_ctxs, state,
_row_desc));
+ RETURN_IF_ERROR(_partitioner->prepare(state, _row_desc));
}
_bytes_sent_counter = ADD_COUNTER(profile(), "BytesSent", TUnit::BYTES);
@@ -490,7 +492,10 @@ Status VDataStreamSender::open(RuntimeState* state) {
}
_only_local_exchange = local_size == _channels.size();
SCOPED_CONSUME_MEM_TRACKER(_mem_tracker.get());
- RETURN_IF_ERROR(VExpr::open(_partition_expr_ctxs, state));
+ if (_part_type == TPartitionType::HASH_PARTITIONED ||
+ _part_type == TPartitionType::BUCKET_SHFFULE_HASH_PARTITIONED) {
+ RETURN_IF_ERROR(_partitioner->open(state));
+ }
_compression_type = state->fragement_transmission_compression_type();
return Status::OK();
@@ -613,67 +618,17 @@ Status VDataStreamSender::send(RuntimeState* state,
Block* block, bool eos) {
_current_channel_idx = (_current_channel_idx + 1) % _channels.size();
} else if (_part_type == TPartitionType::HASH_PARTITIONED ||
_part_type == TPartitionType::BUCKET_SHFFULE_HASH_PARTITIONED) {
- // will only copy schema
- // we don't want send temp columns
- auto column_to_keep = block->columns();
-
- int result_size = _partition_expr_ctxs.size();
- int result[result_size];
-
- // vectorized calculate hash
- int rows = block->rows();
- auto element_size = _part_type == TPartitionType::HASH_PARTITIONED
- ? _channels.size()
- : _channel_shared_ptrs.size();
- std::vector<uint64_t> hash_vals(rows);
- auto* __restrict hashes = hash_vals.data();
-
- if (rows > 0) {
- {
- SCOPED_CONSUME_MEM_TRACKER(_mem_tracker.get());
- RETURN_IF_ERROR(get_partition_column_result(block, result));
- }
- // TODO: after we support new shuffle hash method, should simple
the code
- if (_part_type == TPartitionType::HASH_PARTITIONED) {
- SCOPED_TIMER(_split_block_hash_compute_timer);
- // result[j] means column index, i means rows index, here to
calculate the xxhash value
- for (int j = 0; j < result_size; ++j) {
- // complex type most not implement get_data_at() method
which column_const will call
- unpack_if_const(block->get_by_position(result[j]).column)
- .first->update_hashes_with_value(hashes);
- }
-
- for (int i = 0; i < rows; i++) {
- hashes[i] = hashes[i] % element_size;
- }
-
- {
- SCOPED_CONSUME_MEM_TRACKER(_mem_tracker.get());
- Block::erase_useless_column(block, column_to_keep);
- }
- } else {
- for (int j = 0; j < result_size; ++j) {
- // complex type most not implement get_data_at() method
which column_const will call
- unpack_if_const(block->get_by_position(result[j]).column)
- .first->update_crcs_with_value(
- hash_vals,
_partition_expr_ctxs[j]->root()->type().type);
- }
- for (int i = 0; i < rows; i++) {
- hashes[i] = hashes[i] % element_size;
- }
-
- {
- SCOPED_CONSUME_MEM_TRACKER(_mem_tracker.get());
- Block::erase_useless_column(block, column_to_keep);
- }
- }
- }
+ auto rows = block->rows();
+ SCOPED_TIMER(_split_block_hash_compute_timer);
+ RETURN_IF_ERROR(_partitioner->do_partitioning(state, block,
_mem_tracker.get()));
if (_part_type == TPartitionType::HASH_PARTITIONED) {
- RETURN_IF_ERROR(channel_add_rows(state, _channels, element_size,
hashes, rows, block,
- _enable_pipeline_exec ? eos :
false));
+ RETURN_IF_ERROR(channel_add_rows(state, _channels,
_partition_count,
+
(uint64_t*)_partitioner->get_hash_values(), rows,
+ block, _enable_pipeline_exec ?
eos : false));
} else {
- RETURN_IF_ERROR(channel_add_rows(state, _channel_shared_ptrs,
element_size, hashes,
- rows, block,
_enable_pipeline_exec ? eos : false));
+ RETURN_IF_ERROR(channel_add_rows(state, _channel_shared_ptrs,
_partition_count,
+
(uint32_t*)_partitioner->get_hash_values(), rows,
+ block, _enable_pipeline_exec ?
eos : false));
}
} else {
// Range partition
diff --git a/be/src/vec/sink/vdata_stream_sender.h
b/be/src/vec/sink/vdata_stream_sender.h
index d72e1dad39a..203d59dc664 100644
--- a/be/src/vec/sink/vdata_stream_sender.h
+++ b/be/src/vec/sink/vdata_stream_sender.h
@@ -46,6 +46,7 @@
#include "util/uid_util.h"
#include "vec/core/block.h"
#include "vec/exprs/vexpr_context.h"
+#include "vec/runtime/partitioner.h"
#include "vec/runtime/vdata_stream_recvr.h"
namespace doris {
@@ -151,17 +152,10 @@ protected:
void _roll_pb_block();
Status _get_next_available_buffer(BroadcastPBlockHolder** holder);
- Status get_partition_column_result(Block* block, int* result) const {
- int counter = 0;
- for (auto ctx : _partition_expr_ctxs) {
- RETURN_IF_ERROR(ctx->execute(block, &result[counter++]));
- }
- return Status::OK();
- }
-
- template <typename Channels>
+ template <typename Channels, typename HashValueType>
Status channel_add_rows(RuntimeState* state, Channels& channels, int
num_channels,
- const uint64_t* channel_ids, int rows, Block*
block, bool eos);
+ const HashValueType* __restrict channel_ids, int
rows, Block* block,
+ bool eos);
template <typename ChannelPtrType>
void _handle_eof_channel(RuntimeState* state, ChannelPtrType channel,
Status st);
@@ -186,8 +180,8 @@ protected:
std::vector<BroadcastPBlockHolder> _broadcast_pb_blocks;
int _broadcast_pb_block_idx;
- // compute per-row partition values
- VExprContextSPtrs _partition_expr_ctxs;
+ std::unique_ptr<PartitionerBase> _partitioner;
+ size_t _partition_count;
std::vector<Channel<VDataStreamSender>*> _channels;
std::vector<std::shared_ptr<Channel<VDataStreamSender>>>
_channel_shared_ptrs;
@@ -416,10 +410,11 @@ protected:
} \
} while (0)
-template <typename Channels>
+template <typename Channels, typename HashValueType>
Status VDataStreamSender::channel_add_rows(RuntimeState* state, Channels&
channels,
- int num_channels, const uint64_t*
__restrict channel_ids,
- int rows, Block* block, bool eos) {
+ int num_channels,
+ const HashValueType* __restrict
channel_ids, int rows,
+ Block* block, bool eos) {
std::vector<int> channel2rows[num_channels];
for (int i = 0; i < rows; i++) {
diff --git a/be/test/vec/columns/column_hash_func_test.cpp
b/be/test/vec/columns/column_hash_func_test.cpp
index f80edb035f0..bdde0f33a62 100644
--- a/be/test/vec/columns/column_hash_func_test.cpp
+++ b/be/test/vec/columns/column_hash_func_test.cpp
@@ -64,7 +64,7 @@ TEST(HashFuncTest, ArrayTypeTest) {
std::vector<uint64_t> sip_hash_vals(1);
std::vector<uint64_t> xx_hash_vals(1);
- std::vector<uint64_t> crc_hash_vals(1);
+ std::vector<uint32_t> crc_hash_vals(1);
auto* __restrict sip_hashes = sip_hash_vals.data();
auto* __restrict xx_hashes = xx_hash_vals.data();
auto* __restrict crc_hashes = crc_hash_vals.data();
@@ -83,7 +83,7 @@ TEST(HashFuncTest, ArrayTypeTest) {
std::cout << xx_hashes[0] << std::endl;
// crcHash
EXPECT_NO_FATAL_FAILURE(
- col_a->update_crcs_with_value(crc_hash_vals,
PrimitiveType::TYPE_ARRAY));
+ col_a->update_crcs_with_value(crc_hashes,
PrimitiveType::TYPE_ARRAY, 1));
std::cout << crc_hashes[0] << std::endl;
}
}
@@ -103,12 +103,12 @@ TEST(HashFuncTest, ArraySimpleBenchmarkTest) {
}
array_mutable_col->insert(a);
}
- std::vector<uint64_t> crc_hash_vals(r_num);
+ std::vector<uint32_t> crc_hash_vals(r_num);
int64_t time_t = 0;
{
SCOPED_RAW_TIMER(&time_t);
EXPECT_NO_FATAL_FAILURE(array_mutable_col->update_crcs_with_value(
- crc_hash_vals, PrimitiveType::TYPE_ARRAY));
+ crc_hash_vals.data(), PrimitiveType::TYPE_ARRAY, r_num));
}
std::cout << time_t << "ns" << std::endl;
}
@@ -150,7 +150,7 @@ TEST(HashFuncTest, ArrayNestedArrayTest) {
EXPECT_EQ(nested_col->size(), 8);
std::vector<uint64_t> xx_hash_vals(4);
- std::vector<uint64_t> crc_hash_vals(4);
+ std::vector<uint32_t> crc_hash_vals(4);
auto* __restrict xx_hashes = xx_hash_vals.data();
auto* __restrict crc_hashes = crc_hash_vals.data();
@@ -160,7 +160,7 @@ TEST(HashFuncTest, ArrayNestedArrayTest) {
EXPECT_TRUE(xx_hashes[2] != xx_hashes[3]);
// crcHash
EXPECT_NO_FATAL_FAILURE(
- array_mutable_col->update_crcs_with_value(crc_hash_vals,
PrimitiveType::TYPE_ARRAY));
+ array_mutable_col->update_crcs_with_value(crc_hashes,
PrimitiveType::TYPE_ARRAY, 4));
EXPECT_TRUE(crc_hashes[0] != crc_hashes[1]);
EXPECT_TRUE(crc_hashes[2] != crc_hashes[3]);
}
@@ -186,7 +186,7 @@ TEST(HashFuncTest, ArrayCornerCaseTest) {
std::vector<uint64_t> sip_hash_vals(3);
std::vector<uint64_t> xx_hash_vals(3);
- std::vector<uint64_t> crc_hash_vals(3);
+ std::vector<uint32_t> crc_hash_vals(3);
auto* __restrict sip_hashes = sip_hash_vals.data();
auto* __restrict xx_hashes = xx_hash_vals.data();
auto* __restrict crc_hashes = crc_hash_vals.data();
@@ -205,8 +205,8 @@ TEST(HashFuncTest, ArrayCornerCaseTest) {
EXPECT_EQ(xx_hashes[0], xx_hashes[1]);
EXPECT_TRUE(xx_hashes[0] != xx_hashes[2]);
// crcHash
- EXPECT_NO_FATAL_FAILURE(
- array_mutable_col->update_crcs_with_value(crc_hash_vals,
PrimitiveType::TYPE_ARRAY));
+ EXPECT_NO_FATAL_FAILURE(array_mutable_col->update_crcs_with_value(
+ crc_hashes, PrimitiveType::TYPE_ARRAY, array_mutable_col->size()));
EXPECT_EQ(crc_hashes[0], crc_hashes[1]);
EXPECT_TRUE(xx_hashes[0] != xx_hashes[2]);
}
@@ -216,7 +216,7 @@ TEST(HashFuncTest, MapTypeTest) {
std::vector<uint64_t> sip_hash_vals(1);
std::vector<uint64_t> xx_hash_vals(1);
- std::vector<uint64_t> crc_hash_vals(1);
+ std::vector<uint32_t> crc_hash_vals(1);
auto* __restrict sip_hashes = sip_hash_vals.data();
auto* __restrict xx_hashes = xx_hash_vals.data();
auto* __restrict crc_hashes = crc_hash_vals.data();
@@ -234,7 +234,7 @@ TEST(HashFuncTest, MapTypeTest) {
std::cout << xx_hashes[0] << std::endl;
// crcHash
EXPECT_NO_FATAL_FAILURE(unpack_if_const(col_a).first->update_crcs_with_value(
- crc_hash_vals, PrimitiveType::TYPE_MAP));
+ crc_hashes, PrimitiveType::TYPE_MAP, 1));
std::cout << crc_hashes[0] << std::endl;
}
}
@@ -244,7 +244,7 @@ TEST(HashFuncTest, StructTypeTest) {
std::vector<uint64_t> sip_hash_vals(1);
std::vector<uint64_t> xx_hash_vals(1);
- std::vector<uint64_t> crc_hash_vals(1);
+ std::vector<uint32_t> crc_hash_vals(1);
auto* __restrict sip_hashes = sip_hash_vals.data();
auto* __restrict xx_hashes = xx_hash_vals.data();
auto* __restrict crc_hashes = crc_hash_vals.data();
@@ -262,7 +262,7 @@ TEST(HashFuncTest, StructTypeTest) {
std::cout << xx_hashes[0] << std::endl;
// crcHash
EXPECT_NO_FATAL_FAILURE(unpack_if_const(col_a).first->update_crcs_with_value(
- crc_hash_vals, PrimitiveType::TYPE_STRUCT));
+ crc_hashes, PrimitiveType::TYPE_STRUCT, 1));
std::cout << crc_hashes[0] << std::endl;
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]