This is an automated email from the ASF dual-hosted git repository.

jakevin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 439863f660 refactor: separate get_result_type from `coerce_type` 
(#6221)
439863f660 is described below

commit 439863f6604d6a90a9739eb5dadd428710160937
Author: jakevin <[email protected]>
AuthorDate: Fri May 5 11:11:23 2023 +0800

    refactor: separate get_result_type from `coerce_type` (#6221)
    
    * feat: separate get_result_type and coerce_type
    
    * comment
    
    * Apply suggestions from code review
    
    Co-authored-by: Andrew Lamb <[email protected]>
    
    * fix slt
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/core/src/physical_plan/planner.rs       |  5 +-
 .../core/tests/sqllogictests/test_files/dates.slt  |  2 +-
 .../tests/sqllogictests/test_files/timestamps.slt  |  5 +-
 datafusion/expr/src/type_coercion/binary.rs        | 65 ++++++++++++++++++----
 .../physical-expr/src/expressions/datetime.rs      |  4 +-
 .../physical-expr/src/intervals/cp_solver.rs       |  4 +-
 .../src/intervals/interval_aritmetic.rs            |  6 +-
 7 files changed, 64 insertions(+), 27 deletions(-)

diff --git a/datafusion/core/src/physical_plan/planner.rs 
b/datafusion/core/src/physical_plan/planner.rs
index 458b1b2a21..f685d5a5f4 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -2093,10 +2093,7 @@ mod tests {
         ];
         for case in cases {
             let logical_plan = 
test_csv_scan().await?.project(vec![case.clone()]);
-            let message = format!(
-                "Expression {case:?} expected to error due to impossible 
coercion"
-            );
-            assert!(logical_plan.is_err(), "{}", message);
+            assert!(logical_plan.is_ok());
         }
         Ok(())
     }
diff --git a/datafusion/core/tests/sqllogictests/test_files/dates.slt 
b/datafusion/core/tests/sqllogictests/test_files/dates.slt
index 6ab4730ef4..0cb96d8bc7 100644
--- a/datafusion/core/tests/sqllogictests/test_files/dates.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/dates.slt
@@ -90,7 +90,7 @@ select i_item_desc from test
 where d3_date > now() + '5 days';
 
 # DATE minus DATE
-query error DataFusion error: Error during planning: Date32 \- Date32 can't be 
evaluated because there isn't a common type to coerce the types to
+query error DataFusion error: Error during planning: Unsupported argument 
types\. Can not evaluate Date32 \- Date32
 SELECT DATE '2023-04-09' - DATE '2023-04-02';
 
 # DATE minus Timestamp
diff --git a/datafusion/core/tests/sqllogictests/test_files/timestamps.slt 
b/datafusion/core/tests/sqllogictests/test_files/timestamps.slt
index 16b1e20e8f..e1102d626c 100644
--- a/datafusion/core/tests/sqllogictests/test_files/timestamps.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/timestamps.slt
@@ -1014,7 +1014,7 @@ SELECT ts1 + i FROM foo;
 2003-07-12T01:31:15.000123463
 
 # Timestamp + Timestamp => error
-statement error DataFusion error: Error during planning: 
Timestamp\(Nanosecond, None\) \+ Timestamp\(Nanosecond, None\) can't be 
evaluated because there isn't a common type to coerce the types to
+statement error DataFusion error: Error during planning: Unsupported argument 
types\. Can not evaluate Timestamp\(Nanosecond, None\) \+ 
Timestamp\(Nanosecond, None\)
 SELECT ts1 + ts2
 FROM foo;
 
@@ -1031,8 +1031,7 @@ SELECT '2000-01-01T00:00:00'::timestamp - 
'2010-01-01T00:00:00'::timestamp;
 0 years 0 mons -3653 days 0 hours 0 mins 0.000000000 secs
 
 # Interval - Timestamp => error
-statement error DataFusion error: Error during planning: interval can't 
subtract timestamp/date
-
+statement error DataFusion error: type_coercion\ncaused by\nError during 
planning: interval can't subtract timestamp/date
 SELECT i - ts1 from FOO;
 
 statement ok
diff --git a/datafusion/expr/src/type_coercion/binary.rs 
b/datafusion/expr/src/type_coercion/binary.rs
index 3d88491c68..e4d09aba01 100644
--- a/datafusion/expr/src/type_coercion/binary.rs
+++ b/datafusion/expr/src/type_coercion/binary.rs
@@ -42,7 +42,7 @@ pub fn binary_operator_data_type(
     let result_type = if !any_decimal(lhs_type, rhs_type) {
         // validate that it is possible to perform the operation on incoming 
types.
         // (or the return datatype cannot be inferred)
-        coerce_types(lhs_type, op, rhs_type)?
+        get_result_type(lhs_type, op, rhs_type)?
     } else {
         let (coerced_lhs_type, coerced_rhs_type) =
             math_decimal_coercion(lhs_type, rhs_type);
@@ -106,19 +106,60 @@ pub fn binary_operator_data_type(
     }
 }
 
-/// Coercion rules for all binary operators. Returns the output type
-/// of applying `op` to an argument of `lhs_type` and `rhs_type`.
+/// returns the resulting type of a binary expression evaluating the `op` with 
the left and right hand types
+pub fn get_result_type(
+    lhs_type: &DataType,
+    op: &Operator,
+    rhs_type: &DataType,
+) -> Result<DataType> {
+    let result = match op {
+        Operator::And
+        | Operator::Or
+        | Operator::Eq
+        | Operator::NotEq
+        | Operator::Lt
+        | Operator::Gt
+        | Operator::GtEq
+        | Operator::LtEq
+        | Operator::IsDistinctFrom
+        | Operator::IsNotDistinctFrom => Some(DataType::Boolean),
+        Operator::Plus | Operator::Minus
+            if is_datetime(lhs_type)
+                || is_datetime(rhs_type)
+                || is_interval(lhs_type)
+                || is_interval(rhs_type) =>
+        {
+            temporal_add_sub_coercion(lhs_type, rhs_type, op)
+        }
+        Operator::BitwiseAnd
+        | Operator::BitwiseOr
+        | Operator::BitwiseXor
+        | Operator::BitwiseShiftRight
+        | Operator::BitwiseShiftLeft
+        | Operator::Plus
+        | Operator::Minus
+        | Operator::Modulo
+        | Operator::Divide
+        | Operator::Multiply
+        | Operator::RegexMatch
+        | Operator::RegexIMatch
+        | Operator::RegexNotMatch
+        | Operator::RegexNotIMatch
+        | Operator::StringConcat => coerce_types(lhs_type, op, rhs_type).ok(),
+    };
+
+    match result {
+        None => Err(DataFusionError::Plan(format!(
+            "Unsupported argument types. Can not evaluate {lhs_type:?} {op} 
{rhs_type:?}"
+        ))),
+        Some(t) => Ok(t),
+    }
+}
+
+/// Coercion rules for all binary operators. Returns the 'coerce_types'
+/// is returns the type the arguments should be coerced to
 ///
 /// Returns None if no suitable type can be found.
-///
-/// TODO this function is trying to serve two purposes at once; it
-/// determines the result type of the binary operation and also
-/// determines how the inputs can be coerced but this results in
-/// inconsistencies in some cases (particular around date + interval)
-/// when the input argument types do not match the output argument
-/// types
-///
-/// Tracking issue is <https://github.com/apache/arrow-datafusion/issues/3419>
 pub fn coerce_types(
     lhs_type: &DataType,
     op: &Operator,
diff --git a/datafusion/physical-expr/src/expressions/datetime.rs 
b/datafusion/physical-expr/src/expressions/datetime.rs
index 6e7b8a43fa..976ef845aa 100644
--- a/datafusion/physical-expr/src/expressions/datetime.rs
+++ b/datafusion/physical-expr/src/expressions/datetime.rs
@@ -23,7 +23,7 @@ use arrow::datatypes::{DataType, Schema};
 use arrow::record_batch::RecordBatch;
 
 use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::type_coercion::binary::coerce_types;
+use datafusion_expr::type_coercion::binary::get_result_type;
 use datafusion_expr::{ColumnarValue, Operator};
 use std::any::Any;
 use std::fmt::{Display, Formatter};
@@ -77,7 +77,7 @@ impl PhysicalExpr for DateTimeIntervalExpr {
     }
 
     fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
-        coerce_types(
+        get_result_type(
             &self.lhs.data_type(input_schema)?,
             &Operator::Minus,
             &self.rhs.data_type(input_schema)?,
diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs 
b/datafusion/physical-expr/src/intervals/cp_solver.rs
index 3a682049a0..a1698e6651 100644
--- a/datafusion/physical-expr/src/intervals/cp_solver.rs
+++ b/datafusion/physical-expr/src/intervals/cp_solver.rs
@@ -23,7 +23,7 @@ use std::sync::Arc;
 
 use arrow_schema::DataType;
 use datafusion_common::{Result, ScalarValue};
-use datafusion_expr::type_coercion::binary::coerce_types;
+use datafusion_expr::type_coercion::binary::get_result_type;
 use datafusion_expr::Operator;
 use petgraph::graph::NodeIndex;
 use petgraph::stable_graph::{DefaultIx, StableGraph};
@@ -260,7 +260,7 @@ fn comparison_operator_target(
     op: &Operator,
     right_datatype: &DataType,
 ) -> Result<Interval> {
-    let datatype = coerce_types(left_datatype, &Operator::Minus, 
right_datatype)?;
+    let datatype = get_result_type(left_datatype, &Operator::Minus, 
right_datatype)?;
     let unbounded = IntervalBound::make_unbounded(&datatype)?;
     let zero = ScalarValue::new_zero(&datatype)?;
     Ok(match *op {
diff --git a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs 
b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs
index 9745234a53..3e2b4697a1 100644
--- a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs
+++ b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs
@@ -24,7 +24,7 @@ use std::fmt::{Display, Formatter};
 use arrow::compute::{cast_with_options, CastOptions};
 use arrow::datatypes::DataType;
 use datafusion_common::{DataFusionError, Result, ScalarValue};
-use datafusion_expr::type_coercion::binary::coerce_types;
+use datafusion_expr::type_coercion::binary::get_result_type;
 use datafusion_expr::Operator;
 
 use crate::aggregate::min_max::{max, min};
@@ -82,7 +82,7 @@ impl IntervalBound {
     ) -> Result<IntervalBound> {
         let rhs = other.borrow();
         if self.is_unbounded() || rhs.is_unbounded() {
-            return IntervalBound::make_unbounded(coerce_types(
+            return IntervalBound::make_unbounded(get_result_type(
                 &self.get_datatype(),
                 &Operator::Plus,
                 &rhs.get_datatype(),
@@ -109,7 +109,7 @@ impl IntervalBound {
     ) -> Result<IntervalBound> {
         let rhs = other.borrow();
         if self.is_unbounded() || rhs.is_unbounded() {
-            return IntervalBound::make_unbounded(coerce_types(
+            return IntervalBound::make_unbounded(get_result_type(
                 &self.get_datatype(),
                 &Operator::Minus,
                 &rhs.get_datatype(),

Reply via email to