This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 01698cb605 chore: refactor `BuildProbeJoinMetrics` to use 
`BaselineMetrics` (#16500)
01698cb605 is described below

commit 01698cb6050cf680a19413e4323798d2aa9e3a94
Author: Samyak Sarnayak <[email protected]>
AuthorDate: Mon Jul 7 16:57:37 2025 +0530

    chore: refactor `BuildProbeJoinMetrics` to use `BaselineMetrics` (#16500)
    
    * chore: refactor `BuildProbeJoinMetrics` to use `BaselineMetrics`
    
    Closes #16495
    
    Here's an example of an `explain analyze` of a hash join showing these 
metrics:
    ```
    [(WatchID@0, WatchID@0)], metrics=[output_rows=100, 
elapsed_compute=2.313624ms, build_input_batches=1, build_input_rows=100, 
input_batches=1, input_rows=100, output_batches=1, build_mem_used=3688, 
build_time=865.832µs, join_time=1.369875ms]
    ```
    
    Notice `output_rows=100, elapsed_compute=2.313624ms` in the above.
    
    * test: add checks for join metrics in tests
    
    * fix: add record_poll to ExhaustedProbeSide for nested_loop_join
    
    This was needed because ExhaustedProbeSide state can also return output
    rows - in certain types of joins. Without this, the output_rows metric
    for nested loop join was wrong!
---
 datafusion/physical-plan/src/joins/cross_join.rs   |  15 +--
 datafusion/physical-plan/src/joins/hash_join.rs    | 103 +++++++++++++++------
 .../physical-plan/src/joins/nested_loop_join.rs    |  58 ++++++++----
 datafusion/physical-plan/src/joins/utils.rs        |  30 ++++--
 datafusion/physical-plan/src/test.rs               |  30 ++++++
 5 files changed, 179 insertions(+), 57 deletions(-)

diff --git a/datafusion/physical-plan/src/joins/cross_join.rs 
b/datafusion/physical-plan/src/joins/cross_join.rs
index e4d554ceb6..a41e668ab4 100644
--- a/datafusion/physical-plan/src/joins/cross_join.rs
+++ b/datafusion/physical-plan/src/joins/cross_join.rs
@@ -559,7 +559,8 @@ impl<T: BatchTransformer> CrossJoinStream<T> {
                     handle_state!(ready!(self.fetch_probe_batch(cx)))
                 }
                 CrossJoinStreamState::BuildBatches(_) => {
-                    handle_state!(self.build_batches())
+                    let poll = handle_state!(self.build_batches());
+                    self.join_metrics.baseline.record_poll(poll)
                 }
             };
         }
@@ -632,7 +633,6 @@ impl<T: BatchTransformer> CrossJoinStream<T> {
                     }
 
                     self.join_metrics.output_batches.add(1);
-                    self.join_metrics.output_rows.add(batch.num_rows());
                     return Ok(StatefulStreamResult::Ready(Some(batch)));
                 }
             }
