This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 56f57f4d feat: Implement Spark-compatible CAST from 
floating-point/double to decimal (#384)
56f57f4d is described below

commit 56f57f4dc9357fddc16072a0d93ed2bef2090fa1
Author: Vipul Vaibhaw <[email protected]>
AuthorDate: Thu May 9 06:29:43 2024 +0530

    feat: Implement Spark-compatible CAST from floating-point/double to decimal 
(#384)
    
    * support NumericValueOutOfRange error
    
    * adding ansi checks and code refactor
    
    * fmt fixes
    
    * Remove redundant comment
    
    * bug fix
    
    * adding cast for float32 as well
    
    * fix test case for spark 3.2 and 3.3
    
    * return error only in ansi mode
---
 core/src/errors.rs                                 | 11 +++
 core/src/execution/datafusion/expressions/cast.rs  | 90 +++++++++++++++++++++-
 docs/source/user-guide/compatibility.md            |  4 +-
 .../org/apache/comet/expressions/CometCast.scala   |  4 +-
 .../scala/org/apache/comet/CometCastSuite.scala    | 16 ++--
 5 files changed, 114 insertions(+), 11 deletions(-)

diff --git a/core/src/errors.rs b/core/src/errors.rs
index a06c613a..04a1629d 100644
--- a/core/src/errors.rs
+++ b/core/src/errors.rs
@@ -72,6 +72,13 @@ pub enum CometError {
         to_type: String,
     },
 
+    #[error("[NUMERIC_VALUE_OUT_OF_RANGE] {value} cannot be represented as 
Decimal({precision}, {scale}). If necessary set \"spark.sql.ansi.enabled\" to 
\"false\" to bypass this error, and return NULL instead.")]
+    NumericValueOutOfRange {
+        value: String,
+        precision: u8,
+        scale: i8,
+    },
+
     #[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\" 
cannot be cast to \"{to_type}\" \
         due to an overflow. Use `try_cast` to tolerate overflow and return 
NULL instead. If necessary \
         set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
@@ -208,6 +215,10 @@ impl jni::errors::ToException for CometError {
                 class: "org/apache/spark/SparkException".to_string(),
                 msg: self.to_string(),
             },
