This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new c261af34 feat: Implement Spark-compatible CAST from non-integral 
numeric types to integral types (#399)
c261af34 is described below

commit c261af34527c17d1b5eb74d6c06f1f657819fe11
Author: Rohit Rastogi <[email protected]>
AuthorDate: Wed May 8 12:40:36 2024 -0700

    feat: Implement Spark-compatible CAST from non-integral numeric types to 
integral types (#399)
    
    * WIP - float to int, sketchy
    
    * WIP - extremely ugly but functional
    
    * WIP - use macro
    
    * simply further
    
    * delete dead code
    
    * make format
    
    * progress on decimals
    
    * refactor
    
    * format decimal value in overflow exception
    
    * wip - have to use 4 macros, need more decimal tests
    
    * ready for review
    
    * forgot to commit whoops
    
    * bad merge
    
    * address pr comments
    
    * commit missed compatibility
    
    * improve error message
    
    * improve error message again
    
    * revert perf reression in cast_int_to_int_macro
    
    * remove branching in loop for legacy case
    
    ---------
    
    Co-authored-by: Rohit Rastogi <[email protected]>
---
 core/src/execution/datafusion/expressions/cast.rs  | 418 ++++++++++++++++++++-
 docs/source/user-guide/compatibility.md            |  12 +
 .../org/apache/comet/expressions/CometCast.scala   |  12 +-
 .../scala/org/apache/comet/CometCastSuite.scala    | 116 ++++--
 4 files changed, 511 insertions(+), 47 deletions(-)

diff --git a/core/src/execution/datafusion/expressions/cast.rs 
b/core/src/execution/datafusion/expressions/cast.rs
index a6e3adac..2ad9c40d 100644
--- a/core/src/execution/datafusion/expressions/cast.rs
+++ b/core/src/execution/datafusion/expressions/cast.rs
@@ -31,8 +31,8 @@ use arrow::{
 };
 use arrow_array::{
     types::{Int16Type, Int32Type, Int64Type, Int8Type},
-    Array, ArrayRef, BooleanArray, Float32Array, Float64Array, 
GenericStringArray, OffsetSizeTrait,
-    PrimitiveArray,
+    Array, ArrayRef, BooleanArray, Decimal128Array, Float32Array, 
Float64Array, GenericStringArray,
+    Int16Array, Int32Array, Int64Array, Int8Array, OffsetSizeTrait, 
PrimitiveArray,
 };
 use arrow_schema::{DataType, Schema};
 use chrono::{TimeZone, Timelike};
@@ -214,11 +214,11 @@ macro_rules! cast_int_to_int_macro {
                     Some(value) => {
                         let res = <$to_native_type>::try_from(value);
                         if res.is_err() {
-                            Err(CometError::CastOverFlow {
-                                value: value.to_string() + 
spark_int_literal_suffix,
-                                from_type: 
$spark_from_data_type_name.to_string(),
-                                to_type: $spark_to_data_type_name.to_string(),
-                            })
+                            Err(cast_overflow(
+                                &(value.to_string() + 
spark_int_literal_suffix),
+                                $spark_from_data_type_name,
+                                $spark_to_data_type_name,
+                            ))
                         } else {
                             Ok::<Option<$to_native_type>, 
CometError>(Some(res.unwrap()))
                         }
@@ -232,6 +232,240 @@ macro_rules! cast_int_to_int_macro {
     }};
 }
 
+// When Spark casts to Byte/Short Types, it does not cast directly to 
Byte/Short.
+// It casts to Int first and then to Byte/Short. Because of potential 
overflows in the Int cast,
+// this can cause unexpected Short/Byte cast results. Replicate this behavior.
+macro_rules! cast_float_to_int16_down {
+    (
+        $array:expr,
+        $eval_mode:expr,
+        $src_array_type:ty,
+        $dest_array_type:ty,
+        $rust_src_type:ty,
+        $rust_dest_type:ty,
+        $src_type_str:expr,
+        $dest_type_str:expr,
+        $format_str:expr
+    ) => {{
+        let cast_array = $array
+            .as_any()
+            .downcast_ref::<$src_array_type>()
+            .expect(concat!("Expected a ", stringify!($src_array_type)));
+
+        let output_array = match $eval_mode {
+            EvalMode::Ansi => cast_array
+                .iter()
+                .map(|value| match value {
+                    Some(value) => {
+                        let is_overflow = value.is_nan() || value.abs() as i32 
== std::i32::MAX;
+                        if is_overflow {
+                            return Err(cast_overflow(
+                                &format!($format_str, value).replace("e", "E"),
+                                $src_type_str,
+                                $dest_type_str,
+                            ));
+                        }
+                        let i32_value = value as i32;
+                        <$rust_dest_type>::try_from(i32_value)
+                            .map_err(|_| {
+                                cast_overflow(
+                                    &format!($format_str, value).replace("e", 
"E"),
+                                    $src_type_str,
+                                    $dest_type_str,
+                                )
+                            })
+                            .map(Some)
+                    }
+                    None => Ok(None),
+                })
+                .collect::<Result<$dest_array_type, _>>()?,
+            _ => cast_array
+                .iter()
+                .map(|value| match value {
+                    Some(value) => {
+                        let i32_value = value as i32;
+                        Ok::<Option<$rust_dest_type>, CometError>(Some(
+                            i32_value as $rust_dest_type,
+                        ))
+                    }
+                    None => Ok(None),
+                })
+                .collect::<Result<$dest_array_type, _>>()?,
+        };
+        Ok(Arc::new(output_array) as ArrayRef)
+    }};
+}
+
+macro_rules! cast_float_to_int32_up {
+    (
+        $array:expr,
+        $eval_mode:expr,
+        $src_array_type:ty,
+        $dest_array_type:ty,
+        $rust_src_type:ty,
+        $rust_dest_type:ty,
+        $src_type_str:expr,
+        $dest_type_str:expr,
+        $max_dest_val:expr,
+        $format_str:expr
+    ) => {{
+        let cast_array = $array
+            .as_any()
+            .downcast_ref::<$src_array_type>()
+            .expect(concat!("Expected a ", stringify!($src_array_type)));
+
+        let output_array = match $eval_mode {
+            EvalMode::Ansi => cast_array
+                .iter()
+                .map(|value| match value {
+                    Some(value) => {
+                        let is_overflow =
+                            value.is_nan() || value.abs() as $rust_dest_type 
== $max_dest_val;
+                        if is_overflow {
+                            return Err(cast_overflow(
+                                &format!($format_str, value).replace("e", "E"),
+                                $src_type_str,
+                                $dest_type_str,
+                            ));
+                        }
+                        Ok(Some(value as $rust_dest_type))
+                    }
+                    None => Ok(None),
+                })
+                .collect::<Result<$dest_array_type, _>>()?,
+            _ => cast_array
+                .iter()
+                .map(|value| match value {
+                    Some(value) => {
+                        Ok::<Option<$rust_dest_type>, CometError>(Some(value 
as $rust_dest_type))
+                    }
+                    None => Ok(None),
+                })
+                .collect::<Result<$dest_array_type, _>>()?,
+        };
+        Ok(Arc::new(output_array) as ArrayRef)
+    }};
+}
+
+// When Spark casts to Byte/Short Types, it does not cast directly to 
Byte/Short.
+// It casts to Int first and then to Byte/Short. Because of potential 
overflows in the Int cast,
+// this can cause unexpected Short/Byte cast results. Replicate this behavior.
+macro_rules! cast_decimal_to_int16_down {
+    (
+        $array:expr,
+        $eval_mode:expr,
+        $dest_array_type:ty,
+        $rust_dest_type:ty,
+        $dest_type_str:expr,
+        $precision:expr,
+        $scale:expr
+    ) => {{
+        let cast_array = $array
+            .as_any()
+            .downcast_ref::<Decimal128Array>()
+            .expect(concat!("Expected a Decimal128ArrayType"));
+
+        let output_array = match $eval_mode {
+            EvalMode::Ansi => cast_array
+                .iter()
+                .map(|value| match value {
+                    Some(value) => {
+                        let divisor = 10_i128.pow($scale as u32);
+                        let (truncated, decimal) = (value / divisor, (value % 
divisor).abs());
+                        let is_overflow = truncated.abs() > 
std::i32::MAX.into();
+                        if is_overflow {
+                            return Err(cast_overflow(
+                                &format!("{}.{}BD", truncated, decimal),
+                                &format!("DECIMAL({},{})", $precision, $scale),
+                                $dest_type_str,
+                            ));
+                        }
+                        let i32_value = truncated as i32;
+                        <$rust_dest_type>::try_from(i32_value)
+                            .map_err(|_| {
+                                cast_overflow(
+                                    &format!("{}.{}BD", truncated, decimal),
+                                    &format!("DECIMAL({},{})", $precision, 
$scale),
+                                    $dest_type_str,
+                                )
+                            })
+                            .map(Some)
+                    }
+                    None => Ok(None),
+                })
+                .collect::<Result<$dest_array_type, _>>()?,
+            _ => cast_array
+                .iter()
+                .map(|value| match value {
+                    Some(value) => {
+                        let divisor = 10_i128.pow($scale as u32);
+                        let i32_value = (value / divisor) as i32;
+                        Ok::<Option<$rust_dest_type>, CometError>(Some(
+                            i32_value as $rust_dest_type,
+                        ))
+                    }
+                    None => Ok(None),
+                })
+                .collect::<Result<$dest_array_type, _>>()?,
+        };
+        Ok(Arc::new(output_array) as ArrayRef)
+    }};
+}
+
+macro_rules! cast_decimal_to_int32_up {
+    (
+        $array:expr,
+        $eval_mode:expr,
+        $dest_array_type:ty,
+        $rust_dest_type:ty,
+        $dest_type_str:expr,
+        $max_dest_val:expr,
+        $precision:expr,
+        $scale:expr
+    ) => {{
+        let cast_array = $array
+            .as_any()
+            .downcast_ref::<Decimal128Array>()
+            .expect(concat!("Expected a Decimal128ArrayType"));
+
+        let output_array = match $eval_mode {
+            EvalMode::Ansi => cast_array
+                .iter()
+                .map(|value| match value {
+                    Some(value) => {
+                        let divisor = 10_i128.pow($scale as u32);
+                        let (truncated, decimal) = (value / divisor, (value % 
divisor).abs());
+                        let is_overflow = truncated.abs() > 
$max_dest_val.into();
+                        if is_overflow {
+                            return Err(cast_overflow(
+                                &format!("{}.{}BD", truncated, decimal),
+                                &format!("DECIMAL({},{})", $precision, $scale),
+                                $dest_type_str,
+                            ));
+                        }
+                        Ok(Some(truncated as $rust_dest_type))
+                    }
+                    None => Ok(None),
+                })
+                .collect::<Result<$dest_array_type, _>>()?,
+            _ => cast_array
+                .iter()
+                .map(|value| match value {
+                    Some(value) => {
+                        let divisor = 10_i128.pow($scale as u32);
+                        let truncated = value / divisor;
+                        Ok::<Option<$rust_dest_type>, CometError>(Some(
+                            truncated as $rust_dest_type,
+                        ))
+                    }
+                    None => Ok(None),
+                })
+                .collect::<Result<$dest_array_type, _>>()?,
+        };
+        Ok(Arc::new(output_array) as ArrayRef)
+    }};
+}
+
 impl Cast {
     pub fn new(
         child: Arc<dyn PhysicalExpr>,
@@ -332,6 +566,27 @@ impl Cast {
             (DataType::Float32, DataType::LargeUtf8) => {
                 Self::spark_cast_float32_to_utf8::<i64>(&array, 
self.eval_mode)?
             }
+            (DataType::Float32, DataType::Int8)
+            | (DataType::Float32, DataType::Int16)
+            | (DataType::Float32, DataType::Int32)
+            | (DataType::Float32, DataType::Int64)
+            | (DataType::Float64, DataType::Int8)
+            | (DataType::Float64, DataType::Int16)
+            | (DataType::Float64, DataType::Int32)
+            | (DataType::Float64, DataType::Int64)
+            | (DataType::Decimal128(_, _), DataType::Int8)
+            | (DataType::Decimal128(_, _), DataType::Int16)
+            | (DataType::Decimal128(_, _), DataType::Int32)
+            | (DataType::Decimal128(_, _), DataType::Int64)
+                if self.eval_mode != EvalMode::Try =>
+            {
+                Self::spark_cast_nonintegral_numeric_to_integral(
+                    &array,
+                    self.eval_mode,
+                    from_type,
+                    to_type,
+                )?
+            }
             _ => {
                 // when we have no Spark-specific casting we delegate to 
DataFusion
                 cast_with_options(&array, to_type, &CAST_OPTIONS)?
@@ -478,6 +733,146 @@ impl Cast {
 
         Ok(Arc::new(output_array))
     }
+
+    fn spark_cast_nonintegral_numeric_to_integral(
+        array: &dyn Array,
+        eval_mode: EvalMode,
+        from_type: &DataType,
+        to_type: &DataType,
+    ) -> CometResult<ArrayRef> {
+        match (from_type, to_type) {
+            (DataType::Float32, DataType::Int8) => cast_float_to_int16_down!(
+                array,
+                eval_mode,
+                Float32Array,
+                Int8Array,
+                f32,
+                i8,
+                "FLOAT",
+                "TINYINT",
+                "{:e}"
+            ),
+            (DataType::Float32, DataType::Int16) => cast_float_to_int16_down!(
+                array,
+                eval_mode,
+                Float32Array,
+                Int16Array,
+                f32,
+                i16,
+                "FLOAT",
+                "SMALLINT",
+                "{:e}"
+            ),
+            (DataType::Float32, DataType::Int32) => cast_float_to_int32_up!(
+                array,
+                eval_mode,
+                Float32Array,
+                Int32Array,
+                f32,
+                i32,
+                "FLOAT",
+                "INT",
+                std::i32::MAX,
+                "{:e}"
+            ),
+            (DataType::Float32, DataType::Int64) => cast_float_to_int32_up!(
+                array,
+                eval_mode,
+                Float32Array,
+                Int64Array,
+                f32,
+                i64,
+                "FLOAT",
+                "BIGINT",
+                std::i64::MAX,
+                "{:e}"
+            ),
+            (DataType::Float64, DataType::Int8) => cast_float_to_int16_down!(
+                array,
+                eval_mode,
+                Float64Array,
+                Int8Array,
+                f64,
+                i8,
+                "DOUBLE",
+                "TINYINT",
+                "{:e}D"
+            ),
+            (DataType::Float64, DataType::Int16) => cast_float_to_int16_down!(
+                array,
+                eval_mode,
+                Float64Array,
+                Int16Array,
+                f64,
+                i16,
+                "DOUBLE",
+                "SMALLINT",
+                "{:e}D"
+            ),
+            (DataType::Float64, DataType::Int32) => cast_float_to_int32_up!(
+                array,
+                eval_mode,
+                Float64Array,
+                Int32Array,
+                f64,
+                i32,
+                "DOUBLE",
+                "INT",
+                std::i32::MAX,
+                "{:e}D"
+            ),
+            (DataType::Float64, DataType::Int64) => cast_float_to_int32_up!(
+                array,
+                eval_mode,
+                Float64Array,
+                Int64Array,
+                f64,
+                i64,
+                "DOUBLE",
+                "BIGINT",
+                std::i64::MAX,
+                "{:e}D"
+            ),
+            (DataType::Decimal128(precision, scale), DataType::Int8) => {
+                cast_decimal_to_int16_down!(
+                    array, eval_mode, Int8Array, i8, "TINYINT", precision, 
*scale
+                )
+            }
+            (DataType::Decimal128(precision, scale), DataType::Int16) => {
+                cast_decimal_to_int16_down!(
+                    array, eval_mode, Int16Array, i16, "SMALLINT", precision, 
*scale
+                )
+            }
+            (DataType::Decimal128(precision, scale), DataType::Int32) => {
+                cast_decimal_to_int32_up!(
+                    array,
+                    eval_mode,
+                    Int32Array,
+                    i32,
+                    "INT",
+                    std::i32::MAX,
+                    *precision,
+                    *scale
+                )
+            }
+            (DataType::Decimal128(precision, scale), DataType::Int64) => {
+                cast_decimal_to_int32_up!(
+                    array,
+                    eval_mode,
+                    Int64Array,
+                    i64,
+                    "BIGINT",
+                    std::i64::MAX,
+                    *precision,
+                    *scale
+                )
+            }
+            _ => unreachable!(
+                "{}",
+                format!("invalid cast from non-integral numeric type: 
{from_type} to integral numeric type: {to_type}")
+            ),
+        }
+    }
 }
 
 /// Equivalent to org.apache.spark.unsafe.types.UTF8String.toByte
@@ -676,6 +1071,15 @@ fn invalid_value(value: &str, from_type: &str, to_type: 
&str) -> CometError {
     }
 }
 
+#[inline]
+fn cast_overflow(value: &str, from_type: &str, to_type: &str) -> CometError {
+    CometError::CastOverFlow {
+        value: value.to_string(),
+        from_type: from_type.to_string(),
+        to_type: to_type.to_string(),
+    }
+}
+
 impl Display for Cast {
     fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
         write!(
diff --git a/docs/source/user-guide/compatibility.md 
b/docs/source/user-guide/compatibility.md
index 57a4271f..2fd4b09b 100644
--- a/docs/source/user-guide/compatibility.md
+++ b/docs/source/user-guide/compatibility.md
@@ -88,11 +88,23 @@ The following cast operations are generally compatible with 
Spark except for the
 | long | double |  |
 | long | string |  |
 | float | boolean |  |
+| float | byte |  |
+| float | short |  |
+| float | integer |  |
+| float | long |  |
 | float | double |  |
 | float | string | There can be differences in precision. For example, the 
input "1.4E-45" will produce 1.0E-45 instead of 1.4E-45 |
 | double | boolean |  |
+| double | byte |  |
+| double | short |  |
+| double | integer |  |
+| double | long |  |
 | double | float |  |
 | double | string | There can be differences in precision. For example, the 
input "1.4E-45" will produce 1.0E-45 instead of 1.4E-45 |
+| decimal | byte |  |
+| decimal | short |  |
+| decimal | integer |  |
+| decimal | long |  |
 | decimal | float |  |
 | decimal | double |  |
 | string | boolean |  |
diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala 
b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
index 57e07b8c..5c225e3b 100644
--- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
+++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
@@ -226,19 +226,25 @@ object CometCast {
   }
 
   private def canCastFromFloat(toType: DataType): SupportLevel = toType match {
-    case DataTypes.BooleanType | DataTypes.DoubleType => Compatible()
+    case DataTypes.BooleanType | DataTypes.DoubleType | DataTypes.ByteType | 
DataTypes.ShortType |
+        DataTypes.IntegerType | DataTypes.LongType =>
+      Compatible()
     case _: DecimalType => Incompatible(Some("No overflow check"))
     case _ => Unsupported
   }
 
   private def canCastFromDouble(toType: DataType): SupportLevel = toType match 
{
-    case DataTypes.BooleanType | DataTypes.FloatType => Compatible()
+    case DataTypes.BooleanType | DataTypes.FloatType | DataTypes.ByteType | 
DataTypes.ShortType |
+        DataTypes.IntegerType | DataTypes.LongType =>
+      Compatible()
     case _: DecimalType => Incompatible(Some("No overflow check"))
     case _ => Unsupported
   }
 
   private def canCastFromDecimal(toType: DataType): SupportLevel = toType 
match {
-    case DataTypes.FloatType | DataTypes.DoubleType => Compatible()
+    case DataTypes.FloatType | DataTypes.DoubleType | DataTypes.ByteType | 
DataTypes.ShortType |
+        DataTypes.IntegerType | DataTypes.LongType =>
+      Compatible()
     case _ => Unsupported
   }
 
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index 219feca1..827f4238 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
 import org.apache.spark.sql.functions.col
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{DataType, DataTypes}
+import org.apache.spark.sql.types.{DataType, DataTypes, DecimalType}
 
 import org.apache.comet.expressions.{CometCast, Compatible}
 
@@ -320,23 +320,19 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     castTest(generateFloats(), DataTypes.BooleanType)
   }
 
-  ignore("cast FloatType to ByteType") {
-    // https://github.com/apache/datafusion-comet/issues/350
+  test("cast FloatType to ByteType") {
     castTest(generateFloats(), DataTypes.ByteType)
   }
 
-  ignore("cast FloatType to ShortType") {
-    // https://github.com/apache/datafusion-comet/issues/350
+  test("cast FloatType to ShortType") {
     castTest(generateFloats(), DataTypes.ShortType)
   }
 
-  ignore("cast FloatType to IntegerType") {
-    // https://github.com/apache/datafusion-comet/issues/350
+  test("cast FloatType to IntegerType") {
     castTest(generateFloats(), DataTypes.IntegerType)
   }
 
-  ignore("cast FloatType to LongType") {
-    // https://github.com/apache/datafusion-comet/issues/350
+  test("cast FloatType to LongType") {
     castTest(generateFloats(), DataTypes.LongType)
   }
 
@@ -378,23 +374,19 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     castTest(generateDoubles(), DataTypes.BooleanType)
   }
 
-  ignore("cast DoubleType to ByteType") {
-    // https://github.com/apache/datafusion-comet/issues/350
+  test("cast DoubleType to ByteType") {
     castTest(generateDoubles(), DataTypes.ByteType)
   }
 
-  ignore("cast DoubleType to ShortType") {
-    // https://github.com/apache/datafusion-comet/issues/350
+  test("cast DoubleType to ShortType") {
     castTest(generateDoubles(), DataTypes.ShortType)
   }
 
-  ignore("cast DoubleType to IntegerType") {
-    // https://github.com/apache/datafusion-comet/issues/350
+  test("cast DoubleType to IntegerType") {
     castTest(generateDoubles(), DataTypes.IntegerType)
   }
 
-  ignore("cast DoubleType to LongType") {
-    // https://github.com/apache/datafusion-comet/issues/350
+  test("cast DoubleType to LongType") {
     castTest(generateDoubles(), DataTypes.LongType)
   }
 
@@ -430,45 +422,57 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
 
   ignore("cast DecimalType(10,2) to BooleanType") {
     // Arrow error: Cast error: Casting from Decimal128(38, 18) to Boolean not 
supported
-    castTest(generateDecimals(), DataTypes.BooleanType)
+    castTest(generateDecimalsPrecision10Scale2(), DataTypes.BooleanType)
   }
 
-  ignore("cast DecimalType(10,2) to ByteType") {
-    // https://github.com/apache/datafusion-comet/issues/350
-    castTest(generateDecimals(), DataTypes.ByteType)
+  test("cast DecimalType(10,2) to ByteType") {
+    castTest(generateDecimalsPrecision10Scale2(), DataTypes.ByteType)
   }
 
-  ignore("cast DecimalType(10,2) to ShortType") {
-    // https://github.com/apache/datafusion-comet/issues/350
-    castTest(generateDecimals(), DataTypes.ShortType)
+  test("cast DecimalType(10,2) to ShortType") {
+    castTest(generateDecimalsPrecision10Scale2(), DataTypes.ShortType)
   }
 
-  ignore("cast DecimalType(10,2) to IntegerType") {
-    // https://github.com/apache/datafusion-comet/issues/350
-    castTest(generateDecimals(), DataTypes.IntegerType)
+  test("cast DecimalType(10,2) to IntegerType") {
+    castTest(generateDecimalsPrecision10Scale2(), DataTypes.IntegerType)
   }
 
-  ignore("cast DecimalType(10,2) to LongType") {
-    // https://github.com/apache/datafusion-comet/issues/350
-    castTest(generateDecimals(), DataTypes.LongType)
+  test("cast DecimalType(10,2) to LongType") {
+    castTest(generateDecimalsPrecision10Scale2(), DataTypes.LongType)
   }
 
   test("cast DecimalType(10,2) to FloatType") {
-    castTest(generateDecimals(), DataTypes.FloatType)
+    castTest(generateDecimalsPrecision10Scale2(), DataTypes.FloatType)
   }
 
   test("cast DecimalType(10,2) to DoubleType") {
-    castTest(generateDecimals(), DataTypes.DoubleType)
+    castTest(generateDecimalsPrecision10Scale2(), DataTypes.DoubleType)
+  }
+
+  test("cast DecimalType(38,18) to ByteType") {
+    castTest(generateDecimalsPrecision38Scale18(), DataTypes.ByteType)
+  }
+
+  test("cast DecimalType(38,18) to ShortType") {
+    castTest(generateDecimalsPrecision38Scale18(), DataTypes.ShortType)
+  }
+
+  test("cast DecimalType(38,18) to IntegerType") {
+    castTest(generateDecimalsPrecision38Scale18(), DataTypes.IntegerType)
+  }
+
+  test("cast DecimalType(38,18) to LongType") {
+    castTest(generateDecimalsPrecision38Scale18(), DataTypes.LongType)
   }
 
   ignore("cast DecimalType(10,2) to StringType") {
     // input: 0E-18, expected: 0E-18, actual: 0.000000000000000000
-    castTest(generateDecimals(), DataTypes.StringType)
+    castTest(generateDecimalsPrecision10Scale2(), DataTypes.StringType)
   }
 
   ignore("cast DecimalType(10,2) to TimestampType") {
     // input: -123456.789000000000000000, expected: 1969-12-30 05:42:23.211, 
actual: 1969-12-31 15:59:59.876544
-    castTest(generateDecimals(), DataTypes.TimestampType)
+    castTest(generateDecimalsPrecision10Scale2(), DataTypes.TimestampType)
   }
 
   // CAST from StringType
@@ -800,9 +804,47 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     withNulls(values).toDF("a")
   }
 
-  private def generateDecimals(): DataFrame = {
-    // TODO improve this
-    val values = Seq(BigDecimal("123456.789"), BigDecimal("-123456.789"), 
BigDecimal("0.0"))
+  private def generateDecimalsPrecision10Scale2(): DataFrame = {
+    val values = Seq(
+      BigDecimal("-99999999.999"),
+      BigDecimal("-123456.789"),
+      BigDecimal("-32768.678"),
+      // Short Min
+      BigDecimal("-32767.123"),
+      BigDecimal("-128.12312"),
+      // Byte Min
+      BigDecimal("-127.123"),
+      BigDecimal("0.0"),
+      // Byte Max
+      BigDecimal("127.123"),
+      BigDecimal("128.12312"),
+      BigDecimal("32767.122"),
+      // Short Max
+      BigDecimal("32768.678"),
+      BigDecimal("123456.789"),
+      BigDecimal("99999999.999"))
+    withNulls(values).toDF("b").withColumn("a", col("b").cast(DecimalType(10, 
2))).drop("b")
+  }
+
+  private def generateDecimalsPrecision38Scale18(): DataFrame = {
+    val values = Seq(
+      BigDecimal("-99999999999999999999.999999999999"),
+      BigDecimal("-9223372036854775808.234567"),
+      // Long Min
+      BigDecimal("-9223372036854775807.123123"),
+      BigDecimal("-2147483648.123123123"),
+      // Int Min
+      BigDecimal("-2147483647.123123123"),
+      BigDecimal("-123456.789"),
+      BigDecimal("0.00000000000"),
+      BigDecimal("123456.789"),
+      // Int Max
+      BigDecimal("2147483647.123123123"),
+      BigDecimal("2147483648.123123123"),
+      BigDecimal("9223372036854775807.123123"),
+      // Long Max
+      BigDecimal("9223372036854775808.234567"),
+      BigDecimal("99999999999999999999.999999999999"))
     withNulls(values).toDF("a")
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to