This is an automated email from the ASF dual-hosted git repository.
comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 3e011742c1 Bugfix - Projection Removal Conditions (#9215)
3e011742c1 is described below
commit 3e011742c18e9065040030bf2103572d9d2fe984
Author: Berkay Şahin <[email protected]>
AuthorDate: Wed Feb 14 19:25:37 2024 +0300
Bugfix - Projection Removal Conditions (#9215)
* Update projection_pushdown.rs
* Test added
---
.../src/physical_optimizer/projection_pushdown.rs | 113 ++++++++++++++++++---
1 file changed, 100 insertions(+), 13 deletions(-)
diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs
b/datafusion/core/src/physical_optimizer/projection_pushdown.rs
index 437d63dad2..79d22374f9 100644
--- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs
+++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs
@@ -813,20 +813,16 @@ fn try_swapping_with_sym_hash_join(
)?)))
}
-/// Compare the inputs and outputs of the projection. If the projection causes
-/// any change in the fields, it returns `false`.
+/// Compare the inputs and outputs of the projection. All expressions must be
+/// columns without alias, and projection does not change the order of fields.
fn is_projection_removable(projection: &ProjectionExec) -> bool {
- all_alias_free_columns(projection.expr()) && {
- let schema = projection.schema();
- let input_schema = projection.input().schema();
- let fields = schema.fields();
- let input_fields = input_schema.fields();
- fields.len() == input_fields.len()
- && fields
- .iter()
- .zip(input_fields.iter())
- .all(|(out, input)| out.eq(input))
- }
+ let exprs = projection.expr();
+ exprs.iter().enumerate().all(|(idx, (expr, alias))| {
+ let Some(col) = expr.as_any().downcast_ref::<Column>() else {
+ return false;
+ };
+ col.name() == alias && col.index() == idx
+ }) && exprs.len() == projection.input().schema().fields().len()
}
/// Given the expression set of a projection, checks if the projection causes
@@ -2156,6 +2152,97 @@ mod tests {
Ok(())
}
+ #[test]
+ fn test_join_after_required_projection() -> Result<()> {
+ let left_csv = create_simple_csv_exec();
+ let right_csv = create_simple_csv_exec();
+
+ let join: Arc<dyn ExecutionPlan> =
Arc::new(SymmetricHashJoinExec::try_new(
+ left_csv,
+ right_csv,
+ vec![(Arc::new(Column::new("b", 1)), Arc::new(Column::new("c",
2)))],
+ // b_left-(1+a_right)<=a_right+c_left
+ Some(JoinFilter::new(
+ Arc::new(BinaryExpr::new(
+ Arc::new(BinaryExpr::new(
+ Arc::new(Column::new("b_left_inter", 0)),
+ Operator::Minus,
+ Arc::new(BinaryExpr::new(
+
Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
+ Operator::Plus,
+ Arc::new(Column::new("a_right_inter", 1)),
+ )),
+ )),
+ Operator::LtEq,
+ Arc::new(BinaryExpr::new(
+ Arc::new(Column::new("a_right_inter", 1)),
+ Operator::Plus,
+ Arc::new(Column::new("c_left_inter", 2)),
+ )),
+ )),
+ vec![
+ ColumnIndex {
+ index: 1,
+ side: JoinSide::Left,
+ },
+ ColumnIndex {
+ index: 0,
+ side: JoinSide::Right,
+ },
+ ColumnIndex {
+ index: 2,
+ side: JoinSide::Left,
+ },
+ ],
+ Schema::new(vec![
+ Field::new("b_left_inter", DataType::Int32, true),
+ Field::new("a_right_inter", DataType::Int32, true),
+ Field::new("c_left_inter", DataType::Int32, true),
+ ]),
+ )),
+ &JoinType::Inner,
+ true,
+ None,
+ None,
+ StreamJoinPartitionMode::SinglePartition,
+ )?);
+ let projection: Arc<dyn ExecutionPlan> =
Arc::new(ProjectionExec::try_new(
+ vec![
+ (Arc::new(Column::new("a", 5)), "a".to_string()),
+ (Arc::new(Column::new("b", 6)), "b".to_string()),
+ (Arc::new(Column::new("c", 7)), "c".to_string()),
+ (Arc::new(Column::new("d", 8)), "d".to_string()),
+ (Arc::new(Column::new("e", 9)), "e".to_string()),
+ (Arc::new(Column::new("a", 0)), "a".to_string()),
+ (Arc::new(Column::new("b", 1)), "b".to_string()),
+ (Arc::new(Column::new("c", 2)), "c".to_string()),
+ (Arc::new(Column::new("d", 3)), "d".to_string()),
+ (Arc::new(Column::new("e", 4)), "e".to_string()),
+ ],
+ join,
+ )?);
+ let initial = get_plan_string(&projection);
+ let expected_initial = [
+ "ProjectionExec: expr=[a@5 as a, b@6 as b, c@7 as c, d@8 as d, e@9
as e, a@0 as a, b@1 as b, c@2 as c, d@3 as d, e@4 as e]",
+ " SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner,
on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1
+ c_left_inter@2",
+ " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c,
d, e], has_header=false",
+ " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c,
d, e], has_header=false"
+ ];
+ assert_eq!(initial, expected_initial);
+
+ let after_optimize =
+ ProjectionPushdown::new().optimize(projection,
&ConfigOptions::new())?;
+
+ let expected = [
+ "ProjectionExec: expr=[a@5 as a, b@6 as b, c@7 as c, d@8 as d, e@9
as e, a@0 as a, b@1 as b, c@2 as c, d@3 as d, e@4 as e]",
+ " SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner,
on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1
+ c_left_inter@2",
+ " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c,
d, e], has_header=false",
+ " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c,
d, e], has_header=false"
+ ];
+ assert_eq!(get_plan_string(&after_optimize), expected);
+ Ok(())
+ }
+
#[test]
fn test_repartition_after_projection() -> Result<()> {
let csv = create_simple_csv_exec();