This is an automated email from the ASF dual-hosted git repository.

liukun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 7afd43773 Use coerced type in inlist expr planning (#2794)
7afd43773 is described below

commit 7afd4377309eba65100b7bfc07e0c6cd5bf3d780
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Mon Jun 27 00:51:01 2022 -0700

    Use coerced type in inlist expr planning (#2794)
    
    * Use coerced type in inlist expr planning
    
    * Add coerce rule
    
    * Update datafusion/physical-expr/src/planner.rs
    
    Co-authored-by: Andrew Lamb <[email protected]>
    
    * Update datafusion/physical-expr/src/planner.rs
    
    Co-authored-by: Andrew Lamb <[email protected]>
    
    * Update datafusion/physical-expr/src/planner.rs
    
    Co-authored-by: Andrew Lamb <[email protected]>
    
    * Fix test
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/core/src/physical_plan/planner.rs       | 11 +--
 datafusion/expr/src/binary_rule.rs                 | 12 +++
 .../physical-expr/src/expressions/in_list.rs       | 58 +------------
 datafusion/physical-expr/src/planner.rs            | 98 +++++++++++++---------
 4 files changed, 79 insertions(+), 100 deletions(-)

diff --git a/datafusion/core/src/physical_plan/planner.rs 
b/datafusion/core/src/physical_plan/planner.rs
index 527638668..4a8ad5d77 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -1739,8 +1739,6 @@ mod tests {
             col("c1").and(col("c1")),
             // u8 AND u8
             col("c3").and(col("c3")),
-            // utf8 = u32
-            col("c1").eq(col("c2")),
             // utf8 = bool
             col("c1").eq(bool_expr.clone()),
             // u32 AND bool
@@ -1842,7 +1840,7 @@ mod tests {
             .build()?;
         let execution_plan = plan(&logical_plan).await?;
         // verify that the plan correctly adds cast from Int64(1) to Utf8
-        let expected = "InListExpr { expr: Column { name: \"c1\", index: 0 }, 
list: [Literal { value: Utf8(\"a\") }, CastExpr { expr: Literal { value: 
Int64(1) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }], 
negated: false, set: None }";
+        let expected = "expr: [(InListExpr { expr: Column { name: \"c1\", 
index: 0 }, list: [Literal { value: Utf8(\"a\") }, TryCastExpr { expr: Literal 
{ value: Int64(1) }, cast_type: Utf8 }], negated: false, set: None }";
         assert!(format!("{:?}", execution_plan).contains(expected));
 
         // expression: "a in (struct::null, 'a')"
@@ -1857,8 +1855,7 @@ mod tests {
         let execution_plan = plan(&logical_plan).await;
 
         let e = execution_plan.unwrap_err().to_string();
-        assert_contains!(&e, "Unsupported CAST from Struct");
-        assert_contains!(&e, "to Boolean");
+        assert_contains!(&e, "Can not find compatible types to compare Boolean 
with [Struct([Field { name: \"foo\", data_type: Boolean, nullable: false, 
dict_id: 0, dict_is_ordered: false, metadata: None }]), Utf8]");
 
         Ok(())
     }
@@ -1887,7 +1884,7 @@ mod tests {
             .project(vec![col("c1").in_list(list, false)])?
             .build()?;
         let execution_plan = plan(&logical_plan).await?;
-        let expected = "expr: [(InListExpr { expr: Column { name: \"c1\", 
index: 0 }, list: [Literal { value: Utf8(\"a\") }, CastExpr { expr: Literal { 
value: Int64(1) }, cast_type: Utf8, cast_options: CastOptions { safe: false } 
}, CastExpr { expr: Literal { value: Int64(2) }, cast_type: Utf8, cast_options: 
CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(3) }, 
cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: 
Literal { value: Int64(4)  [...]
+        let expected = "expr: [(InListExpr { expr: Column { name: \"c1\", 
index: 0 }, list: [Literal { value: Utf8(\"a\") }, TryCastExpr { expr: Literal 
{ value: Int64(1) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: 
Int64(2) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(3) }, 
cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(4) }, cast_type: 
Utf8 }, TryCastExpr { expr: Literal { value: Int64(5) }, cast_type: Utf8 }, 
TryCastExpr { expr: Literal [...]
         assert!(format!("{:?}", execution_plan).contains(expected));
         Ok(())
     }
@@ -1906,7 +1903,7 @@ mod tests {
             .project(vec![col("c1").in_list(list, false)])?
             .build()?;
         let execution_plan = plan(&logical_plan).await?;
-        let expected = "expr: [(InListExpr { expr: Column { name: \"c1\", 
index: 0 }, list: [CastExpr { expr: Literal { value: Int64(NULL) }, cast_type: 
Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal { 
value: Int64(1) }, cast_type: Utf8, cast_options: CastOptions { safe: false } 
}, CastExpr { expr: Literal { value: Int64(2) }, cast_type: Utf8, cast_options: 
CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(3) }, 
cast_type: Utf8, cast_opti [...]
+        let expected = "expr: [(InListExpr { expr: Column { name: \"c1\", 
index: 0 }, list: [TryCastExpr { expr: Literal { value: Int64(NULL) }, 
cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(1) }, cast_type: 
Utf8 }, TryCastExpr { expr: Literal { value: Int64(2) }, cast_type: Utf8 }, 
TryCastExpr { expr: Literal { value: Int64(3) }, cast_type: Utf8 }, TryCastExpr 
{ expr: Literal { value: Int64(4) }, cast_type: Utf8 }, TryCastExpr { expr: 
Literal { value: Int64(5) }, cast_ty [...]
         assert!(format!("{:?}", execution_plan).contains(expected));
         Ok(())
     }
diff --git a/datafusion/expr/src/binary_rule.rs 
b/datafusion/expr/src/binary_rule.rs
index 88b4d95ec..b7b2c57e8 100644
--- a/datafusion/expr/src/binary_rule.rs
+++ b/datafusion/expr/src/binary_rule.rs
@@ -166,6 +166,7 @@ pub fn comparison_eq_coercion(
         .or_else(|| temporal_coercion(lhs_type, rhs_type))
         .or_else(|| string_coercion(lhs_type, rhs_type))
         .or_else(|| null_coercion(lhs_type, rhs_type))
+        .or_else(|| string_numeric_coercion(lhs_type, rhs_type))
 }
 
 fn comparison_order_coercion(
@@ -185,6 +186,17 @@ fn comparison_order_coercion(
         .or_else(|| null_coercion(lhs_type, rhs_type))
 }
 
+fn string_numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> 
Option<DataType> {
+    use arrow::datatypes::DataType::*;
+    match (lhs_type, rhs_type) {
+        (Utf8, _) if DataType::is_numeric(rhs_type) => Some(Utf8),
+        (LargeUtf8, _) if DataType::is_numeric(rhs_type) => Some(LargeUtf8),
+        (_, Utf8) if DataType::is_numeric(lhs_type) => Some(Utf8),
+        (_, LargeUtf8) if DataType::is_numeric(lhs_type) => Some(LargeUtf8),
+        _ => None,
+    }
+}
+
 fn comparison_binary_numeric_coercion(
     lhs_type: &DataType,
     rhs_type: &DataType,
diff --git a/datafusion/physical-expr/src/expressions/in_list.rs 
b/datafusion/physical-expr/src/expressions/in_list.rs
index 346eea472..a0448a1f1 100644
--- a/datafusion/physical-expr/src/expressions/in_list.rs
+++ b/datafusion/physical-expr/src/expressions/in_list.rs
@@ -33,14 +33,12 @@ use arrow::{
     record_batch::RecordBatch,
 };
 
-use crate::expressions::try_cast;
 use crate::{expressions, PhysicalExpr};
 use arrow::array::*;
 use arrow::buffer::{Buffer, MutableBuffer};
 use datafusion_common::ScalarValue;
 use datafusion_common::ScalarValue::Decimal128;
 use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::binary_rule::comparison_eq_coercion;
 use datafusion_expr::ColumnarValue;
 
 /// Size at which to use a Set rather than Vec for `IN` / `NOT IN`
@@ -745,63 +743,13 @@ impl PhysicalExpr for InListExpr {
     }
 }
 
-type InListCastResult = (Arc<dyn PhysicalExpr>, Vec<Arc<dyn PhysicalExpr>>);
-
 /// Creates a unary expression InList
 pub fn in_list(
     expr: Arc<dyn PhysicalExpr>,
     list: Vec<Arc<dyn PhysicalExpr>>,
     negated: &bool,
-    input_schema: &Schema,
 ) -> Result<Arc<dyn PhysicalExpr>> {
-    let (cast_expr, cast_list) = in_list_cast(expr, list, input_schema)?;
-    Ok(Arc::new(InListExpr::new(cast_expr, cast_list, *negated)))
-}
-
-fn in_list_cast(
-    expr: Arc<dyn PhysicalExpr>,
-    list: Vec<Arc<dyn PhysicalExpr>>,
-    input_schema: &Schema,
-) -> Result<InListCastResult> {
-    let expr_type = &expr.data_type(input_schema)?;
-    let list_types: Vec<DataType> = list
-        .iter()
-        .map(|list_expr| list_expr.data_type(input_schema).unwrap())
-        .collect();
-    // TODO in the arrow-rs, should support NULL type to Decimal Data type
-    // TODO support in the arrow-rs, NULL value cast to Decimal Value
-    // https://github.com/apache/arrow-datafusion/issues/2759
-    let result_type = get_coerce_type(expr_type, &list_types);
-    match result_type {
-        None => Err(DataFusionError::Internal(format!(
-            "In expr can find the coerced type for {:?} in {:?}",
-            expr_type, list_types
-        ))),
-        Some(data_type) => {
-            // find the coerced type
-            let cast_expr = try_cast(expr, input_schema, data_type.clone())?;
-            let cast_list_expr = list
-                .into_iter()
-                .map(|list_expr| {
-                    try_cast(list_expr, input_schema, 
data_type.clone()).unwrap()
-                })
-                .collect();
-            Ok((cast_expr, cast_list_expr))
-        }
-    }
-}
-
-fn get_coerce_type(expr_type: &DataType, list_type: &[DataType]) -> 
Option<DataType> {
-    // get the equal coerced data type
-    list_type
-        .iter()
-        .fold(Some(expr_type.clone()), |left, right_type| {
-            match left {
-                None => None,
-                // TODO refactor a framework to do the data type coercion
-                Some(left_type) => comparison_eq_coercion(&left_type, 
right_type),
-            }
-        })
+    Ok(Arc::new(InListExpr::new(expr, list, *negated)))
 }
 
 #[cfg(test)]
@@ -810,12 +758,14 @@ mod tests {
 
     use super::*;
     use crate::expressions::{col, lit};
+    use crate::planner::in_list_cast;
     use datafusion_common::Result;
 
     // applies the in_list expr to an input batch and list
     macro_rules! in_list {
         ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr, $COL:expr, 
$SCHEMA:expr) => {{
-            let expr = in_list($COL, $LIST, $NEGATED, $SCHEMA).unwrap();
+            let (cast_expr, cast_list_exprs) = in_list_cast($COL, $LIST, 
$SCHEMA)?;
+            let expr = in_list(cast_expr, cast_list_exprs, $NEGATED).unwrap();
             let result = expr.evaluate(&$BATCH)?.into_array($BATCH.num_rows());
             let result = result
                 .as_any()
diff --git a/datafusion/physical-expr/src/planner.rs 
b/datafusion/physical-expr/src/planner.rs
index 26583cd28..d8a7a3004 100644
--- a/datafusion/physical-expr/src/planner.rs
+++ b/datafusion/physical-expr/src/planner.rs
@@ -15,6 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use crate::expressions::try_cast;
 use crate::{
     execution_props::ExecutionProps,
     expressions::{
@@ -24,11 +25,9 @@ use crate::{
     var_provider::VarType,
     PhysicalExpr,
 };
-use arrow::{
-    compute::can_cast_types,
-    datatypes::{DataType, Schema},
-};
+use arrow::datatypes::{DataType, Schema};
 use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue};
+use datafusion_expr::binary_rule::comparison_eq_coercion;
 use datafusion_expr::{Expr, Operator};
 use std::sync::Arc;
 
@@ -283,7 +282,6 @@ pub fn create_physical_expr(
                     input_schema,
                     execution_props,
                 )?;
-                let value_expr_data_type = value_expr.data_type(input_schema)?;
 
                 let list_exprs = list
                     .iter()
@@ -294,43 +292,18 @@ pub fn create_physical_expr(
                             input_schema,
                             execution_props,
                         ),
-                        // TODO refactor the logic of coercion the data type
-                        // data type in the `list expr` may be conflict with 
`value expr`,
-                        // we should not just compare data type between `value 
expr` with each `list expr`.
-                        _ => {
-                            let list_expr = create_physical_expr(
-                                expr,
-                                input_dfschema,
-                                input_schema,
-                                execution_props,
-                            )?;
-                            let list_expr_data_type =
-                                list_expr.data_type(input_schema)?;
-
-                            if list_expr_data_type == value_expr_data_type {
-                                Ok(list_expr)
-                            } else if can_cast_types(
-                                &list_expr_data_type,
-                                &value_expr_data_type,
-                            ) {
-                                // TODO: Can't cast from list type to value 
type directly
-                                // We should use the coercion rule to get the 
common data type
-                                expressions::cast(
-                                    list_expr,
-                                    input_schema,
-                                    value_expr.data_type(input_schema)?,
-                                )
-                            } else {
-                                Err(DataFusionError::Plan(format!(
-                                    "Unsupported CAST from {:?} to {:?}",
-                                    list_expr_data_type, value_expr_data_type
-                                )))
-                            }
-                        }
+                        _ => create_physical_expr(
+                            expr,
+                            input_dfschema,
+                            input_schema,
+                            execution_props,
+                        ),
                     })
                     .collect::<Result<Vec<_>>>()?;
 
-                expressions::in_list(value_expr, list_exprs, negated, 
input_schema)
+                let (cast_expr, cast_list_exprs) =
+                    in_list_cast(value_expr, list_exprs, input_schema)?;
+                expressions::in_list(cast_expr, cast_list_exprs, negated)
             }
         },
         other => Err(DataFusionError::NotImplemented(format!(
@@ -339,3 +312,50 @@ pub fn create_physical_expr(
         ))),
     }
 }
+
+type InListCastResult = (Arc<dyn PhysicalExpr>, Vec<Arc<dyn PhysicalExpr>>);
+
+pub(crate) fn in_list_cast(
+    expr: Arc<dyn PhysicalExpr>,
+    list: Vec<Arc<dyn PhysicalExpr>>,
+    input_schema: &Schema,
+) -> Result<InListCastResult> {
+    let expr_type = &expr.data_type(input_schema)?;
+    let list_types: Vec<DataType> = list
+        .iter()
+        .map(|list_expr| list_expr.data_type(input_schema).unwrap())
+        .collect();
+    let result_type = get_coerce_type(expr_type, &list_types);
+    match result_type {
+        None => Err(DataFusionError::Plan(format!(
+            "Can not find compatible types to compare {:?} with {:?}",
+            expr_type, list_types
+        ))),
+        Some(data_type) => {
+            // find the coerced type
+            let cast_expr = try_cast(expr, input_schema, data_type.clone())?;
+            let cast_list_expr = list
+                .into_iter()
+                .map(|list_expr| {
+                    try_cast(list_expr, input_schema, 
data_type.clone()).unwrap()
+                })
+                .collect();
+            Ok((cast_expr, cast_list_expr))
+        }
+    }
+}
+
+/// Attempts to coerce the types of `list_type` to be comparable with the
+/// `expr_type`
+fn get_coerce_type(expr_type: &DataType, list_type: &[DataType]) -> 
Option<DataType> {
+    // get the equal coerced data type
+    list_type
+        .iter()
+        .fold(Some(expr_type.clone()), |left, right_type| {
+            match left {
+                None => None,
+                // TODO refactor a framework to do the data type coercion
+                Some(left_type) => comparison_eq_coercion(&left_type, 
right_type),
+            }
+        })
+}

Reply via email to