tustvold commented on code in PR #2317:
URL: https://github.com/apache/arrow-datafusion/pull/2317#discussion_r856067898


##########
datafusion/core/src/physical_plan/hash_join.rs:
##########
@@ -294,150 +296,85 @@ impl ExecutionPlan for HashJoinExec {
         context: Arc<TaskContext>,
     ) -> Result<SendableRecordBatchStream> {
         let on_left = self.on.iter().map(|on| 
on.0.clone()).collect::<Vec<_>>();
-        // we only want to compute the build side once for 
PartitionMode::CollectLeft
-        let left_data = {
-            match self.mode {
-                PartitionMode::CollectLeft => {
-                    let mut build_side = self.build_side.lock().await;
-
-                    match build_side.as_ref() {
-                        Some(stream) => stream.clone(),
-                        None => {
-                            let start = Instant::now();
-
-                            // merge all left parts into a single stream
-                            let merge = 
CoalescePartitionsExec::new(self.left.clone());
-                            let stream = merge.execute(0, 
context.clone()).await?;
-
-                            // This operation performs 2 steps at once:
-                            // 1. creates a [JoinHashMap] of all batches from 
the stream
-                            // 2. stores the batches in a vector.
-                            let initial = (0, Vec::new());
-                            let (num_rows, batches) = stream
-                                .try_fold(initial, |mut acc, batch| async {
-                                    acc.0 += batch.num_rows();
-                                    acc.1.push(batch);
-                                    Ok(acc)
-                                })
-                                .await?;
-                            let mut hashmap =
-                                JoinHashMap(RawTable::with_capacity(num_rows));
-                            let mut hashes_buffer = Vec::new();
-                            let mut offset = 0;
-                            for batch in batches.iter() {
-                                hashes_buffer.clear();
-                                hashes_buffer.resize(batch.num_rows(), 0);
-                                update_hash(
-                                    &on_left,
-                                    batch,
-                                    &mut hashmap,
-                                    offset,
-                                    &self.random_state,
-                                    &mut hashes_buffer,
-                                )?;
-                                offset += batch.num_rows();
-                            }
-                            // Merge all batches into a single batch, so we
-                            // can directly index into the arrays
-                            let single_batch =
-                                concat_batches(&self.left.schema(), &batches, 
num_rows)?;
-
-                            let left_side = Arc::new((hashmap, single_batch));
-
-                            *build_side = Some(left_side.clone());
-
-                            debug!(
-                                "Built build-side of hash join containing {} 
rows in {} ms",
-                                num_rows,
-                                start.elapsed().as_millis()
-                            );
-
-                            left_side
-                        }
-                    }
-                }
-                PartitionMode::Partitioned => {
-                    let start = Instant::now();
-
-                    // Load 1 partition of left side in memory
-                    let stream = self.left.execute(partition, 
context.clone()).await?;
-
-                    // This operation performs 2 steps at once:
-                    // 1. creates a [JoinHashMap] of all batches from the 
stream
-                    // 2. stores the batches in a vector.
-                    let initial = (0, Vec::new());
-                    let (num_rows, batches) = stream
-                        .try_fold(initial, |mut acc, batch| async {
-                            acc.0 += batch.num_rows();
-                            acc.1.push(batch);
-                            Ok(acc)
-                        })
-                        .await?;
-                    let mut hashmap = 
JoinHashMap(RawTable::with_capacity(num_rows));
-                    let mut hashes_buffer = Vec::new();
-                    let mut offset = 0;
-                    for batch in batches.iter() {
-                        hashes_buffer.clear();
-                        hashes_buffer.resize(batch.num_rows(), 0);
-                        update_hash(
-                            &on_left,
-                            batch,
-                            &mut hashmap,
-                            offset,
-                            &self.random_state,
-                            &mut hashes_buffer,
-                        )?;
-                        offset += batch.num_rows();
-                    }
-                    // Merge all batches into a single batch, so we
-                    // can directly index into the arrays
-                    let single_batch =
-                        concat_batches(&self.left.schema(), &batches, 
num_rows)?;
-
-                    let left_side = Arc::new((hashmap, single_batch));
-
-                    debug!(
-                        "Built build-side {} of hash join containing {} rows 
in {} ms",
-                        partition,
-                        num_rows,
-                        start.elapsed().as_millis()
-                    );
+        let on_right = self.on.iter().map(|on| 
on.1.clone()).collect::<Vec<_>>();
 
-                    left_side
+        let left_fut = match self.mode {
+            PartitionMode::CollectLeft => self.left_fut.once(|| {
+                collect_left_input(
+                    self.random_state.clone(),
+                    self.left.clone(),
+                    on_left.clone(),
+                    context.clone(),
+                )
+            }),
+            PartitionMode::Partitioned => {
+                let start = Instant::now();
+
+                // Load 1 partition of left side in memory
+                let stream = self.left.execute(partition, 
context.clone()).await?;
+
+                // This operation performs 2 steps at once:
+                // 1. creates a [JoinHashMap] of all batches from the stream
+                // 2. stores the batches in a vector.
+                let initial = (0, Vec::new());
+                let (num_rows, batches) = stream
+                    .try_fold(initial, |mut acc, batch| async {
+                        acc.0 += batch.num_rows();
+                        acc.1.push(batch);
+                        Ok(acc)
+                    })
+                    .await?;
+
+                let mut hashmap = 
JoinHashMap(RawTable::with_capacity(num_rows));
+                let mut hashes_buffer = Vec::new();
+                let mut offset = 0;
+                for batch in batches.iter() {
+                    hashes_buffer.clear();
+                    hashes_buffer.resize(batch.num_rows(), 0);
+                    update_hash(
+                        &on_left,
+                        batch,
+                        &mut hashmap,
+                        offset,
+                        &self.random_state,
+                        &mut hashes_buffer,
+                    )?;
+                    offset += batch.num_rows();
                 }
+                // Merge all batches into a single batch, so we
+                // can directly index into the arrays
+                let single_batch =
+                    concat_batches(&self.left.schema(), &batches, num_rows)?;
+
+                debug!(
+                    "Built build-side {} of hash join containing {} rows in {} 
ms",
+                    partition,
+                    num_rows,
+                    start.elapsed().as_millis()
+                );
+
+                OnceFut::ready(Ok((hashmap, single_batch)))
             }
         };
 
         // we have the batches and the hash map with their keys. We can how 
create a stream
         // over the right that uses this information to issue new batches.
+        let right_stream = self.right.execute(partition, context).await?;
 
-        let right_stream = self.right.execute(partition, 
context.clone()).await?;
-        let on_right = self.on.iter().map(|on| 
on.1.clone()).collect::<Vec<_>>();
-
-        let num_rows = left_data.1.num_rows();

Review Comment:
   This logic is moved into the stream implementation



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to