This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new a6cfadb2a feat: Improve compatibility of string to decimal cast (#2925)
a6cfadb2a is described below
commit a6cfadb2a32ab2f02e47462e03219f89459edb5c
Author: B Vadlamani <[email protected]>
AuthorDate: Mon Dec 22 12:32:35 2025 -0800
feat: Improve compatibility of string to decimal cast (#2925)
---
docs/source/user-guide/latest/compatibility.md | 3 +-
native/spark-expr/src/conversion_funcs/cast.rs | 319 ++++++++++++++++++++-
.../org/apache/comet/expressions/CometCast.scala | 5 +-
.../scala/org/apache/comet/CometCastSuite.scala | 92 +++++-
4 files changed, 395 insertions(+), 24 deletions(-)
diff --git a/docs/source/user-guide/latest/compatibility.md
b/docs/source/user-guide/latest/compatibility.md
index 60e2234f5..58dd8d6ab 100644
--- a/docs/source/user-guide/latest/compatibility.md
+++ b/docs/source/user-guide/latest/compatibility.md
@@ -183,7 +183,8 @@ The following cast operations are not compatible with Spark
for all inputs and a
| double | decimal | There can be rounding differences |
| string | float | Does not support inputs ending with 'd' or 'f'. Does not
support 'inf'. Does not support ANSI mode. |
| string | double | Does not support inputs ending with 'd' or 'f'. Does not
support 'inf'. Does not support ANSI mode. |
-| string | decimal | Does not support inputs ending with 'd' or 'f'. Does not
support 'inf'. Does not support ANSI mode. Returns 0.0 instead of null if input
contains no digits |
+| string | decimal | Does not support fullwidth unicode digits (e.g \\uFF10)
+or strings containing null bytes (e.g \\u0000) |
| string | timestamp | Not all valid formats are supported |
<!-- prettier-ignore-end -->
<!--END:INCOMPAT_CAST_TABLE-->
diff --git a/native/spark-expr/src/conversion_funcs/cast.rs
b/native/spark-expr/src/conversion_funcs/cast.rs
index 12a147c6e..6b69c7288 100644
--- a/native/spark-expr/src/conversion_funcs/cast.rs
+++ b/native/spark-expr/src/conversion_funcs/cast.rs
@@ -20,12 +20,13 @@ use crate::{timezone, BinaryOutputStyle};
use crate::{EvalMode, SparkError, SparkResult};
use arrow::array::builder::StringBuilder;
use arrow::array::{
- BooleanBuilder, Decimal128Builder, DictionaryArray, GenericByteArray,
ListArray, StringArray,
- StructArray,
+ BooleanBuilder, Decimal128Builder, DictionaryArray, GenericByteArray,
ListArray,
+ PrimitiveBuilder, StringArray, StructArray,
};
use arrow::compute::can_cast_types;
use arrow::datatypes::{
- ArrowDictionaryKeyType, ArrowNativeType, DataType, GenericBinaryType,
Schema,
+ i256, ArrowDictionaryKeyType, ArrowNativeType, DataType, Decimal256Type,
GenericBinaryType,
+ Schema,
};
use arrow::{
array::{
@@ -224,9 +225,7 @@ fn can_cast_from_string(to_type: &DataType, options:
&SparkCastOptions) -> bool
}
Decimal128(_, _) => {
// https://github.com/apache/datafusion-comet/issues/325
- // Does not support inputs ending with 'd' or 'f'. Does not
support 'inf'.
- // Does not support ANSI mode. Returns 0.0 instead of null if
input contains no digits
-
+ // Does not support fullwidth digits and null byte handling.
options.allow_incompat
}
Date32 | Date64 => {
@@ -976,6 +975,12 @@ fn cast_array(
cast_string_to_timestamp(&array, to_type, eval_mode,
&cast_options.timezone)
}
(Utf8, Date32) => cast_string_to_date(&array, to_type, eval_mode),
+ (Utf8 | LargeUtf8, Decimal128(precision, scale)) => {
+ cast_string_to_decimal(&array, to_type, precision, scale,
eval_mode)
+ }
+ (Utf8 | LargeUtf8, Decimal256(precision, scale)) => {
+ cast_string_to_decimal(&array, to_type, precision, scale,
eval_mode)
+ }
(Int64, Int32)
| (Int64, Int16)
| (Int64, Int8)
@@ -1187,7 +1192,7 @@ fn is_datafusion_spark_compatible(
),
DataType::Utf8 if allow_incompat => matches!(
to_type,
- DataType::Binary | DataType::Float32 | DataType::Float64 |
DataType::Decimal128(_, _)
+ DataType::Binary | DataType::Float32 | DataType::Float64
),
DataType::Utf8 => matches!(to_type, DataType::Binary),
DataType::Date32 => matches!(to_type, DataType::Utf8),
@@ -1976,6 +1981,306 @@ fn do_cast_string_to_int<
Ok(Some(result))
}
+fn cast_string_to_decimal(
+ array: &ArrayRef,
+ to_type: &DataType,
+ precision: &u8,
+ scale: &i8,
+ eval_mode: EvalMode,
+) -> SparkResult<ArrayRef> {
+ match to_type {
+ DataType::Decimal128(_, _) => {
+ cast_string_to_decimal128_impl(array, eval_mode, *precision,
*scale)
+ }
+ DataType::Decimal256(_, _) => {
+ cast_string_to_decimal256_impl(array, eval_mode, *precision,
*scale)
+ }
+ _ => Err(SparkError::Internal(format!(
+ "Unexpected type in cast_string_to_decimal: {:?}",
+ to_type
+ ))),
+ }
+}
+
+fn cast_string_to_decimal128_impl(
+ array: &ArrayRef,
+ eval_mode: EvalMode,
+ precision: u8,
+ scale: i8,
+) -> SparkResult<ArrayRef> {
+ let string_array = array
+ .as_any()
+ .downcast_ref::<StringArray>()
+ .ok_or_else(|| SparkError::Internal("Expected string
array".to_string()))?;
+
+ let mut decimal_builder =
Decimal128Builder::with_capacity(string_array.len());
+
+ for i in 0..string_array.len() {
+ if string_array.is_null(i) {
+ decimal_builder.append_null();
+ } else {
+ let str_value = string_array.value(i);
+ match parse_string_to_decimal(str_value, precision, scale) {
+ Ok(Some(decimal_value)) => {
+ decimal_builder.append_value(decimal_value);
+ }
+ Ok(None) => {
+ if eval_mode == EvalMode::Ansi {
+ return Err(invalid_value(
+ string_array.value(i),
+ "STRING",
+ &format!("DECIMAL({},{})", precision, scale),
+ ));
+ }
+ decimal_builder.append_null();
+ }
+ Err(e) => {
+ if eval_mode == EvalMode::Ansi {
+ return Err(e);
+ }
+ decimal_builder.append_null();
+ }
+ }
+ }
+ }
+
+ Ok(Arc::new(
+ decimal_builder
+ .with_precision_and_scale(precision, scale)?
+ .finish(),
+ ))
+}
+
+fn cast_string_to_decimal256_impl(
+ array: &ArrayRef,
+ eval_mode: EvalMode,
+ precision: u8,
+ scale: i8,
+) -> SparkResult<ArrayRef> {
+ let string_array = array
+ .as_any()
+ .downcast_ref::<StringArray>()
+ .ok_or_else(|| SparkError::Internal("Expected string
array".to_string()))?;
+
+ let mut decimal_builder =
PrimitiveBuilder::<Decimal256Type>::with_capacity(string_array.len());
+
+ for i in 0..string_array.len() {
+ if string_array.is_null(i) {
+ decimal_builder.append_null();
+ } else {
+ let str_value = string_array.value(i);
+ match parse_string_to_decimal(str_value, precision, scale) {
+ Ok(Some(decimal_value)) => {
+ // Convert i128 to i256
+ let i256_value = i256::from_i128(decimal_value);
+ decimal_builder.append_value(i256_value);
+ }
+ Ok(None) => {
+ if eval_mode == EvalMode::Ansi {
+ return Err(invalid_value(
+ str_value,
+ "STRING",
+ &format!("DECIMAL({},{})", precision, scale),
+ ));
+ }
+ decimal_builder.append_null();
+ }
+ Err(e) => {
+ if eval_mode == EvalMode::Ansi {
+ return Err(e);
+ }
+ decimal_builder.append_null();
+ }
+ }
+ }
+ }
+
+ Ok(Arc::new(
+ decimal_builder
+ .with_precision_and_scale(precision, scale)?
+ .finish(),
+ ))
+}
+
+/// Parse a string to decimal following Spark's behavior
+fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) ->
SparkResult<Option<i128>> {
+ let string_bytes = s.as_bytes();
+ let mut start = 0;
+ let mut end = string_bytes.len();
+
+ // trim whitespaces
+ while start < end && string_bytes[start].is_ascii_whitespace() {
+ start += 1;
+ }
+ while end > start && string_bytes[end - 1].is_ascii_whitespace() {
+ end -= 1;
+ }
+
+ let trimmed = &s[start..end];
+
+ if trimmed.is_empty() {
+ return Ok(None);
+ }
+ // Handle special values (inf, nan, etc.)
+ if trimmed.eq_ignore_ascii_case("inf")
+ || trimmed.eq_ignore_ascii_case("+inf")
+ || trimmed.eq_ignore_ascii_case("infinity")
+ || trimmed.eq_ignore_ascii_case("+infinity")
+ || trimmed.eq_ignore_ascii_case("-inf")
+ || trimmed.eq_ignore_ascii_case("-infinity")
+ || trimmed.eq_ignore_ascii_case("nan")
+ {
+ return Ok(None);
+ }
+
+ // validate and parse mantissa and exponent
+ match parse_decimal_str(trimmed) {
+ Ok((mantissa, exponent)) => {
+ // Convert to target scale
+ let target_scale = scale as i32;
+ let scale_adjustment = target_scale - exponent;
+
+ let scaled_value = if scale_adjustment >= 0 {
+ // Need to multiply (increase scale) but return None if scale
is too high to fit i128
+ if scale_adjustment > 38 {
+ return Ok(None);
+ }
+ mantissa.checked_mul(10_i128.pow(scale_adjustment as u32))
+ } else {
+ // Need to multiply (increase scale) but return None if scale
is too high to fit i128
+ let abs_scale_adjustment = (-scale_adjustment) as u32;
+ if abs_scale_adjustment > 38 {
+ return Ok(Some(0));
+ }
+
+ let divisor = 10_i128.pow(abs_scale_adjustment);
+ let quotient_opt = mantissa.checked_div(divisor);
+ // Check if divisor is 0
+ if quotient_opt.is_none() {
+ return Ok(None);
+ }
+ let quotient = quotient_opt.unwrap();
+ let remainder = mantissa % divisor;
+
+ // Round half up: if abs(remainder) >= divisor/2, round away
from zero
+ let half_divisor = divisor / 2;
+ let rounded = if remainder.abs() >= half_divisor {
+ if mantissa >= 0 {
+ quotient + 1
+ } else {
+ quotient - 1
+ }
+ } else {
+ quotient
+ };
+ Some(rounded)
+ };
+
+ match scaled_value {
+ Some(value) => {
+ // Check if it fits target precision
+ if is_validate_decimal_precision(value, precision) {
+ Ok(Some(value))
+ } else {
+ Ok(None)
+ }
+ }
+ None => {
+ // Overflow while scaling
+ Ok(None)
+ }
+ }
+ }
+ Err(_) => Ok(None),
+ }
+}
+
+/// Parse a decimal string into mantissa and scale
+/// e.g., "123.45" -> (12345, 2), "-0.001" -> (-1, 3)
+fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> {
+ if s.is_empty() {
+ return Err("Empty string".to_string());
+ }
+
+ let (mantissa_str, exponent) = if let Some(e_pos) = s.find(|c| ['e',
'E'].contains(&c)) {
+ let mantissa_part = &s[..e_pos];
+ let exponent_part = &s[e_pos + 1..];
+ // Parse exponent
+ let exp: i32 = exponent_part
+ .parse()
+ .map_err(|e| format!("Invalid exponent: {}", e))?;
+
+ (mantissa_part, exp)
+ } else {
+ (s, 0)
+ };
+
+ let negative = mantissa_str.starts_with('-');
+ let mantissa_str = if negative || mantissa_str.starts_with('+') {
+ &mantissa_str[1..]
+ } else {
+ mantissa_str
+ };
+
+ if mantissa_str.starts_with('+') || mantissa_str.starts_with('-') {
+ return Err("Invalid sign format".to_string());
+ }
+
+ let (integral_part, fractional_part) = match mantissa_str.find('.') {
+ Some(dot_pos) => {
+ if mantissa_str[dot_pos + 1..].contains('.') {
+ return Err("Multiple decimal points".to_string());
+ }
+ (&mantissa_str[..dot_pos], &mantissa_str[dot_pos + 1..])
+ }
+ None => (mantissa_str, ""),
+ };
+
+ if integral_part.is_empty() && fractional_part.is_empty() {
+ return Err("No digits found".to_string());
+ }
+
+ if !integral_part.is_empty() && !integral_part.bytes().all(|b|
b.is_ascii_digit()) {
+ return Err("Invalid integral part".to_string());
+ }
+
+ if !fractional_part.is_empty() && !fractional_part.bytes().all(|b|
b.is_ascii_digit()) {
+ return Err("Invalid fractional part".to_string());
+ }
+
+ // Parse integral part
+ let integral_value: i128 = if integral_part.is_empty() {
+ // Empty integral part is valid (e.g., ".5" or "-.7e9")
+ 0
+ } else {
+ integral_part
+ .parse()
+ .map_err(|_| "Invalid integral part".to_string())?
+ };
+
+ // Parse fractional part
+ let fractional_scale = fractional_part.len() as i32;
+ let fractional_value: i128 = if fractional_part.is_empty() {
+ 0
+ } else {
+ fractional_part
+ .parse()
+ .map_err(|_| "Invalid fractional part".to_string())?
+ };
+
+ // Combine: value = integral * 10^fractional_scale + fractional
+ let mantissa = integral_value
+ .checked_mul(10_i128.pow(fractional_scale as u32))
+ .and_then(|v| v.checked_add(fractional_value))
+ .ok_or("Overflow in mantissa calculation")?;
+
+ let final_mantissa = if negative { -mantissa } else { mantissa };
+ // final scale = fractional_scale - exponent
+ // For example : "1.23E-5" has fractional_scale=2, exponent=-5, so scale =
2 - (-5) = 7
+ let final_scale = fractional_scale - exponent;
+ Ok((final_mantissa, final_scale))
+}
+
/// Either return Ok(None) or Err(SparkError::CastInvalidValue) depending on
the evaluation mode
#[inline]
fn none_or_err<T>(eval_mode: EvalMode, type_name: &str, str: &str) ->
SparkResult<Option<T>> {
diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
index 98ce8ac44..14db7c278 100644
--- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
+++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
@@ -192,9 +192,8 @@ object CometCast extends CometExpressionSerde[Cast] with
CometExprShim {
"Does not support ANSI mode."))
case _: DecimalType =>
// https://github.com/apache/datafusion-comet/issues/325
- Incompatible(
- Some("Does not support inputs ending with 'd' or 'f'. Does not
support 'inf'. " +
- "Does not support ANSI mode. Returns 0.0 instead of null if input
contains no digits"))
+ Incompatible(Some("""Does not support fullwidth unicode digits (e.g
\\uFF10)
+ |or strings containing null bytes (e.g \\u0000)""".stripMargin))
case DataTypes.DateType =>
// https://github.com/apache/datafusion-comet/issues/327
Compatible(Some("Only supports years between 262143 BC and 262142 AD"))
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index 90386a979..a7bd6febf 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -661,7 +661,6 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
// https://github.com/apache/datafusion-comet/issues/326
castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"),
DataTypes.DoubleType)
}
-
test("cast StringType to DoubleType (partial support)") {
withSQLConf(
CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true",
@@ -673,21 +672,88 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
+// This is to pass the first `all cast combinations are covered`
ignore("cast StringType to DecimalType(10,2)") {
- // https://github.com/apache/datafusion-comet/issues/325
- val values = gen.generateStrings(dataSize, numericPattern, 8).toDF("a")
- castTest(values, DataTypes.createDecimalType(10, 2))
+ val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a")
+ castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = false)
}
- test("cast StringType to DecimalType(10,2) (partial support)") {
- withSQLConf(
- CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true",
- SQLConf.ANSI_ENABLED.key -> "false") {
- val values = gen
- .generateStrings(dataSize, "0123456789.", 8)
- .filter(_.exists(_.isDigit))
- .toDF("a")
- castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = false)
+ test("cast StringType to DecimalType(10,2) (does not support fullwidth
unicode digits)") {
+ withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) ->
"true") {
+ // TODO fix for Spark 4.0.0
+ assume(!isSpark40Plus)
+ val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a")
+ Seq(true, false).foreach(ansiEnabled =>
+ castTest(values, DataTypes.createDecimalType(10, 2), testAnsi =
ansiEnabled))
+ }
+ }
+
+ test("cast StringType to DecimalType(2,2)") {
+ withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) ->
"true") {
+ // TODO fix for Spark 4.0.0
+ assume(!isSpark40Plus)
+ val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a")
+ Seq(true, false).foreach(ansiEnabled =>
+ castTest(values, DataTypes.createDecimalType(2, 2), testAnsi =
ansiEnabled))
+ }
+ }
+
+ test("cast StringType to DecimalType(38,10) high precision") {
+ withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) ->
"true") {
+ // TODO fix for Spark 4.0.0
+ assume(!isSpark40Plus)
+ val values = gen.generateStrings(dataSize, numericPattern, 38).toDF("a")
+ Seq(true, false).foreach(ansiEnabled =>
+ castTest(values, DataTypes.createDecimalType(38, 10), testAnsi =
ansiEnabled))
+ }
+ }
+
+ test("cast StringType to DecimalType(10,2) basic values") {
+ withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) ->
"true") {
+ // TODO fix for Spark 4.0.0
+ assume(!isSpark40Plus)
+ val values = Seq(
+ "123.45",
+ "-67.89",
+ "-67.89",
+ "-67.895",
+ "67.895",
+ "0.001",
+ "999.99",
+ "123.456",
+ "123.45D",
+ ".5",
+ "5.",
+ "+123.45",
+ " 123.45 ",
+ "inf",
+ "",
+ "abc",
+ null).toDF("a")
+ Seq(true, false).foreach(ansiEnabled =>
+ castTest(values, DataTypes.createDecimalType(10, 2), testAnsi =
ansiEnabled))
+ }
+ }
+
+ test("cast StringType to Decimal type scientific notation") {
+ withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) ->
"true") {
+ // TODO fix for Spark 4.0.0
+ assume(!isSpark40Plus)
+ val values = Seq(
+ "1.23E-5",
+ "1.23e10",
+ "1.23E+10",
+ "-1.23e-5",
+ "1e5",
+ "1E-2",
+ "-1.5e3",
+ "1.23E0",
+ "0e0",
+ "1.23e",
+ "e5",
+ null).toDF("a")
+ Seq(true, false).foreach(ansiEnabled =>
+ castTest(values, DataTypes.createDecimalType(23, 8), testAnsi =
ansiEnabled))
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]