ygf11 commented on code in PR #4826:
URL: https://github.com/apache/arrow-datafusion/pull/4826#discussion_r1070554693


##########
datafusion/core/tests/sql/joins.rs:
##########
@@ -2868,3 +2868,278 @@ async fn 
test_cross_join_to_groupby_with_different_key_ordering() -> Result<()>
 
     Ok(())
 }
+
+#[tokio::test]
+async fn subquery_to_join_with_both_side_expr() -> Result<()> {
+    let ctx = create_join_context("t1_id", "t2_id", false)?;
+
+    let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 
12 in (select t2.t2_id + 1 from t2)";
+
+    // assert logical plan
+    let msg = format!("Creating logical plan for '{sql}'");
+    let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
+    let plan = dataframe.into_optimized_plan().unwrap();
+
+    let expected = vec![
+        "Explain [plan_type:Utf8, plan:Utf8]",
+        "  Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, 
t1_name:Utf8;N, t1_int:UInt32;N]",
+        "    LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = 
__correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) [t1_id:UInt32;N, 
t1_name:Utf8;N, t1_int:UInt32;N]",
+        "      TableScan: t1 projection=[t1_id, t1_name, t1_int] 
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "      SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + 
Int64(1):Int64;N]",
+        "        Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id 
AS Int64) + Int64(1) [CAST(t2_id AS Int64) + Int64(1):Int64;N]",
+        "          TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
+    ];
+
+    let formatted = plan.display_indent_schema().to_string();
+    let actual: Vec<&str> = formatted.trim().lines().collect();
+    assert_eq!(
+        expected, actual,
+        "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+    );
+
+    let expected = vec![
+        "+-------+---------+--------+",
+        "| t1_id | t1_name | t1_int |",
+        "+-------+---------+--------+",
+        "| 11    | a       | 1      |",
+        "| 33    | c       | 3      |",
+        "| 44    | d       | 4      |",
+        "+-------+---------+--------+",
+    ];
+
+    let results = execute_to_batches(&ctx, sql).await;
+    assert_batches_sorted_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn subquery_to_join_with_muti_filter() -> Result<()> {
+    let ctx = create_join_context("t1_id", "t2_id", false)?;
+
+    let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 
12 in 
+                         (select t2.t2_id + 1 from t2 where t1.t1_int <= 
t2.t2_int and t2.t2_int > 0)";
+
+    // assert logical plan
+    let msg = format!("Creating logical plan for '{sql}'");
+    let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
+    let plan = dataframe.into_optimized_plan().unwrap();
+
+    let expected = vec![
+        "Explain [plan_type:Utf8, plan:Utf8]",
+        "  Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, 
t1_name:Utf8;N, t1_int:UInt32;N]",
+        "    LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = 
__correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) Filter: t1.t1_int <= 
__correlated_sq_1.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "      TableScan: t1 projection=[t1_id, t1_name, t1_int] 
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "      SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + 
Int64(1):Int64;N, t2_int:UInt32;N]",
+        "        Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id 
AS Int64) + Int64(1), t2.t2_int [CAST(t2_id AS Int64) + Int64(1):Int64;N, 
t2_int:UInt32;N]",
+        "          Filter: t2.t2_int > UInt32(0) [t2_id:UInt32;N, 
t2_int:UInt32;N]",
+        "            TableScan: t2 projection=[t2_id, t2_int] [t2_id:UInt32;N, 
t2_int:UInt32;N]",
+    ];
+
+    let formatted = plan.display_indent_schema().to_string();
+    let actual: Vec<&str> = formatted.trim().lines().collect();
+    assert_eq!(
+        expected, actual,
+        "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+    );
+
+    let expected = vec![
+        "+-------+---------+--------+",
+        "| t1_id | t1_name | t1_int |",
+        "+-------+---------+--------+",
+        "| 11    | a       | 1      |",
+        "| 33    | c       | 3      |",
+        "+-------+---------+--------+",
+    ];
+
+    let results = execute_to_batches(&ctx, sql).await;
+    assert_batches_sorted_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn three_projection_exprs_subquery_to_join() -> Result<()> {
+    let ctx = create_join_context("t1_id", "t2_id", false)?;
+
+    let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 
12 in 
+                         (select t2.t2_id + 1 from t2 where t1.t1_int <= 
t2.t2_int and t1.t1_name != t2.t2_name and t2.t2_int > 0)";
+
+    // assert logical plan
+    let msg = format!("Creating logical plan for '{sql}'");
+    let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
+    let plan = dataframe.into_optimized_plan().unwrap();
+
+    let expected = vec![
+        "Explain [plan_type:Utf8, plan:Utf8]",
+        "  Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, 
t1_name:Utf8;N, t1_int:UInt32;N]",
+        "    LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = 
__correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) Filter: t1.t1_int <= 
__correlated_sq_1.t2_int AND t1.t1_name != __correlated_sq_1.t2_name 
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "      TableScan: t1 projection=[t1_id, t1_name, t1_int] 
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "      SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + 
Int64(1):Int64;N, t2_int:UInt32;N, t2_name:Utf8;N]",
+        "        Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id 
AS Int64) + Int64(1), t2.t2_int, t2.t2_name [CAST(t2_id AS Int64) + 
Int64(1):Int64;N, t2_int:UInt32;N, t2_name:Utf8;N]",
+        "          Filter: t2.t2_int > UInt32(0) [t2_id:UInt32;N, 
t2_name:Utf8;N, t2_int:UInt32;N]",
+        "            TableScan: t2 projection=[t2_id, t2_name, t2_int] 
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+    ];
+
+    let formatted = plan.display_indent_schema().to_string();
+    let actual: Vec<&str> = formatted.trim().lines().collect();
+    assert_eq!(
+        expected, actual,
+        "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+    );
+
+    let expected = vec![
+        "+-------+---------+--------+",
+        "| t1_id | t1_name | t1_int |",
+        "+-------+---------+--------+",
+        "| 11    | a       | 1      |",
+        "| 33    | c       | 3      |",
+        "+-------+---------+--------+",
+    ];
+
+    let results = execute_to_batches(&ctx, sql).await;
+    assert_batches_sorted_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn in_subquery_to_join_with_correlated_outer_filter() -> Result<()> {
+    let ctx = create_join_context("t1_id", "t2_id", false)?;
+
+    let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 
12 in 
+                         (select t2.t2_id + 1 from t2 where t1.t1_int > 0)";
+
+    // assert logical plan
+    let msg = format!("Creating logical plan for '{sql}'");
+    let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
+    let plan = dataframe.into_optimized_plan().unwrap();
+
+    // The `t1.t1_int > UInt32(0)` should be pushdown by `filter push down 
rule`.
+    let expected = vec![
+        "Explain [plan_type:Utf8, plan:Utf8]",
+        "  Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, 
t1_name:Utf8;N, t1_int:UInt32;N]",
+        "    LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = 
__correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) Filter: t1.t1_int > UInt32(0) 
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "      TableScan: t1 projection=[t1_id, t1_name, t1_int] 
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "      SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + 
Int64(1):Int64;N]",
+        "        Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id 
AS Int64) + Int64(1) [CAST(t2_id AS Int64) + Int64(1):Int64;N]",
+        "          TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
+    ];

Review Comment:
   This test case show that the special case `t1.t1_int > 0` is not pushed down 
after optimizing.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to