comphead commented on code in PR #21184:
URL: https://github.com/apache/datafusion/pull/21184#discussion_r3002327433
##########
datafusion/physical-plan/src/joins/sort_merge_join/filter.rs:
##########
@@ -282,314 +255,131 @@ pub fn get_corrected_filter_mask(
let mut seen_true = false;
match join_type {
- JoinType::Left | JoinType::Right => {
- // For outer joins: Keep first matching row per input row,
- // convert rest to nulls, add null-joined rows for unmatched
- for i in 0..row_indices_length {
- let last_index =
- last_index_for_row(i, row_indices, batch_ids,
row_indices_length);
- if filter_mask.value(i) {
- seen_true = true;
- corrected_mask.append_value(true);
- } else if seen_true || !filter_mask.value(i) && !last_index {
- corrected_mask.append_null(); // to be ignored and not set
to output
- } else {
- corrected_mask.append_value(false); // to be converted to
null joined row
- }
-
- if last_index {
- seen_true = false;
- }
- }
-
- // Generate null joined rows for records which have no matching
join key
- corrected_mask.append_n(expected_size - corrected_mask.len(),
false);
- Some(corrected_mask.finish())
- }
- JoinType::LeftMark | JoinType::RightMark => {
- // For mark joins: Like outer but only keep first match, mark with
boolean
+ JoinType::Left | JoinType::Right | JoinType::Full => {
+ // For each input row group: keep first filter-passing row,
+ // discard (null) remaining matches, null-join if none passed.
+ // Null metadata entries are already-null-joined rows that
+ // flow through unchanged to preserve output ordering.
for i in 0..row_indices_length {
let last_index =
last_index_for_row(i, row_indices, batch_ids,
row_indices_length);
- if filter_mask.value(i) && !seen_true {
- seen_true = true;
- corrected_mask.append_value(true);
- } else if seen_true || !filter_mask.value(i) && !last_index {
- corrected_mask.append_null(); // to be ignored and not set
to output
- } else {
- corrected_mask.append_value(false); // to be converted to
null joined row
- }
-
- if last_index {
- seen_true = false;
- }
- }
-
- // Generate null joined rows for records which have no matching
join key
- corrected_mask.append_n(expected_size - corrected_mask.len(),
false);
- Some(corrected_mask.finish())
- }
- JoinType::LeftSemi | JoinType::RightSemi => {
- // For semi joins: Keep only first matching row per input row,
discard rest
- for i in 0..row_indices_length {
- let last_index =
- last_index_for_row(i, row_indices, batch_ids,
row_indices_length);
- if filter_mask.value(i) && !seen_true {
- seen_true = true;
- corrected_mask.append_value(true);
- } else {
- corrected_mask.append_null(); // to be ignored and not set
to output
- }
-
- if last_index {
- seen_true = false;
- }
- }
-
- Some(corrected_mask.finish())
- }
- JoinType::LeftAnti | JoinType::RightAnti => {
- // For anti joins: Keep row only if NO matches passed the filter
- for i in 0..row_indices_length {
- let last_index =
- last_index_for_row(i, row_indices, batch_ids,
row_indices_length);
-
- if filter_mask.value(i) {
- seen_true = true;
- }
-
- if last_index {
- if !seen_true {
- corrected_mask.append_value(true);
- } else {
- corrected_mask.append_null();
- }
-
- seen_true = false;
- } else {
- corrected_mask.append_null();
- }
- }
- // Generate null joined rows for records which have no matching
join key,
- // for LeftAnti non-matched considered as true
- corrected_mask.append_n(expected_size - corrected_mask.len(),
true);
- Some(corrected_mask.finish())
- }
- JoinType::Full => {
- // For full joins: Similar to outer but handle both sides
- for i in 0..row_indices_length {
- let last_index =
- last_index_for_row(i, row_indices, batch_ids,
row_indices_length);
-
if filter_mask.is_null(i) {
- // null joined
corrected_mask.append_value(true);
} else if filter_mask.value(i) {
seen_true = true;
corrected_mask.append_value(true);
} else if seen_true || !filter_mask.value(i) && !last_index {
- corrected_mask.append_null(); // to be ignored and not set
to output
+ corrected_mask.append_null();
} else {
- corrected_mask.append_value(false); // to be converted to
null joined row
+ corrected_mask.append_value(false);
}
if last_index {
seen_true = false;
}
}
- // Generate null joined rows for records which have no matching
join key
+
corrected_mask.append_n(expected_size - corrected_mask.len(),
false);
Some(corrected_mask.finish())
}
- JoinType::Inner => {
- // Inner joins don't need deferred filtering
- None
+ JoinType::LeftMark
+ | JoinType::RightMark
+ | JoinType::LeftSemi
+ | JoinType::RightSemi
+ | JoinType::LeftAnti
+ | JoinType::RightAnti => {
+ unreachable!(
+ "Semi/anti/mark joins are handled by
SemiAntiMarkSortMergeJoinStream"
+ )
}
+ JoinType::Inner => None,
}
}
/// Applies corrected filter mask to record batch based on join type
///
-/// Different join types require different handling of filtered results:
-/// - Outer joins: Add null-joined rows for false mask values
-/// - Semi/Anti joins: May need projection to remove right columns
-/// - Full joins: Add null-joined rows for both sides
+/// The corrected mask has three possible values per row:
+/// - `true`: Keep the row as-is (matched and passed filter)
+/// - `false`: Convert to null-joined row (all filter matches failed for this
input row)
+/// - `null`: Discard the row entirely (duplicate match for an already-output
input row)
+///
+/// This function preserves input row ordering by processing each row in place
+/// rather than separating matched/unmatched rows.
pub fn filter_record_batch_by_join_type(
record_batch: &RecordBatch,
corrected_mask: &BooleanArray,
join_type: JoinType,
schema: &SchemaRef,
- streamed_schema: &SchemaRef,
buffered_schema: &SchemaRef,
) -> Result<RecordBatch> {
- let filtered_record_batch = filter_record_batch(record_batch,
corrected_mask)?;
-
match join_type {
- JoinType::Left | JoinType::LeftMark => {
- // For left joins, add null-joined rows where mask is false
- let null_mask = compute::not(corrected_mask)?;
- let null_joined_batch = filter_record_batch(record_batch,
&null_mask)?;
+ JoinType::Left | JoinType::Right | JoinType::Full => {
+ // Discard null-masked rows (keep true + false only)
+ let keep_mask = compute::is_not_null(corrected_mask)?;
+ let kept_batch = filter_record_batch(record_batch, &keep_mask)?;
+ let kept_corrected = compute::filter(corrected_mask, &keep_mask)?;
+ let kept_corrected = kept_corrected
+ .as_any()
+ .downcast_ref::<BooleanArray>()
+ .unwrap();
+
+ if kept_batch.num_rows() == 0 {
Review Comment:
this can be returned earlier
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]