AssHero commented on code in PR #2750:
URL: https://github.com/apache/arrow-datafusion/pull/2750#discussion_r905641685
##########
datafusion/core/tests/sql/joins.rs:
##########
@@ -1375,3 +1375,351 @@ async fn hash_join_with_dictionary() -> Result<()> {
Ok(())
}
+
+#[tokio::test]
+async fn reduce_left_join_1() -> Result<()> {
+ let ctx = create_join_context("t1_id", "t2_id")?;
+
+ // reduce to inner join
+ let sql = "select * from t1 left join t2 on t1.t1_id = t2.t2_id where
t2.t2_id < 100";
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let plan = ctx
+ .create_logical_plan(&("explain ".to_owned() + sql))
+ .expect(&msg);
+ let state = ctx.state();
+ let plan = state.optimize(&plan)?;
+ let expected = vec![
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Projection: #t1.t1_id, #t1.t1_name, #t1.t1_int, #t2.t2_id,
#t2.t2_name, #t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N,
t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ " Inner Join: #t1.t1_id = #t2.t2_id [t1_id:UInt32;N,
t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N,
t2_int:UInt32;N]",
+ " Filter: #t1.t1_id < Int64(100) [t1_id:UInt32;N, t1_name:Utf8;N,
t1_int:UInt32;N]",
+ " TableScan: t1 projection=Some([t1_id, t1_name, t1_int])
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " Filter: #t2.t2_id < Int64(100) [t2_id:UInt32;N, t2_name:Utf8;N,
t2_int:UInt32;N]",
+ " TableScan: t2 projection=Some([t2_id, t2_name, t2_int])
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ ];
+ let formatted = plan.display_indent_schema().to_string();
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+ let expected = vec![
+ "+-------+---------+--------+-------+---------+--------+",
+ "| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int |",
+ "+-------+---------+--------+-------+---------+--------+",
+ "| 11 | a | 1 | 11 | z | 3 |",
+ "| 22 | b | 2 | 22 | y | 1 |",
+ "| 44 | d | 4 | 44 | x | 3 |",
+ "+-------+---------+--------+-------+---------+--------+",
+ ];
+
+ let results = execute_to_batches(&ctx, sql).await;
+ assert_batches_sorted_eq!(expected, &results);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn reduce_left_join_2() -> Result<()> {
+ let ctx = create_join_context("t1_id", "t2_id")?;
+
+ // reduce to inner join
+ let sql = "select * from t1 left join t2 on t1.t1_id = t2.t2_id where
t2.t2_int < 10 or (t1.t1_int > 2 and t2.t2_name != 'w')";
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let plan = ctx
+ .create_logical_plan(&("explain ".to_owned() + sql))
+ .expect(&msg);
+ let state = ctx.state();
+ let plan = state.optimize(&plan)?;
+ let expected = vec![
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Projection: #t1.t1_id, #t1.t1_name, #t1.t1_int, #t2.t2_id,
#t2.t2_name, #t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N,
t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ " Filter: #t2.t2_int < Int64(10) OR #t1.t1_int > Int64(2) AND
#t2.t2_name != Utf8(\"w\") [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N,
t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ " Inner Join: #t1.t1_id = #t2.t2_id [t1_id:UInt32;N,
t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N,
t2_int:UInt32;N]",
+ " TableScan: t1 projection=Some([t1_id, t1_name, t1_int])
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " TableScan: t2 projection=Some([t2_id, t2_name, t2_int])
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ ];
+ let formatted = plan.display_indent_schema().to_string();
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+ let expected = vec![
+ "+-------+---------+--------+-------+---------+--------+",
+ "| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int |",
+ "+-------+---------+--------+-------+---------+--------+",
+ "| 11 | a | 1 | 11 | z | 3 |",
+ "| 22 | b | 2 | 22 | y | 1 |",
+ "| 44 | d | 4 | 44 | x | 3 |",
+ "+-------+---------+--------+-------+---------+--------+",
+ ];
+
+ let results = execute_to_batches(&ctx, sql).await;
+ assert_batches_sorted_eq!(expected, &results);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn reduce_left_join_3() -> Result<()> {
+ let ctx = create_join_context("t1_id", "t2_id")?;
+
+ // reduce subquery to inner join
+ let sql = "select * from (select t1.* from t1 left join t2 on t1.t1_id =
t2.t2_id where t2.t2_int < 3) t3 left join t2 on t3.t1_int = t2.t2_int where
t3.t1_id < 100";
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let plan = ctx
+ .create_logical_plan(&("explain ".to_owned() + sql))
+ .expect(&msg);
+ let state = ctx.state();
+ let plan = state.optimize(&plan)?;
+ let expected = vec![
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Projection: #t3.t1_id, #t3.t1_name, #t3.t1_int, #t2.t2_id,
#t2.t2_name, #t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N,
t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ " Left Join: #t3.t1_int = #t2.t2_int [t1_id:UInt32;N,
t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N,
t2_int:UInt32;N]",
+ " Projection: #t3.t1_id, #t3.t1_name, #t3.t1_int, alias=t3
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " Projection: #t1.t1_id, #t1.t1_name, #t1.t1_int, alias=t3
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " Inner Join: #t1.t1_id = #t2.t2_id [t1_id:UInt32;N,
t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N,
t2_int:UInt32;N]",
+ " Filter: #t1.t1_id < Int64(100) [t1_id:UInt32;N,
t1_name:Utf8;N, t1_int:UInt32;N]",
+ " TableScan: t1 projection=Some([t1_id, t1_name, t1_int])
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " Filter: #t2.t2_int < Int64(3) AND #t2.t2_id < Int64(100)
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ " TableScan: t2 projection=Some([t2_id, t2_name, t2_int])
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ " TableScan: t2 projection=Some([t2_id, t2_name, t2_int])
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ ];
+ let formatted = plan.display_indent_schema().to_string();
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+ let expected = vec![
+ "+-------+---------+--------+-------+---------+--------+",
+ "| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int |",
+ "+-------+---------+--------+-------+---------+--------+",
+ "| 22 | b | 2 | | | |",
+ "+-------+---------+--------+-------+---------+--------+",
+ ];
+
+ let results = execute_to_batches(&ctx, sql).await;
+ assert_batches_sorted_eq!(expected, &results);
+
+ Ok(())
+}
Review Comment:
I'll add more test cases in test part of this rule, including this.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]