alamb commented on code in PR #3396: URL: https://github.com/apache/arrow-datafusion/pull/3396#discussion_r979222535
########## datafusion/expr/src/binary_rule.rs: ########## @@ -801,4 +804,230 @@ mod tests { Some(rhs_type.clone()) ); } + + macro_rules! test_coercion_binary_rule { + ($A_TYPE:expr, $B_TYPE:expr, $OP:expr, $C_TYPE:expr) => {{ + let result = coerce_types(&$A_TYPE, &$OP, &$B_TYPE)?; + assert_eq!(result, $C_TYPE); + }}; + } + + #[test] + fn test_type_coercion() -> Result<()> { + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Utf8, + Operator::Like, + DataType::Utf8 + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Date32, + Operator::Eq, + DataType::Date32 + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Date64, + Operator::Lt, + DataType::Date64 + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Utf8, + Operator::RegexMatch, + DataType::Utf8 + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Utf8, + Operator::RegexNotMatch, + DataType::Utf8 + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Utf8, + Operator::RegexNotIMatch, + DataType::Utf8 + ); + test_coercion_binary_rule!( + DataType::Int16, + DataType::Int64, + Operator::BitwiseAnd, + DataType::Int64 + ); + Ok(()) + } + + #[test] + fn test_type_coercion_arithmetic() -> Result<()> { + // integer + test_coercion_binary_rule!( + DataType::Int32, + DataType::UInt32, + Operator::Plus, + DataType::Int32 + ); + test_coercion_binary_rule!( + DataType::Int32, + DataType::UInt16, + Operator::Minus, + DataType::Int32 + ); + test_coercion_binary_rule!( + DataType::Int8, + DataType::Int64, + Operator::Multiply, + DataType::Int64 + ); + // float + test_coercion_binary_rule!( + DataType::Float32, + DataType::Int32, + Operator::Plus, + DataType::Float32 + ); + test_coercion_binary_rule!( + DataType::Float32, + DataType::Float64, + Operator::Multiply, + DataType::Float64 + ); + // decimal + // bug: https://github.com/apache/arrow-datafusion/issues/3387 will be fixed in the next pr + // test_coercion_binary_rule!( + // DataType::Decimal128(10, 2), + // DataType::Decimal128(10, 2), + // Operator::Plus, + // DataType::Decimal128(11, 2) + // ); + test_coercion_binary_rule!( + DataType::Int32, + DataType::Decimal128(10, 2), + Operator::Plus, + DataType::Decimal128(13, 2) + ); + test_coercion_binary_rule!( + DataType::Int32, + DataType::Decimal128(10, 2), + Operator::Minus, + DataType::Decimal128(13, 2) + ); + test_coercion_binary_rule!( + DataType::Int32, + DataType::Decimal128(10, 2), + Operator::Multiply, + DataType::Decimal128(21, 2) + ); + test_coercion_binary_rule!( + DataType::Int32, + DataType::Decimal128(10, 2), + Operator::Divide, + DataType::Decimal128(23, 11) + ); + test_coercion_binary_rule!( + DataType::Int32, + DataType::Decimal128(10, 2), + Operator::Modulo, + DataType::Decimal128(10, 2) + ); + // TODO add other data type + Ok(()) + } + + #[test] + fn test_type_coercion_compare() -> Result<()> { + // boolean + test_coercion_binary_rule!( + DataType::Boolean, + DataType::Boolean, + Operator::Eq, + DataType::Boolean + ); + // float + test_coercion_binary_rule!( + DataType::Float32, + DataType::Int64, + Operator::Eq, + DataType::Float32 + ); + test_coercion_binary_rule!( + DataType::Float32, + DataType::Float64, + Operator::GtEq, + DataType::Float64 + ); + // signed integer + test_coercion_binary_rule!( + DataType::Int8, + DataType::Int32, + Operator::LtEq, + DataType::Int32 + ); + test_coercion_binary_rule!( + DataType::Int64, + DataType::Int32, + Operator::LtEq, + DataType::Int64 + ); + // unsigned integer + test_coercion_binary_rule!( + DataType::UInt32, + DataType::UInt8, + Operator::Gt, + DataType::UInt32 + ); + // numeric/decimal + test_coercion_binary_rule!( + DataType::Int64, + DataType::Decimal128(10, 0), + Operator::Eq, + DataType::Decimal128(20, 0) + ); + test_coercion_binary_rule!( + DataType::Int64, + DataType::Decimal128(10, 2), + Operator::Lt, + DataType::Decimal128(22, 2) + ); + test_coercion_binary_rule!( + DataType::Float64, + DataType::Decimal128(10, 3), + Operator::Gt, + DataType::Decimal128(30, 15) + ); + test_coercion_binary_rule!( + DataType::Int64, + DataType::Decimal128(10, 0), + Operator::Eq, + DataType::Decimal128(20, 0) + ); + test_coercion_binary_rule!( + DataType::Decimal128(14, 2), + DataType::Decimal128(10, 3), + Operator::GtEq, + DataType::Decimal128(15, 3) + ); + + // TODO add other data type + Ok(()) + } + + #[test] + fn test_type_coercion_logical_op() -> Result<()> { + test_coercion_binary_rule!( + DataType::Boolean, + DataType::Boolean, + Operator::And, + DataType::Boolean + ); + + test_coercion_binary_rule!( + DataType::Boolean, + DataType::Boolean, + Operator::Or, + DataType::Boolean Review Comment: Is there any other types that are coerced to boolean for logical operations? Or are the tests for boolean just showing that boolean types are not changed when coerced? ########## datafusion/optimizer/src/simplify_expressions.rs: ########## @@ -804,24 +817,28 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { // // Rules for Between // + // TODO https://github.com/apache/arrow-datafusion/issues/3587 Review Comment: makes sense -- thank you for the comments ########## datafusion/physical-expr/src/expressions/binary.rs: ########## @@ -1080,7 +1071,7 @@ mod tests { Operator::Plus, Int32Array, DataType::Int32, - vec![2i32, 4i32] + vec![2i32, 4i32], ); test_coercion!( Review Comment: I wonder if these tests are doing anything useful anymore now that we are coercing in the logical layer? 🤔 It seems like they are now testing the test code 😆 ########## datafusion/optimizer/src/simplify_expressions.rs: ########## @@ -1573,13 +1597,11 @@ mod tests { low: Box::new(lit(0)), high: Box::new(lit(10)), }; + let between_expr = expr.clone(); let expr = expr.or(lit_bool_null()); let result = simplify(expr); - let expected_expr = or( - and(col("c1").gt_eq(lit(0)), col("c1").lt_eq(lit(10))), - lit_bool_null(), - ); + let expected_expr = or(between_expr, lit_bool_null()); Review Comment: ```suggestion let expected_expr = or(between_expr, lit_bool_null()); // let expected_expr = or( // and(col("c1").gt_eq(lit(0)), col("c1").lt_eq(lit(10))), // lit_bool_null(), //); ``` ########## datafusion/expr/src/binary_rule.rs: ########## @@ -801,4 +804,230 @@ mod tests { Some(rhs_type.clone()) ); } + + macro_rules! test_coercion_binary_rule { + ($A_TYPE:expr, $B_TYPE:expr, $OP:expr, $C_TYPE:expr) => {{ + let result = coerce_types(&$A_TYPE, &$OP, &$B_TYPE)?; + assert_eq!(result, $C_TYPE); + }}; + } + + #[test] + fn test_type_coercion() -> Result<()> { + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Utf8, + Operator::Like, + DataType::Utf8 + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Date32, + Operator::Eq, + DataType::Date32 + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Date64, + Operator::Lt, + DataType::Date64 + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Utf8, + Operator::RegexMatch, + DataType::Utf8 + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Utf8, + Operator::RegexNotMatch, + DataType::Utf8 + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Utf8, + Operator::RegexNotIMatch, + DataType::Utf8 + ); + test_coercion_binary_rule!( + DataType::Int16, + DataType::Int64, + Operator::BitwiseAnd, + DataType::Int64 + ); + Ok(()) + } + + #[test] + fn test_type_coercion_arithmetic() -> Result<()> { + // integer + test_coercion_binary_rule!( + DataType::Int32, + DataType::UInt32, + Operator::Plus, + DataType::Int32 + ); + test_coercion_binary_rule!( + DataType::Int32, + DataType::UInt16, + Operator::Minus, + DataType::Int32 + ); + test_coercion_binary_rule!( + DataType::Int8, + DataType::Int64, + Operator::Multiply, + DataType::Int64 + ); + // float + test_coercion_binary_rule!( + DataType::Float32, + DataType::Int32, + Operator::Plus, + DataType::Float32 + ); + test_coercion_binary_rule!( + DataType::Float32, + DataType::Float64, + Operator::Multiply, + DataType::Float64 + ); + // decimal + // bug: https://github.com/apache/arrow-datafusion/issues/3387 will be fixed in the next pr + // test_coercion_binary_rule!( + // DataType::Decimal128(10, 2), + // DataType::Decimal128(10, 2), + // Operator::Plus, + // DataType::Decimal128(11, 2) + // ); + test_coercion_binary_rule!( + DataType::Int32, + DataType::Decimal128(10, 2), + Operator::Plus, + DataType::Decimal128(13, 2) + ); + test_coercion_binary_rule!( + DataType::Int32, + DataType::Decimal128(10, 2), + Operator::Minus, + DataType::Decimal128(13, 2) + ); + test_coercion_binary_rule!( + DataType::Int32, + DataType::Decimal128(10, 2), + Operator::Multiply, + DataType::Decimal128(21, 2) + ); + test_coercion_binary_rule!( + DataType::Int32, + DataType::Decimal128(10, 2), + Operator::Divide, + DataType::Decimal128(23, 11) + ); + test_coercion_binary_rule!( + DataType::Int32, + DataType::Decimal128(10, 2), + Operator::Modulo, + DataType::Decimal128(10, 2) + ); + // TODO add other data type + Ok(()) + } + + #[test] + fn test_type_coercion_compare() -> Result<()> { + // boolean + test_coercion_binary_rule!( + DataType::Boolean, + DataType::Boolean, + Operator::Eq, + DataType::Boolean + ); + // float + test_coercion_binary_rule!( + DataType::Float32, + DataType::Int64, + Operator::Eq, + DataType::Float32 + ); + test_coercion_binary_rule!( + DataType::Float32, + DataType::Float64, + Operator::GtEq, + DataType::Float64 + ); + // signed integer + test_coercion_binary_rule!( + DataType::Int8, + DataType::Int32, + Operator::LtEq, + DataType::Int32 + ); + test_coercion_binary_rule!( + DataType::Int64, + DataType::Int32, + Operator::LtEq, + DataType::Int64 + ); + // unsigned integer + test_coercion_binary_rule!( + DataType::UInt32, + DataType::UInt8, + Operator::Gt, + DataType::UInt32 + ); + // numeric/decimal + test_coercion_binary_rule!( + DataType::Int64, + DataType::Decimal128(10, 0), + Operator::Eq, + DataType::Decimal128(20, 0) + ); + test_coercion_binary_rule!( + DataType::Int64, + DataType::Decimal128(10, 2), + Operator::Lt, + DataType::Decimal128(22, 2) + ); + test_coercion_binary_rule!( + DataType::Float64, + DataType::Decimal128(10, 3), + Operator::Gt, + DataType::Decimal128(30, 15) + ); + test_coercion_binary_rule!( + DataType::Int64, + DataType::Decimal128(10, 0), + Operator::Eq, + DataType::Decimal128(20, 0) + ); + test_coercion_binary_rule!( + DataType::Decimal128(14, 2), + DataType::Decimal128(10, 3), + Operator::GtEq, + DataType::Decimal128(15, 3) + ); + Review Comment: I reviewed the tests cases in this file carefully. Very nice 👌 ########## datafusion/expr/src/binary_rule.rs: ########## @@ -801,4 +804,230 @@ mod tests { Some(rhs_type.clone()) ); } + + macro_rules! test_coercion_binary_rule { + ($A_TYPE:expr, $B_TYPE:expr, $OP:expr, $C_TYPE:expr) => {{ + let result = coerce_types(&$A_TYPE, &$OP, &$B_TYPE)?; + assert_eq!(result, $C_TYPE); + }}; + } + + #[test] + fn test_type_coercion() -> Result<()> { + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Utf8, + Operator::Like, + DataType::Utf8 + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Date32, + Operator::Eq, + DataType::Date32 + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Date64, + Operator::Lt, + DataType::Date64 + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Utf8, + Operator::RegexMatch, + DataType::Utf8 + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Utf8, + Operator::RegexNotMatch, + DataType::Utf8 + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Utf8, + Operator::RegexNotIMatch, + DataType::Utf8 + ); + test_coercion_binary_rule!( + DataType::Int16, + DataType::Int64, + Operator::BitwiseAnd, + DataType::Int64 + ); + Ok(()) + } + + #[test] + fn test_type_coercion_arithmetic() -> Result<()> { + // integer + test_coercion_binary_rule!( + DataType::Int32, + DataType::UInt32, + Operator::Plus, + DataType::Int32 + ); + test_coercion_binary_rule!( + DataType::Int32, + DataType::UInt16, + Operator::Minus, + DataType::Int32 + ); + test_coercion_binary_rule!( + DataType::Int8, + DataType::Int64, + Operator::Multiply, + DataType::Int64 + ); + // float + test_coercion_binary_rule!( + DataType::Float32, + DataType::Int32, + Operator::Plus, + DataType::Float32 + ); + test_coercion_binary_rule!( + DataType::Float32, + DataType::Float64, + Operator::Multiply, + DataType::Float64 + ); + // decimal + // bug: https://github.com/apache/arrow-datafusion/issues/3387 will be fixed in the next pr + // test_coercion_binary_rule!( + // DataType::Decimal128(10, 2), + // DataType::Decimal128(10, 2), + // Operator::Plus, + // DataType::Decimal128(11, 2) + // ); + test_coercion_binary_rule!( + DataType::Int32, + DataType::Decimal128(10, 2), + Operator::Plus, + DataType::Decimal128(13, 2) + ); + test_coercion_binary_rule!( + DataType::Int32, + DataType::Decimal128(10, 2), + Operator::Minus, + DataType::Decimal128(13, 2) + ); + test_coercion_binary_rule!( + DataType::Int32, + DataType::Decimal128(10, 2), + Operator::Multiply, + DataType::Decimal128(21, 2) + ); + test_coercion_binary_rule!( + DataType::Int32, + DataType::Decimal128(10, 2), + Operator::Divide, + DataType::Decimal128(23, 11) + ); + test_coercion_binary_rule!( + DataType::Int32, + DataType::Decimal128(10, 2), + Operator::Modulo, + DataType::Decimal128(10, 2) + ); + // TODO add other data type Review Comment: I wonder if we should file another ticket to track this gap? Specifically, a "help wanted" ticket that explained what was needed might encourage some additional contributions ########## datafusion/optimizer/tests/integration-test.rs: ########## @@ -79,31 +79,34 @@ fn intersect() -> Result<()> { #[test] fn between_date32_plus_interval() -> Result<()> { + // TODO: https://github.com/apache/arrow-datafusion/issues/3587 let sql = "SELECT count(1) FROM test \ WHERE col_date32 between '1998-03-18' AND cast('1998-03-18' as date) + INTERVAL '90 days'"; let plan = test_sql(sql)?; let expected = "Projection: #COUNT(UInt8(1))\n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\ - \n Filter: #test.col_date32 >= Date32(\"10303\") AND #test.col_date32 <= Date32(\"10393\")\ + \n Filter: #test.col_date32 BETWEEN Date32(\"10303\") AND Date32(\"10393\")\ \n TableScan: test projection=[col_date32]"; assert_eq!(expected, format!("{:?}", plan)); Ok(()) } #[test] fn between_date64_plus_interval() -> Result<()> { + // TODO: https://github.com/apache/arrow-datafusion/issues/3587 let sql = "SELECT count(1) FROM test \ WHERE col_date64 between '1998-03-18T00:00:00' AND cast('1998-03-18' as date) + INTERVAL '90 days'"; let plan = test_sql(sql)?; let expected = "Projection: #COUNT(UInt8(1))\n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\ - \n Filter: #test.col_date64 >= Date64(\"890179200000\") AND #test.col_date64 <= Date64(\"897955200000\")\ + \n Filter: #test.col_date64 BETWEEN Date64(\"890179200000\") AND Date64(\"897955200000\")\ \n TableScan: test projection=[col_date64]"; assert_eq!(expected, format!("{:?}", plan)); Ok(()) } fn test_sql(sql: &str) -> Result<LogicalPlan> { + // TODO should make align with rules in the context Review Comment: ```suggestion // TODO should make align with rules in the context // https://github.com/apache/arrow-datafusion/issues/3524 ``` ########## datafusion/optimizer/src/simplify_expressions.rs: ########## @@ -360,6 +362,7 @@ pub struct ConstEvaluator<'a> { execution_props: &'a ExecutionProps, input_schema: DFSchema, input_batch: RecordBatch, + type_coercion_helper: TypeCoercionRewriter, Review Comment: ```suggestion // Needed until we ensure type coercion is done before any optimizations // https://github.com/apache/arrow-datafusion/issues/3557 type_coercion_helper: TypeCoercionRewriter, ``` ########## datafusion/core/tests/sql/aggregates.rs: ########## @@ -1834,11 +1834,11 @@ async fn aggregate_avg_add() -> Result<()> { assert_eq!(results.len(), 1); let expected = vec![ - "+--------------+---------------------------+---------------------------+---------------------------+", - "| AVG(test.c1) | AVG(test.c1) + Float64(1) | AVG(test.c1) + Float64(2) | Float64(1) + AVG(test.c1) |", - "+--------------+---------------------------+---------------------------+---------------------------+", - "| 1.5 | 2.5 | 3.5 | 2.5 |", - "+--------------+---------------------------+---------------------------+---------------------------+", + "+--------------+-------------------------+-------------------------+-------------------------+", Review Comment: 👍 ########## datafusion/physical-expr/src/expressions/binary.rs: ########## @@ -892,25 +891,6 @@ impl BinaryExpr { } } -/// return two physical expressions that are optionally coerced to a -/// common type that the binary operator supports. -fn binary_cast( Review Comment: 🎉 ########## datafusion/optimizer/src/type_coercion.rs: ########## @@ -69,14 +67,8 @@ impl OptimizerRule for TypeCoercion { }, ); - let mut execution_props = ExecutionProps::new(); Review Comment: Very nice ########## datafusion/optimizer/src/simplify_expressions.rs: ########## @@ -1611,17 +1635,17 @@ mod tests { low: Box::new(lit(0)), high: Box::new(lit(10)), }; + let between_expr = expr.clone(); let expr = expr.and(lit_bool_null()); let result = simplify(expr); - let expected_expr = and( - and(col("c1").gt_eq(lit(0)), col("c1").lt_eq(lit(10))), - lit_bool_null(), - ); + let expected_expr = and(between_expr, lit_bool_null()); Review Comment: ```suggestion let expected_expr = and(between_expr, lit_bool_null()); // let expected_expr = and( // and(col("c1").gt_eq(lit(0)), col("c1").lt_eq(lit(10))), // lit_bool_null(), // ); ``` ########## datafusion/optimizer/src/simplify_expressions.rs: ########## @@ -1107,6 +1124,12 @@ mod tests { assert_eq!(simplify(expr_eq), lit(true)); } + #[test] + fn test_simplify_with_type_coercion() { + let expr_plus = binary_expr(lit(1_i32), Operator::Plus, lit(1_i64)); + assert_eq!(simplify(expr_plus), lit(2_i64)); Review Comment: Nice ########## datafusion/optimizer/src/simplify_expressions.rs: ########## @@ -1573,13 +1597,11 @@ mod tests { low: Box::new(lit(0)), high: Box::new(lit(10)), }; + let between_expr = expr.clone(); let expr = expr.or(lit_bool_null()); let result = simplify(expr); - let expected_expr = or( - and(col("c1").gt_eq(lit(0)), col("c1").lt_eq(lit(10))), - lit_bool_null(), - ); + let expected_expr = or(between_expr, lit_bool_null()); Review Comment: I suggest keeping the old expected result in the code as comments to make it easier to understand what the desired effect is ########## datafusion/core/src/execution/context.rs: ########## @@ -1452,7 +1452,11 @@ impl SessionState { rules.push(Arc::new(FilterNullJoinKeys::default())); } rules.push(Arc::new(ReduceOuterJoin::new())); + // TODO: https://github.com/apache/arrow-datafusion/issues/3557 Review Comment: 👍 ########## datafusion/core/src/execution/context.rs: ########## @@ -1452,7 +1452,11 @@ impl SessionState { rules.push(Arc::new(FilterNullJoinKeys::default())); } rules.push(Arc::new(ReduceOuterJoin::new())); + // TODO: https://github.com/apache/arrow-datafusion/issues/3557 Review Comment: it makes sense to me that we need to simplify expressons after coercion -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org