This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 56f57f4d feat: Implement Spark-compatible CAST from
floating-point/double to decimal (#384)
56f57f4d is described below
commit 56f57f4dc9357fddc16072a0d93ed2bef2090fa1
Author: Vipul Vaibhaw <[email protected]>
AuthorDate: Thu May 9 06:29:43 2024 +0530
feat: Implement Spark-compatible CAST from floating-point/double to decimal
(#384)
* support NumericValueOutOfRange error
* adding ansi checks and code refactor
* fmt fixes
* Remove redundant comment
* bug fix
* adding cast for float32 as well
* fix test case for spark 3.2 and 3.3
* return error only in ansi mode
---
core/src/errors.rs | 11 +++
core/src/execution/datafusion/expressions/cast.rs | 90 +++++++++++++++++++++-
docs/source/user-guide/compatibility.md | 4 +-
.../org/apache/comet/expressions/CometCast.scala | 4 +-
.../scala/org/apache/comet/CometCastSuite.scala | 16 ++--
5 files changed, 114 insertions(+), 11 deletions(-)
diff --git a/core/src/errors.rs b/core/src/errors.rs
index a06c613a..04a1629d 100644
--- a/core/src/errors.rs
+++ b/core/src/errors.rs
@@ -72,6 +72,13 @@ pub enum CometError {
to_type: String,
},
+ #[error("[NUMERIC_VALUE_OUT_OF_RANGE] {value} cannot be represented as
Decimal({precision}, {scale}). If necessary set \"spark.sql.ansi.enabled\" to
\"false\" to bypass this error, and return NULL instead.")]
+ NumericValueOutOfRange {
+ value: String,
+ precision: u8,
+ scale: i8,
+ },
+
#[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\"
cannot be cast to \"{to_type}\" \
due to an overflow. Use `try_cast` to tolerate overflow and return
NULL instead. If necessary \
set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
@@ -208,6 +215,10 @@ impl jni::errors::ToException for CometError {
class: "org/apache/spark/SparkException".to_string(),
msg: self.to_string(),
},
+ CometError::NumericValueOutOfRange { .. } => Exception {
+ class: "org/apache/spark/SparkException".to_string(),
+ msg: self.to_string(),
+ },
CometError::NumberIntFormat { source: s } => Exception {
class: "java/lang/NumberFormatException".to_string(),
msg: s.to_string(),
diff --git a/core/src/execution/datafusion/expressions/cast.rs
b/core/src/execution/datafusion/expressions/cast.rs
index 2ad9c40d..35ab23a7 100644
--- a/core/src/execution/datafusion/expressions/cast.rs
+++ b/core/src/execution/datafusion/expressions/cast.rs
@@ -25,7 +25,10 @@ use std::{
use crate::errors::{CometError, CometResult};
use arrow::{
compute::{cast_with_options, CastOptions},
- datatypes::TimestampMicrosecondType,
+ datatypes::{
+ ArrowPrimitiveType, Decimal128Type, DecimalType, Float32Type,
Float64Type,
+ TimestampMicrosecondType,
+ },
record_batch::RecordBatch,
util::display::FormatOptions,
};
@@ -39,7 +42,7 @@ use chrono::{TimeZone, Timelike};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue};
use datafusion_physical_expr::PhysicalExpr;
-use num::{traits::CheckedNeg, CheckedSub, Integer, Num};
+use num::{cast::AsPrimitive, traits::CheckedNeg, CheckedSub, Integer, Num,
ToPrimitive};
use regex::Regex;
use crate::execution::datafusion::expressions::utils::{
@@ -566,6 +569,12 @@ impl Cast {
(DataType::Float32, DataType::LargeUtf8) => {
Self::spark_cast_float32_to_utf8::<i64>(&array,
self.eval_mode)?
}
+ (DataType::Float32, DataType::Decimal128(precision, scale)) => {
+ Self::cast_float32_to_decimal128(&array, *precision, *scale,
self.eval_mode)?
+ }
+ (DataType::Float64, DataType::Decimal128(precision, scale)) => {
+ Self::cast_float64_to_decimal128(&array, *precision, *scale,
self.eval_mode)?
+ }
(DataType::Float32, DataType::Int8)
| (DataType::Float32, DataType::Int16)
| (DataType::Float32, DataType::Int32)
@@ -650,6 +659,83 @@ impl Cast {
Ok(cast_array)
}
+ fn cast_float64_to_decimal128(
+ array: &dyn Array,
+ precision: u8,
+ scale: i8,
+ eval_mode: EvalMode,
+ ) -> CometResult<ArrayRef> {
+ Self::cast_floating_point_to_decimal128::<Float64Type>(array,
precision, scale, eval_mode)
+ }
+
+ fn cast_float32_to_decimal128(
+ array: &dyn Array,
+ precision: u8,
+ scale: i8,
+ eval_mode: EvalMode,
+ ) -> CometResult<ArrayRef> {
+ Self::cast_floating_point_to_decimal128::<Float32Type>(array,
precision, scale, eval_mode)
+ }
+
+ fn cast_floating_point_to_decimal128<T: ArrowPrimitiveType>(
+ array: &dyn Array,
+ precision: u8,
+ scale: i8,
+ eval_mode: EvalMode,
+ ) -> CometResult<ArrayRef>
+ where
+ <T as ArrowPrimitiveType>::Native: AsPrimitive<f64>,
+ {
+ let input =
array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+ let mut cast_array =
PrimitiveArray::<Decimal128Type>::builder(input.len());
+
+ let mul = 10_f64.powi(scale as i32);
+
+ for i in 0..input.len() {
+ if input.is_null(i) {
+ cast_array.append_null();
+ } else {
+ let input_value = input.value(i).as_();
+ let value = (input_value * mul).round().to_i128();
+
+ match value {
+ Some(v) => {
+ if Decimal128Type::validate_decimal_precision(v,
precision).is_err() {
+ if eval_mode == EvalMode::Ansi {
+ return Err(CometError::NumericValueOutOfRange {
+ value: input_value.to_string(),
+ precision,
+ scale,
+ });
+ } else {
+ cast_array.append_null();
+ }
+ }
+ cast_array.append_value(v);
+ }
+ None => {
+ if eval_mode == EvalMode::Ansi {
+ return Err(CometError::NumericValueOutOfRange {
+ value: input_value.to_string(),
+ precision,
+ scale,
+ });
+ } else {
+ cast_array.append_null();
+ }
+ }
+ }
+ }
+ }
+
+ let res = Arc::new(
+ cast_array
+ .with_precision_and_scale(precision, scale)?
+ .finish(),
+ ) as ArrayRef;
+ Ok(res)
+ }
+
fn spark_cast_float64_to_utf8<OffsetSize>(
from: &dyn Array,
_eval_mode: EvalMode,
diff --git a/docs/source/user-guide/compatibility.md
b/docs/source/user-guide/compatibility.md
index 2fd4b09b..a4ed9289 100644
--- a/docs/source/user-guide/compatibility.md
+++ b/docs/source/user-guide/compatibility.md
@@ -93,6 +93,7 @@ The following cast operations are generally compatible with
Spark except for the
| float | integer | |
| float | long | |
| float | double | |
+| float | decimal | |
| float | string | There can be differences in precision. For example, the
input "1.4E-45" will produce 1.0E-45 instead of 1.4E-45 |
| double | boolean | |
| double | byte | |
@@ -100,6 +101,7 @@ The following cast operations are generally compatible with
Spark except for the
| double | integer | |
| double | long | |
| double | float | |
+| double | decimal | |
| double | string | There can be differences in precision. For example, the
input "1.4E-45" will produce 1.0E-45 instead of 1.4E-45 |
| decimal | byte | |
| decimal | short | |
@@ -127,8 +129,6 @@ The following cast operations are not compatible with Spark
for all inputs and a
|-|-|-|
| integer | decimal | No overflow check |
| long | decimal | No overflow check |
-| float | decimal | No overflow check |
-| double | decimal | No overflow check |
| string | timestamp | Not all valid formats are supported |
| binary | string | Only works for binary data representing valid UTF-8
strings |
diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
index 5c225e3b..795bdb42 100644
--- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
+++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
@@ -229,7 +229,7 @@ object CometCast {
case DataTypes.BooleanType | DataTypes.DoubleType | DataTypes.ByteType |
DataTypes.ShortType |
DataTypes.IntegerType | DataTypes.LongType =>
Compatible()
- case _: DecimalType => Incompatible(Some("No overflow check"))
+ case _: DecimalType => Compatible()
case _ => Unsupported
}
@@ -237,7 +237,7 @@ object CometCast {
case DataTypes.BooleanType | DataTypes.FloatType | DataTypes.ByteType |
DataTypes.ShortType |
DataTypes.IntegerType | DataTypes.LongType =>
Compatible()
- case _: DecimalType => Incompatible(Some("No overflow check"))
+ case _: DecimalType => Compatible()
case _ => Unsupported
}
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index 827f4238..1881c561 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -340,8 +340,7 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
castTest(generateFloats(), DataTypes.DoubleType)
}
- ignore("cast FloatType to DecimalType(10,2)") {
- // Comet should have failed with [NUMERIC_VALUE_OUT_OF_RANGE]
+ test("cast FloatType to DecimalType(10,2)") {
castTest(generateFloats(), DataTypes.createDecimalType(10, 2))
}
@@ -394,8 +393,7 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
castTest(generateDoubles(), DataTypes.FloatType)
}
- ignore("cast DoubleType to DecimalType(10,2)") {
- // Comet should have failed with [NUMERIC_VALUE_OUT_OF_RANGE]
+ test("cast DoubleType to DecimalType(10,2)") {
castTest(generateDoubles(), DataTypes.createDecimalType(10, 2))
}
@@ -1003,11 +1001,19 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
val cometMessageModified = cometMessage
.replace("[CAST_INVALID_INPUT] ", "")
.replace("[CAST_OVERFLOW] ", "")
- assert(cometMessageModified == sparkMessage)
+ .replace("[NUMERIC_VALUE_OUT_OF_RANGE] ", "")
+
+ if (sparkMessage.contains("cannot be represented as")) {
+ assert(cometMessage.contains("cannot be represented as"))
+ } else {
+ assert(cometMessageModified == sparkMessage)
+ }
} else {
// for Spark 3.2 we just make sure we are seeing a similar type
of error
if (sparkMessage.contains("causes overflow")) {
assert(cometMessage.contains("due to an overflow"))
+ } else if (sparkMessage.contains("cannot be represented as")) {
+ assert(cometMessage.contains("cannot be represented as"))
} else {
// assume that this is an invalid input message in the form:
// `invalid input syntax for type numeric:
-9223372036854775809`
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]