This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 3e9f8503a chore: Cast module refactor boolean module (#3491)
3e9f8503a is described below

commit 3e9f8503a3d80b856025604dcbf7323258299cb7
Author: B Vadlamani <[email protected]>
AuthorDate: Thu Feb 19 13:57:01 2026 -0800

    chore: Cast module refactor boolean module (#3491)
---
 native/spark-expr/Cargo.toml                      |   4 +
 native/spark-expr/benches/cast_from_boolean.rs    |  89 ++++++++++
 native/spark-expr/src/conversion_funcs/boolean.rs | 196 ++++++++++++++++++++++
 native/spark-expr/src/conversion_funcs/cast.rs    | 138 ++-------------
 native/spark-expr/src/conversion_funcs/mod.rs     |   2 +
 native/spark-expr/src/conversion_funcs/utils.rs   | 128 ++++++++++++++
 6 files changed, 430 insertions(+), 127 deletions(-)

diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml
index 63e1c0476..e7c238f7e 100644
--- a/native/spark-expr/Cargo.toml
+++ b/native/spark-expr/Cargo.toml
@@ -99,3 +99,7 @@ harness = false
 [[test]]
 name = "test_udf_registration"
 path = "tests/spark_expr_reg.rs"
+
+[[bench]]
+name = "cast_from_boolean"
+harness = false
diff --git a/native/spark-expr/benches/cast_from_boolean.rs 
b/native/spark-expr/benches/cast_from_boolean.rs
new file mode 100644
index 000000000..dbb986df9
--- /dev/null
+++ b/native/spark-expr/benches/cast_from_boolean.rs
@@ -0,0 +1,89 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::array::{BooleanBuilder, RecordBatch};
+use arrow::datatypes::{DataType, Field, Schema};
+use criterion::{criterion_group, criterion_main, Criterion};
+use datafusion::physical_expr::expressions::Column;
+use datafusion::physical_expr::PhysicalExpr;
+use datafusion_comet_spark_expr::{Cast, EvalMode, SparkCastOptions};
+use std::sync::Arc;
+
+fn criterion_benchmark(c: &mut Criterion) {
+    let expr = Arc::new(Column::new("a", 0));
+    let boolean_batch = create_boolean_batch();
+    let spark_cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC", 
false);
+    let cast_to_i8 = Cast::new(expr.clone(), DataType::Int8, 
spark_cast_options.clone());
+    let cast_to_i16 = Cast::new(expr.clone(), DataType::Int16, 
spark_cast_options.clone());
+    let cast_to_i32 = Cast::new(expr.clone(), DataType::Int32, 
spark_cast_options.clone());
+    let cast_to_i64 = Cast::new(expr.clone(), DataType::Int64, 
spark_cast_options.clone());
+    let cast_to_f32 = Cast::new(expr.clone(), DataType::Float32, 
spark_cast_options.clone());
+    let cast_to_f64 = Cast::new(expr.clone(), DataType::Float64, 
spark_cast_options.clone());
+    let cast_to_str = Cast::new(expr.clone(), DataType::Utf8, 
spark_cast_options.clone());
+    let cast_to_decimal = Cast::new(expr, DataType::Decimal128(10, 4), 
spark_cast_options);
+
+    let mut group = c.benchmark_group("cast_bool".to_string());
+    group.bench_function("i8", |b| {
+        b.iter(|| cast_to_i8.evaluate(&boolean_batch).unwrap());
+    });
+    group.bench_function("i16", |b| {
+        b.iter(|| cast_to_i16.evaluate(&boolean_batch).unwrap());
+    });
+    group.bench_function("i32", |b| {
+        b.iter(|| cast_to_i32.evaluate(&boolean_batch).unwrap());
+    });
+    group.bench_function("i64", |b| {
+        b.iter(|| cast_to_i64.evaluate(&boolean_batch).unwrap());
+    });
+    group.bench_function("f32", |b| {
+        b.iter(|| cast_to_f32.evaluate(&boolean_batch).unwrap());
+    });
+    group.bench_function("f64", |b| {
+        b.iter(|| cast_to_f64.evaluate(&boolean_batch).unwrap());
+    });
+    group.bench_function("str", |b| {
+        b.iter(|| cast_to_str.evaluate(&boolean_batch).unwrap());
+    });
+    group.bench_function("decimal", |b| {
+        b.iter(|| cast_to_decimal.evaluate(&boolean_batch).unwrap());
+    });
+}
+
+fn create_boolean_batch() -> RecordBatch {
+    let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Boolean, 
true)]));
+    let mut b = BooleanBuilder::with_capacity(1000);
+    for i in 0..1000 {
+        if i % 10 == 0 {
+            b.append_null();
+        } else {
+            b.append_value(rand::random::<bool>());
+        }
+    }
+    let array = b.finish();
+    RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
+}
+
+fn config() -> Criterion {
+    Criterion::default()
+}
+
+criterion_group! {
+    name = benches;
+    config = config();
+    targets = criterion_benchmark
+}
+criterion_main!(benches);
diff --git a/native/spark-expr/src/conversion_funcs/boolean.rs 
b/native/spark-expr/src/conversion_funcs/boolean.rs
new file mode 100644
index 000000000..db288fa32
--- /dev/null
+++ b/native/spark-expr/src/conversion_funcs/boolean.rs
@@ -0,0 +1,196 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::SparkResult;
+use arrow::array::{ArrayRef, AsArray, Decimal128Array};
+use arrow::datatypes::DataType;
+use std::sync::Arc;
+
+pub fn is_df_cast_from_bool_spark_compatible(to_type: &DataType) -> bool {
+    use DataType::*;
+    matches!(
+        to_type,
+        Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Utf8
+    )
+}
+
+// only DF incompatible boolean cast
+pub fn cast_boolean_to_decimal(
+    array: &ArrayRef,
+    precision: u8,
+    scale: i8,
+) -> SparkResult<ArrayRef> {
+    let bool_array = array.as_boolean();
+    let scaled_val = 10_i128.pow(scale as u32);
+    let result: Decimal128Array = bool_array
+        .iter()
+        .map(|v| v.map(|b| if b { scaled_val } else { 0 }))
+        .collect();
+    Ok(Arc::new(result.with_precision_and_scale(precision, scale)?))
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::cast::cast_array;
+    use crate::{EvalMode, SparkCastOptions};
+    use arrow::array::{
+        Array, ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, 
Int32Array,
+        Int64Array, Int8Array, StringArray,
+    };
+    use arrow::datatypes::DataType::Decimal128;
+    use std::sync::Arc;
+
+    fn test_input_bool_array() -> ArrayRef {
+        Arc::new(BooleanArray::from(vec![Some(true), Some(false), None]))
+    }
+
+    fn test_input_spark_opts() -> SparkCastOptions {
+        SparkCastOptions::new(EvalMode::Legacy, "Asia/Kolkata", false)
+    }
+
+    #[test]
+    fn test_is_df_cast_from_bool_spark_compatible() {
+        assert!(!is_df_cast_from_bool_spark_compatible(&DataType::Boolean));
+        assert!(is_df_cast_from_bool_spark_compatible(&DataType::Int8));
+        assert!(is_df_cast_from_bool_spark_compatible(&DataType::Int16));
+        assert!(is_df_cast_from_bool_spark_compatible(&DataType::Int32));
+        assert!(is_df_cast_from_bool_spark_compatible(&DataType::Int64));
+        assert!(is_df_cast_from_bool_spark_compatible(&DataType::Float32));
+        assert!(is_df_cast_from_bool_spark_compatible(&DataType::Float64));
+        assert!(is_df_cast_from_bool_spark_compatible(&DataType::Utf8));
+        assert!(!is_df_cast_from_bool_spark_compatible(
+            &DataType::Decimal128(10, 4)
+        ));
+        assert!(!is_df_cast_from_bool_spark_compatible(&DataType::Null));
+    }
+
+    #[test]
+    fn test_bool_to_int8_cast() {
+        let result = cast_array(
+            test_input_bool_array(),
+            &DataType::Int8,
+            &test_input_spark_opts(),
+        )
+        .unwrap();
+        let arr = result.as_any().downcast_ref::<Int8Array>().unwrap();
+        assert_eq!(arr.value(0), 1);
+        assert_eq!(arr.value(1), 0);
+        assert!(arr.is_null(2));
+    }
+
+    #[test]
+    fn test_bool_to_int16_cast() {
+        let result = cast_array(
+            test_input_bool_array(),
+            &DataType::Int16,
+            &test_input_spark_opts(),
+        )
+        .unwrap();
+        let arr = result.as_any().downcast_ref::<Int16Array>().unwrap();
+        assert_eq!(arr.value(0), 1);
+        assert_eq!(arr.value(1), 0);
+        assert!(arr.is_null(2));
+    }
+
+    #[test]
+    fn test_bool_to_int32_cast() {
+        let result = cast_array(
+            test_input_bool_array(),
+            &DataType::Int32,
+            &test_input_spark_opts(),
+        )
+        .unwrap();
+        let arr = result.as_any().downcast_ref::<Int32Array>().unwrap();
+        assert_eq!(arr.value(0), 1);
+        assert_eq!(arr.value(1), 0);
+        assert!(arr.is_null(2));
+    }
+
+    #[test]
+    fn test_bool_to_int64_cast() {
+        let result = cast_array(
+            test_input_bool_array(),
+            &DataType::Int64,
+            &test_input_spark_opts(),
+        )
+        .unwrap();
+        let arr = result.as_any().downcast_ref::<Int64Array>().unwrap();
+        assert_eq!(arr.value(0), 1);
+        assert_eq!(arr.value(1), 0);
+        assert!(arr.is_null(2));
+    }
+
+    #[test]
+    fn test_bool_to_float32_cast() {
+        let result = cast_array(
+            test_input_bool_array(),
+            &DataType::Float32,
+            &test_input_spark_opts(),
+        )
+        .unwrap();
+        let arr = result.as_any().downcast_ref::<Float32Array>().unwrap();
+        assert_eq!(arr.value(0), 1.0);
+        assert_eq!(arr.value(1), 0.0);
+        assert!(arr.is_null(2));
+    }
+
+    #[test]
+    fn test_bool_to_float64_cast() {
+        let result = cast_array(
+            test_input_bool_array(),
+            &DataType::Float64,
+            &test_input_spark_opts(),
+        )
+        .unwrap();
+        let arr = result.as_any().downcast_ref::<Float64Array>().unwrap();
+        assert_eq!(arr.value(0), 1.0);
+        assert_eq!(arr.value(1), 0.0);
+        assert!(arr.is_null(2));
+    }
+
+    #[test]
+    fn test_bool_to_string_cast() {
+        let result = cast_array(
+            test_input_bool_array(),
+            &DataType::Utf8,
+            &test_input_spark_opts(),
+        )
+        .unwrap();
+        let arr = result.as_any().downcast_ref::<StringArray>().unwrap();
+        assert_eq!(arr.value(0), "true");
+        assert_eq!(arr.value(1), "false");
+        assert!(arr.is_null(2));
+    }
+
+    #[test]
+    fn test_bool_to_decimal_cast() {
+        let result = cast_array(
+            test_input_bool_array(),
+            &Decimal128(10, 4),
+            &test_input_spark_opts(),
+        )
+        .unwrap();
+        let expected_arr = Decimal128Array::from(vec![10000_i128, 0_i128])
+            .with_precision_and_scale(10, 4)
+            .unwrap();
+        let arr = result.as_any().downcast_ref::<Decimal128Array>().unwrap();
+        assert_eq!(arr.value(0), expected_arr.value(0));
+        assert_eq!(arr.value(1), expected_arr.value(1));
+        assert!(arr.is_null(2));
+    }
+}
diff --git a/native/spark-expr/src/conversion_funcs/cast.rs 
b/native/spark-expr/src/conversion_funcs/cast.rs
index f5ab83b8a..004668b8f 100644
--- a/native/spark-expr/src/conversion_funcs/cast.rs
+++ b/native/spark-expr/src/conversion_funcs/cast.rs
@@ -15,6 +15,11 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use crate::conversion_funcs::boolean::{
+    cast_boolean_to_decimal, is_df_cast_from_bool_spark_compatible,
+};
+use crate::conversion_funcs::utils::spark_cast_postprocess;
+use crate::conversion_funcs::utils::{cast_overflow, invalid_value};
 use crate::utils::array_with_timezone;
 use crate::EvalMode::Legacy;
 use crate::{timezone, BinaryOutputStyle};
