alamb commented on code in PR #4666:
URL: https://github.com/apache/arrow-datafusion/pull/4666#discussion_r1052569474
##########
datafusion/expr/src/utils.rs:
##########
@@ -478,20 +479,33 @@ pub fn from_plan(
}) => {
let schema =
build_join_schema(inputs[0].schema(), inputs[1].schema(),
join_type)?;
+
+ let equi_expr_count = on.len();
+ // The preceding part of expr is equi-exprs,
+ // and the struct of each equi-expr is like `left-expr =
right-expr`.
+ let new_on:Vec<(Expr,Expr)> =
expr.iter().take(equi_expr_count).map(|equi_expr| {
+ if let Expr::BinaryExpr(BinaryExpr { left, op, right }) =
equi_expr {
+ assert!(op == &Operator::Eq);
+ Ok(((**left).clone(), (**right).clone()))
+ } else {
+ Err(DataFusionError::Internal(format!(
Review Comment:
👍
##########
datafusion/core/tests/sql/joins.rs:
##########
@@ -2778,3 +2781,141 @@ async fn select_wildcard_with_expr_key_inner_join() ->
Result<()> {
Ok(())
}
+
+#[tokio::test]
+async fn join_with_type_coercion_for_equi_expr() -> Result<()> {
+ let ctx = create_join_context("t1_id", "t2_id", false)?;
+
+ let sql = "select t1.t1_id, t1.t1_name, t2.t2_id from t1 inner join t2 on
t1.t1_id + 11 = t2.t2_id";
+
+ // assert logical plan
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let plan = ctx
+ .create_logical_plan(&("explain ".to_owned() + sql))
+ .expect(&msg);
+ let state = ctx.state();
+ let plan = state.optimize(&plan)?;
+
+ let expected = vec![
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Projection: t1.t1_id, t1.t1_name, t2.t2_id [t1_id:UInt32;N,
t1_name:Utf8;N, t2_id:UInt32;N]",
+ " Inner Join: CAST(t1.t1_id AS Int64) + Int64(11) = CAST(t2.t2_id
AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
+ " TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N,
t1_name:Utf8;N]",
+ " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
+ ];
+
+ let formatted = plan.display_indent_schema().to_string();
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+
+ let expected = vec![
+ "+-------+---------+-------+",
+ "| t1_id | t1_name | t2_id |",
+ "+-------+---------+-------+",
+ "| 11 | a | 22 |",
+ "| 33 | c | 44 |",
+ "| 44 | d | 55 |",
+ "+-------+---------+-------+",
+ ];
+
+ let results = execute_to_batches(&ctx, sql).await;
+ assert_batches_sorted_eq!(expected, &results);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn join_only_with_filter() -> Result<()> {
+ let ctx = create_join_context("t1_id", "t2_id", false)?;
+
+ let sql = "select t1.t1_id, t1.t1_name, t2.t2_id from t1 inner join t2 on
t1.t1_id * 4 < t2.t2_id";
+
+ // assert logical plan
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let plan = ctx
+ .create_logical_plan(&("explain ".to_owned() + sql))
+ .expect(&msg);
+ let state = ctx.state();
+ let plan = state.optimize(&plan)?;
+
+ let expected = vec![
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Projection: t1.t1_id, t1.t1_name, t2.t2_id [t1_id:UInt32;N,
t1_name:Utf8;N, t2_id:UInt32;N]",
+ " Filter: CAST(t1.t1_id AS Int64) * Int64(4) < CAST(t2.t2_id AS
Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
+ " CrossJoin: [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
+ " TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N,
t1_name:Utf8;N]",
+ " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
+ ];
+
+ let formatted = plan.display_indent_schema().to_string();
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+
+ let expected = vec![
+ "+-------+---------+-------+",
+ "| t1_id | t1_name | t2_id |",
+ "+-------+---------+-------+",
+ "| 11 | a | 55 |",
+ "+-------+---------+-------+",
+ ];
+
+ let results = execute_to_batches(&ctx, sql).await;
+ assert_batches_sorted_eq!(expected, &results);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn type_coercion_join_with_filter_and_equi_expr() -> Result<()> {
+ let ctx = create_join_context("t1_id", "t2_id", false)?;
+
+ let sql = "select t1.t1_id, t1.t1_name, t2.t2_id \
+ from t1 \
+ inner join t2 \
+ on t1.t1_id * 5 = t2.t2_id and t1.t1_id * 4 < t2.t2_id";
+
+ // assert logical plan
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let plan = ctx
+ .create_logical_plan(&("explain ".to_owned() + sql))
+ .expect(&msg);
+ let state = ctx.state();
+ let plan = state.optimize(&plan)?;
+
+ let expected = vec![
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Projection: t1.t1_id, t1.t1_name, t2.t2_id [t1_id:UInt32;N,
t1_name:Utf8;N, t2_id:UInt32;N]",
+ " Inner Join: CAST(t1.t1_id AS Int64) * Int64(5) = CAST(t2.t2_id AS
Int64) Filter: CAST(t1.t1_id AS Int64) * Int64(4) < CAST(t2.t2_id AS Int64)
[t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
Review Comment:
👍
##########
datafusion/core/tests/sql/joins.rs:
##########
@@ -2778,3 +2781,141 @@ async fn select_wildcard_with_expr_key_inner_join() ->
Result<()> {
Ok(())
}
+
+#[tokio::test]
+async fn join_with_type_coercion_for_equi_expr() -> Result<()> {
+ let ctx = create_join_context("t1_id", "t2_id", false)?;
+
+ let sql = "select t1.t1_id, t1.t1_name, t2.t2_id from t1 inner join t2 on
t1.t1_id + 11 = t2.t2_id";
+
+ // assert logical plan
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let plan = ctx
+ .create_logical_plan(&("explain ".to_owned() + sql))
+ .expect(&msg);
+ let state = ctx.state();
+ let plan = state.optimize(&plan)?;
+
+ let expected = vec![
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Projection: t1.t1_id, t1.t1_name, t2.t2_id [t1_id:UInt32;N,
t1_name:Utf8;N, t2_id:UInt32;N]",
+ " Inner Join: CAST(t1.t1_id AS Int64) + Int64(11) = CAST(t2.t2_id
AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
Review Comment:
I think eventually it would be great to have these casts unwrapped too, like
```suggestion
" Inner Join: t1.t1_id + Int32(11) = t2.t2_id [t1_id:UInt32;N,
t1_name:Utf8;N, t2_id:UInt32;N]",
```
To avoid the runtime casting
I am not quite sure why
https://github.com/apache/arrow-datafusion/blob/master/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
is not doing so
##########
datafusion/expr/src/utils.rs:
##########
@@ -478,20 +479,33 @@ pub fn from_plan(
}) => {
let schema =
build_join_schema(inputs[0].schema(), inputs[1].schema(),
join_type)?;
+
+ let equi_expr_count = on.len();
+ // The preceding part of expr is equi-exprs,
+ // and the struct of each equi-expr is like `left-expr =
right-expr`.
+ let new_on:Vec<(Expr,Expr)> =
expr.iter().take(equi_expr_count).map(|equi_expr| {
Review Comment:
I think we should error here if `expr` does not has at least
`equi_expr_count` elements left. Otherwise I think `take` will silently return
fewer than `equi_expr_count` elements, which might result in quite hard to
track down bugs
https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.take
##########
datafusion/core/tests/sql/joins.rs:
##########
@@ -1448,11 +1448,11 @@ async fn hash_join_with_decimal() -> Result<()> {
let state = ctx.state();
let plan = state.optimize(&plan)?;
let expected = vec![
- "Explain [plan_type:Utf8, plan:Utf8]",
- " Projection: t1.c1, t1.c2, t1.c3, t1.c4, t2.c1, t2.c2, t2.c3, t2.c4
[c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N,
c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32,
Utf8);N]",
- " Right Join: t1.c3 = t2.c3 [c1:Date32;N, c2:Date64;N, c3:Decimal128(5,
2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal128(10,
2);N, c4:Dictionary(Int32, Utf8);N]",
Review Comment:
what happened previously with this plan? Would it error at runtime?
##########
datafusion/expr/src/logical_plan/plan.rs:
##########
@@ -253,9 +253,12 @@ impl LogicalPlan {
aggr_expr,
..
}) => group_expr.iter().chain(aggr_expr.iter()).cloned().collect(),
+ // There are two part of expression for join, equijoin(on) and
non-equijoin(filter).
+ // 1. the first part is `on.len()` equijoin expressions, and the
struct of each expr is `left-on = right-on`.
+ // 2. the second part is non-equijoin(filter).
LogicalPlan::Join(Join { on, filter, .. }) => on
.iter()
- .flat_map(|(l, r)| vec![l.clone(), r.clone()])
+ .map(|(l, r)| Expr::eq(l.clone(), r.clone()))
Review Comment:
This is the fix, right? It then exposes the `<l> = <r>` expr to the existing
type coercion logic ?
Very nice 👍
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]