This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 79431f88 fix: Only delegate to DataFusion cast when we know that it is 
compatible with Spark (#461)
79431f88 is described below

commit 79431f8837f21f91cb24f3a0b4c11a364dcd70f9
Author: Andy Grove <[email protected]>
AuthorDate: Sat May 25 12:48:23 2024 -0600

    fix: Only delegate to DataFusion cast when we know that it is compatible 
with Spark (#461)
    
    * only delegate to DataFusion cast when we know that it is compatible with 
Spark
    
    * add more supported casts
    
    * improve support for dictionary-encoded string arrays
    
    * clippy
    
    * fix merge conflict
    
    * fix a regression
    
    * fix a regression
    
    * fix a regression
    
    * fix regression
    
    * fix regression
    
    * fix regression
    
    * remove TODO comment now that issue has been filed
    
    * remove cast int32/int64 -> decimal from datafusion compatible list
    
    * Revert "remove cast int32/int64 -> decimal from datafusion compatible 
list"
    
    This reverts commit 340e00007575e7f91affa018caa7cd9e2d8964f9.
    
    * add comment
---
 core/src/execution/datafusion/expressions/cast.rs | 182 ++++++++++++++--------
 1 file changed, 115 insertions(+), 67 deletions(-)

diff --git a/core/src/execution/datafusion/expressions/cast.rs 
b/core/src/execution/datafusion/expressions/cast.rs
index fd1f9166..7e8a96f2 100644
--- a/core/src/execution/datafusion/expressions/cast.rs
+++ b/core/src/execution/datafusion/expressions/cast.rs
@@ -503,41 +503,37 @@ impl Cast {
     fn cast_array(&self, array: ArrayRef) -> DataFusionResult<ArrayRef> {
         let to_type = &self.data_type;
         let array = array_with_timezone(array, self.timezone.clone(), 
Some(to_type));
+        let from_type = array.data_type().clone();
+
+        // unpack dictionary string arrays first
+        // TODO: we are unpacking a dictionary-encoded array and then 
performing
+        // the cast. We could potentially improve performance here by casting 
the
+        // dictionary values directly without unpacking the array first, 
although this
+        // would add more complexity to the code
+        let array = match &from_type {
+            DataType::Dictionary(key_type, value_type)
+                if key_type.as_ref() == &DataType::Int32
+                    && (value_type.as_ref() == &DataType::Utf8
+                        || value_type.as_ref() == &DataType::LargeUtf8) =>
+            {
+                cast_with_options(&array, value_type.as_ref(), &CAST_OPTIONS)?
+            }
+            _ => array,
+        };
         let from_type = array.data_type();
+
         let cast_result = match (from_type, to_type) {
             (DataType::Utf8, DataType::Boolean) => {
-                Self::spark_cast_utf8_to_boolean::<i32>(&array, 
self.eval_mode)?
+                Self::spark_cast_utf8_to_boolean::<i32>(&array, self.eval_mode)
             }
             (DataType::LargeUtf8, DataType::Boolean) => {
-                Self::spark_cast_utf8_to_boolean::<i64>(&array, 
self.eval_mode)?
+                Self::spark_cast_utf8_to_boolean::<i64>(&array, self.eval_mode)
             }
             (DataType::Utf8, DataType::Timestamp(_, _)) => {
-                Self::cast_string_to_timestamp(&array, to_type, 
self.eval_mode)?
+                Self::cast_string_to_timestamp(&array, to_type, self.eval_mode)
             }
             (DataType::Utf8, DataType::Date32) => {
-                Self::cast_string_to_date(&array, to_type, self.eval_mode)?
-            }
-            (DataType::Dictionary(key_type, value_type), DataType::Date32)
-                if key_type.as_ref() == &DataType::Int32
-                    && (value_type.as_ref() == &DataType::Utf8
-                        || value_type.as_ref() == &DataType::LargeUtf8) =>
-            {
-                match value_type.as_ref() {
-                    DataType::Utf8 => {
-                        let unpacked_array =
-                            cast_with_options(&array, &DataType::Utf8, 
&CAST_OPTIONS)?;
-                        Self::cast_string_to_date(&unpacked_array, to_type, 
self.eval_mode)?
-                    }
-                    DataType::LargeUtf8 => {
-                        let unpacked_array =
-                            cast_with_options(&array, &DataType::LargeUtf8, 
&CAST_OPTIONS)?;
-                        Self::cast_string_to_date(&unpacked_array, to_type, 
self.eval_mode)?
-                    }
-                    dt => unreachable!(
-                        "{}",
-                        format!("invalid value type {dt} for 
dictionary-encoded string array")
-                    ),
-                }
+                Self::cast_string_to_date(&array, to_type, self.eval_mode)
             }
             (DataType::Int64, DataType::Int32)
             | (DataType::Int64, DataType::Int16)
@@ -547,61 +543,33 @@ impl Cast {
             | (DataType::Int16, DataType::Int8)
                 if self.eval_mode != EvalMode::Try =>
             {
-                Self::spark_cast_int_to_int(&array, self.eval_mode, from_type, 
to_type)?
+                Self::spark_cast_int_to_int(&array, self.eval_mode, from_type, 
to_type)
             }
             (
                 DataType::Utf8,
                 DataType::Int8 | DataType::Int16 | DataType::Int32 | 
DataType::Int64,
-            ) => Self::cast_string_to_int::<i32>(to_type, &array, 
self.eval_mode)?,
+            ) => Self::cast_string_to_int::<i32>(to_type, &array, 
self.eval_mode),
             (
                 DataType::LargeUtf8,
                 DataType::Int8 | DataType::Int16 | DataType::Int32 | 
DataType::Int64,
-            ) => Self::cast_string_to_int::<i64>(to_type, &array, 
self.eval_mode)?,
-            (
-                DataType::Dictionary(key_type, value_type),
-                DataType::Int8 | DataType::Int16 | DataType::Int32 | 
DataType::Int64,
-            ) if key_type.as_ref() == &DataType::Int32
-                && (value_type.as_ref() == &DataType::Utf8
-                    || value_type.as_ref() == &DataType::LargeUtf8) =>
-            {
-                // TODO: we are unpacking a dictionary-encoded array and then 
performing
-                // the cast. We could potentially improve performance here by 
casting the
-                // dictionary values directly without unpacking the array 
first, although this
-                // would add more complexity to the code
-                match value_type.as_ref() {
-                    DataType::Utf8 => {
-                        let unpacked_array =
-                            cast_with_options(&array, &DataType::Utf8, 
&CAST_OPTIONS)?;
-                        Self::cast_string_to_int::<i32>(to_type, 
&unpacked_array, self.eval_mode)?
-                    }
-                    DataType::LargeUtf8 => {
-                        let unpacked_array =
-                            cast_with_options(&array, &DataType::LargeUtf8, 
&CAST_OPTIONS)?;
-                        Self::cast_string_to_int::<i64>(to_type, 
&unpacked_array, self.eval_mode)?
-                    }
-                    dt => unreachable!(
-                        "{}",
-                        format!("invalid value type {dt} for 
dictionary-encoded string array")
-                    ),
-                }
-            }
+            ) => Self::cast_string_to_int::<i64>(to_type, &array, 
self.eval_mode),
             (DataType::Float64, DataType::Utf8) => {
-                Self::spark_cast_float64_to_utf8::<i32>(&array, 
self.eval_mode)?
+                Self::spark_cast_float64_to_utf8::<i32>(&array, self.eval_mode)
             }
             (DataType::Float64, DataType::LargeUtf8) => {
-                Self::spark_cast_float64_to_utf8::<i64>(&array, 
self.eval_mode)?
+                Self::spark_cast_float64_to_utf8::<i64>(&array, self.eval_mode)
             }
             (DataType::Float32, DataType::Utf8) => {
-                Self::spark_cast_float32_to_utf8::<i32>(&array, 
self.eval_mode)?
+                Self::spark_cast_float32_to_utf8::<i32>(&array, self.eval_mode)
             }
             (DataType::Float32, DataType::LargeUtf8) => {
-                Self::spark_cast_float32_to_utf8::<i64>(&array, 
self.eval_mode)?
+                Self::spark_cast_float32_to_utf8::<i64>(&array, self.eval_mode)
             }
             (DataType::Float32, DataType::Decimal128(precision, scale)) => {
-                Self::cast_float32_to_decimal128(&array, *precision, *scale, 
self.eval_mode)?
+                Self::cast_float32_to_decimal128(&array, *precision, *scale, 
self.eval_mode)
             }
             (DataType::Float64, DataType::Decimal128(precision, scale)) => {
-                Self::cast_float64_to_decimal128(&array, *precision, *scale, 
self.eval_mode)?
+                Self::cast_float64_to_decimal128(&array, *precision, *scale, 
self.eval_mode)
             }
             (DataType::Float32, DataType::Int8)
             | (DataType::Float32, DataType::Int16)
@@ -622,14 +590,94 @@ impl Cast {
                     self.eval_mode,
                     from_type,
                     to_type,
-                )?
+                )
+            }
+            _ if Self::is_datafusion_spark_compatible(from_type, to_type) => {
+                // use DataFusion cast only when we know that it is compatible 
with Spark
+                Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
             }
             _ => {
-                // when we have no Spark-specific casting we delegate to 
DataFusion
-                cast_with_options(&array, to_type, &CAST_OPTIONS)?
+                // we should never reach this code because the Scala code 
should be checking
+                // for supported cast operations and falling back to Spark for 
anything that
+                // is not yet supported
+                Err(CometError::Internal(format!(
+                    "Native cast invoked for unsupported cast from 
{from_type:?} to {to_type:?}"
+                )))
             }
         };
-        Ok(spark_cast(cast_result, from_type, to_type))
+        Ok(spark_cast(cast_result?, from_type, to_type))
+    }
+
+    /// Determines if DataFusion supports the given cast in a way that is
+    /// compatible with Spark
+    fn is_datafusion_spark_compatible(from_type: &DataType, to_type: 
&DataType) -> bool {
+        if from_type == to_type {
+            return true;
+        }
+        match from_type {
+            DataType::Boolean => matches!(
+                to_type,
+                DataType::Int8
+                    | DataType::Int16
+                    | DataType::Int32
+                    | DataType::Int64
+                    | DataType::Float32
+                    | DataType::Float64
+                    | DataType::Utf8
+            ),
+            DataType::Int8 | DataType::Int16 | DataType::Int32 | 
DataType::Int64 => {
+                // note that the cast from Int32/Int64 -> Decimal128 here is 
actually
+                // not compatible with Spark (no overflow checks) but we have 
tests that
+                // rely on this cast working so we have to leave it here for 
now
+                matches!(
+                    to_type,
+                    DataType::Boolean
+                        | DataType::Int8
+                        | DataType::Int16
+                        | DataType::Int32
+                        | DataType::Int64
+                        | DataType::Float32
+                        | DataType::Float64
+                        | DataType::Decimal128(_, _)
+                        | DataType::Utf8
+                )
+            }
+            DataType::Float32 | DataType::Float64 => matches!(
+                to_type,
+                DataType::Boolean
+                    | DataType::Int8
+                    | DataType::Int16
+                    | DataType::Int32
+                    | DataType::Int64
+                    | DataType::Float32
+                    | DataType::Float64
+            ),
+            DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => 
matches!(
+                to_type,
+                DataType::Int8
+                    | DataType::Int16
+                    | DataType::Int32
+                    | DataType::Int64
+                    | DataType::Float32
+                    | DataType::Float64
+                    | DataType::Decimal128(_, _)
+                    | DataType::Decimal256(_, _)
+            ),
+            DataType::Utf8 => matches!(to_type, DataType::Binary),
+            DataType::Date32 => matches!(to_type, DataType::Utf8),
+            DataType::Timestamp(_, _) => {
+                matches!(
+                    to_type,
+                    DataType::Int64 | DataType::Date32 | DataType::Utf8 | 
DataType::Timestamp(_, _)
+                )
+            }
+            DataType::Binary => {
+                // note that this is not completely Spark compatible because
+                // DataFusion only supports binary data containing valid UTF-8 
strings
+                matches!(to_type, DataType::Utf8)
+            }
+            _ => false,
+        }
     }
 
     fn cast_string_to_int<OffsetSize: OffsetSizeTrait>(


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to