Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor](local exchange) Simplify block wrapper #47620

Merged
merged 3 commits into from
Feb 8, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 61 additions & 108 deletions be/src/pipeline/local_exchange/local_exchanger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,35 +35,19 @@ void Exchanger<BlockType>::_enqueue_data_and_set_ready(int channel_id,
_enqueue_data_and_set_ready(channel_id, std::move(block));
return;
}
size_t allocated_bytes = 0;
// PartitionedBlock is used by shuffle exchanger.
// PartitionedBlock will be push into multiple queues with different row ranges, so it will be
// referenced multiple times. Otherwise, we only ref the block once because it is only push into
// one queue.
std::unique_lock l(*_m[channel_id]);
if constexpr (std::is_same_v<PartitionedBlock, BlockType> ||
std::is_same_v<BroadcastBlock, BlockType>) {
allocated_bytes = block.first->data_block.allocated_bytes();
block.first->record_channel_id(channel_id, false);
} else {
block->ref(1);
allocated_bytes = block->data_block.allocated_bytes();
block->record_channel_id(channel_id, true);
}
std::unique_lock l(*_m[channel_id]);
local_state->_shared_state->add_mem_usage(channel_id, allocated_bytes,
!std::is_same_v<PartitionedBlock, BlockType> &&
!std::is_same_v<BroadcastBlock, BlockType>);
if (_data_queue[channel_id].enqueue(std::move(block))) {
local_state->_shared_state->set_ready_to_read(channel_id);
} else {
local_state->_shared_state->sub_mem_usage(channel_id, allocated_bytes);
// `enqueue(block)` return false iff this queue's source operator is already closed so we
// just unref the block.
if constexpr (std::is_same_v<PartitionedBlock, BlockType> ||
std::is_same_v<BroadcastBlock, BlockType>) {
block.first->unref(local_state->_shared_state, allocated_bytes, channel_id);
} else {
block->unref(local_state->_shared_state, allocated_bytes, channel_id);
DCHECK_EQ(block->ref_value(), 0);
}
}
}

