irenjj commented on code in PR #16611:
URL: https://github.com/apache/datafusion/pull/16611#discussion_r2174245375


##########
datafusion/optimizer/src/decorrelate_dependent_join.rs:
##########
@@ -1096,6 +1097,88 @@ mod tests {
         ");
         Ok(())
     }
+    #[test]
+    fn paper() -> Result<()> {
+        let outer_table = test_table_scan_with_name("T1")?;
+        let inner_table_lv1 = test_table_scan_with_name("T2")?;
+
+        let inner_table_lv2 = test_table_scan_with_name("T3")?;
+        let scalar_sq_level2 = Arc::new(
+            LogicalPlanBuilder::from(inner_table_lv2)
+                .filter(
+                    col("T3.b")
+                        .eq(out_ref_col(ArrowDataType::UInt32, "T2.b"))
+                        .and(col("T3.a").eq(out_ref_col(ArrowDataType::UInt32, 
"T1.a"))),
+                )?
+                .aggregate(Vec::<Expr>::new(), vec![sum(col("T3.a"))])?
+                .build()?,
+        );
+        let scalar_sq_level1 = Arc::new(
+            LogicalPlanBuilder::from(inner_table_lv1.clone())
+                .filter(
+                    col("T2.a")
+                        .eq(out_ref_col(ArrowDataType::UInt32, "T1.a"))
+                        
.and(scalar_subquery(scalar_sq_level2).gt(lit(300000))),
+                )?
+                .aggregate(Vec::<Expr>::new(), vec![count(col("T2.a"))])?
+                .build()?,
+        );
+
+        let plan = LogicalPlanBuilder::from(outer_table.clone())
+            .filter(
+                col("T1.c")
+                    .eq(lit(123))
+                    .and(scalar_subquery(scalar_sq_level1).gt(lit(5))),
+            )?
+            .build()?;
+        print_graphviz(&plan);
+
+        // Projection: outer_table.a, outer_table.b, outer_table.c
+        //   Filter: outer_table.a > Int32(1) AND __scalar_sq_2.output = 
outer_table.a
+        //     DependentJoin on [outer_table.a lvl 2, outer_table.c lvl 1] 
with expr (<subquery>) depth 1
+        //       TableScan: outer_table
+        //       Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]]
+        //         Projection: inner_table_lv1.a, inner_table_lv1.b, 
inner_table_lv1.c
+        //           Filter: inner_table_lv1.c = outer_ref(outer_table.c) AND 
__scalar_sq_1.output = Int32(1)
+        //             DependentJoin on [inner_table_lv1.b lvl 2] with expr 
(<subquery>) depth 2
+        //               TableScan: inner_table_lv1
+        //               Aggregate: groupBy=[[]], 
aggr=[[count(inner_table_lv2.a)]]
+        //                 Filter: inner_table_lv2.a = 
outer_ref(outer_table.a) AND inner_table_lv2.b = outer_ref(inner_table_lv1.b)
+        //                   TableScan: inner_table_lv2
+        assert_decorrelate!(plan, @r"
+        Projection: t1.a, t1.b, t1.c [a:UInt32, b:UInt32, c:UInt32]
+          Filter: t1.c = Int32(123) AND __scalar_sq_2.output > Int32(5) 
[a:UInt32, b:UInt32, c:UInt32, CASE WHEN count(t2.a) IS NULL THEN Int32(0) ELSE 
count(t2.a) END:Int32;N, t1_a:UInt32;N, __scalar_sq_2.output:Int32;N]
+            Projection: t1.a, t1.b, t1.c, CASE WHEN count(t2.a) IS NULL THEN 
Int32(0) ELSE count(t2.a) END, delim_scan_2.t1_a, CASE WHEN count(t2.a) IS NULL 
THEN Int32(0) ELSE count(t2.a) END AS __scalar_sq_2.output [a:UInt32, b:UInt32, 
c:UInt32, CASE WHEN count(t2.a) IS NULL THEN Int32(0) ELSE count(t2.a) 
END:Int32;N, t1_a:UInt32;N, __scalar_sq_2.output:Int32;N]
+              Left Join(ComparisonJoin):  Filter: t1.a IS NOT DISTINCT FROM 
delim_scan_4.t1_a [a:UInt32, b:UInt32, c:UInt32, CASE WHEN count(t2.a) IS NULL 
THEN Int32(0) ELSE count(t2.a) END:Int32;N, t1_a:UInt32;N]
+                TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
+                Projection: CASE WHEN count(t2.a) IS NULL THEN Int32(0) ELSE 
count(t2.a) END, delim_scan_2.t1_a [CASE WHEN count(t2.a) IS NULL THEN Int32(0) 
ELSE count(t2.a) END:Int32, t1_a:UInt32;N]
+                  Inner Join(DelimJoin):  Filter: delim_scan_2.t1_a IS NOT 
DISTINCT FROM delim_scan_1.t1_a [count(t2.a):Int64, t1_a:UInt32;N, 
t1_a:UInt32;N]
+                    Projection: CASE WHEN count(t2.a) IS NULL THEN Int32(0) 
ELSE count(t2.a) END, delim_scan_2.t1_a [count(t2.a):Int64, t1_a:UInt32;N]
+                      Aggregate: groupBy=[[delim_scan_2.t1_a]], 
aggr=[[count(t2.a)]] [t1_a:UInt32;N, count(t2.a):Int64]
+                        Projection: t2.a, t2.b, t2.c, delim_scan_2.t1_a 
[a:UInt32, b:UInt32, c:UInt32, t1_a:UInt32;N]
+                          Filter: t2.a = delim_scan_2.t1_a AND 
__scalar_sq_1.output > Int32(300000) [a:UInt32, b:UInt32, c:UInt32, 
t1_a:UInt32;N, sum(t3.a):UInt64;N, t1_a:UInt32;N, t2_b:UInt32;N, 
__scalar_sq_1.output:UInt64;N]
+                            Projection: t2.a, t2.b, t2.c, delim_scan_2.t1_a, 
sum(t3.a), delim_scan_4.t1_a, delim_scan_4.t2_b, sum(t3.a) AS 
__scalar_sq_1.output [a:UInt32, b:UInt32, c:UInt32, t1_a:UInt32;N, 
sum(t3.a):UInt64;N, t1_a:UInt32;N, t2_b:UInt32;N, __scalar_sq_1.output:UInt64;N]
+                              Left Join(ComparisonJoin):  Filter: t2.b IS NOT 
DISTINCT FROM delim_scan_4.t2_b [a:UInt32, b:UInt32, c:UInt32, t1_a:UInt32;N, 
sum(t3.a):UInt64;N, t1_a:UInt32;N, t2_b:UInt32;N]
+                                Inner Join(DelimJoin):  Filter: Boolean(true) 
[a:UInt32, b:UInt32, c:UInt32, t1_a:UInt32;N]
+                                  TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
+                                  SubqueryAlias: delim_scan_2 [t1_a:UInt32;N]
+                                    DelimGet: t1.a [t1_a:UInt32;N]
+                                Projection: sum(t3.a), delim_scan_4.t1_a, 
delim_scan_4.t2_b [sum(t3.a):UInt64;N, t1_a:UInt32;N, t2_b:UInt32;N]
+                                  Inner Join(DelimJoin):  Filter: 
delim_scan_4.t2_b IS NOT DISTINCT FROM delim_scan_3.t2_b AND delim_scan_4.t1_a 
IS NOT DISTINCT FROM delim_scan_3.t1_a [sum(t3.a):UInt64;N, t1_a:UInt32;N, 
t2_b:UInt32;N, t2_b:UInt32;N, t1_a:UInt32;N]
+                                    Projection: sum(t3.a), delim_scan_4.t1_a, 
delim_scan_4.t2_b [sum(t3.a):UInt64;N, t1_a:UInt32;N, t2_b:UInt32;N]
+                                      Aggregate: groupBy=[[delim_scan_4.t2_b, 
delim_scan_4.t1_a]], aggr=[[sum(t3.a)]] [t2_b:UInt32;N, t1_a:UInt32;N, 
sum(t3.a):UInt64;N]

Review Comment:
   we still need to add two outercolumn into group by if we let one delimget to 
scan two different outer table right?
   `delim_scan_4.t2_b, delim_scan_4.t1_a`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org
For additional commands, e-mail: github-h...@datafusion.apache.org

Reply via email to