This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new b39ed882 feat: Implement Spark-compatible CAST between integer types 
(#340)
b39ed882 is described below

commit b39ed8823a6bf5ce7b624918c37cec40afcf7b36
Author: గణేష్ <[email protected]>
AuthorDate: Sat May 4 03:37:47 2024 +0530

    feat: Implement Spark-compatible CAST between integer types (#340)
    
    * handled cast for long to short
    
    * handled cast for all overflow cases
    
    * ran make format
    
    * added check for overflow exception for 3.4 below.
    
    * added comments to on why we do overflow check.
    added a check before we fetch the sparkInvalidValue
    
    * -1 instead of 0, -1 indicates the provided character is not present
    
    * ran mvn spotless:apply
    
    * check for presence of ':' and have asserts accordingly
    
    * reusing exising test functions
    
    * added one more check in assert when ':' is not present
    
    * redo the compare logic as per andy's suggestions.
    
    ---------
    
    Co-authored-by: ganesh.maddula <[email protected]>
---
 core/src/errors.rs                                 |  9 ++
 core/src/execution/datafusion/expressions/cast.rs  | 98 ++++++++++++++++++++++
 .../scala/org/apache/comet/CometCastSuite.scala    | 37 +++++---
 3 files changed, 131 insertions(+), 13 deletions(-)

diff --git a/core/src/errors.rs b/core/src/errors.rs
index f02bd196..a06c613a 100644
--- a/core/src/errors.rs
+++ b/core/src/errors.rs
@@ -72,6 +72,15 @@ pub enum CometError {
         to_type: String,
     },
 
+    #[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\" 
cannot be cast to \"{to_type}\" \
+        due to an overflow. Use `try_cast` to tolerate overflow and return 
NULL instead. If necessary \
+        set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
+    CastOverFlow {
+        value: String,
+        from_type: String,
+        to_type: String,
+    },
+
     #[error(transparent)]
     Arrow {
         #[from]
diff --git a/core/src/execution/datafusion/expressions/cast.rs 
b/core/src/execution/datafusion/expressions/cast.rs
index 45859c5f..a6e3adac 100644
--- a/core/src/execution/datafusion/expressions/cast.rs
+++ b/core/src/execution/datafusion/expressions/cast.rs
@@ -176,6 +176,62 @@ macro_rules! cast_float_to_string {
     }};
 }
 
+macro_rules! cast_int_to_int_macro {
+    (
+        $array: expr,
+        $eval_mode:expr,
+        $from_arrow_primitive_type: ty,
+        $to_arrow_primitive_type: ty,
+        $from_data_type: expr,
+        $to_native_type: ty,
+        $spark_from_data_type_name: expr,
+        $spark_to_data_type_name: expr
+    ) => {{
+        let cast_array = $array
+            .as_any()
+            .downcast_ref::<PrimitiveArray<$from_arrow_primitive_type>>()
+            .unwrap();
+        let spark_int_literal_suffix = match $from_data_type {
+            &DataType::Int64 => "L",
+            &DataType::Int16 => "S",
+            &DataType::Int8 => "T",
+            _ => "",
+        };
+
+        let output_array = match $eval_mode {
+            EvalMode::Legacy => cast_array
+                .iter()
+                .map(|value| match value {
+                    Some(value) => {
+                        Ok::<Option<$to_native_type>, CometError>(Some(value 
as $to_native_type))
+                    }
+                    _ => Ok(None),
+                })
+                .collect::<Result<PrimitiveArray<$to_arrow_primitive_type>, 
_>>(),
+            _ => cast_array
+                .iter()
+                .map(|value| match value {
+                    Some(value) => {
+                        let res = <$to_native_type>::try_from(value);
+                        if res.is_err() {
+                            Err(CometError::CastOverFlow {
+                                value: value.to_string() + 
spark_int_literal_suffix,
+                                from_type: 
$spark_from_data_type_name.to_string(),
+                                to_type: $spark_to_data_type_name.to_string(),
+                            })
+                        } else {
+                            Ok::<Option<$to_native_type>, 
CometError>(Some(res.unwrap()))
+                        }
+                    }
+                    _ => Ok(None),
+                })
+                .collect::<Result<PrimitiveArray<$to_arrow_primitive_type>, 
_>>(),
+        }?;
+        let result: CometResult<ArrayRef> = Ok(Arc::new(output_array) as 
ArrayRef);
+        result
+    }};
+}
+
 impl Cast {
     pub fn new(
         child: Arc<dyn PhysicalExpr>,
@@ -218,6 +274,16 @@ impl Cast {
             (DataType::Utf8, DataType::Timestamp(_, _)) => {
                 Self::cast_string_to_timestamp(&array, to_type, 
self.eval_mode)?
             }
+            (DataType::Int64, DataType::Int32)
+            | (DataType::Int64, DataType::Int16)
+            | (DataType::Int64, DataType::Int8)
+            | (DataType::Int32, DataType::Int16)
+            | (DataType::Int32, DataType::Int8)
+            | (DataType::Int16, DataType::Int8)
+                if self.eval_mode != EvalMode::Try =>
+            {
+                Self::spark_cast_int_to_int(&array, self.eval_mode, from_type, 
to_type)?
+            }
             (
                 DataType::Utf8,
                 DataType::Int8 | DataType::Int16 | DataType::Int32 | 
DataType::Int64,
@@ -349,6 +415,38 @@ impl Cast {
         cast_float_to_string!(from, _eval_mode, f32, Float32Array, OffsetSize)
     }
 
+    fn spark_cast_int_to_int(
+        array: &dyn Array,
+        eval_mode: EvalMode,
+        from_type: &DataType,
+        to_type: &DataType,
+    ) -> CometResult<ArrayRef> {
+        match (from_type, to_type) {
+            (DataType::Int64, DataType::Int32) => cast_int_to_int_macro!(
+                array, eval_mode, Int64Type, Int32Type, from_type, i32, 
"BIGINT", "INT"
+            ),
+            (DataType::Int64, DataType::Int16) => cast_int_to_int_macro!(
+                array, eval_mode, Int64Type, Int16Type, from_type, i16, 
"BIGINT", "SMALLINT"
+            ),
+            (DataType::Int64, DataType::Int8) => cast_int_to_int_macro!(
+                array, eval_mode, Int64Type, Int8Type, from_type, i8, 
"BIGINT", "TINYINT"
+            ),
+            (DataType::Int32, DataType::Int16) => cast_int_to_int_macro!(
+                array, eval_mode, Int32Type, Int16Type, from_type, i16, "INT", 
"SMALLINT"
+            ),
+            (DataType::Int32, DataType::Int8) => cast_int_to_int_macro!(
+                array, eval_mode, Int32Type, Int8Type, from_type, i8, "INT", 
"TINYINT"
+            ),
+            (DataType::Int16, DataType::Int8) => cast_int_to_int_macro!(
+                array, eval_mode, Int16Type, Int8Type, from_type, i8, 
"SMALLINT", "TINYINT"
+            ),
+            _ => unreachable!(
+                "{}",
+                format!("invalid integer type {to_type} in cast from 
{from_type}")
+            ),
+        }
+    }
+
     fn spark_cast_utf8_to_boolean<OffsetSize>(
         from: &dyn Array,
         eval_mode: EvalMode,
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index 54b13679..483301e0 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -166,7 +166,7 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     castTest(generateShorts(), DataTypes.BooleanType)
   }
 
-  ignore("cast ShortType to ByteType") {
+  test("cast ShortType to ByteType") {
     // https://github.com/apache/datafusion-comet/issues/311
     castTest(generateShorts(), DataTypes.ByteType)
   }
@@ -210,12 +210,12 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     castTest(generateInts(), DataTypes.BooleanType)
   }
 
-  ignore("cast IntegerType to ByteType") {
+  test("cast IntegerType to ByteType") {
     // https://github.com/apache/datafusion-comet/issues/311
     castTest(generateInts(), DataTypes.ByteType)
   }
 
-  ignore("cast IntegerType to ShortType") {
+  test("cast IntegerType to ShortType") {
     // https://github.com/apache/datafusion-comet/issues/311
     castTest(generateInts(), DataTypes.ShortType)
   }
@@ -256,17 +256,17 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     castTest(generateLongs(), DataTypes.BooleanType)
   }
 
-  ignore("cast LongType to ByteType") {
+  test("cast LongType to ByteType") {
     // https://github.com/apache/datafusion-comet/issues/311
     castTest(generateLongs(), DataTypes.ByteType)
   }
 
-  ignore("cast LongType to ShortType") {
+  test("cast LongType to ShortType") {
     // https://github.com/apache/datafusion-comet/issues/311
     castTest(generateLongs(), DataTypes.ShortType)
   }
 
-  ignore("cast LongType to IntegerType") {
+  test("cast LongType to IntegerType") {
     // https://github.com/apache/datafusion-comet/issues/311
     castTest(generateLongs(), DataTypes.IntegerType)
   }
@@ -921,15 +921,26 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
             val cometMessage = cometException.getCause.getMessage
               .replace("Execution error: ", "")
             if (CometSparkSessionExtensions.isSpark34Plus) {
+              // for Spark 3.4 we expect to reproduce the error message exactly
               assert(cometMessage == sparkMessage)
+            } else if (CometSparkSessionExtensions.isSpark33Plus) {
+              // for Spark 3.3 we just need to strip the prefix from the Comet 
message
+              // before comparing
+              val cometMessageModified = cometMessage
+                .replace("[CAST_INVALID_INPUT] ", "")
+                .replace("[CAST_OVERFLOW] ", "")
+              assert(cometMessageModified == sparkMessage)
             } else {
-              // Spark 3.2 and 3.3 have a different error message format so we 
can't do a direct
-              // comparison between Spark and Comet.
-              // Spark message is in format `invalid input syntax for type 
TYPE: VALUE`
-              // Comet message is in format `The value 'VALUE' of the type 
FROM_TYPE cannot be cast to TO_TYPE`
-              // We just check that the comet message contains the same 
invalid value as the Spark message
-              val sparkInvalidValue = 
sparkMessage.substring(sparkMessage.indexOf(':') + 2)
-              assert(cometMessage.contains(sparkInvalidValue))
+              // for Spark 3.2 we just make sure we are seeing a similar type 
of error
+              if (sparkMessage.contains("causes overflow")) {
+                assert(cometMessage.contains("due to an overflow"))
+              } else {
+                // assume that this is an invalid input message in the form:
+                // `invalid input syntax for type numeric: 
-9223372036854775809`
+                // we just check that the Comet message contains the same 
literal value
+                val sparkInvalidValue = 
sparkMessage.substring(sparkMessage.indexOf(':') + 2)
+                assert(cometMessage.contains(sparkInvalidValue))
+              }
             }
         }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to