martin-g commented on code in PR #2835:
URL: https://github.com/apache/datafusion-comet/pull/2835#discussion_r2581626417
##########
native/spark-expr/src/conversion_funcs/cast.rs:
##########
@@ -976,6 +961,13 @@ fn cast_array(
cast_string_to_timestamp(&array, to_type, eval_mode,
&cast_options.timezone)
}
(Utf8, Date32) => cast_string_to_date(&array, to_type, eval_mode),
+ (Utf8, Float16 | Float32 | Float64) => cast_string_to_float(&array,
to_type, eval_mode),
Review Comment:
cast_string_to_float() does not support Float16
##########
native/spark-expr/src/conversion_funcs/cast.rs:
##########
@@ -1058,6 +1050,363 @@ fn cast_array(
Ok(spark_cast_postprocess(cast_result?, from_type, to_type))
}
+fn cast_string_to_decimal(
+ array: &ArrayRef,
+ to_type: &DataType,
+ precision: &u8,
+ scale: &i8,
+ eval_mode: EvalMode,
+) -> SparkResult<ArrayRef> {
+ match to_type {
+ DataType::Decimal128(_, _) => {
+ cast_string_to_decimal128_impl(array, eval_mode, *precision,
*scale)
+ }
+ DataType::Decimal256(_, _) => {
+ cast_string_to_decimal256_impl(array, eval_mode, *precision,
*scale)
+ }
+ _ => Err(SparkError::Internal(format!(
+ "Unexpected type in cast_string_to_decimal: {:?}",
+ to_type
+ ))),
+ }
+}
+
+fn cast_string_to_decimal128_impl(
+ array: &ArrayRef,
+ eval_mode: EvalMode,
+ precision: u8,
+ scale: i8,
+) -> SparkResult<ArrayRef> {
+ let string_array = array
+ .as_any()
+ .downcast_ref::<StringArray>()
Review Comment:
```suggestion
.downcast_ref::<GenericStringArray>()
```
To support LargeUtf8
##########
spark/src/test/scala/org/apache/comet/CometCastSuite.scala:
##########
@@ -652,35 +652,42 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"),
DataTypes.LongType)
}
- ignore("cast StringType to FloatType") {
- // https://github.com/apache/datafusion-comet/issues/326
- castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"),
DataTypes.FloatType)
- }
+ def specialValues: Seq[String] = Seq(
+ "1.5f",
+ "1.5F",
+ "2.0d",
+ "2.0D",
+ "3.14159265358979d",
+ "inf",
+ "Inf",
+ "INF",
+ "+inf",
+ "+Infinity",
+ "-inf",
+ "-Infinity",
+ "NaN",
+ "nan",
+ "NAN",
+ "1.23e4",
+ "1.23E4",
+ "-1.23e-4",
+ " 123.456789 ",
+ "0.0",
+ "-0.0",
+ "",
+ "xyz",
+ null)
- test("cast StringType to FloatType (partial support)") {
- withSQLConf(
- CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true",
- SQLConf.ANSI_ENABLED.key -> "false") {
- castTest(
- gen.generateStrings(dataSize, "0123456789.", 8).toDF("a"),
- DataTypes.FloatType,
- testAnsi = false)
+ test("cast StringType to FloatType") {
+ Seq(true, false).foreach { v =>
+ castTest(specialValues.toDF("a"), DataTypes.FloatType, testAnsi = v)
}
- }
- ignore("cast StringType to DoubleType") {
- // https://github.com/apache/datafusion-comet/issues/326
- castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"),
DataTypes.DoubleType)
}
- test("cast StringType to DoubleType (partial support)") {
- withSQLConf(
- CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true",
- SQLConf.ANSI_ENABLED.key -> "false") {
- castTest(
- gen.generateStrings(dataSize, "0123456789.", 8).toDF("a"),
- DataTypes.DoubleType,
- testAnsi = false)
+ test("cast StringType to DoubleType") {
+ Seq(true, false).foreach { v =>
+ castTest(specialValues.toDF("a"), DataTypes.FloatType, testAnsi = v)
Review Comment:
```suggestion
castTest(specialValues.toDF("a"), DataTypes.DoubleType, testAnsi = v)
```
##########
native/spark-expr/src/conversion_funcs/cast.rs:
##########
@@ -1058,6 +1050,363 @@ fn cast_array(
Ok(spark_cast_postprocess(cast_result?, from_type, to_type))
}
+fn cast_string_to_decimal(
+ array: &ArrayRef,
+ to_type: &DataType,
+ precision: &u8,
+ scale: &i8,
+ eval_mode: EvalMode,
+) -> SparkResult<ArrayRef> {
+ match to_type {
+ DataType::Decimal128(_, _) => {
+ cast_string_to_decimal128_impl(array, eval_mode, *precision,
*scale)
+ }
+ DataType::Decimal256(_, _) => {
+ cast_string_to_decimal256_impl(array, eval_mode, *precision,
*scale)
+ }
+ _ => Err(SparkError::Internal(format!(
+ "Unexpected type in cast_string_to_decimal: {:?}",
+ to_type
+ ))),
+ }
+}
+
+fn cast_string_to_decimal128_impl(
+ array: &ArrayRef,
+ eval_mode: EvalMode,
+ precision: u8,
+ scale: i8,
+) -> SparkResult<ArrayRef> {
+ let string_array = array
+ .as_any()
+ .downcast_ref::<StringArray>()
+ .ok_or_else(|| SparkError::Internal("Expected string
array".to_string()))?;
+
+ let mut decimal_builder =
Decimal128Builder::with_capacity(string_array.len());
+
+ for i in 0..string_array.len() {
+ if string_array.is_null(i) {
+ decimal_builder.append_null();
+ } else {
+ let str_value = string_array.value(i).trim();
+ match parse_string_to_decimal(str_value, precision, scale) {
+ Ok(Some(decimal_value)) => {
+ decimal_builder.append_value(decimal_value);
+ }
+ Ok(None) => {
+ if eval_mode == EvalMode::Ansi {
+ return Err(invalid_value(
+ str_value,
+ "STRING",
+ &format!("DECIMAL({},{})", precision, scale),
+ ));
+ }
+ decimal_builder.append_null();
+ }
+ Err(e) => {
+ if eval_mode == EvalMode::Ansi {
+ return Err(e);
+ }
+ decimal_builder.append_null();
+ }
+ }
+ }
+ }
+
+ Ok(Arc::new(
+ decimal_builder
+ .with_precision_and_scale(precision, scale)?
+ .finish(),
+ ))
+}
+
+fn cast_string_to_decimal256_impl(
+ array: &ArrayRef,
+ eval_mode: EvalMode,
+ precision: u8,
+ scale: i8,
+) -> SparkResult<ArrayRef> {
+ let string_array = array
+ .as_any()
+ .downcast_ref::<StringArray>()
+ .ok_or_else(|| SparkError::Internal("Expected string
array".to_string()))?;
+
+ let mut decimal_builder =
PrimitiveBuilder::<Decimal256Type>::with_capacity(string_array.len());
+
+ for i in 0..string_array.len() {
+ if string_array.is_null(i) {
+ decimal_builder.append_null();
+ } else {
+ let str_value = string_array.value(i).trim();
+ match parse_string_to_decimal(str_value, precision, scale) {
+ Ok(Some(decimal_value)) => {
+ // Convert i128 to i256
+ let i256_value = i256::from_i128(decimal_value);
+ decimal_builder.append_value(i256_value);
+ }
+ Ok(None) => {
+ if eval_mode == EvalMode::Ansi {
+ return Err(invalid_value(
+ str_value,
+ "STRING",
+ &format!("DECIMAL({},{})", precision, scale),
+ ));
+ }
+ decimal_builder.append_null();
+ }
+ Err(e) => {
+ if eval_mode == EvalMode::Ansi {
+ return Err(e);
+ }
+ decimal_builder.append_null();
+ }
+ }
+ }
+ }
+
+ Ok(Arc::new(
+ decimal_builder
+ .with_precision_and_scale(precision, scale)?
+ .finish(),
+ ))
+}
+
+/// Parse a string to decimal following Spark's behavior
+/// Returns Ok(Some(value)) if successful, Ok(None) if null, Err if invalid in
ANSI mode
+fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) ->
SparkResult<Option<i128>> {
+ if s.is_empty() {
+ return Ok(None);
+ }
+
+ // Handle special values (inf, nan, etc.)
+ let s_lower = s.to_lowercase();
+ if s_lower == "inf"
+ || s_lower == "+inf"
+ || s_lower == "infinity"
+ || s_lower == "+infinity"
+ || s_lower == "-inf"
+ || s_lower == "-infinity"
+ || s_lower == "nan"
+ {
+ return Ok(None);
+ }
+
+ // Parse the string as a decimal number
+ // Note: We do NOT strip 'D' or 'F' suffixes - let parsing fail naturally
+ // This matches Spark's behavior which uses JavaBigDecimal(string)
+ match parse_decimal_str(s) {
+ Ok((mantissa, exponent)) => {
+ // Convert to target scale
+ let target_scale = scale as i32;
+ let scale_adjustment = target_scale - exponent;
+
+ let scaled_value = if scale_adjustment >= 0 {
+ // Need to multiply (increase scale)
+ mantissa.checked_mul(10_i128.pow(scale_adjustment as u32))
+ } else {
+ // Need to divide (decrease scale) - use rounding half up
+ let divisor = 10_i128.pow((-scale_adjustment) as u32);
+ let quotient = mantissa / divisor;
+ let remainder = mantissa % divisor;
+
+ // Round half up: if abs(remainder) >= divisor/2, round away
from zero
+ let half_divisor = divisor / 2;
+ let rounded = if remainder.abs() >= half_divisor {
+ if mantissa >= 0 {
+ quotient + 1
+ } else {
+ quotient - 1
+ }
+ } else {
+ quotient
+ };
+ Some(rounded)
+ };
+
+ match scaled_value {
+ Some(value) => {
+ // Check if it fits target precision
+ if is_validate_decimal_precision(value, precision) {
+ Ok(Some(value))
+ } else {
+ // Overflow
+ Ok(None)
+ }
+ }
+ None => {
+ // Overflow during scaling
+ Ok(None)
+ }
+ }
+ }
+ Err(_) => Ok(None),
+ }
+}
+
+/// Parse a decimal string into (mantissa, scale)
+/// e.g., "123.45" -> (12345, 2), "-0.001" -> (-1, 3)
Review Comment:
Should it support scientific notation too (e.g. "1.23e4", "1E-5") ?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]