alamb commented on code in PR #10009:
URL:
https://github.com/apache/arrow-datafusion/pull/10009#discussion_r1557758257
##########
datafusion/physical-plan/src/repartition/mod.rs:
##########
@@ -77,6 +79,90 @@ struct RepartitionExecState {
abort_helper: Arc<Vec<SpawnedTask<()>>>,
}
+impl RepartitionExecState {
+ fn new(
+ input: Arc<dyn ExecutionPlan>,
+ partitioning: Partitioning,
+ metrics: ExecutionPlanMetricsSet,
+ preserve_order: bool,
+ name: String,
+ context: Arc<TaskContext>,
+ ) -> Self {
+ let num_input_partitions =
input.output_partitioning().partition_count();
+ let num_output_partitions = partitioning.partition_count();
+
+ let (txs, rxs) = if preserve_order {
+ let (txs, rxs) =
+ partition_aware_channels(num_input_partitions,
num_output_partitions);
+ // Take transpose of senders and receivers. `state.channels` keeps
track of entries per output partition
+ let txs = transpose(txs);
+ let rxs = transpose(rxs);
+ (txs, rxs)
+ } else {
+ // create one channel per *output* partition
+ // note we use a custom channel that ensures there is always data
for each receiver
+ // but limits the amount of buffering if required.
+ let (txs, rxs) = channels(num_output_partitions);
+ // Clone sender for each input partitions
+ let txs = txs
+ .into_iter()
+ .map(|item| vec![item; num_input_partitions])
+ .collect::<Vec<_>>();
+ let rxs = rxs.into_iter().map(|item|
vec![item]).collect::<Vec<_>>();
+ (txs, rxs)
+ };
+
+ let mut channels = HashMap::with_capacity(txs.len());
+ for (partition, (tx, rx)) in txs.into_iter().zip(rxs).enumerate() {
+ let reservation = Arc::new(Mutex::new(
+ MemoryConsumer::new(format!("{}[{partition}]", name))
+ .register(context.memory_pool()),
+ ));
+ channels.insert(partition, (tx, rx, reservation));
+ }
+
+ // launch one async task per *input* partition
+ let mut spawned_tasks = Vec::with_capacity(num_input_partitions);
+ for i in 0..num_input_partitions {
+ let txs: HashMap<_, _> = channels
+ .iter()
+ .map(|(partition, (tx, _rx, reservation))| {
+ (*partition, (tx[i].clone(), Arc::clone(reservation)))
+ })
+ .collect();
+
+ // TODO: metric input-output mapping is broken
Review Comment:
Filed https://github.com/apache/arrow-datafusion/issues/10015 to track
##########
datafusion/physical-plan/src/repartition/mod.rs:
##########
@@ -77,6 +79,90 @@ struct RepartitionExecState {
abort_helper: Arc<Vec<SpawnedTask<()>>>,
}
+impl RepartitionExecState {
+ fn new(
+ input: Arc<dyn ExecutionPlan>,
+ partitioning: Partitioning,
+ metrics: ExecutionPlanMetricsSet,
+ preserve_order: bool,
+ name: String,
+ context: Arc<TaskContext>,
+ ) -> Self {
+ let num_input_partitions =
input.output_partitioning().partition_count();
+ let num_output_partitions = partitioning.partition_count();
+
+ let (txs, rxs) = if preserve_order {
+ let (txs, rxs) =
+ partition_aware_channels(num_input_partitions,
num_output_partitions);
+ // Take transpose of senders and receivers. `state.channels` keeps
track of entries per output partition
+ let txs = transpose(txs);
+ let rxs = transpose(rxs);
+ (txs, rxs)
+ } else {
+ // create one channel per *output* partition
+ // note we use a custom channel that ensures there is always data
for each receiver
+ // but limits the amount of buffering if required.
+ let (txs, rxs) = channels(num_output_partitions);
+ // Clone sender for each input partitions
+ let txs = txs
+ .into_iter()
+ .map(|item| vec![item; num_input_partitions])
+ .collect::<Vec<_>>();
+ let rxs = rxs.into_iter().map(|item|
vec![item]).collect::<Vec<_>>();
+ (txs, rxs)
+ };
+
+ let mut channels = HashMap::with_capacity(txs.len());
+ for (partition, (tx, rx)) in txs.into_iter().zip(rxs).enumerate() {
+ let reservation = Arc::new(Mutex::new(
+ MemoryConsumer::new(format!("{}[{partition}]", name))
+ .register(context.memory_pool()),
+ ));
+ channels.insert(partition, (tx, rx, reservation));
+ }
+
+ // launch one async task per *input* partition
+ let mut spawned_tasks = Vec::with_capacity(num_input_partitions);
+ for i in 0..num_input_partitions {
+ let txs: HashMap<_, _> = channels
+ .iter()
+ .map(|(partition, (tx, _rx, reservation))| {
+ (*partition, (tx[i].clone(), Arc::clone(reservation)))
+ })
+ .collect();
+
+ // TODO: metric input-output mapping is broken
+ let r_metrics = RepartitionMetrics::new(i, 0, &metrics);
+
+ let input_task =
SpawnedTask::spawn(RepartitionExec::pull_from_input(
+ input.clone(),
+ i,
+ txs.clone(),
+ partitioning.clone(),
+ r_metrics,
+ context.clone(),
+ ));
+
+ // In a separate task, wait for each input to be done
+ // (and pass along any errors, including panic!s)
+ let wait_for_task =
SpawnedTask::spawn(RepartitionExec::wait_for_task(
+ input_task,
+ txs.into_iter()
+ .map(|(partition, (tx, _reservation))| (partition, tx))
+ .collect(),
+ ));
+ spawned_tasks.push(wait_for_task);
+ }
+
+ Self {
+ channels,
+ abort_helper: Arc::new(spawned_tasks),
+ }
+ }
+}
+
+type LazyState = Arc<OnceCell<Mutex<RepartitionExecState>>>;
Review Comment:
I think including the context from your PR description about *why* this is
using a (`tokio`) `OnceCell` would help avoid future regressions (if someone in
the future concluded incorrectly, for example, that the overhead of
`async::OnceCell` was not needed)
```suggestion
//// Lazily initialized state
////
/// Note that the state is initialized ONCE for all partitions by a single
task(thread).
/// This may take a short while. It is also like that multiple threads
/// call execute at the same time, because we have just started "target
partitions" tasks
/// which is commonly set to the number of CPU cores and all call execute at
the same time.
///
/// Thus, use a **tokio** `OnceCell` for this initialization so as not to
waste CPU cycles
/// in a futex lock but instead allow other threads to do something useful.
///
/// Uses a parking_lot `Mutex` to control other accesses as they are very
short duration
/// (e.g. removing channels on completion) where the overhead of `await` is
not warranted.
type LazyState = Arc<tokio::sync::OnceCell<Mutex<RepartitionExecState>>>;
```
##########
datafusion/physical-plan/src/repartition/mod.rs:
##########
@@ -1240,7 +1294,10 @@ mod tests {
std::mem::drop(output_stream0);
// Now, start sending input
- input.wait().await;
+ let mut background_task = JoinSet::new();
Review Comment:
What is the purpose of this change? I tried this change without the other
changes in this PR and the test still passes (I was expecting it would hang or
something)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]