@@ -647,7 +647,7 @@ impl<T: BatchTransformer> CrossJoinStream<T> {
 mod tests {
     use super::*;
     use crate::common;
-    use crate::test::build_table_scan_i32;
+    use crate::test::{assert_join_metrics, build_table_scan_i32};
 
     use datafusion_common::{assert_contains, 
test_util::batches_to_sort_string};
     use datafusion_execution::runtime_env::RuntimeEnvBuilder;
@@ -657,14 +657,15 @@ mod tests {
         left: Arc<dyn ExecutionPlan>,
         right: Arc<dyn ExecutionPlan>,
         context: Arc<TaskContext>,
-    ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
+    ) -> Result<(Vec<String>, Vec<RecordBatch>, MetricsSet)> {
         let join = CrossJoinExec::new(left, right);
         let columns_header = columns(&join.schema());
 
         let stream = join.execute(0, context)?;
         let batches = common::collect(stream).await?;
+        let metrics = join.metrics().unwrap();
 
-        Ok((columns_header, batches))
+        Ok((columns_header, batches, metrics))
     }
 
     #[tokio::test]
@@ -831,7 +832,7 @@ mod tests {
             ("c2", &vec![14, 15]),
         );
 
-        let (columns, batches) = join_collect(left, right, task_ctx).await?;
+        let (columns, batches, metrics) = join_collect(left, right, 
task_ctx).await?;
 
         assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
 
@@ -848,6 +849,8 @@ mod tests {
             +----+----+----+----+----+----+
             "#);
 
+        assert_join_metrics!(metrics, 6);
+
         Ok(())
     }
 
diff --git a/datafusion/physical-plan/src/joins/hash_join.rs 
b/datafusion/physical-plan/src/joins/hash_join.rs
index 770399290d..652f4a7915 100644
--- a/datafusion/physical-plan/src/joins/hash_join.rs
+++ b/datafusion/physical-plan/src/joins/hash_join.rs
@@ -1403,10 +1403,12 @@ impl HashJoinStream {
                     handle_state!(ready!(self.fetch_probe_batch(cx)))
                 }
                 HashJoinStreamState::ProcessProbeBatch(_) => {
-                    handle_state!(self.process_probe_batch())
+                    let poll = handle_state!(self.process_probe_batch());
+                    self.join_metrics.baseline.record_poll(poll)
                 }
                 HashJoinStreamState::ExhaustedProbeSide => {
-                    handle_state!(self.process_unmatched_build_batch())
+                    let poll = 
handle_state!(self.process_unmatched_build_batch());
+                    self.join_metrics.baseline.record_poll(poll)
                 }
                 HashJoinStreamState::Completed => Poll::Ready(None),
             };
@@ -1582,7 +1584,6 @@ impl HashJoinStream {
         };
 
         self.join_metrics.output_batches.add(1);
-        self.join_metrics.output_rows.add(result.num_rows());
         timer.done();
 
         if next_offset.is_none() {
@@ -1639,7 +1640,6 @@ impl HashJoinStream {
             self.join_metrics.input_rows.add(batch.num_rows());
 
             self.join_metrics.output_batches.add(1);
-            self.join_metrics.output_rows.add(batch.num_rows());
         }
         timer.done();
 
@@ -1670,7 +1670,7 @@ impl EmbeddedProjection for HashJoinExec {
 mod tests {
     use super::*;
     use crate::coalesce_partitions::CoalescePartitionsExec;
-    use crate::test::TestMemoryExec;
+    use crate::test::{assert_join_metrics, TestMemoryExec};
     use crate::{
         common, expressions::Column, repartition::RepartitionExec, 
test::build_table_i32,
         test::exec::MockExec,
@@ -1763,14 +1763,15 @@ mod tests {
         join_type: &JoinType,
         null_equality: NullEquality,
         context: Arc<TaskContext>,
-    ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
+    ) -> Result<(Vec<String>, Vec<RecordBatch>, MetricsSet)> {
         let join = join(left, right, on, join_type, null_equality)?;
         let columns_header = columns(&join.schema());
 
         let stream = join.execute(0, context)?;
         let batches = common::collect(stream).await?;
+        let metrics = join.metrics().unwrap();
 
-        Ok((columns_header, batches))
+        Ok((columns_header, batches, metrics))
     }
 
     async fn partitioned_join_collect(
@@ -1780,7 +1781,7 @@ mod tests {
         join_type: &JoinType,
         null_equality: NullEquality,
         context: Arc<TaskContext>,
-    ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
+    ) -> Result<(Vec<String>, Vec<RecordBatch>, MetricsSet)> {
         join_collect_with_partition_mode(
             left,
             right,
@@ -1801,7 +1802,7 @@ mod tests {
         partition_mode: PartitionMode,
         null_equality: NullEquality,
         context: Arc<TaskContext>,
-    ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
+    ) -> Result<(Vec<String>, Vec<RecordBatch>, MetricsSet)> {
         let partition_count = 4;
 
         let (left_expr, right_expr) = on
@@ -1865,8 +1866,9 @@ mod tests {
                     .collect::<Vec<_>>(),
             );
         }
+        let metrics = join.metrics().unwrap();
 
-        Ok((columns, batches))
+        Ok((columns, batches, metrics))
     }
 
     #[apply(batch_sizes)]
@@ -1889,7 +1891,7 @@ mod tests {
             Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
         )];
 
-        let (columns, batches) = join_collect(
+        let (columns, batches, metrics) = join_collect(
             Arc::clone(&left),
             Arc::clone(&right),
             on.clone(),
@@ -1914,6 +1916,8 @@ mod tests {
                 "#);
         }
 
+        assert_join_metrics!(metrics, 3);
+
         Ok(())
     }
 
@@ -1936,7 +1940,7 @@ mod tests {
             Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
         )];
 
-        let (columns, batches) = partitioned_join_collect(
+        let (columns, batches, metrics) = partitioned_join_collect(
             Arc::clone(&left),
             Arc::clone(&right),
             on.clone(),
@@ -1960,6 +1964,8 @@ mod tests {
                 "#);
         }
 
+        assert_join_metrics!(metrics, 3);
+
         Ok(())
     }
 
@@ -1981,7 +1987,7 @@ mod tests {
             Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
         )];
 
-        let (columns, batches) = join_collect(
+        let (columns, batches, metrics) = join_collect(
             left,
             right,
             on,
@@ -2006,6 +2012,8 @@ mod tests {
                 "#);
         }
 
+        assert_join_metrics!(metrics, 3);
+
         Ok(())
     }
 
@@ -2027,7 +2035,7 @@ mod tests {
             Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
         )];
 
-        let (columns, batches) = join_collect(
+        let (columns, batches, metrics) = join_collect(
             left,
             right,
             on,
@@ -2053,6 +2061,8 @@ mod tests {
                 "#);
         }
 
+        assert_join_metrics!(metrics, 4);
+
         Ok(())
     }
 
@@ -2081,7 +2091,7 @@ mod tests {
             ),
         ];
 
-        let (columns, batches) = join_collect(
+        let (columns, batches, metrics) = join_collect(
             left,
             right,
             on,
@@ -2122,6 +2132,8 @@ mod tests {
                 "#);
         }
 
+        assert_join_metrics!(metrics, 3);
+
         Ok(())
     }
 
@@ -2159,7 +2171,7 @@ mod tests {
             ),
         ];
 
-        let (columns, batches) = join_collect(
+        let (columns, batches, metrics) = join_collect(
             left,
             right,
             on,
@@ -2200,6 +2212,8 @@ mod tests {
                 "#);
         }
 
+        assert_join_metrics!(metrics, 3);
+
         Ok(())
     }
 
@@ -2232,7 +2246,7 @@ mod tests {
             Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
         )];
 
-        let (columns, batches) = join_collect(
+        let (columns, batches, metrics) = join_collect(
             left,
             right,
             on,
@@ -2258,6 +2272,8 @@ mod tests {
                 "#);
         }
 
+        assert_join_metrics!(metrics, 4);
+
         Ok(())
     }
 
@@ -2577,7 +2593,7 @@ mod tests {
             Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
         )];
 
-        let (columns, batches) = join_collect(
+        let (columns, batches, metrics) = join_collect(
             Arc::clone(&left),
             Arc::clone(&right),
             on.clone(),
@@ -2586,6 +2602,7 @@ mod tests {
             task_ctx,
         )
         .await?;
+
         assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
 
         allow_duplicates! {
@@ -2600,6 +2617,8 @@ mod tests {
                 "#);
         }
 
+        assert_join_metrics!(metrics, 3);
+
         Ok(())
     }
 
@@ -2622,7 +2641,7 @@ mod tests {
             Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
         )];
 
-        let (columns, batches) = partitioned_join_collect(
+        let (columns, batches, metrics) = partitioned_join_collect(
             Arc::clone(&left),
             Arc::clone(&right),
             on.clone(),
@@ -2631,6 +2650,7 @@ mod tests {
             task_ctx,
         )
         .await?;
+
         assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
 
         allow_duplicates! {
@@ -2645,6 +2665,8 @@ mod tests {
                 "#);
         }
 
+        assert_join_metrics!(metrics, 3);
+
         Ok(())
     }
 
@@ -3267,7 +3289,7 @@ mod tests {
             Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
         )];
 
-        let (columns, batches) = join_collect(
+        let (columns, batches, metrics) = join_collect(
             left,
             right,
             on,
@@ -3291,6 +3313,8 @@ mod tests {
                 "#);
         }
 
+        assert_join_metrics!(metrics, 3);
+
         Ok(())
     }
 
@@ -3313,7 +3337,7 @@ mod tests {
             Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
         )];
 
-        let (columns, batches) = partitioned_join_collect(
+        let (columns, batches, metrics) = partitioned_join_collect(
             left,
             right,
             on,
@@ -3337,6 +3361,8 @@ mod tests {
                 "#);
         }
 
+        assert_join_metrics!(metrics, 3);
+
         Ok(())
     }
 
@@ -3408,7 +3434,7 @@ mod tests {
             Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
         )];
 
-        let (columns, batches) = join_collect(
+        let (columns, batches, metrics) = join_collect(
             Arc::clone(&left),
             Arc::clone(&right),
             on.clone(),
@@ -3417,6 +3443,7 @@ mod tests {
             task_ctx,
         )
         .await?;
+
         assert_eq!(columns, vec!["a1", "b1", "c1", "mark"]);
 
         allow_duplicates! {
@@ -3431,6 +3458,8 @@ mod tests {
                 "#);
         }
 
+        assert_join_metrics!(metrics, 3);
+
         Ok(())
     }
 
@@ -3453,7 +3482,7 @@ mod tests {
             Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
         )];
 
-        let (columns, batches) = partitioned_join_collect(
+        let (columns, batches, metrics) = partitioned_join_collect(
             Arc::clone(&left),
             Arc::clone(&right),
             on.clone(),
@@ -3462,6 +3491,7 @@ mod tests {
             task_ctx,
         )
         .await?;
+
         assert_eq!(columns, vec!["a1", "b1", "c1", "mark"]);
 
         allow_duplicates! {
@@ -3476,6 +3506,8 @@ mod tests {
                 "#);
         }
 
+        assert_join_metrics!(metrics, 3);
+
         Ok(())
     }
 
@@ -3498,7 +3530,7 @@ mod tests {
             Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
         )];
 
-        let (columns, batches) = join_collect(
+        let (columns, batches, metrics) = join_collect(
             Arc::clone(&left),
             Arc::clone(&right),
             on.clone(),
@@ -3507,6 +3539,7 @@ mod tests {
             task_ctx,
         )
         .await?;
+
         assert_eq!(columns, vec!["a2", "b1", "c2", "mark"]);
 
         let expected = [
@@ -3520,6 +3553,8 @@ mod tests {
         ];
         assert_batches_sorted_eq!(expected, &batches);
 
+        assert_join_metrics!(metrics, 3);
+
         Ok(())
     }
 
@@ -3542,7 +3577,7 @@ mod tests {
             Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
         )];
 
-        let (columns, batches) = partitioned_join_collect(
+        let (columns, batches, metrics) = partitioned_join_collect(
             Arc::clone(&left),
             Arc::clone(&right),
             on.clone(),
@@ -3551,6 +3586,7 @@ mod tests {
             task_ctx,
         )
         .await?;
+
         assert_eq!(columns, vec!["a2", "b1", "c2", "mark"]);
 
         let expected = [
@@ -3565,6 +3601,8 @@ mod tests {
         ];
         assert_batches_sorted_eq!(expected, &batches);
 
+        assert_join_metrics!(metrics, 4);
+
         Ok(())
     }
 
@@ -4054,7 +4092,7 @@ mod tests {
         ];
 
         for (join_type, expected) in test_cases {
-            let (_, batches) = join_collect_with_partition_mode(
+            let (_, batches, metrics) = join_collect_with_partition_mode(
                 Arc::clone(&left),
                 Arc::clone(&right),
                 on.clone(),
@@ -4065,6 +4103,7 @@ mod tests {
             )
             .await?;
             assert_batches_sorted_eq!(expected, &batches);
+            assert_join_metrics!(metrics, expected.len() - 4);
         }
 
         Ok(())
@@ -4492,7 +4531,7 @@ mod tests {
             Arc::new(Column::new_with_schema("n2", &right.schema())?) as _,
         )];
 
-        let (columns, batches) = join_collect(
+        let (columns, batches, metrics) = join_collect(
             left,
             right,
             on,
@@ -4516,6 +4555,8 @@ mod tests {
                 "#);
         }
 
+        assert_join_metrics!(metrics, 3);
+
         Ok(())
     }
 
@@ -4531,7 +4572,7 @@ mod tests {
             Arc::new(Column::new_with_schema("n2", &right.schema())?) as _,
         )];
 
-        let (_, batches_null_eq) = join_collect(
+        let (_, batches_null_eq, metrics) = join_collect(
             Arc::clone(&left),
             Arc::clone(&right),
             on.clone(),
@@ -4551,7 +4592,9 @@ mod tests {
                 "#);
         }
 
-        let (_, batches_null_neq) = join_collect(
+        assert_join_metrics!(metrics, 1);
+
+        let (_, batches_null_neq, metrics) = join_collect(
             left,
             right,
             on,
@@ -4561,6 +4604,8 @@ mod tests {
         )
         .await?;
 
+        assert_join_metrics!(metrics, 0);
+
         let expected_null_neq =
             ["+----+----+", "| n1 | n2 |", "+----+----+", "+----+----+"];
         assert_batches_eq!(expected_null_neq, &batches_null_neq);
diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs 
b/datafusion/physical-plan/src/joins/nested_loop_join.rs
index fcc1107a0e..c84b3a9d40 100644
--- a/datafusion/physical-plan/src/joins/nested_loop_join.rs
+++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs
@@ -825,10 +825,12 @@ impl<T: BatchTransformer> NestedLoopJoinStream<T> {
                     handle_state!(ready!(self.fetch_probe_batch(cx)))
                 }
                 NestedLoopJoinStreamState::ProcessProbeBatch(_) => {
-                    handle_state!(self.process_probe_batch())
+                    let poll = handle_state!(self.process_probe_batch());
+                    self.join_metrics.baseline.record_poll(poll)
                 }
                 NestedLoopJoinStreamState::ExhaustedProbeSide => {
-                    handle_state!(self.process_unmatched_build_batch())
+                    let poll = 
handle_state!(self.process_unmatched_build_batch());
+                    self.join_metrics.baseline.record_poll(poll)
                 }
                 NestedLoopJoinStreamState::Completed => Poll::Ready(None),
             };
@@ -912,7 +914,6 @@ impl<T: BatchTransformer> NestedLoopJoinStream<T> {
                 }
 
                 self.join_metrics.output_batches.add(1);
-                self.join_metrics.output_rows.add(batch.num_rows());
                 Ok(StatefulStreamResult::Ready(Some(batch)))
             }
         }
@@ -963,6 +964,8 @@ impl<T: BatchTransformer> NestedLoopJoinStream<T> {
                 timer.done();
             }
 
+            self.join_metrics.output_batches.add(1);
+
             Ok(StatefulStreamResult::Ready(Some(result?)))
         } else {
             // end of the join loop
@@ -1062,7 +1065,7 @@ impl EmbeddedProjection for NestedLoopJoinExec {
 #[cfg(test)]
 pub(crate) mod tests {
     use super::*;
-    use crate::test::TestMemoryExec;
+    use crate::test::{assert_join_metrics, TestMemoryExec};
     use crate::{
         common, expressions::Column, repartition::RepartitionExec, 
test::build_table_i32,
     };
@@ -1195,7 +1198,7 @@ pub(crate) mod tests {
         join_type: &JoinType,
         join_filter: Option<JoinFilter>,
         context: Arc<TaskContext>,
-    ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
+    ) -> Result<(Vec<String>, Vec<RecordBatch>, MetricsSet)> {
         let partition_count = 4;
 
         // Redistributing right input
@@ -1219,7 +1222,10 @@ pub(crate) mod tests {
                     .collect::<Vec<_>>(),
             );
         }
-        Ok((columns, batches))
+
+        let metrics = nested_loop_join.metrics().unwrap();
+
+        Ok((columns, batches, metrics))
     }
 
     #[tokio::test]
@@ -1228,7 +1234,7 @@ pub(crate) mod tests {
         let left = build_left_table();
         let right = build_right_table();
         let filter = prepare_join_filter();
-        let (columns, batches) = multi_partitioned_join_collect(
+        let (columns, batches, metrics) = multi_partitioned_join_collect(
             left,
             right,
             &JoinType::Inner,
@@ -1245,6 +1251,8 @@ pub(crate) mod tests {
             +----+----+----+----+----+----+
             "#);
 
+        assert_join_metrics!(metrics, 1);
+
         Ok(())
     }
 
@@ -1255,7 +1263,7 @@ pub(crate) mod tests {
         let right = build_right_table();
 
         let filter = prepare_join_filter();
-        let (columns, batches) = multi_partitioned_join_collect(
+        let (columns, batches, metrics) = multi_partitioned_join_collect(
             left,
             right,
             &JoinType::Left,
@@ -1274,6 +1282,8 @@ pub(crate) mod tests {
             +----+----+-----+----+----+----+
             "#);
 
+        assert_join_metrics!(metrics, 3);
+
         Ok(())
     }
 
@@ -1284,7 +1294,7 @@ pub(crate) mod tests {
         let right = build_right_table();
 
         let filter = prepare_join_filter();
-        let (columns, batches) = multi_partitioned_join_collect(
+        let (columns, batches, metrics) = multi_partitioned_join_collect(
             left,
             right,
             &JoinType::Right,
@@ -1303,6 +1313,8 @@ pub(crate) mod tests {
             +----+----+----+----+----+-----+
             "#);
 
+        assert_join_metrics!(metrics, 3);
+
         Ok(())
     }
 
@@ -1313,7 +1325,7 @@ pub(crate) mod tests {
         let right = build_right_table();
 
         let filter = prepare_join_filter();
-        let (columns, batches) = multi_partitioned_join_collect(
+        let (columns, batches, metrics) = multi_partitioned_join_collect(
             left,
             right,
             &JoinType::Full,
@@ -1334,6 +1346,8 @@ pub(crate) mod tests {
             +----+----+-----+----+----+-----+
             "#);
 
+        assert_join_metrics!(metrics, 5);
+
         Ok(())
     }
 
@@ -1344,7 +1358,7 @@ pub(crate) mod tests {
         let right = build_right_table();
 
         let filter = prepare_join_filter();
-        let (columns, batches) = multi_partitioned_join_collect(
+        let (columns, batches, metrics) = multi_partitioned_join_collect(
             left,
             right,
             &JoinType::LeftSemi,
@@ -1361,6 +1375,8 @@ pub(crate) mod tests {
             +----+----+----+
             "#);
 
+        assert_join_metrics!(metrics, 1);
+
         Ok(())
     }
 
@@ -1371,7 +1387,7 @@ pub(crate) mod tests {
         let right = build_right_table();
 
         let filter = prepare_join_filter();
-        let (columns, batches) = multi_partitioned_join_collect(
+        let (columns, batches, metrics) = multi_partitioned_join_collect(
             left,
             right,
             &JoinType::LeftAnti,
@@ -1389,6 +1405,8 @@ pub(crate) mod tests {
             +----+----+-----+
             "#);
 
+        assert_join_metrics!(metrics, 2);
+
         Ok(())
     }
 
@@ -1399,7 +1417,7 @@ pub(crate) mod tests {
         let right = build_right_table();
 
         let filter = prepare_join_filter();
-        let (columns, batches) = multi_partitioned_join_collect(
+        let (columns, batches, metrics) = multi_partitioned_join_collect(
             left,
             right,
             &JoinType::RightSemi,
@@ -1416,6 +1434,8 @@ pub(crate) mod tests {
             +----+----+----+
             "#);
 
+        assert_join_metrics!(metrics, 1);
+
         Ok(())
     }
 
@@ -1426,7 +1446,7 @@ pub(crate) mod tests {
         let right = build_right_table();
 
         let filter = prepare_join_filter();
-        let (columns, batches) = multi_partitioned_join_collect(
+        let (columns, batches, metrics) = multi_partitioned_join_collect(
             left,
             right,
             &JoinType::RightAnti,
@@ -1444,6 +1464,8 @@ pub(crate) mod tests {
             +----+----+-----+
             "#);
 
+        assert_join_metrics!(metrics, 2);
+
         Ok(())
     }
 
@@ -1454,7 +1476,7 @@ pub(crate) mod tests {
         let right = build_right_table();
 
         let filter = prepare_join_filter();
-        let (columns, batches) = multi_partitioned_join_collect(
+        let (columns, batches, metrics) = multi_partitioned_join_collect(
             left,
             right,
             &JoinType::LeftMark,
@@ -1473,6 +1495,8 @@ pub(crate) mod tests {
             +----+----+-----+-------+
             "#);
 
+        assert_join_metrics!(metrics, 3);
+
         Ok(())
     }
 
@@ -1483,7 +1507,7 @@ pub(crate) mod tests {
         let right = build_right_table();
 
         let filter = prepare_join_filter();
-        let (columns, batches) = multi_partitioned_join_collect(
+        let (columns, batches, metrics) = multi_partitioned_join_collect(
             left,
             right,
             &JoinType::RightMark,
@@ -1503,6 +1527,8 @@ pub(crate) mod tests {
             +----+----+-----+-------+
             "#);
 
+        assert_join_metrics!(metrics, 3);
+
         Ok(())
     }
 
diff --git a/datafusion/physical-plan/src/joins/utils.rs 
b/datafusion/physical-plan/src/joins/utils.rs
index c5f7087ac1..4249e479c9 100644
--- a/datafusion/physical-plan/src/joins/utils.rs
+++ b/datafusion/physical-plan/src/joins/utils.rs
@@ -26,7 +26,7 @@ use std::sync::Arc;
 use std::task::{Context, Poll};
 
 use crate::joins::SharedBitmapBuilder;
-use crate::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder};
+use crate::metrics::{self, BaselineMetrics, ExecutionPlanMetricsSet, 
MetricBuilder};
 use crate::projection::ProjectionExec;
 use crate::{
     ColumnStatistics, ExecutionPlan, ExecutionPlanProperties, Partitioning, 
Statistics,
@@ -1196,6 +1196,7 @@ fn append_probe_indices_in_order(
 /// Metrics for build & probe joins
 #[derive(Clone, Debug)]
 pub(crate) struct BuildProbeJoinMetrics {
+    pub(crate) baseline: BaselineMetrics,
     /// Total time for collecting build-side of join
     pub(crate) build_time: metrics::Time,
     /// Number of batches consumed by build-side
@@ -1212,12 +1213,31 @@ pub(crate) struct BuildProbeJoinMetrics {
     pub(crate) input_rows: metrics::Count,
     /// Number of batches produced by this operator
     pub(crate) output_batches: metrics::Count,
-    /// Number of rows produced by this operator
-    pub(crate) output_rows: metrics::Count,
+}
+
+// This Drop implementation updates the elapsed compute part of the metrics.
+//
+// Why is this in a Drop?
+// - We keep track of build_time and join_time separately, but baseline 
metrics have
+// a total elapsed_compute time. Instead of remembering to update both the 
metrics
+// at the same time, we chose to update elapsed_compute once at the end - 
summing up
+// both the parts.
+//
+// How does this work?
+// - The elapsed_compute `Time` is represented by an `Arc<AtomicUsize>`. So 
even when
+// this `BuildProbeJoinMetrics` is dropped, the elapsed_compute is usable 
through the
+// Arc reference.
+impl Drop for BuildProbeJoinMetrics {
+    fn drop(&mut self) {
+        self.baseline.elapsed_compute().add(&self.build_time);
+        self.baseline.elapsed_compute().add(&self.join_time);
+    }
 }
 
 impl BuildProbeJoinMetrics {
     pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self {
+        let baseline = BaselineMetrics::new(metrics, partition);
+
         let join_time = MetricBuilder::new(metrics).subset_time("join_time", 
partition);
 
         let build_time = MetricBuilder::new(metrics).subset_time("build_time", 
partition);
@@ -1239,8 +1259,6 @@ impl BuildProbeJoinMetrics {
         let output_batches =
             MetricBuilder::new(metrics).counter("output_batches", partition);
 
-        let output_rows = MetricBuilder::new(metrics).output_rows(partition);
-
         Self {
             build_time,
             build_input_batches,
@@ -1250,7 +1268,7 @@ impl BuildProbeJoinMetrics {
             input_batches,
             input_rows,
             output_batches,
-            output_rows,
+            baseline,
         }
     }
 }
diff --git a/datafusion/physical-plan/src/test.rs 
b/datafusion/physical-plan/src/test.rs
index 5e6410a017..be921e0581 100644
--- a/datafusion/physical-plan/src/test.rs
+++ b/datafusion/physical-plan/src/test.rs
@@ -522,3 +522,33 @@ impl PartitionStream for TestPartitionStream {
         ))
     }
 }
+
+#[cfg(test)]
+macro_rules! assert_join_metrics {
+    ($metrics:expr, $expected_rows:expr) => {
+        assert_eq!($metrics.output_rows().unwrap(), $expected_rows);
+
+        let elapsed_compute = $metrics
+            .elapsed_compute()
+            .expect("did not find elapsed_compute metric");
+        let join_time = $metrics
+            .sum_by_name("join_time")
+            .expect("did not find join_time metric")
+            .as_usize();
+        let build_time = $metrics
+            .sum_by_name("build_time")
+            .expect("did not find build_time metric")
+            .as_usize();
+        // ensure join_time and build_time are considered in elapsed_compute
+        assert!(
+            join_time + build_time <= elapsed_compute,
+            "join_time ({}) + build_time ({}) = {} was <= elapsed_compute = 
{}",
+            join_time,
+            build_time,
+            join_time + build_time,
+            elapsed_compute
+        );
+    };
+}
+#[cfg(test)]
+pub(crate) use assert_join_metrics;


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to