This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new b26c1b819d fix: Fix the incorrect null joined rows for SMJ outer join
with join filter (#10892)
b26c1b819d is described below
commit b26c1b819dff7ed1b48ab20c66b4bd1226ff8d79
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Tue Jun 18 04:02:59 2024 -0700
fix: Fix the incorrect null joined rows for SMJ outer join with join filter
(#10892)
* fix: Fix the incorrect null joined rows for outer join with join filter
* Update datafusion/physical-plan/src/joins/sort_merge_join.rs
Co-authored-by: Oleks V <[email protected]>
* Update datafusion/physical-plan/src/joins/sort_merge_join.rs
Co-authored-by: Oleks V <[email protected]>
* Update datafusion/physical-plan/src/joins/sort_merge_join.rs
Co-authored-by: Oleks V <[email protected]>
* Update datafusion/physical-plan/src/joins/sort_merge_join.rs
Co-authored-by: Oleks V <[email protected]>
* For review
---------
Co-authored-by: Oleks V <[email protected]>
---
.../physical-plan/src/joins/sort_merge_join.rs | 278 +++++++++++++--------
.../sqllogictest/test_files/sort_merge_join.slt | 29 ++-
2 files changed, 194 insertions(+), 113 deletions(-)
diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs
b/datafusion/physical-plan/src/joins/sort_merge_join.rs
index 01abb30181..420fab51da 100644
--- a/datafusion/physical-plan/src/joins/sort_merge_join.rs
+++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs
@@ -46,6 +46,7 @@ use arrow::array::*;
use arrow::compute::{self, concat_batches, take, SortOptions};
use arrow::datatypes::{DataType, SchemaRef, TimeUnit};
use arrow::error::ArrowError;
+use arrow_array::types::UInt64Type;
use datafusion_common::{
internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType,
Result,
@@ -471,6 +472,7 @@ struct StreamedJoinedChunk {
/// Array builder for streamed indices
streamed_indices: UInt64Builder,
/// Array builder for buffered indices
+ /// This could contain nulls if the join is null-joined
buffered_indices: UInt64Builder,
}
@@ -559,6 +561,9 @@ struct BufferedBatch {
pub null_joined: Vec<usize>,
/// Size estimation used for reserving / releasing memory
pub size_estimation: usize,
+ /// The indices of buffered batch that failed the join filter.
+ /// When dequeuing the buffered batch, we need to produce null joined rows
for these indices.
+ pub join_filter_failed_idxs: HashSet<u64>,
}
impl BufferedBatch {
@@ -590,6 +595,7 @@ impl BufferedBatch {
join_arrays,
null_joined: vec![],
size_estimation,
+ join_filter_failed_idxs: HashSet::new(),
}
}
}
@@ -847,6 +853,7 @@ impl SMJStream {
// pop previous buffered batches
while !self.buffered_data.batches.is_empty() {
let head_batch = self.buffered_data.head_batch();
+ // If the head batch is fully processed, dequeue it
and produce output of it.
if head_batch.range.end == head_batch.batch.num_rows()
{
self.freeze_dequeuing_buffered()?;
if let Some(buffered_batch) =
@@ -855,6 +862,8 @@ impl SMJStream {
self.reservation.shrink(buffered_batch.size_estimation);
}
} else {
+ // If the head batch is not fully processed, break
the loop.
+ // Streamed batch will be joined with the head
batch in the next step.
break;
}
}
@@ -1050,7 +1059,7 @@ impl SMJStream {
Some(scanning_idx),
);
} else {
- // Join nulls and buffered row
+ // Join nulls and buffered row for FULL join
self.buffered_data
.scanning_batch_mut()
.null_joined
@@ -1083,7 +1092,7 @@ impl SMJStream {
fn freeze_all(&mut self) -> Result<()> {
self.freeze_streamed()?;
- self.freeze_buffered(self.buffered_data.batches.len())?;
+ self.freeze_buffered(self.buffered_data.batches.len(), false)?;
Ok(())
}
@@ -1093,7 +1102,8 @@ impl SMJStream {
// 2. freezes NULLs joined to dequeued buffered batch to "release" it
fn freeze_dequeuing_buffered(&mut self) -> Result<()> {
self.freeze_streamed()?;
- self.freeze_buffered(1)?;
+ // Only freeze and produce the first batch in buffered_data as the
batch is fully processed
+ self.freeze_buffered(1, true)?;
Ok(())
}
@@ -1101,7 +1111,14 @@ impl SMJStream {
// NULLs on streamed side.
//
// Applicable only in case of Full join.
- fn freeze_buffered(&mut self, batch_count: usize) -> Result<()> {
+ //
+ // If `output_not_matched_filter` is true, this will also produce record
batches
+ // for buffered rows which are joined with streamed side but don't match
join filter.
+ fn freeze_buffered(
+ &mut self,
+ batch_count: usize,
+ output_not_matched_filter: bool,
+ ) -> Result<()> {
if !matches!(self.join_type, JoinType::Full) {
return Ok(());
}
@@ -1109,33 +1126,31 @@ impl SMJStream {
let buffered_indices = UInt64Array::from_iter_values(
buffered_batch.null_joined.iter().map(|&index| index as u64),
);
- if buffered_indices.is_empty() {
- continue;
+ if let Some(record_batch) = produce_buffered_null_batch(
+ &self.schema,
+ &self.streamed_schema,
+ &buffered_indices,
+ buffered_batch,
+ )? {
+ self.output_record_batches.push(record_batch);
}
buffered_batch.null_joined.clear();
- // Take buffered (right) columns
- let buffered_columns = buffered_batch
- .batch
- .columns()
- .iter()
- .map(|column| take(column, &buffered_indices, None))
- .collect::<Result<Vec<_>, ArrowError>>()
- .map_err(Into::<DataFusionError>::into)?;
-
- // Create null streamed (left) columns
- let mut streamed_columns = self
- .streamed_schema
- .fields()
- .iter()
- .map(|f| new_null_array(f.data_type(), buffered_indices.len()))
- .collect::<Vec<_>>();
-
- streamed_columns.extend(buffered_columns);
- let columns = streamed_columns;
-
- self.output_record_batches
- .push(RecordBatch::try_new(self.schema.clone(), columns)?);
+ // For buffered rows which are joined with streamed side but
doesn't satisfy the join filter
+ if output_not_matched_filter {
+ let buffered_indices = UInt64Array::from_iter_values(
+ buffered_batch.join_filter_failed_idxs.iter().copied(),
+ );
+ if let Some(record_batch) = produce_buffered_null_batch(
+ &self.schema,
+ &self.streamed_schema,
+ &buffered_indices,
+ buffered_batch,
+ )? {
+ self.output_record_batches.push(record_batch);
+ }
+ buffered_batch.join_filter_failed_idxs.clear();
+ }
}
Ok(())
}
@@ -1144,6 +1159,7 @@ impl SMJStream {
// for current streamed batch and clears staged output indices.
fn freeze_streamed(&mut self) -> Result<()> {
for chunk in self.streamed_batch.output_indices.iter_mut() {
+ // The row indices of joined streamed batch
let streamed_indices = chunk.streamed_indices.finish();
if streamed_indices.is_empty() {
@@ -1158,6 +1174,7 @@ impl SMJStream {
.map(|column| take(column, &streamed_indices, None))
.collect::<Result<Vec<_>, ArrowError>>()?;
+ // The row indices of joined buffered batch
let buffered_indices: UInt64Array =
chunk.buffered_indices.finish();
let mut buffered_columns =
if matches!(self.join_type, JoinType::LeftSemi |
JoinType::LeftAnti) {
@@ -1169,6 +1186,8 @@ impl SMJStream {
&buffered_indices,
)?
} else {
+ // If buffered batch none, meaning it is null joined batch.
+ // We need to create null arrays for buffered columns to
join with streamed rows.
self.buffered_schema
.fields()
.iter()
@@ -1200,7 +1219,8 @@ impl SMJStream {
get_filter_column(&self.filter, &streamed_columns,
&buffered_columns)
}
} else {
- // This chunk is for null joined rows (outer join), we don't
need to apply join filter.
+ // This chunk is totally for null joined rows (outer join), we
don't need to apply join filter.
+ // Any join filter applied only on either streamed or buffered
side will be pushed already.
vec![]
};
@@ -1229,49 +1249,73 @@ impl SMJStream {
.evaluate(&filter_batch)?
.into_array(filter_batch.num_rows())?;
- // The selection mask of the filter
- let mut mask =
+ // The boolean selection mask of the join filter result
+ let pre_mask =
datafusion_common::cast::as_boolean_array(&filter_result)?;
+ // If there are nulls in join filter result, exclude them
from selecting
+ // the rows to output.
+ let mask = if pre_mask.null_count() > 0 {
+ compute::prep_null_mask_filter(
+
datafusion_common::cast::as_boolean_array(&filter_result)?,
+ )
+ } else {
+ pre_mask.clone()
+ };
+
+ // For certain join types, we need to adjust the initial
mask to handle the join filter.
let maybe_filtered_join_mask: Option<(BooleanArray,
Vec<u64>)> =
get_filtered_join_mask(
self.join_type,
- streamed_indices,
- mask,
+ &streamed_indices,
+ &mask,
&self.streamed_batch.join_filter_matched_idxs,
&self.buffered_data.scanning_offset,
);
- if let Some(ref filtered_join_mask) =
maybe_filtered_join_mask {
- mask = &filtered_join_mask.0;
- self.streamed_batch
- .join_filter_matched_idxs
- .extend(&filtered_join_mask.1);
- }
+ let mask =
+ if let Some(ref filtered_join_mask) =
maybe_filtered_join_mask {
+ self.streamed_batch
+ .join_filter_matched_idxs
+ .extend(&filtered_join_mask.1);
+ &filtered_join_mask.0
+ } else {
+ &mask
+ };
- // Push the filtered batch to the output
+ // Push the filtered batch which contains rows passing
join filter to the output
let filtered_batch =
compute::filter_record_batch(&output_batch, mask)?;
self.output_record_batches.push(filtered_batch);
- // For outer joins, we need to push the null joined rows
to the output.
+ // For outer joins, we need to push the null joined rows
to the output if
+ // all joined rows are failed on the join filter.
+ // I.e., if all rows joined from a streamed row are failed
with the join filter,
+ // we need to join it with nulls as buffered side.
if matches!(
self.join_type,
JoinType::Left | JoinType::Right | JoinType::Full
) {
- // The reverse of the selection mask. For the rows not
pass join filter above,
- // we need to join them (left or right) with null rows
for outer joins.
- let not_mask = if mask.null_count() > 0 {
- // If the mask contains nulls, we need to use
`prep_null_mask_filter` to
- // handle the nulls in the mask as false to
produce rows where the mask
- // was null itself.
-
compute::not(&compute::prep_null_mask_filter(mask))?
- } else {
- compute::not(mask)?
- };
+ // We need to get the mask for row indices that the
joined rows are failed
+ // on the join filter. I.e., for a row in streamed
side, if all joined rows
+ // between it and all buffered rows are failed on the
join filter, we need to
+ // output it with null columns from buffered side. For
the mask here, it
+ // behaves like LeftAnti join.
+ let null_mask: BooleanArray = get_filtered_join_mask(
+ // Set a mask slot as true only if all joined rows
of same streamed index
+ // are failed on the join filter.
+ // The masking behavior is like LeftAnti join.
+ JoinType::LeftAnti,
+ &streamed_indices,
+ mask,
+ &self.streamed_batch.join_filter_matched_idxs,
+ &self.buffered_data.scanning_offset,
+ )
+ .unwrap()
+ .0;
let null_joined_batch =
- compute::filter_record_batch(&output_batch,
¬_mask)?;
+ compute::filter_record_batch(&output_batch,
&null_mask)?;
let mut buffered_columns = self
.buffered_schema
@@ -1308,51 +1352,37 @@ impl SMJStream {
streamed_columns
};
+ // Push the streamed/buffered batch joined nulls to
the output
let null_joined_streamed_batch =
RecordBatch::try_new(self.schema.clone(),
columns.clone())?;
self.output_record_batches.push(null_joined_streamed_batch);
- // For full join, we also need to output the null
joined rows from the buffered side
+ // For full join, we also need to output the null
joined rows from the buffered side.
+ // Usually this is done by `freeze_buffered`. However,
if a buffered row is joined with
+ // streamed side, it won't be outputted by
`freeze_buffered`.
+ // We need to check if a buffered row is joined with
streamed side and output.
+ // If it is joined with streamed side, but doesn't
match the join filter,
+ // we need to output it with nulls as streamed side.
if matches!(self.join_type, JoinType::Full) {
- // Handle not mask for buffered side further.
- // For buffered side, we want to output the rows
that are not null joined with
- // the streamed side. i.e. the rows that are not
null in the `buffered_indices`.
- let not_mask = if let Some(nulls) =
buffered_indices.nulls() {
- let mask = not_mask.values() & nulls.inner();
- BooleanArray::new(mask, None)
- } else {
- not_mask
- };
-
- let null_joined_batch =
- compute::filter_record_batch(&output_batch,
¬_mask)?;
-
- let mut streamed_columns = self
- .streamed_schema
- .fields()
- .iter()
- .map(|f| {
- new_null_array(
- f.data_type(),
- null_joined_batch.num_rows(),
- )
- })
- .collect::<Vec<_>>();
-
- let buffered_columns = null_joined_batch
- .columns()
- .iter()
- .skip(streamed_columns_length)
- .cloned()
- .collect::<Vec<_>>();
-
- streamed_columns.extend(buffered_columns);
-
- let null_joined_buffered_batch =
RecordBatch::try_new(
- self.schema.clone(),
- streamed_columns,
- )?;
-
self.output_record_batches.push(null_joined_buffered_batch);
+ for i in 0..pre_mask.len() {
+ let buffered_batch = &mut
self.buffered_data.batches
+ [chunk.buffered_batch_idx.unwrap()];
+ let buffered_index = buffered_indices.value(i);
+
+ if !pre_mask.value(i) {
+ // For a buffered row that is joined with
streamed side but doesn't satisfy the join filter,
+ buffered_batch
+ .join_filter_failed_idxs
+ .insert(buffered_index);
+ } else if buffered_batch
+ .join_filter_failed_idxs
+ .contains(&buffered_index)
+ {
+ buffered_batch
+ .join_filter_failed_idxs
+ .remove(&buffered_index);
+ }
+ }
}
}
} else {
@@ -1417,6 +1447,40 @@ fn get_filter_column(
filter_columns
}
+fn produce_buffered_null_batch(
+ schema: &SchemaRef,
+ streamed_schema: &SchemaRef,
+ buffered_indices: &PrimitiveArray<UInt64Type>,
+ buffered_batch: &BufferedBatch,
+) -> Result<Option<RecordBatch>> {
+ if buffered_indices.is_empty() {
+ return Ok(None);
+ }
+
+ // Take buffered (right) columns
+ let buffered_columns = buffered_batch
+ .batch
+ .columns()
+ .iter()
+ .map(|column| take(column, &buffered_indices, None))
+ .collect::<Result<Vec<_>, ArrowError>>()
+ .map_err(Into::<DataFusionError>::into)?;
+
+ // Create null streamed (left) columns
+ let mut streamed_columns = streamed_schema
+ .fields()
+ .iter()
+ .map(|f| new_null_array(f.data_type(), buffered_indices.len()))
+ .collect::<Vec<_>>();
+
+ streamed_columns.extend(buffered_columns);
+
+ Ok(Some(RecordBatch::try_new(
+ schema.clone(),
+ streamed_columns,
+ )?))
+}
+
/// Get `buffered_indices` rows for `buffered_data[buffered_batch_idx]`
#[inline(always)]
fn get_buffered_columns(
@@ -1440,9 +1504,13 @@ fn get_buffered_columns(
/// `streamed_indices` have the same length as `mask`
/// `matched_indices` array of streaming indices that already has a join
filter match
/// `scanning_buffered_offset` current buffered offset across batches
+///
+/// This return a tuple of:
+/// - corrected mask with respect to the join type
+/// - indices of rows in streamed batch that have a join filter match
fn get_filtered_join_mask(
join_type: JoinType,
- streamed_indices: UInt64Array,
+ streamed_indices: &UInt64Array,
mask: &BooleanArray,
matched_indices: &HashSet<u64>,
scanning_buffered_offset: &usize,
@@ -2803,7 +2871,7 @@ mod tests {
assert_eq!(
get_filtered_join_mask(
LeftSemi,
- UInt64Array::from(vec![0, 0, 1, 1]),
+ &UInt64Array::from(vec![0, 0, 1, 1]),
&BooleanArray::from(vec![true, true, false, false]),
&HashSet::new(),
&0,
@@ -2814,7 +2882,7 @@ mod tests {
assert_eq!(
get_filtered_join_mask(
LeftSemi,
- UInt64Array::from(vec![0, 1]),
+ &UInt64Array::from(vec![0, 1]),
&BooleanArray::from(vec![true, true]),
&HashSet::new(),
&0,
@@ -2825,7 +2893,7 @@ mod tests {
assert_eq!(
get_filtered_join_mask(
LeftSemi,
- UInt64Array::from(vec![0, 1]),
+ &UInt64Array::from(vec![0, 1]),
&BooleanArray::from(vec![false, true]),
&HashSet::new(),
&0,
@@ -2836,7 +2904,7 @@ mod tests {
assert_eq!(
get_filtered_join_mask(
LeftSemi,
- UInt64Array::from(vec![0, 1]),
+ &UInt64Array::from(vec![0, 1]),
&BooleanArray::from(vec![true, false]),
&HashSet::new(),
&0,
@@ -2847,7 +2915,7 @@ mod tests {
assert_eq!(
get_filtered_join_mask(
LeftSemi,
- UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
+ &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
&BooleanArray::from(vec![false, true, true, true, true, true]),
&HashSet::new(),
&0,
@@ -2861,7 +2929,7 @@ mod tests {
assert_eq!(
get_filtered_join_mask(
LeftSemi,
- UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
+ &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
&BooleanArray::from(vec![false, false, false, false, false,
true]),
&HashSet::new(),
&0,
@@ -2880,7 +2948,7 @@ mod tests {
assert_eq!(
get_filtered_join_mask(
LeftAnti,
- UInt64Array::from(vec![0, 0, 1, 1]),
+ &UInt64Array::from(vec![0, 0, 1, 1]),
&BooleanArray::from(vec![true, true, false, false]),
&HashSet::new(),
&0,
@@ -2891,7 +2959,7 @@ mod tests {
assert_eq!(
get_filtered_join_mask(
LeftAnti,
- UInt64Array::from(vec![0, 1]),
+ &UInt64Array::from(vec![0, 1]),
&BooleanArray::from(vec![true, true]),
&HashSet::new(),
&0,
@@ -2902,7 +2970,7 @@ mod tests {
assert_eq!(
get_filtered_join_mask(
LeftAnti,
- UInt64Array::from(vec![0, 1]),
+ &UInt64Array::from(vec![0, 1]),
&BooleanArray::from(vec![false, true]),
&HashSet::new(),
&0,
@@ -2913,7 +2981,7 @@ mod tests {
assert_eq!(
get_filtered_join_mask(
LeftAnti,
- UInt64Array::from(vec![0, 1]),
+ &UInt64Array::from(vec![0, 1]),
&BooleanArray::from(vec![true, false]),
&HashSet::new(),
&0,
@@ -2924,7 +2992,7 @@ mod tests {
assert_eq!(
get_filtered_join_mask(
LeftAnti,
- UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
+ &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
&BooleanArray::from(vec![false, true, true, true, true, true]),
&HashSet::new(),
&0,
@@ -2938,7 +3006,7 @@ mod tests {
assert_eq!(
get_filtered_join_mask(
LeftAnti,
- UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
+ &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
&BooleanArray::from(vec![false, false, false, false, false,
true]),
&HashSet::new(),
&0,
diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt
b/datafusion/sqllogictest/test_files/sort_merge_join.slt
index b4deb43a72..5a6334602c 100644
--- a/datafusion/sqllogictest/test_files/sort_merge_join.slt
+++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt
@@ -84,7 +84,6 @@ SELECT * FROM t1 LEFT JOIN t2 ON t1.a = t2.a AND t2.b * 50 <=
t1.b
Alice 100 Alice 1
Alice 100 Alice 2
Alice 50 Alice 1
-Alice 50 NULL NULL
Bob 1 NULL NULL
query TITI rowsort
@@ -112,7 +111,6 @@ SELECT * FROM t1 RIGHT JOIN t2 ON t1.a = t2.a AND t2.b * 50
<= t1.b
Alice 100 Alice 1
Alice 100 Alice 2
Alice 50 Alice 1
-NULL NULL Alice 2
query TITI rowsort
SELECT * FROM t1 RIGHT JOIN t2 ON t1.a = t2.a AND t1.b > t2.b
@@ -137,12 +135,9 @@ query TITI rowsort
SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t2.b * 50 > t1.b
----
Alice 100 NULL NULL
-Alice 100 NULL NULL
Alice 50 Alice 2
-Alice 50 NULL NULL
Bob 1 NULL NULL
NULL NULL Alice 1
-NULL NULL Alice 1
NULL NULL Alice 2
query TITI rowsort
@@ -151,10 +146,7 @@ SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t1.b >
t2.b + 50
Alice 100 Alice 1
Alice 100 Alice 2
Alice 50 NULL NULL
-Alice 50 NULL NULL
Bob 1 NULL NULL
-NULL NULL Alice 1
-NULL NULL Alice 2
statement ok
DROP TABLE t1;
@@ -613,6 +605,27 @@ select t1.* from t1 where not exists (select 1 from t2
where t2.a = t1.a and t1.
) order by 1, 2
----
+query IIII
+select * from (
+with t as (
+ select id, id % 5 id1 from (select unnest(range(0,10)) id)
+), t1 as (
+ select id % 10 id, id + 2 id1 from (select unnest(range(0,10)) id)
+)
+select * from t right join t1 on t.id1 = t1.id and t.id > t1.id1
+) order by 1, 2, 3, 4
+----
+5 0 0 2
+6 1 1 3
+7 2 2 4
+8 3 3 5
+9 4 4 6
+NULL NULL 5 7
+NULL NULL 6 8
+NULL NULL 7 9
+NULL NULL 8 10
+NULL NULL 9 11
+
# return sql params back to default values
statement ok
set datafusion.optimizer.prefer_hash_join = true;
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]