alamb commented on code in PR #9080:
URL: https://github.com/apache/arrow-datafusion/pull/9080#discussion_r1478428075


##########
datafusion/physical-plan/src/joins/sort_merge_join.rs:
##########
@@ -1121,16 +1151,138 @@ impl SMJStream {
                         .collect::<Vec<_>>()
                 };
 
+            let streamed_columns_length = streamed_columns.len();
+            let buffered_columns_length = buffered_columns.len();
+
+            // Prepare the columns we apply join filter on later.
+            // Only for joined rows between streamed and buffered.
+            let filter_columns = if chunk.buffered_batch_idx.is_some() {
+                if matches!(self.join_type, JoinType::Right) {
+                    get_filter_column(&self.filter, &buffered_columns, 
&streamed_columns)
+                } else {
+                    get_filter_column(&self.filter, &streamed_columns, 
&buffered_columns)
+                }
+            } else {
+                vec![]
+            };
+
             let columns = if matches!(self.join_type, JoinType::Right) {
-                buffered_columns.extend(streamed_columns);
+                buffered_columns.extend(streamed_columns.clone());
                 buffered_columns
             } else {
                 streamed_columns.extend(buffered_columns);
                 streamed_columns
             };
 
-            self.output_record_batches
-                .push(RecordBatch::try_new(self.schema.clone(), columns)?);
+            let output_batch =
+                RecordBatch::try_new(self.schema.clone(), columns.clone())?;
+
+            // Apply join filter if any
+            if !filter_columns.is_empty() {

Review Comment:
   I don't undersand why there is the check for filer columns *and* if 
`self.filter` is Some. I expected the check to simply be if `self.filter` is 
some (and the `else` case is the same for both below)
   
   If the filter has no columns, it seems like the `else` clause does the same 
thing in both cases.
   
   Thus, I wonder if we could remove the check for `filter_columns` entirely 
:thinking: 



##########
datafusion/physical-plan/src/joins/sort_merge_join.rs:
##########
@@ -1142,12 +1294,49 @@ impl SMJStream {
         let record_batch = concat_batches(&self.schema, 
&self.output_record_batches)?;
         self.join_metrics.output_batches.add(1);
         self.join_metrics.output_rows.add(record_batch.num_rows());
-        self.output_size -= record_batch.num_rows();
+        // If join filter exists, `self.output_size` is not accurate as we 
don't know the exact

Review Comment:
   Is the idea here that `output_size` is tracking the number of rows remaining 
to output? If so, it seems like the `filter` could only decrease the number of 
output rows (never increase it)
   
   However, I can see how the SMJ code could overshoot for LEFT/RIGHT/FULL 
joins, so maybe this fix was needed because now there is more test coverage of 
SMJ :thinking:



##########
datafusion/physical-plan/src/joins/sort_merge_join.rs:
##########
@@ -1121,16 +1151,138 @@ impl SMJStream {
                         .collect::<Vec<_>>()
                 };
 
+            let streamed_columns_length = streamed_columns.len();
+            let buffered_columns_length = buffered_columns.len();
+
+            // Prepare the columns we apply join filter on later.
+            // Only for joined rows between streamed and buffered.
+            let filter_columns = if chunk.buffered_batch_idx.is_some() {
+                if matches!(self.join_type, JoinType::Right) {
+                    get_filter_column(&self.filter, &buffered_columns, 
&streamed_columns)
+                } else {
+                    get_filter_column(&self.filter, &streamed_columns, 
&buffered_columns)
+                }
+            } else {
+                vec![]
+            };
+
             let columns = if matches!(self.join_type, JoinType::Right) {
-                buffered_columns.extend(streamed_columns);
+                buffered_columns.extend(streamed_columns.clone());
                 buffered_columns
             } else {
                 streamed_columns.extend(buffered_columns);
                 streamed_columns
             };
 
-            self.output_record_batches
-                .push(RecordBatch::try_new(self.schema.clone(), columns)?);
+            let output_batch =
+                RecordBatch::try_new(self.schema.clone(), columns.clone())?;
+
+            // Apply join filter if any
+            if !filter_columns.is_empty() {
+                if let Some(f) = &self.filter {
+                    // Construct batch with only filter columns
+                    let filter_batch = RecordBatch::try_new(
+                        Arc::new(f.schema().clone()),
+                        filter_columns,
+                    )?;
+
+                    let filter_result = f
+                        .expression()
+                        .evaluate(&filter_batch)?
+                        .into_array(filter_batch.num_rows())?;
+
+                    // The selection mask of the filter
+                    let mask = 
datafusion_common::cast::as_boolean_array(&filter_result)?;
+
+                    // Push the filtered batch to the output
+                    let filtered_batch =
+                        compute::filter_record_batch(&output_batch, mask)?;
+                    self.output_record_batches.push(filtered_batch);
+
+                    // For outer joins, we need to push the null joined rows 
to the output.
+                    if matches!(
+                        self.join_type,
+                        JoinType::Left | JoinType::Right | JoinType::Full
+                    ) {
+                        // The reverse of the selection mask, which is for 
null joined rows
+                        let not_mask = compute::not(mask)?;
+                        let null_joined_batch =
+                            compute::filter_record_batch(&output_batch, 
&not_mask)?;
+
+                        let mut buffered_columns = self
+                            .buffered_schema
+                            .fields()
+                            .iter()
+                            .map(|f| {
+                                new_null_array(
+                                    f.data_type(),
+                                    null_joined_batch.num_rows(),
+                                )
+                            })
+                            .collect::<Vec<_>>();
+
+                        let columns = if matches!(self.join_type, 
JoinType::Right) {
+                            let streamed_columns = null_joined_batch
+                                .columns()
+                                .iter()
+                                .skip(buffered_columns_length)
+                                .cloned()
+                                .collect::<Vec<_>>();
+
+                            buffered_columns.extend(streamed_columns);
+                            buffered_columns
+                        } else {

Review Comment:
   I missed the fact that this handles left and full (not just left)
   
   ```suggestion
                           } 
                          // Left join or full outer join
                           else {



##########
datafusion/physical-plan/src/joins/sort_merge_join.rs:
##########
@@ -1121,16 +1151,138 @@ impl SMJStream {
                         .collect::<Vec<_>>()
                 };
 
+            let streamed_columns_length = streamed_columns.len();
+            let buffered_columns_length = buffered_columns.len();
+
+            // Prepare the columns we apply join filter on later.
+            // Only for joined rows between streamed and buffered.
+            let filter_columns = if chunk.buffered_batch_idx.is_some() {
+                if matches!(self.join_type, JoinType::Right) {
+                    get_filter_column(&self.filter, &buffered_columns, 
&streamed_columns)
+                } else {
+                    get_filter_column(&self.filter, &streamed_columns, 
&buffered_columns)
+                }
+            } else {
+                vec![]
+            };
+
             let columns = if matches!(self.join_type, JoinType::Right) {
-                buffered_columns.extend(streamed_columns);
+                buffered_columns.extend(streamed_columns.clone());
                 buffered_columns
             } else {
                 streamed_columns.extend(buffered_columns);
                 streamed_columns
             };
 
-            self.output_record_batches
-                .push(RecordBatch::try_new(self.schema.clone(), columns)?);
+            let output_batch =
+                RecordBatch::try_new(self.schema.clone(), columns.clone())?;
+
+            // Apply join filter if any
+            if !filter_columns.is_empty() {
+                if let Some(f) = &self.filter {
+                    // Construct batch with only filter columns
+                    let filter_batch = RecordBatch::try_new(
+                        Arc::new(f.schema().clone()),
+                        filter_columns,
+                    )?;
+
+                    let filter_result = f
+                        .expression()
+                        .evaluate(&filter_batch)?
+                        .into_array(filter_batch.num_rows())?;
+
+                    // The selection mask of the filter
+                    let mask = 
datafusion_common::cast::as_boolean_array(&filter_result)?;
+
+                    // Push the filtered batch to the output
+                    let filtered_batch =
+                        compute::filter_record_batch(&output_batch, mask)?;
+                    self.output_record_batches.push(filtered_batch);
+
+                    // For outer joins, we need to push the null joined rows 
to the output.
+                    if matches!(
+                        self.join_type,
+                        JoinType::Left | JoinType::Right | JoinType::Full
+                    ) {
+                        // The reverse of the selection mask, which is for 
null joined rows

Review Comment:
   Does 'null joined rows' mean 'rows that passed neither the equijoin 
predicates NOR the filter? If so I would find a term like 'non_matching_rows` 
easier to understand. But that is a personal preference



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to