alamb commented on code in PR #4826:
URL: https://github.com/apache/arrow-datafusion/pull/4826#discussion_r1070566001
##########
datafusion/core/tests/sql/joins.rs:
##########
@@ -2868,3 +2868,278 @@ async fn
test_cross_join_to_groupby_with_different_key_ordering() -> Result<()>
Ok(())
}
+
+#[tokio::test]
+async fn subquery_to_join_with_both_side_expr() -> Result<()> {
+ let ctx = create_join_context("t1_id", "t2_id", false)?;
+
+ let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id +
12 in (select t2.t2_id + 1 from t2)";
+
+ // assert logical plan
+ let msg = format!("Creating logical plan for '{sql}'");
+ let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
+ let plan = dataframe.into_optimized_plan().unwrap();
+
+ let expected = vec![
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N,
t1_name:Utf8;N, t1_int:UInt32;N]",
+ " LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) =
__correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) [t1_id:UInt32;N,
t1_name:Utf8;N, t1_int:UInt32;N]",
+ " TableScan: t1 projection=[t1_id, t1_name, t1_int]
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) +
Int64(1):Int64;N]",
+ " Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id
AS Int64) + Int64(1) [CAST(t2_id AS Int64) + Int64(1):Int64;N]",
+ " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
+ ];
+
+ let formatted = plan.display_indent_schema().to_string();
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+ );
+
+ let expected = vec![
+ "+-------+---------+--------+",
+ "| t1_id | t1_name | t1_int |",
+ "+-------+---------+--------+",
+ "| 11 | a | 1 |",
+ "| 33 | c | 3 |",
+ "| 44 | d | 4 |",
+ "+-------+---------+--------+",
+ ];
+
+ let results = execute_to_batches(&ctx, sql).await;
+ assert_batches_sorted_eq!(expected, &results);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn subquery_to_join_with_muti_filter() -> Result<()> {
+ let ctx = create_join_context("t1_id", "t2_id", false)?;
+
+ let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id +
12 in
+ (select t2.t2_id + 1 from t2 where t1.t1_int <=
t2.t2_int and t2.t2_int > 0)";
+
+ // assert logical plan
+ let msg = format!("Creating logical plan for '{sql}'");
+ let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
+ let plan = dataframe.into_optimized_plan().unwrap();
+
+ let expected = vec![
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N,
t1_name:Utf8;N, t1_int:UInt32;N]",
+ " LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) =
__correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) Filter: t1.t1_int <=
__correlated_sq_1.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " TableScan: t1 projection=[t1_id, t1_name, t1_int]
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) +
Int64(1):Int64;N, t2_int:UInt32;N]",
+ " Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id
AS Int64) + Int64(1), t2.t2_int [CAST(t2_id AS Int64) + Int64(1):Int64;N,
t2_int:UInt32;N]",
+ " Filter: t2.t2_int > UInt32(0) [t2_id:UInt32;N,
t2_int:UInt32;N]",
+ " TableScan: t2 projection=[t2_id, t2_int] [t2_id:UInt32;N,
t2_int:UInt32;N]",
+ ];
+
+ let formatted = plan.display_indent_schema().to_string();
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+ );
+
+ let expected = vec![
+ "+-------+---------+--------+",
+ "| t1_id | t1_name | t1_int |",
+ "+-------+---------+--------+",
+ "| 11 | a | 1 |",
+ "| 33 | c | 3 |",
+ "+-------+---------+--------+",
+ ];
+
+ let results = execute_to_batches(&ctx, sql).await;
+ assert_batches_sorted_eq!(expected, &results);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn three_projection_exprs_subquery_to_join() -> Result<()> {
+ let ctx = create_join_context("t1_id", "t2_id", false)?;
+
+ let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id +
12 in
+ (select t2.t2_id + 1 from t2 where t1.t1_int <=
t2.t2_int and t1.t1_name != t2.t2_name and t2.t2_int > 0)";
+
+ // assert logical plan
+ let msg = format!("Creating logical plan for '{sql}'");
+ let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
+ let plan = dataframe.into_optimized_plan().unwrap();
+
+ let expected = vec![
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N,
t1_name:Utf8;N, t1_int:UInt32;N]",
+ " LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) =
__correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) Filter: t1.t1_int <=
__correlated_sq_1.t2_int AND t1.t1_name != __correlated_sq_1.t2_name
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " TableScan: t1 projection=[t1_id, t1_name, t1_int]
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) +
Int64(1):Int64;N, t2_int:UInt32;N, t2_name:Utf8;N]",
+ " Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id
AS Int64) + Int64(1), t2.t2_int, t2.t2_name [CAST(t2_id AS Int64) +
Int64(1):Int64;N, t2_int:UInt32;N, t2_name:Utf8;N]",
+ " Filter: t2.t2_int > UInt32(0) [t2_id:UInt32;N,
t2_name:Utf8;N, t2_int:UInt32;N]",
+ " TableScan: t2 projection=[t2_id, t2_name, t2_int]
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ ];
+
+ let formatted = plan.display_indent_schema().to_string();
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+ );
+
+ let expected = vec![
+ "+-------+---------+--------+",
+ "| t1_id | t1_name | t1_int |",
+ "+-------+---------+--------+",
+ "| 11 | a | 1 |",
+ "| 33 | c | 3 |",
+ "+-------+---------+--------+",
+ ];
+
+ let results = execute_to_batches(&ctx, sql).await;
+ assert_batches_sorted_eq!(expected, &results);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn in_subquery_to_join_with_correlated_outer_filter() -> Result<()> {
+ let ctx = create_join_context("t1_id", "t2_id", false)?;
+
+ let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id +
12 in
+ (select t2.t2_id + 1 from t2 where t1.t1_int > 0)";
+
+ // assert logical plan
+ let msg = format!("Creating logical plan for '{sql}'");
+ let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
+ let plan = dataframe.into_optimized_plan().unwrap();
+
+ // The `t1.t1_int > UInt32(0)` should be pushdown by `filter push down
rule`.
+ let expected = vec![
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N,
t1_name:Utf8;N, t1_int:UInt32;N]",
+ " LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) =
__correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) Filter: t1.t1_int > UInt32(0)
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " TableScan: t1 projection=[t1_id, t1_name, t1_int]
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) +
Int64(1):Int64;N]",
+ " Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id
AS Int64) + Int64(1) [CAST(t2_id AS Int64) + Int64(1):Int64;N]",
+ " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
+ ];
Review Comment:
Is that something you would like to fix in this PR or is it something you
would like to fix in a follow on?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]