Expand All @@ -84,8 +68,6 @@ bool Exchanger<BlockType>::_dequeue_data(LocalExchangeSourceLocalState* local_st
local_state->_shared_state->sub_mem_usage(channel_id,
block->data_block.allocated_bytes());
data_block->swap(block->data_block);
block->unref(local_state->_shared_state, data_block->allocated_bytes(), channel_id);
DCHECK_EQ(block->ref_value(), 0);
}
return true;
} else if (all_finished) {
Expand All @@ -101,8 +83,6 @@ bool Exchanger<BlockType>::_dequeue_data(LocalExchangeSourceLocalState* local_st
local_state->_shared_state->sub_mem_usage(channel_id,
block->data_block.allocated_bytes());
data_block->swap(block->data_block);
block->unref(local_state->_shared_state, data_block->allocated_bytes(), channel_id);
DCHECK_EQ(block->ref_value(), 0);
}
return true;
}
Expand All @@ -114,18 +94,14 @@ bool Exchanger<BlockType>::_dequeue_data(LocalExchangeSourceLocalState* local_st

template <typename BlockType>
void Exchanger<BlockType>::_enqueue_data_and_set_ready(int channel_id, BlockType&& block) {
if constexpr (!std::is_same_v<PartitionedBlock, BlockType> &&
!std::is_same_v<BroadcastBlock, BlockType>) {
block->ref(1);
if constexpr (std::is_same_v<PartitionedBlock, BlockType> ||
std::is_same_v<BroadcastBlock, BlockType>) {
block.first->record_channel_id(channel_id, false);
} else {
block->record_channel_id(channel_id, true);
}
if (!_data_queue[channel_id].enqueue(std::move(block))) {
Gabriel39 marked this conversation as resolved.
Show resolved Hide resolved
if constexpr (std::is_same_v<PartitionedBlock, BlockType> ||
std::is_same_v<BroadcastBlock, BlockType>) {
block.first->unref();
} else {
block->unref();
DCHECK_EQ(block->ref_value(), 0);
}
// do nothing
}
}

Expand All @@ -136,8 +112,6 @@ bool Exchanger<BlockType>::_dequeue_data(BlockType& block, bool* eos, vectorized
if constexpr (!std::is_same_v<PartitionedBlock, BlockType> &&
!std::is_same_v<BroadcastBlock, BlockType>) {
data_block->swap(block->data_block);
block->unref();
DCHECK_EQ(block->ref_value(), 0);
}
return true;
}
Expand Down Expand Up @@ -170,9 +144,7 @@ void ShuffleExchanger::close(SourceInfo&& source_info) {
_data_queue[source_info.channel_id].set_eos();
while (_dequeue_data(source_info.local_state, partitioned_block, &eos, &block,
source_info.channel_id)) {
partitioned_block.first->unref(
source_info.local_state ? source_info.local_state->_shared_state : nullptr,
source_info.channel_id);
// do nothing
}
}

Expand All @@ -186,11 +158,6 @@ Status ShuffleExchanger::get_block(RuntimeState* state, vectorized::Block* block
const auto* offset_start = partitioned_block.second.row_idxs->data() +
partitioned_block.second.offset_start;
auto block_wrapper = partitioned_block.first;
Defer defer {[&]() {
block_wrapper->unref(
source_info.local_state ? source_info.local_state->_shared_state : nullptr,
source_info.channel_id);
}};
RETURN_IF_ERROR(mutable_block.add_rows(&block_wrapper->data_block, offset_start,
offset_start + partitioned_block.second.length));
} while (mutable_block.rows() < state->batch_size() && !*eos &&
Expand Down Expand Up @@ -235,52 +202,33 @@ Status ShuffleExchanger::_split_rows(RuntimeState* state, const uint32_t* __rest

vectorized::Block data_block;
std::shared_ptr<BlockWrapper> new_block_wrapper;
if (_free_blocks.try_dequeue(data_block)) {
new_block_wrapper = BlockWrapper::create_shared(std::move(data_block));
} else {
new_block_wrapper = BlockWrapper::create_shared(block->clone_empty());
if (!_free_blocks.try_dequeue(data_block)) {
data_block = block->clone_empty();
}

new_block_wrapper->data_block.swap(*block);
data_block.swap(*block);
new_block_wrapper =
BlockWrapper::create_shared(std::move(data_block), local_state->_shared_state);
if (new_block_wrapper->data_block.empty()) {
return Status::OK();
}
local_state->_shared_state->add_total_mem_usage(new_block_wrapper->data_block.allocated_bytes(),
channel_id);
if (get_type() == ExchangeType::HASH_SHUFFLE) {
/**
* If type is `HASH_SHUFFLE`, data are hash-shuffled and distributed to all instances of
* all BEs. So we need a shuffleId-To-InstanceId mapping.
* For example, row 1 get a hash value 1 which means we should distribute to instance 1 on
* BE 1 and row 2 get a hash value 2 which means we should distribute to instance 1 on BE 3.
*/
DCHECK(shuffle_idx_to_instance_idx && shuffle_idx_to_instance_idx->size() > 0);
const auto& map = *shuffle_idx_to_instance_idx;
new_block_wrapper->ref(cast_set<int>(map.size()));
for (const auto& it : map) {
DCHECK(it.second >= 0 && it.second < _num_partitions)
<< it.first << " : " << it.second << " " << _num_partitions;
uint32_t start = partition_rows_histogram[it.first];
uint32_t size = partition_rows_histogram[it.first + 1] - start;
if (size > 0) {
_enqueue_data_and_set_ready(it.second, local_state,
{new_block_wrapper, {row_idx, start, size}});
} else {
new_block_wrapper->unref(local_state->_shared_state, channel_id);
}
}
} else {
DCHECK(shuffle_idx_to_instance_idx && shuffle_idx_to_instance_idx->size() > 0);
new_block_wrapper->ref(_num_partitions);
for (int i = 0; i < _num_partitions; i++) {
uint32_t start = partition_rows_histogram[i];
uint32_t size = partition_rows_histogram[i + 1] - start;
if (size > 0) {
_enqueue_data_and_set_ready((*shuffle_idx_to_instance_idx)[i], local_state,
{new_block_wrapper, {row_idx, start, size}});
} else {
new_block_wrapper->unref(local_state->_shared_state, channel_id);
}
/**
* Data are hash-shuffled and distributed to all instances of
* all BEs. So we need a shuffleId-To-InstanceId mapping.
* For example, row 1 get a hash value 1 which means we should distribute to instance 1 on
* BE 1 and row 2 get a hash value 2 which means we should distribute to instance 1 on BE 3.
*/
DCHECK(shuffle_idx_to_instance_idx && shuffle_idx_to_instance_idx->size() > 0);
const auto& map = *shuffle_idx_to_instance_idx;
for (const auto& it : map) {
DCHECK(it.second >= 0 && it.second < _num_partitions)
<< it.first << " : " << it.second << " " << _num_partitions;
uint32_t start = partition_rows_histogram[it.first];
uint32_t size = partition_rows_histogram[it.first + 1] - start;
if (size > 0) {
_enqueue_data_and_set_ready(it.second, local_state,
{new_block_wrapper, {row_idx, start, size}});
}
}

Expand Down Expand Up @@ -308,24 +256,19 @@ Status ShuffleExchanger::_split_rows(RuntimeState* state, const uint32_t* __rest

vectorized::Block data_block;
std::shared_ptr<BlockWrapper> new_block_wrapper;
if (_free_blocks.try_dequeue(data_block)) {
new_block_wrapper = BlockWrapper::create_shared(std::move(data_block));
} else {
new_block_wrapper = BlockWrapper::create_shared(block->clone_empty());
if (!_free_blocks.try_dequeue(data_block)) {
data_block = block->clone_empty();
}

new_block_wrapper->data_block.swap(*block);
data_block.swap(*block);
new_block_wrapper = BlockWrapper::create_shared(std::move(data_block), nullptr);
if (new_block_wrapper->data_block.empty()) {
return Status::OK();
}
new_block_wrapper->ref(cast_set<int>(_num_partitions));
for (int i = 0; i < _num_partitions; i++) {
uint32_t start = partition_rows_histogram[i];
uint32_t size = partition_rows_histogram[i + 1] - start;
if (size > 0) {
_enqueue_data_and_set_ready(i, {new_block_wrapper, {row_idx, start, size}});
} else {
new_block_wrapper->unref();
}
}

Expand All @@ -343,7 +286,9 @@ Status PassthroughExchanger::sink(RuntimeState* state, vectorized::Block* in_blo
new_block = {in_block->clone_empty()};
}
new_block.swap(*in_block);
wrapper = BlockWrapper::create_shared(std::move(new_block));
wrapper = BlockWrapper::create_shared(
std::move(new_block),
sink_info.local_state ? sink_info.local_state->_shared_state : nullptr);
auto channel_id = ((*sink_info.channel_id)++) % _num_partitions;
_enqueue_data_and_set_ready(channel_id, sink_info.local_state, std::move(wrapper));

Expand Down Expand Up @@ -390,7 +335,9 @@ Status PassToOneExchanger::sink(RuntimeState* state, vectorized::Block* in_block
}
new_block.swap(*in_block);

BlockWrapperSPtr wrapper = BlockWrapper::create_shared(std::move(new_block));
BlockWrapperSPtr wrapper = BlockWrapper::create_shared(
std::move(new_block),
sink_info.local_state ? sink_info.local_state->_shared_state : nullptr);
_enqueue_data_and_set_ready(0, sink_info.local_state, std::move(wrapper));

return Status::OK();
Expand All @@ -417,8 +364,11 @@ Status LocalMergeSortExchanger::sink(RuntimeState* state, vectorized::Block* in_
DCHECK_LE(*sink_info.channel_id, _data_queue.size());

new_block.swap(*in_block);
_enqueue_data_and_set_ready(*sink_info.channel_id, sink_info.local_state,
BlockWrapper::create_shared(std::move(new_block)));
_enqueue_data_and_set_ready(
*sink_info.channel_id, sink_info.local_state,
BlockWrapper::create_shared(
std::move(new_block),
sink_info.local_state ? sink_info.local_state->_shared_state : nullptr));
}
if (eos && sink_info.local_state) {
sink_info.local_state->_shared_state->source_deps[*sink_info.channel_id]
Expand Down Expand Up @@ -503,13 +453,14 @@ Status BroadcastExchanger::sink(RuntimeState* state, vectorized::Block* in_block
new_block = {in_block->clone_empty()};
}
new_block.swap(*in_block);
auto wrapper = BlockWrapper::create_shared(std::move(new_block));
auto wrapper = BlockWrapper::create_shared(
std::move(new_block),
sink_info.local_state ? sink_info.local_state->_shared_state : nullptr);
if (sink_info.local_state) {
sink_info.local_state->_shared_state->add_total_mem_usage(
wrapper->data_block.allocated_bytes(), *sink_info.channel_id);
}

wrapper->ref(_num_partitions);
for (int i = 0; i < _num_partitions; i++) {
_enqueue_data_and_set_ready(i, sink_info.local_state,
{wrapper, {0, wrapper->data_block.rows()}});
Expand All @@ -525,9 +476,7 @@ void BroadcastExchanger::close(SourceInfo&& source_info) {
_data_queue[source_info.channel_id].set_eos();
while (_dequeue_data(source_info.local_state, partitioned_block, &eos, &block,
source_info.channel_id)) {
partitioned_block.first->unref(
source_info.local_state ? source_info.local_state->_shared_state : nullptr,
source_info.channel_id);
// do nothing
}
}

Expand All @@ -545,9 +494,6 @@ Status BroadcastExchanger::get_block(RuntimeState* state, vectorized::Block* blo
RETURN_IF_ERROR(mutable_block.add_rows(&block_wrapper->data_block,
partitioned_block.second.offset_start,
partitioned_block.second.length));
block_wrapper->unref(
source_info.local_state ? source_info.local_state->_shared_state : nullptr,
source_info.channel_id);
}

return Status::OK();
Expand All @@ -562,8 +508,11 @@ Status AdaptivePassthroughExchanger::_passthrough_sink(RuntimeState* state,
}
new_block.swap(*in_block);
auto channel_id = ((*sink_info.channel_id)++) % _num_partitions;
_enqueue_data_and_set_ready(channel_id, sink_info.local_state,
BlockWrapper::create_shared(std::move(new_block)));
_enqueue_data_and_set_ready(
channel_id, sink_info.local_state,
BlockWrapper::create_shared(
std::move(new_block),
sink_info.local_state ? sink_info.local_state->_shared_state : nullptr));

return Status::OK();
}
Expand Down Expand Up @@ -616,8 +565,12 @@ Status AdaptivePassthroughExchanger::_split_rows(RuntimeState* state,
RETURN_IF_ERROR(mutable_block->add_rows(block, start, size));
auto new_block = mutable_block->to_block();

_enqueue_data_and_set_ready(i, sink_info.local_state,
BlockWrapper::create_shared(std::move(new_block)));
_enqueue_data_and_set_ready(
i, sink_info.local_state,
BlockWrapper::create_shared(std::move(new_block),
sink_info.local_state
? sink_info.local_state->_shared_state
: nullptr));
}
}
return Status::OK();
Expand Down
30 changes: 18 additions & 12 deletions be/src/pipeline/local_exchange/local_exchanger.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,16 @@ class LocalExchangeSinkLocalState;
*/
struct BlockWrapper {
Gabriel39 marked this conversation as resolved.
Show resolved Hide resolved
ENABLE_FACTORY_CREATOR(BlockWrapper);
BlockWrapper(vectorized::Block&& data_block_) : data_block(std::move(data_block_)) {}
~BlockWrapper() { DCHECK_EQ(ref_count.load(), 0); }
void ref(int delta) { ref_count += delta; }
void unref(LocalExchangeSharedState* shared_state, size_t allocated_bytes, int channel_id) {
if (ref_count.fetch_sub(1) == 1 && shared_state != nullptr) {
BlockWrapper(vectorized::Block&& data_block_, LocalExchangeSharedState* shared_state_)
Gabriel39 marked this conversation as resolved.
Show resolved Hide resolved
: data_block(std::move(data_block_)),
shared_state(shared_state_),
allocated_bytes(data_block.allocated_bytes()) {}
yiguolei marked this conversation as resolved.
Show resolved Hide resolved
~BlockWrapper() {
if (shared_state != nullptr) {
DCHECK_GT(allocated_bytes, 0);
shared_state->sub_total_mem_usage(allocated_bytes, channel_id);
std::for_each(channel_ids.begin(), channel_ids.end(), [&](int& channel_id) {
shared_state->sub_total_mem_usage(allocated_bytes, channel_id);
});
if (shared_state->exchanger->_free_block_limit == 0 ||
shared_state->exchanger->_free_blocks.size_approx() <
shared_state->exchanger->_free_block_limit *
Expand All @@ -229,15 +232,18 @@ struct BlockWrapper {
// free block will not incur any bad result so just ignore the return value.
shared_state->exchanger->_free_blocks.enqueue(std::move(data_block));
}
}
};
}

void unref(LocalExchangeSharedState* shared_state = nullptr, int channel_id = 0) {
unref(shared_state, data_block.allocated_bytes(), channel_id);
void record_channel_id(int channel_id, bool update_total_mem_usage) {
Gabriel39 marked this conversation as resolved.
Show resolved Hide resolved
channel_ids.push_back(channel_id);
yiguolei marked this conversation as resolved.
Show resolved Hide resolved
if (shared_state) {
shared_state->add_mem_usage(channel_id, allocated_bytes, update_total_mem_usage);
}
}
int ref_value() const { return ref_count.load(); }
std::atomic<int> ref_count = 0;
vectorized::Block data_block;
LocalExchangeSharedState* shared_state;
std::vector<int> channel_ids;
const size_t allocated_bytes;
};

class ShuffleExchanger : public Exchanger<PartitionedBlock> {
Expand Down
Loading
Loading