ygf11 commented on code in PR #4826:
URL: https://github.com/apache/arrow-datafusion/pull/4826#discussion_r1070553578


##########
datafusion/optimizer/src/decorrelate_where_in.rs:
##########
@@ -926,4 +1035,153 @@ mod tests {
         );
         Ok(())
     }
+
+    #[test]
+    fn in_subquery_both_side_expr() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let subquery_scan = test_table_scan_with_name("sq")?;
+
+        let subquery = LogicalPlanBuilder::from(subquery_scan)
+            .project(vec![col("c") * lit(2u32)])?
+            .build()?;
+
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .filter(in_subquery(col("c") + lit(1u32), Arc::new(subquery)))?
+            .project(vec![col("test.b")])?
+            .build()?;
+
+        let expected = "Projection: test.b [b:UInt32]\
+        \n  LeftSemi Join:  Filter: test.c + UInt32(1) = __correlated_sq_1.c * 
UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\
+        \n    TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
+        \n    SubqueryAlias: __correlated_sq_1 [c * UInt32(2):UInt32]\
+        \n      Projection: sq.c * UInt32(2) AS c * UInt32(2) [c * 
UInt32(2):UInt32]\
+        \n        TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
+
+        assert_optimized_plan_eq_display_indent(
+            Arc::new(DecorrelateWhereIn::new()),
+            &plan,
+            expected,
+        );
+        Ok(())
+    }
+
+    #[test]
+    fn in_subquery_join_filter_and_inner_filter() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let subquery_scan = test_table_scan_with_name("sq")?;
+
+        let subquery = LogicalPlanBuilder::from(subquery_scan)
+            .filter(
+                col("test.a")
+                    .eq(col("sq.a"))
+                    .and(col("sq.a").add(lit(1u32)).eq(col("sq.b"))),
+            )?
+            .project(vec![col("c") * lit(2u32)])?
+            .build()?;
+
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .filter(in_subquery(col("c") + lit(1u32), Arc::new(subquery)))?
+            .project(vec![col("test.b")])?
+            .build()?;
+
+        let expected = "Projection: test.b [b:UInt32]\
+        \n  LeftSemi Join:  Filter: test.c + UInt32(1) = __correlated_sq_1.c * 
UInt32(2) AND test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\
+        \n    TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
+        \n    SubqueryAlias: __correlated_sq_1 [c * UInt32(2):UInt32, 
a:UInt32]\
+        \n      Projection: sq.c * UInt32(2) AS c * UInt32(2), sq.a [c * 
UInt32(2):UInt32, a:UInt32]\
+        \n        Filter: sq.a + UInt32(1) = sq.b [a:UInt32, b:UInt32, 
c:UInt32]\
+        \n          TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
+
+        assert_optimized_plan_eq_display_indent(
+            Arc::new(DecorrelateWhereIn::new()),
+            &plan,
+            expected,
+        );
+        Ok(())
+    }
+
+    #[test]
+    fn in_subquery_muti_project_subquery_cols() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let subquery_scan = test_table_scan_with_name("sq")?;
+
+        let subquery = LogicalPlanBuilder::from(subquery_scan)
+            .filter(
+                col("test.a")
+                    .add(col("test.b"))
+                    .eq(col("sq.a").add(col("sq.b")))
+                    .and(col("sq.a").add(lit(1u32)).eq(col("sq.b"))),
+            )?
+            .project(vec![col("c") * lit(2u32)])?
+            .build()?;
+
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .filter(in_subquery(col("c") + lit(1u32), Arc::new(subquery)))?
+            .project(vec![col("test.b")])?
+            .build()?;
+
+        let expected = "Projection: test.b [b:UInt32]\
+        \n  LeftSemi Join:  Filter: test.c + UInt32(1) = __correlated_sq_1.c * 
UInt32(2) AND test.a + test.b = __correlated_sq_1.a + __correlated_sq_1.b 
[a:UInt32, b:UInt32, c:UInt32]\
+        \n    TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
+        \n    SubqueryAlias: __correlated_sq_1 [c * UInt32(2):UInt32, 
a:UInt32, b:UInt32]\
+        \n      Projection: sq.c * UInt32(2) AS c * UInt32(2), sq.a, sq.b [c * 
UInt32(2):UInt32, a:UInt32, b:UInt32]\
+        \n        Filter: sq.a + UInt32(1) = sq.b [a:UInt32, b:UInt32, 
c:UInt32]\
+        \n          TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
+
+        assert_optimized_plan_eq_display_indent(
+            Arc::new(DecorrelateWhereIn::new()),
+            &plan,
+            expected,
+        );
+        Ok(())
+    }
+
+    #[test]
+    fn two_in_subquery_with_outer_filter() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let subquery_scan1 = test_table_scan_with_name("sq1")?;
+        let subquery_scan2 = test_table_scan_with_name("sq2")?;
+
+        let subquery1 = LogicalPlanBuilder::from(subquery_scan1)
+            .filter(col("test.a").gt(col("sq1.a")))?
+            .project(vec![col("c") * lit(2u32)])?
+            .build()?;
+
+        let subquery2 = LogicalPlanBuilder::from(subquery_scan2)
+            .filter(col("test.a").gt(col("sq2.a")))?
+            .project(vec![col("c") * lit(2u32)])?
+            .build()?;
+
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .filter(
+                in_subquery(col("c") + lit(1u32), Arc::new(subquery1)).and(
+                    in_subquery(col("c") * lit(2u32), Arc::new(subquery2))
+                        .and(col("test.c").gt(lit(1u32))),
+                ),
+            )?
+            .project(vec![col("test.b")])?
+            .build()?;
+
+        // Filter: test.c > UInt32(1) happen twice.
+        // issue: https://github.com/apache/arrow-datafusion/issues/4914
+        let expected = "Projection: test.b [b:UInt32]\
+        \n  Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]\
+        \n    LeftSemi Join:  Filter: test.c * UInt32(2) = __correlated_sq_2.c 
* UInt32(2) AND test.a > __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32]\
+        \n      Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]\
+        \n        LeftSemi Join:  Filter: test.c + UInt32(1) = 
__correlated_sq_1.c * UInt32(2) AND test.a > __correlated_sq_1.a [a:UInt32, 
b:UInt32, c:UInt32]\
+        \n          TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
+        \n          SubqueryAlias: __correlated_sq_1 [c * UInt32(2):UInt32, 
a:UInt32]\
+        \n            Projection: sq1.c * UInt32(2) AS c * UInt32(2), sq1.a [c 
* UInt32(2):UInt32, a:UInt32]\
+        \n              TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\
+        \n      SubqueryAlias: __correlated_sq_2 [c * UInt32(2):UInt32, 
a:UInt32]\
+        \n        Projection: sq2.c * UInt32(2) AS c * UInt32(2), sq2.a [c * 
UInt32(2):UInt32, a:UInt32]\
+        \n          TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]";

Review Comment:
   Filter: test.c > UInt32(1) happen twice, it is better to add once, will fix 
it in the following pr.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to