irenjj commented on code in PR #16611: URL: https://github.com/apache/datafusion/pull/16611#discussion_r2174245375
########## datafusion/optimizer/src/decorrelate_dependent_join.rs: ########## @@ -1096,6 +1097,88 @@ mod tests { "); Ok(()) } + #[test] + fn paper() -> Result<()> { + let outer_table = test_table_scan_with_name("T1")?; + let inner_table_lv1 = test_table_scan_with_name("T2")?; + + let inner_table_lv2 = test_table_scan_with_name("T3")?; + let scalar_sq_level2 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv2) + .filter( + col("T3.b") + .eq(out_ref_col(ArrowDataType::UInt32, "T2.b")) + .and(col("T3.a").eq(out_ref_col(ArrowDataType::UInt32, "T1.a"))), + )? + .aggregate(Vec::<Expr>::new(), vec![sum(col("T3.a"))])? + .build()?, + ); + let scalar_sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1.clone()) + .filter( + col("T2.a") + .eq(out_ref_col(ArrowDataType::UInt32, "T1.a")) + .and(scalar_subquery(scalar_sq_level2).gt(lit(300000))), + )? + .aggregate(Vec::<Expr>::new(), vec![count(col("T2.a"))])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(outer_table.clone()) + .filter( + col("T1.c") + .eq(lit(123)) + .and(scalar_subquery(scalar_sq_level1).gt(lit(5))), + )? + .build()?; + print_graphviz(&plan); + + // Projection: outer_table.a, outer_table.b, outer_table.c + // Filter: outer_table.a > Int32(1) AND __scalar_sq_2.output = outer_table.a + // DependentJoin on [outer_table.a lvl 2, outer_table.c lvl 1] with expr (<subquery>) depth 1 + // TableScan: outer_table + // Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] + // Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c + // Filter: inner_table_lv1.c = outer_ref(outer_table.c) AND __scalar_sq_1.output = Int32(1) + // DependentJoin on [inner_table_lv1.b lvl 2] with expr (<subquery>) depth 2 + // TableScan: inner_table_lv1 + // Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv2.a)]] + // Filter: inner_table_lv2.a = outer_ref(outer_table.a) AND inner_table_lv2.b = outer_ref(inner_table_lv1.b) + // TableScan: inner_table_lv2 + assert_decorrelate!(plan, @r" + Projection: t1.a, t1.b, t1.c [a:UInt32, b:UInt32, c:UInt32] + Filter: t1.c = Int32(123) AND __scalar_sq_2.output > Int32(5) [a:UInt32, b:UInt32, c:UInt32, CASE WHEN count(t2.a) IS NULL THEN Int32(0) ELSE count(t2.a) END:Int32;N, t1_a:UInt32;N, __scalar_sq_2.output:Int32;N] + Projection: t1.a, t1.b, t1.c, CASE WHEN count(t2.a) IS NULL THEN Int32(0) ELSE count(t2.a) END, delim_scan_2.t1_a, CASE WHEN count(t2.a) IS NULL THEN Int32(0) ELSE count(t2.a) END AS __scalar_sq_2.output [a:UInt32, b:UInt32, c:UInt32, CASE WHEN count(t2.a) IS NULL THEN Int32(0) ELSE count(t2.a) END:Int32;N, t1_a:UInt32;N, __scalar_sq_2.output:Int32;N] + Left Join(ComparisonJoin): Filter: t1.a IS NOT DISTINCT FROM delim_scan_4.t1_a [a:UInt32, b:UInt32, c:UInt32, CASE WHEN count(t2.a) IS NULL THEN Int32(0) ELSE count(t2.a) END:Int32;N, t1_a:UInt32;N] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + Projection: CASE WHEN count(t2.a) IS NULL THEN Int32(0) ELSE count(t2.a) END, delim_scan_2.t1_a [CASE WHEN count(t2.a) IS NULL THEN Int32(0) ELSE count(t2.a) END:Int32, t1_a:UInt32;N] + Inner Join(DelimJoin): Filter: delim_scan_2.t1_a IS NOT DISTINCT FROM delim_scan_1.t1_a [count(t2.a):Int64, t1_a:UInt32;N, t1_a:UInt32;N] + Projection: CASE WHEN count(t2.a) IS NULL THEN Int32(0) ELSE count(t2.a) END, delim_scan_2.t1_a [count(t2.a):Int64, t1_a:UInt32;N] + Aggregate: groupBy=[[delim_scan_2.t1_a]], aggr=[[count(t2.a)]] [t1_a:UInt32;N, count(t2.a):Int64] + Projection: t2.a, t2.b, t2.c, delim_scan_2.t1_a [a:UInt32, b:UInt32, c:UInt32, t1_a:UInt32;N] + Filter: t2.a = delim_scan_2.t1_a AND __scalar_sq_1.output > Int32(300000) [a:UInt32, b:UInt32, c:UInt32, t1_a:UInt32;N, sum(t3.a):UInt64;N, t1_a:UInt32;N, t2_b:UInt32;N, __scalar_sq_1.output:UInt64;N] + Projection: t2.a, t2.b, t2.c, delim_scan_2.t1_a, sum(t3.a), delim_scan_4.t1_a, delim_scan_4.t2_b, sum(t3.a) AS __scalar_sq_1.output [a:UInt32, b:UInt32, c:UInt32, t1_a:UInt32;N, sum(t3.a):UInt64;N, t1_a:UInt32;N, t2_b:UInt32;N, __scalar_sq_1.output:UInt64;N] + Left Join(ComparisonJoin): Filter: t2.b IS NOT DISTINCT FROM delim_scan_4.t2_b [a:UInt32, b:UInt32, c:UInt32, t1_a:UInt32;N, sum(t3.a):UInt64;N, t1_a:UInt32;N, t2_b:UInt32;N] + Inner Join(DelimJoin): Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32, t1_a:UInt32;N] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: delim_scan_2 [t1_a:UInt32;N] + DelimGet: t1.a [t1_a:UInt32;N] + Projection: sum(t3.a), delim_scan_4.t1_a, delim_scan_4.t2_b [sum(t3.a):UInt64;N, t1_a:UInt32;N, t2_b:UInt32;N] + Inner Join(DelimJoin): Filter: delim_scan_4.t2_b IS NOT DISTINCT FROM delim_scan_3.t2_b AND delim_scan_4.t1_a IS NOT DISTINCT FROM delim_scan_3.t1_a [sum(t3.a):UInt64;N, t1_a:UInt32;N, t2_b:UInt32;N, t2_b:UInt32;N, t1_a:UInt32;N] + Projection: sum(t3.a), delim_scan_4.t1_a, delim_scan_4.t2_b [sum(t3.a):UInt64;N, t1_a:UInt32;N, t2_b:UInt32;N] + Aggregate: groupBy=[[delim_scan_4.t2_b, delim_scan_4.t1_a]], aggr=[[sum(t3.a)]] [t2_b:UInt32;N, t1_a:UInt32;N, sum(t3.a):UInt64;N] Review Comment: we still need to add two outercolumn into group by if we let one delimget to scan two different outer table right? `delim_scan_4.t2_b, delim_scan_4.t1_a` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org