This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 086513cb fix: Optimize some functions to rewrite dictionary-encoded 
strings (#627)
086513cb is described below

commit 086513cb5f4bd136a904b2cf0e086951e3025106
Author: Vipul Vaibhaw <[email protected]>
AuthorDate: Tue Jul 16 01:03:10 2024 +0530

    fix: Optimize some functions to rewrite dictionary-encoded strings (#627)
    
    * dedup code
    
    * transforming the dict directly
    
    * code optimization for cast string to timestamp
    
    * minor optimizations
    
    * fmt fixes and casting to dict array without unpacking to array first
    
    * bug fixes
    
    * revert unrelated change
    
    * Added test case and code refactor
    
    * minor optimization
    
    * minor optimization again
    
    * convert the cast to array
    
    * Revert "convert the cast to array"
    
    This reverts commit 9270aedeafa12dacabc664ca9df7c85236e05d85.
    
    * bug fixes
    
    * rename the test to cast_dict_to_timestamp arr
---
 .../datafusion/expressions/scalar_funcs/hex.rs     | 79 ++++++-----------
 native/spark-expr/src/cast.rs                      | 98 ++++++++++++++++------
 2 files changed, 96 insertions(+), 81 deletions(-)

diff --git 
a/native/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs 
b/native/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs
index 5191e53f..e6059818 100644
--- a/native/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs
+++ b/native/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs
@@ -118,65 +118,36 @@ pub(super) fn spark_hex(args: &[ColumnarValue]) -> 
Result<ColumnarValue, DataFus
 
                 Ok(ColumnarValue::Array(Arc::new(hexed)))
             }
-            DataType::Dictionary(_, value_type) if matches!(**value_type, 
DataType::Int64) => {
+            DataType::Dictionary(_, value_type) => {
                 let dict = as_dictionary_array::<Int32Type>(&array);
 
-                let hexed_values = as_int64_array(dict.values())?;
-                let values = hexed_values
+                let values = match **value_type {
+                    DataType::Int64 => as_int64_array(dict.values())?
+                        .iter()
+                        .map(|v| v.map(hex_int64))
+                        .collect::<Vec<_>>(),
+                    DataType::Utf8 => as_string_array(dict.values())
+                        .iter()
+                        .map(|v| v.map(hex_bytes).transpose())
+                        .collect::<Result<_, _>>()?,
+                    DataType::Binary => as_binary_array(dict.values())?
+                        .iter()
+                        .map(|v| v.map(hex_bytes).transpose())
+                        .collect::<Result<_, _>>()?,
+                    _ => exec_err!(
+                        "hex got an unexpected argument type: {:?}",
+                        array.data_type()
+                    )?,
+                };
+
+                let new_values: Vec<Option<String>> = dict
+                    .keys()
                     .iter()
-                    .map(|v| v.map(hex_int64))
-                    .collect::<Vec<_>>();
+                    .map(|key| key.map(|k| values[k as 
usize].clone()).unwrap_or(None))
+                    .collect();
 
-                let keys = dict.keys().clone();
-                let mut new_keys = Vec::with_capacity(values.len());
+                let string_array_values = StringArray::from(new_values);
 
-                for key in keys.iter() {
-                    let key = key.map(|k| values[k as 
usize].clone()).unwrap_or(None);
-                    new_keys.push(key);
-                }
-
-                let string_array_values = StringArray::from(new_keys);
-                Ok(ColumnarValue::Array(Arc::new(string_array_values)))
-            }
-            DataType::Dictionary(_, value_type) if matches!(**value_type, 
DataType::Utf8) => {
-                let dict = as_dictionary_array::<Int32Type>(&array);
-
-                let hexed_values = as_string_array(dict.values());
-                let values: Vec<Option<String>> = hexed_values
-                    .iter()
-                    .map(|v| v.map(hex_bytes).transpose())
-                    .collect::<Result<_, _>>()?;
-
-                let keys = dict.keys().clone();
-
-                let mut new_keys = Vec::with_capacity(values.len());
-
-                for key in keys.iter() {
-                    let key = key.map(|k| values[k as 
usize].clone()).unwrap_or(None);
-                    new_keys.push(key);
-                }
-
-                let string_array_values = StringArray::from(new_keys);
-                Ok(ColumnarValue::Array(Arc::new(string_array_values)))
-            }
-            DataType::Dictionary(_, value_type) if matches!(**value_type, 
DataType::Binary) => {
-                let dict = as_dictionary_array::<Int32Type>(&array);
-
-                let hexed_values = as_binary_array(dict.values())?;
-                let values: Vec<Option<String>> = hexed_values
-                    .iter()
-                    .map(|v| v.map(hex_bytes).transpose())
-                    .collect::<Result<_, _>>()?;
-
-                let keys = dict.keys().clone();
-                let mut new_keys = Vec::with_capacity(values.len());
-
-                for key in keys.iter() {
-                    let key = key.map(|k| values[k as 
usize].clone()).unwrap_or(None);
-                    new_keys.push(key);
-                }
-
-                let string_array_values = StringArray::from(new_keys);
                 Ok(ColumnarValue::Array(Arc::new(string_array_values)))
             }
             _ => exec_err!(
diff --git a/native/spark-expr/src/cast.rs b/native/spark-expr/src/cast.rs
index 7f53583e..8702ce70 100644
--- a/native/spark-expr/src/cast.rs
+++ b/native/spark-expr/src/cast.rs
@@ -31,7 +31,7 @@ use arrow::{
         GenericStringArray, Int16Array, Int32Array, Int64Array, Int8Array, 
OffsetSizeTrait,
         PrimitiveArray,
     },
-    compute::{cast_with_options, unary, CastOptions},
+    compute::{cast_with_options, take, unary, CastOptions},
     datatypes::{
         ArrowPrimitiveType, Decimal128Type, DecimalType, Float32Type, 
Float64Type, Int64Type,
         TimestampMicrosecondType,
@@ -40,6 +40,7 @@ use arrow::{
     record_batch::RecordBatch,
     util::display::FormatOptions,
 };
+use arrow_array::DictionaryArray;
 use arrow_schema::{DataType, Schema};
 
 use datafusion_common::{
@@ -98,7 +99,6 @@ macro_rules! cast_utf8_to_int {
         result
     }};
 }
-
 macro_rules! cast_utf8_to_timestamp {
     ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{
         let len = $array.len();
@@ -507,19 +507,27 @@ impl Cast {
         let to_type = &self.data_type;
         let array = array_with_timezone(array, self.timezone.clone(), 
Some(to_type))?;
         let from_type = array.data_type().clone();
-
-        // unpack dictionary string arrays first
-        // TODO: we are unpacking a dictionary-encoded array and then 
performing
-        // the cast. We could potentially improve performance here by casting 
the
-        // dictionary values directly without unpacking the array first, 
although this
-        // would add more complexity to the code
         let array = match &from_type {
             DataType::Dictionary(key_type, value_type)
                 if key_type.as_ref() == &DataType::Int32
                     && (value_type.as_ref() == &DataType::Utf8
                         || value_type.as_ref() == &DataType::LargeUtf8) =>
             {
-                cast_with_options(&array, value_type.as_ref(), &CAST_OPTIONS)?
+                let dict_array = array
+                    .as_any()
+                    .downcast_ref::<DictionaryArray<Int32Type>>()
+                    .expect("Expected a dictionary array");
+
+                let casted_dictionary = DictionaryArray::<Int32Type>::new(
+                    dict_array.keys().clone(),
+                    self.cast_array(dict_array.values().clone())?,
+                );
+
+                let casted_result = match to_type {
+                    DataType::Dictionary(_, _) => 
Arc::new(casted_dictionary.clone()),
+                    _ => take(casted_dictionary.values().as_ref(), 
dict_array.keys(), None)?,
+                };
+                return Ok(spark_cast(casted_result, &from_type, to_type));
             }
             _ => array,
         };
@@ -724,26 +732,31 @@ impl Cast {
             .downcast_ref::<GenericStringArray<i32>>()
             .expect("Expected a string array");
 
-        let cast_array: ArrayRef = match to_type {
-            DataType::Date32 => {
-                let len = string_array.len();
-                let mut cast_array = 
PrimitiveArray::<Date32Type>::builder(len);
-                for i in 0..len {
-                    if !string_array.is_null(i) {
-                        match date_parser(string_array.value(i), eval_mode) {
-                            Ok(Some(cast_value)) => 
cast_array.append_value(cast_value),
-                            Ok(None) => cast_array.append_null(),
-                            Err(e) => return Err(e),
-                        }
-                    } else {
-                        cast_array.append_null()
-                    }
+        if to_type != &DataType::Date32 {
+            unreachable!("Invalid data type {:?} in cast from string", 
to_type);
+        }
+
+        let len = string_array.len();
+        let mut cast_array = PrimitiveArray::<Date32Type>::builder(len);
+
+        for i in 0..len {
+            let value = if string_array.is_null(i) {
+                None
+            } else {
+                match date_parser(string_array.value(i), eval_mode) {
+                    Ok(Some(cast_value)) => Some(cast_value),
+                    Ok(None) => None,
+                    Err(e) => return Err(e),
                 }
-                Arc::new(cast_array.finish()) as ArrayRef
+            };
+
+            match value {
+                Some(cast_value) => cast_array.append_value(cast_value),
+                None => cast_array.append_null(),
             }
-            _ => unreachable!("Invalid data type {:?} in cast from string", 
to_type),
-        };
-        Ok(cast_array)
+        }
+
+        Ok(Arc::new(cast_array.finish()) as ArrayRef)
     }
 
     fn cast_string_to_timestamp(
@@ -1796,6 +1809,37 @@ mod tests {
         assert_eq!(result.len(), 2);
     }
 
+    #[test]
+    fn test_cast_dict_string_to_timestamp() -> DataFusionResult<()> {
+        // prepare input data
+        let keys = Int32Array::from(vec![0, 1]);
+        let values: ArrayRef = Arc::new(StringArray::from(vec![
+            Some("2020-01-01T12:34:56.123456"),
+            Some("T2"),
+        ]));
+        let dict_array = Arc::new(DictionaryArray::new(keys, values));
+
+        // prepare cast expression
+        let timezone = "UTC".to_string();
+        let expr = Arc::new(Column::new("a", 0)); // this is not used by the 
test
+        let cast = Cast::new(
+            expr,
+            DataType::Timestamp(TimeUnit::Microsecond, 
Some(timezone.clone().into())),
+            EvalMode::Legacy,
+            timezone.clone(),
+        );
+
+        // test casting string dictionary array to timestamp array
+        let result = cast.cast_array(dict_array)?;
+        assert_eq!(
+            *result.data_type(),
+            DataType::Timestamp(TimeUnit::Microsecond, Some(timezone.into()))
+        );
+        assert_eq!(result.len(), 2);
+
+        Ok(())
+    }
+
     #[test]
     fn date_parser_test() {
         for date in &[


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to