coderfender commented on code in PR #2835:
URL: https://github.com/apache/datafusion-comet/pull/2835#discussion_r2609222917
##########
native/spark-expr/src/conversion_funcs/cast.rs:
##########
@@ -1058,6 +1055,351 @@ fn cast_array(
Ok(spark_cast_postprocess(cast_result?, from_type, to_type))
}
+fn cast_string_to_decimal(
+ array: &ArrayRef,
+ to_type: &DataType,
+ precision: &u8,
+ scale: &i8,
+ eval_mode: EvalMode,
+) -> SparkResult<ArrayRef> {
+ match to_type {
+ DataType::Decimal128(_, _) => {
+ cast_string_to_decimal128_impl(array, eval_mode, *precision,
*scale)
+ }
+ DataType::Decimal256(_, _) => {
+ cast_string_to_decimal256_impl(array, eval_mode, *precision,
*scale)
+ }
+ _ => Err(SparkError::Internal(format!(
+ "Unexpected type in cast_string_to_decimal: {:?}",
+ to_type
+ ))),
+ }
+}
+
+fn cast_string_to_decimal128_impl(
+ array: &ArrayRef,
+ eval_mode: EvalMode,
+ precision: u8,
+ scale: i8,
+) -> SparkResult<ArrayRef> {
+ let string_array = array
+ .as_any()
+ .downcast_ref::<StringArray>()
+ .ok_or_else(|| SparkError::Internal("Expected string
array".to_string()))?;
+
+ let mut decimal_builder =
Decimal128Builder::with_capacity(string_array.len());
+
+ for i in 0..string_array.len() {
+ if string_array.is_null(i) {
+ decimal_builder.append_null();
+ } else {
+ let str_value = string_array.value(i).trim();
+ match parse_string_to_decimal(str_value, precision, scale) {
+ Ok(Some(decimal_value)) => {
+ decimal_builder.append_value(decimal_value);
+ }
+ Ok(None) => {
+ if eval_mode == EvalMode::Ansi {
+ return Err(invalid_value(
+ str_value,
+ "STRING",
+ &format!("DECIMAL({},{})", precision, scale),
+ ));
+ }
+ decimal_builder.append_null();
+ }
+ Err(e) => {
+ if eval_mode == EvalMode::Ansi {
+ return Err(e);
+ }
+ decimal_builder.append_null();
+ }
+ }
+ }
+ }
+
+ Ok(Arc::new(
+ decimal_builder
+ .with_precision_and_scale(precision, scale)?
+ .finish(),
+ ))
+}
+
+fn cast_string_to_decimal256_impl(
+ array: &ArrayRef,
+ eval_mode: EvalMode,
+ precision: u8,
+ scale: i8,
+) -> SparkResult<ArrayRef> {
+ let string_array = array
+ .as_any()
+ .downcast_ref::<StringArray>()
+ .ok_or_else(|| SparkError::Internal("Expected string
array".to_string()))?;
+
+ let mut decimal_builder =
PrimitiveBuilder::<Decimal256Type>::with_capacity(string_array.len());
+
+ for i in 0..string_array.len() {
+ if string_array.is_null(i) {
+ decimal_builder.append_null();
+ } else {
+ let str_value = string_array.value(i).trim();
+ match parse_string_to_decimal(str_value, precision, scale) {
+ Ok(Some(decimal_value)) => {
+ // Convert i128 to i256
+ let i256_value = i256::from_i128(decimal_value);
+ decimal_builder.append_value(i256_value);
+ }
+ Ok(None) => {
+ if eval_mode == EvalMode::Ansi {
+ return Err(invalid_value(
+ str_value,
+ "STRING",
+ &format!("DECIMAL({},{})", precision, scale),
+ ));
+ }
+ decimal_builder.append_null();
+ }
+ Err(e) => {
+ if eval_mode == EvalMode::Ansi {
+ return Err(e);
+ }
+ decimal_builder.append_null();
+ }
+ }
+ }
+ }
+
+ Ok(Arc::new(
+ decimal_builder
+ .with_precision_and_scale(precision, scale)?
+ .finish(),
+ ))
+}
+
+/// Parse a string to decimal following Spark's behavior
+/// Returns Ok(Some(value)) if successful, Ok(None) if null, Err if invalid in
ANSI mode
+fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) ->
SparkResult<Option<i128>> {
+ if s.is_empty() {
+ return Ok(None);
+ }
+
+ // Handle special values (inf, nan, etc.)
+ let s_lower = s.to_lowercase();
Review Comment:
Done! I also added tests to fuzz test various random inputs along with
specific tests to make include scientific notation , invalid values etc
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]