+            CometError::NumericValueOutOfRange { .. } => Exception {
+                class: "org/apache/spark/SparkException".to_string(),
+                msg: self.to_string(),
+            },
             CometError::NumberIntFormat { source: s } => Exception {
                 class: "java/lang/NumberFormatException".to_string(),
                 msg: s.to_string(),
diff --git a/core/src/execution/datafusion/expressions/cast.rs 
b/core/src/execution/datafusion/expressions/cast.rs
index 2ad9c40d..35ab23a7 100644
--- a/core/src/execution/datafusion/expressions/cast.rs
+++ b/core/src/execution/datafusion/expressions/cast.rs
@@ -25,7 +25,10 @@ use std::{
 use crate::errors::{CometError, CometResult};
 use arrow::{
     compute::{cast_with_options, CastOptions},
-    datatypes::TimestampMicrosecondType,
+    datatypes::{
+        ArrowPrimitiveType, Decimal128Type, DecimalType, Float32Type, 
Float64Type,
+        TimestampMicrosecondType,
+    },
     record_batch::RecordBatch,
     util::display::FormatOptions,
 };
@@ -39,7 +42,7 @@ use chrono::{TimeZone, Timelike};
 use datafusion::logical_expr::ColumnarValue;
 use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue};
 use datafusion_physical_expr::PhysicalExpr;
-use num::{traits::CheckedNeg, CheckedSub, Integer, Num};
+use num::{cast::AsPrimitive, traits::CheckedNeg, CheckedSub, Integer, Num, 
ToPrimitive};
 use regex::Regex;
 
 use crate::execution::datafusion::expressions::utils::{
@@ -566,6 +569,12 @@ impl Cast {
             (DataType::Float32, DataType::LargeUtf8) => {
                 Self::spark_cast_float32_to_utf8::<i64>(&array, 
self.eval_mode)?
             }
+            (DataType::Float32, DataType::Decimal128(precision, scale)) => {
+                Self::cast_float32_to_decimal128(&array, *precision, *scale, 
self.eval_mode)?
+            }
+            (DataType::Float64, DataType::Decimal128(precision, scale)) => {
+                Self::cast_float64_to_decimal128(&array, *precision, *scale, 
self.eval_mode)?
+            }
             (DataType::Float32, DataType::Int8)
             | (DataType::Float32, DataType::Int16)
             | (DataType::Float32, DataType::Int32)
@@ -650,6 +659,83 @@ impl Cast {
         Ok(cast_array)
     }
 
+    fn cast_float64_to_decimal128(
+        array: &dyn Array,
+        precision: u8,
+        scale: i8,
+        eval_mode: EvalMode,
+    ) -> CometResult<ArrayRef> {
+        Self::cast_floating_point_to_decimal128::<Float64Type>(array, 
precision, scale, eval_mode)
+    }
+
+    fn cast_float32_to_decimal128(
+        array: &dyn Array,
+        precision: u8,
+        scale: i8,
+        eval_mode: EvalMode,
+    ) -> CometResult<ArrayRef> {
+        Self::cast_floating_point_to_decimal128::<Float32Type>(array, 
precision, scale, eval_mode)
+    }
+
+    fn cast_floating_point_to_decimal128<T: ArrowPrimitiveType>(
+        array: &dyn Array,
+        precision: u8,
+        scale: i8,
+        eval_mode: EvalMode,
+    ) -> CometResult<ArrayRef>
+    where
+        <T as ArrowPrimitiveType>::Native: AsPrimitive<f64>,
+    {
+        let input = 
array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+        let mut cast_array = 
PrimitiveArray::<Decimal128Type>::builder(input.len());
+
+        let mul = 10_f64.powi(scale as i32);
+
+        for i in 0..input.len() {
+            if input.is_null(i) {
+                cast_array.append_null();
+            } else {
+                let input_value = input.value(i).as_();
+                let value = (input_value * mul).round().to_i128();
+
+                match value {
+                    Some(v) => {
+                        if Decimal128Type::validate_decimal_precision(v, 
precision).is_err() {
+                            if eval_mode == EvalMode::Ansi {
+                                return Err(CometError::NumericValueOutOfRange {
+                                    value: input_value.to_string(),
+                                    precision,
+                                    scale,
+                                });
+                            } else {
+                                cast_array.append_null();
+                            }
+                        }
+                        cast_array.append_value(v);
+                    }
+                    None => {
+                        if eval_mode == EvalMode::Ansi {
+                            return Err(CometError::NumericValueOutOfRange {
+                                value: input_value.to_string(),
+                                precision,
+                                scale,
+                            });
+                        } else {
+                            cast_array.append_null();
+                        }
+                    }
+                }
+            }
+        }
+
+        let res = Arc::new(
+            cast_array
+                .with_precision_and_scale(precision, scale)?
+                .finish(),
+        ) as ArrayRef;
+        Ok(res)
+    }
+
     fn spark_cast_float64_to_utf8<OffsetSize>(
         from: &dyn Array,
         _eval_mode: EvalMode,
diff --git a/docs/source/user-guide/compatibility.md 
b/docs/source/user-guide/compatibility.md
index 2fd4b09b..a4ed9289 100644
--- a/docs/source/user-guide/compatibility.md
+++ b/docs/source/user-guide/compatibility.md
@@ -93,6 +93,7 @@ The following cast operations are generally compatible with 
Spark except for the
 | float | integer |  |
 | float | long |  |
 | float | double |  |
+| float | decimal |  |
 | float | string | There can be differences in precision. For example, the 
input "1.4E-45" will produce 1.0E-45 instead of 1.4E-45 |
 | double | boolean |  |
 | double | byte |  |
@@ -100,6 +101,7 @@ The following cast operations are generally compatible with 
Spark except for the
 | double | integer |  |
 | double | long |  |
 | double | float |  |
+| double | decimal |  |
 | double | string | There can be differences in precision. For example, the 
input "1.4E-45" will produce 1.0E-45 instead of 1.4E-45 |
 | decimal | byte |  |
 | decimal | short |  |
@@ -127,8 +129,6 @@ The following cast operations are not compatible with Spark 
for all inputs and a
 |-|-|-|
 | integer | decimal  | No overflow check |
 | long | decimal  | No overflow check |
-| float | decimal  | No overflow check |
-| double | decimal  | No overflow check |
 | string | timestamp  | Not all valid formats are supported |
 | binary | string  | Only works for binary data representing valid UTF-8 
strings |
 
diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala 
b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
index 5c225e3b..795bdb42 100644
--- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
+++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
@@ -229,7 +229,7 @@ object CometCast {
     case DataTypes.BooleanType | DataTypes.DoubleType | DataTypes.ByteType | 
DataTypes.ShortType |
         DataTypes.IntegerType | DataTypes.LongType =>
       Compatible()
-    case _: DecimalType => Incompatible(Some("No overflow check"))
+    case _: DecimalType => Compatible()
     case _ => Unsupported
   }
 
@@ -237,7 +237,7 @@ object CometCast {
     case DataTypes.BooleanType | DataTypes.FloatType | DataTypes.ByteType | 
DataTypes.ShortType |
         DataTypes.IntegerType | DataTypes.LongType =>
       Compatible()
-    case _: DecimalType => Incompatible(Some("No overflow check"))
+    case _: DecimalType => Compatible()
     case _ => Unsupported
   }
 
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index 827f4238..1881c561 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -340,8 +340,7 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     castTest(generateFloats(), DataTypes.DoubleType)
   }
 
-  ignore("cast FloatType to DecimalType(10,2)") {
-    // Comet should have failed with [NUMERIC_VALUE_OUT_OF_RANGE]
+  test("cast FloatType to DecimalType(10,2)") {
     castTest(generateFloats(), DataTypes.createDecimalType(10, 2))
   }
 
@@ -394,8 +393,7 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     castTest(generateDoubles(), DataTypes.FloatType)
   }
 
-  ignore("cast DoubleType to DecimalType(10,2)") {
-    // Comet should have failed with [NUMERIC_VALUE_OUT_OF_RANGE]
+  test("cast DoubleType to DecimalType(10,2)") {
     castTest(generateDoubles(), DataTypes.createDecimalType(10, 2))
   }
 
@@ -1003,11 +1001,19 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
               val cometMessageModified = cometMessage
                 .replace("[CAST_INVALID_INPUT] ", "")
                 .replace("[CAST_OVERFLOW] ", "")
-              assert(cometMessageModified == sparkMessage)
+                .replace("[NUMERIC_VALUE_OUT_OF_RANGE] ", "")
+
+              if (sparkMessage.contains("cannot be represented as")) {
+                assert(cometMessage.contains("cannot be represented as"))
+              } else {
+                assert(cometMessageModified == sparkMessage)
+              }
             } else {
               // for Spark 3.2 we just make sure we are seeing a similar type 
of error
               if (sparkMessage.contains("causes overflow")) {
                 assert(cometMessage.contains("due to an overflow"))
+              } else if (sparkMessage.contains("cannot be represented as")) {
+                assert(cometMessage.contains("cannot be represented as"))
               } else {
                 // assume that this is an invalid input message in the form:
                 // `invalid input syntax for type numeric: 
-9223372036854775809`


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to