@@ -37,7 +42,7 @@ use arrow::{
         GenericStringArray, Int16Array, Int32Array, Int64Array, Int8Array, 
OffsetSizeTrait,
         PrimitiveArray,
     },
-    compute::{cast_with_options, take, unary, CastOptions},
+    compute::{cast_with_options, take, CastOptions},
     datatypes::{
         is_validate_decimal_precision, ArrowPrimitiveType, Decimal128Type, 
Float32Type,
         Float64Type, Int64Type, TimestampMicrosecondType,
@@ -48,16 +53,10 @@ use arrow::{
 };
 use base64::prelude::*;
 use chrono::{DateTime, NaiveDate, TimeZone, Timelike};
-use datafusion::common::{
-    cast::as_generic_string_array, internal_err, DataFusionError, Result as 
DataFusionResult,
-    ScalarValue,
-};
+use datafusion::common::{internal_err, DataFusionError, Result as 
DataFusionResult, ScalarValue};
 use datafusion::physical_expr::PhysicalExpr;
 use datafusion::physical_plan::ColumnarValue;
-use num::{
-    cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub, 
Integer, ToPrimitive,
-    Zero,
-};
+use num::{cast::AsPrimitive, traits::CheckedNeg, CheckedSub, Integer, 
ToPrimitive, Zero};
 use regex::Regex;
 use std::str::FromStr;
 use std::{
@@ -70,7 +69,7 @@ use std::{
 
 static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f");
 
-const MICROS_PER_SECOND: i64 = 1000000;
+pub(crate) const MICROS_PER_SECOND: i64 = 1000000;
 
 static CAST_OPTIONS: CastOptions = CastOptions {
     safe: true,
@@ -776,7 +775,7 @@ fn dict_from_values<K: ArrowDictionaryKeyType>(
     Ok(Arc::new(dict_array))
 }
 
-fn cast_array(
+pub(crate) fn cast_array(
     array: ArrayRef,
     to_type: &DataType,
     cast_options: &SparkCastOptions,
@@ -1018,16 +1017,6 @@ fn cast_date_to_timestamp(
     ))
 }
 
-fn cast_boolean_to_decimal(array: &ArrayRef, precision: u8, scale: i8) -> 
SparkResult<ArrayRef> {
-    let bool_array = array.as_boolean();
-    let scaled_val = 10_i128.pow(scale as u32);
-    let result: Decimal128Array = bool_array
-        .iter()
-        .map(|v| v.map(|b| if b { scaled_val } else { 0 }))
-        .collect();
-    Ok(Arc::new(result.with_precision_and_scale(precision, scale)?))
-}
-
 fn cast_string_to_float(
     array: &ArrayRef,
     to_type: &DataType,
@@ -1186,16 +1175,7 @@ fn is_datafusion_spark_compatible(from_type: &DataType, 
to_type: &DataType) -> b
         DataType::Null => {
             matches!(to_type, DataType::List(_))
         }
-        DataType::Boolean => matches!(
-            to_type,
-            DataType::Int8
-                | DataType::Int16
-                | DataType::Int32
-                | DataType::Int64
-                | DataType::Float32
-                | DataType::Float64
-                | DataType::Utf8
-        ),
+        DataType::Boolean => is_df_cast_from_bool_spark_compatible(to_type),
         DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 
=> {
             matches!(
                 to_type,
@@ -2437,24 +2417,6 @@ fn parse_decimal_str(
     Ok((final_mantissa, final_scale))
 }
 
-#[inline]
-fn invalid_value(value: &str, from_type: &str, to_type: &str) -> SparkError {
-    SparkError::CastInvalidValue {
-        value: value.to_string(),
-        from_type: from_type.to_string(),
-        to_type: to_type.to_string(),
-    }
-}
-
-#[inline]
-fn cast_overflow(value: &str, from_type: &str, to_type: &str) -> SparkError {
-    SparkError::CastOverFlow {
-        value: value.to_string(),
-        from_type: from_type.to_string(),
-        to_type: to_type.to_string(),
-    }
-}
-
 impl Display for Cast {
     fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
         write!(
@@ -2852,84 +2814,6 @@ fn date_parser(date_str: &str, eval_mode: EvalMode) -> 
SparkResult<Option<i32>>
     }
 }
 
-/// This takes for special casting cases of Spark. E.g., Timestamp to Long.
-/// This function runs as a post process of the DataFusion cast(). By the time 
it arrives here,
-/// Dictionary arrays are already unpacked by the DataFusion cast() since 
Spark cannot specify
-/// Dictionary as to_type. The from_type is taken before the DataFusion cast() 
runs in
-/// expressions/cast.rs, so it can be still Dictionary.
-fn spark_cast_postprocess(array: ArrayRef, from_type: &DataType, to_type: 
&DataType) -> ArrayRef {
-    match (from_type, to_type) {
-        (DataType::Timestamp(_, _), DataType::Int64) => {
-            // See Spark's `Cast` expression
-            unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, 
MICROS_PER_SECOND)).unwrap()
-        }
-        (DataType::Dictionary(_, value_type), DataType::Int64)
-            if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) =>
-        {
-            // See Spark's `Cast` expression
-            unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, 
MICROS_PER_SECOND)).unwrap()
-        }
-        (DataType::Timestamp(_, _), DataType::Utf8) => 
remove_trailing_zeroes(array),
-        (DataType::Dictionary(_, value_type), DataType::Utf8)
-            if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) =>
-        {
-            remove_trailing_zeroes(array)
-        }
-        _ => array,
-    }
-}
-
-/// A fork & modified version of Arrow's `unary_dyn` which is being deprecated
-fn unary_dyn<F, T>(array: &ArrayRef, op: F) -> Result<ArrayRef, ArrowError>
-where
-    T: ArrowPrimitiveType,
-    F: Fn(T::Native) -> T::Native,
-{
-    if let Some(d) = array.as_any_dictionary_opt() {
-        let new_values = unary_dyn::<F, T>(d.values(), op)?;
-        return Ok(Arc::new(d.with_values(Arc::new(new_values))));
-    }
-
-    match array.as_primitive_opt::<T>() {
-        Some(a) if PrimitiveArray::<T>::is_compatible(a.data_type()) => {
-            Ok(Arc::new(unary::<T, F, T>(
-                array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
-                op,
-            )))
-        }
-        _ => Err(ArrowError::NotYetImplemented(format!(
-            "Cannot perform unary operation of type {} on array of type {}",
-            T::DATA_TYPE,
-            array.data_type()
-        ))),
-    }
-}
-
-/// Remove any trailing zeroes in the string if they occur after in the 
fractional seconds,
-/// to match Spark behavior
-/// example:
-/// "1970-01-01 05:29:59.900" => "1970-01-01 05:29:59.9"
-/// "1970-01-01 05:29:59.990" => "1970-01-01 05:29:59.99"
-/// "1970-01-01 05:29:59.999" => "1970-01-01 05:29:59.999"
-/// "1970-01-01 05:30:00"     => "1970-01-01 05:30:00"
-/// "1970-01-01 05:30:00.001" => "1970-01-01 05:30:00.001"
-fn remove_trailing_zeroes(array: ArrayRef) -> ArrayRef {
-    let string_array = as_generic_string_array::<i32>(&array).unwrap();
-    let result = string_array
-        .iter()
-        .map(|s| s.map(trim_end))
-        .collect::<GenericStringArray<i32>>();
-    Arc::new(result) as ArrayRef
-}
-
-fn trim_end(s: &str) -> &str {
-    if s.rfind('.').is_some() {
-        s.trim_end_matches('0')
-    } else {
-        s
-    }
-}
-
 #[cfg(test)]
 mod tests {
     use arrow::array::StringArray;
diff --git a/native/spark-expr/src/conversion_funcs/mod.rs 
b/native/spark-expr/src/conversion_funcs/mod.rs
index f2c6f7ca3..190c11520 100644
--- a/native/spark-expr/src/conversion_funcs/mod.rs
+++ b/native/spark-expr/src/conversion_funcs/mod.rs
@@ -15,4 +15,6 @@
 // specific language governing permissions and limitations
 // under the License.
 
+mod boolean;
 pub mod cast;
+mod utils;
diff --git a/native/spark-expr/src/conversion_funcs/utils.rs 
b/native/spark-expr/src/conversion_funcs/utils.rs
new file mode 100644
index 000000000..8b8d974ff
--- /dev/null
+++ b/native/spark-expr/src/conversion_funcs/utils.rs
@@ -0,0 +1,128 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::cast::MICROS_PER_SECOND;
+use crate::SparkError;
+use arrow::array::{
+    Array, ArrayRef, ArrowPrimitiveType, AsArray, GenericStringArray, 
PrimitiveArray,
+};
+use arrow::compute::unary;
+use arrow::datatypes::{DataType, Int64Type};
+use arrow::error::ArrowError;
+use datafusion::common::cast::as_generic_string_array;
+use num::integer::div_floor;
+use std::sync::Arc;
+
+/// A fork & modified version of Arrow's `unary_dyn` which is being deprecated
+pub fn unary_dyn<F, T>(array: &ArrayRef, op: F) -> Result<ArrayRef, ArrowError>
+where
+    T: ArrowPrimitiveType,
+    F: Fn(T::Native) -> T::Native,
+{
+    if let Some(d) = array.as_any_dictionary_opt() {
+        let new_values = unary_dyn::<F, T>(d.values(), op)?;
+        return Ok(Arc::new(d.with_values(Arc::new(new_values))));
+    }
+
+    match array.as_primitive_opt::<T>() {
+        Some(a) if PrimitiveArray::<T>::is_compatible(a.data_type()) => {
+            Ok(Arc::new(unary::<T, F, T>(
+                array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
+                op,
+            )))
+        }
+        _ => Err(ArrowError::NotYetImplemented(format!(
+            "Cannot perform unary operation of type {} on array of type {}",
+            T::DATA_TYPE,
+            array.data_type()
+        ))),
+    }
+}
+
+/// This takes for special casting cases of Spark. E.g., Timestamp to Long.
+/// This function runs as a post process of the DataFusion cast(). By the time 
it arrives here,
+/// Dictionary arrays are already unpacked by the DataFusion cast() since 
Spark cannot specify
+/// Dictionary as to_type. The from_type is taken before the DataFusion cast() 
runs in
+/// expressions/cast.rs, so it can be still Dictionary.
+pub fn spark_cast_postprocess(
+    array: ArrayRef,
+    from_type: &DataType,
+    to_type: &DataType,
+) -> ArrayRef {
+    match (from_type, to_type) {
+        (DataType::Timestamp(_, _), DataType::Int64) => {
+            // See Spark's `Cast` expression
+            unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, 
MICROS_PER_SECOND)).unwrap()
+        }
+        (DataType::Dictionary(_, value_type), DataType::Int64)
+            if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) =>
+        {
+            // See Spark's `Cast` expression
+            unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, 
MICROS_PER_SECOND)).unwrap()
+        }
+        (DataType::Timestamp(_, _), DataType::Utf8) => 
remove_trailing_zeroes(array),
+        (DataType::Dictionary(_, value_type), DataType::Utf8)
+            if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) =>
+        {
+            remove_trailing_zeroes(array)
+        }
+        _ => array,
+    }
+}
+
+/// Remove any trailing zeroes in the string if they occur after in the 
fractional seconds,
+/// to match Spark behavior
+/// example:
+/// "1970-01-01 05:29:59.900" => "1970-01-01 05:29:59.9"
+/// "1970-01-01 05:29:59.990" => "1970-01-01 05:29:59.99"
+/// "1970-01-01 05:29:59.999" => "1970-01-01 05:29:59.999"
+/// "1970-01-01 05:30:00"     => "1970-01-01 05:30:00"
+/// "1970-01-01 05:30:00.001" => "1970-01-01 05:30:00.001"
+fn remove_trailing_zeroes(array: ArrayRef) -> ArrayRef {
+    let string_array = as_generic_string_array::<i32>(&array).unwrap();
+    let result = string_array
+        .iter()
+        .map(|s| s.map(trim_end))
+        .collect::<GenericStringArray<i32>>();
+    Arc::new(result) as ArrayRef
+}
+
+fn trim_end(s: &str) -> &str {
+    if s.rfind('.').is_some() {
+        s.trim_end_matches('0')
+    } else {
+        s
+    }
+}
+
+#[inline]
+pub fn cast_overflow(value: &str, from_type: &str, to_type: &str) -> 
SparkError {
+    SparkError::CastOverFlow {
+        value: value.to_string(),
+        from_type: from_type.to_string(),
+        to_type: to_type.to_string(),
+    }
+}
+
+#[inline]
+pub fn invalid_value(value: &str, from_type: &str, to_type: &str) -> 
SparkError {
+    SparkError::CastInvalidValue {
+        value: value.to_string(),
+        from_type: from_type.to_string(),
+        to_type: to_type.to_string(),
+    }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to