This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 086513cb fix: Optimize some functions to rewrite dictionary-encoded
strings (#627)
086513cb is described below
commit 086513cb5f4bd136a904b2cf0e086951e3025106
Author: Vipul Vaibhaw <[email protected]>
AuthorDate: Tue Jul 16 01:03:10 2024 +0530
fix: Optimize some functions to rewrite dictionary-encoded strings (#627)
* dedup code
* transforming the dict directly
* code optimization for cast string to timestamp
* minor optimizations
* fmt fixes and casting to dict array without unpacking to array first
* bug fixes
* revert unrelated change
* Added test case and code refactor
* minor optimization
* minor optimization again
* convert the cast to array
* Revert "convert the cast to array"
This reverts commit 9270aedeafa12dacabc664ca9df7c85236e05d85.
* bug fixes
* rename the test to cast_dict_to_timestamp arr
---
.../datafusion/expressions/scalar_funcs/hex.rs | 79 ++++++-----------
native/spark-expr/src/cast.rs | 98 ++++++++++++++++------
2 files changed, 96 insertions(+), 81 deletions(-)
diff --git
a/native/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs
b/native/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs
index 5191e53f..e6059818 100644
--- a/native/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs
+++ b/native/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs
@@ -118,65 +118,36 @@ pub(super) fn spark_hex(args: &[ColumnarValue]) ->
Result<ColumnarValue, DataFus
Ok(ColumnarValue::Array(Arc::new(hexed)))
}
- DataType::Dictionary(_, value_type) if matches!(**value_type,
DataType::Int64) => {
+ DataType::Dictionary(_, value_type) => {
let dict = as_dictionary_array::<Int32Type>(&array);
- let hexed_values = as_int64_array(dict.values())?;
- let values = hexed_values
+ let values = match **value_type {
+ DataType::Int64 => as_int64_array(dict.values())?
+ .iter()
+ .map(|v| v.map(hex_int64))
+ .collect::<Vec<_>>(),
+ DataType::Utf8 => as_string_array(dict.values())
+ .iter()
+ .map(|v| v.map(hex_bytes).transpose())
+ .collect::<Result<_, _>>()?,
+ DataType::Binary => as_binary_array(dict.values())?
+ .iter()
+ .map(|v| v.map(hex_bytes).transpose())
+ .collect::<Result<_, _>>()?,
+ _ => exec_err!(
+ "hex got an unexpected argument type: {:?}",
+ array.data_type()
+ )?,
+ };
+
+ let new_values: Vec<Option<String>> = dict
+ .keys()
.iter()
- .map(|v| v.map(hex_int64))
- .collect::<Vec<_>>();
+ .map(|key| key.map(|k| values[k as
usize].clone()).unwrap_or(None))
+ .collect();
- let keys = dict.keys().clone();
- let mut new_keys = Vec::with_capacity(values.len());
+ let string_array_values = StringArray::from(new_values);
- for key in keys.iter() {
- let key = key.map(|k| values[k as
usize].clone()).unwrap_or(None);
- new_keys.push(key);
- }
-
- let string_array_values = StringArray::from(new_keys);
- Ok(ColumnarValue::Array(Arc::new(string_array_values)))
- }
- DataType::Dictionary(_, value_type) if matches!(**value_type,
DataType::Utf8) => {
- let dict = as_dictionary_array::<Int32Type>(&array);
-
- let hexed_values = as_string_array(dict.values());
- let values: Vec<Option<String>> = hexed_values
- .iter()
- .map(|v| v.map(hex_bytes).transpose())
- .collect::<Result<_, _>>()?;
-
- let keys = dict.keys().clone();
-
- let mut new_keys = Vec::with_capacity(values.len());
-
- for key in keys.iter() {
- let key = key.map(|k| values[k as
usize].clone()).unwrap_or(None);
- new_keys.push(key);
- }
-
- let string_array_values = StringArray::from(new_keys);
- Ok(ColumnarValue::Array(Arc::new(string_array_values)))
- }
- DataType::Dictionary(_, value_type) if matches!(**value_type,
DataType::Binary) => {
- let dict = as_dictionary_array::<Int32Type>(&array);
-
- let hexed_values = as_binary_array(dict.values())?;
- let values: Vec<Option<String>> = hexed_values
- .iter()
- .map(|v| v.map(hex_bytes).transpose())
- .collect::<Result<_, _>>()?;
-
- let keys = dict.keys().clone();
- let mut new_keys = Vec::with_capacity(values.len());
-
- for key in keys.iter() {
- let key = key.map(|k| values[k as
usize].clone()).unwrap_or(None);
- new_keys.push(key);
- }
-
- let string_array_values = StringArray::from(new_keys);
Ok(ColumnarValue::Array(Arc::new(string_array_values)))
}
_ => exec_err!(
diff --git a/native/spark-expr/src/cast.rs b/native/spark-expr/src/cast.rs
index 7f53583e..8702ce70 100644
--- a/native/spark-expr/src/cast.rs
+++ b/native/spark-expr/src/cast.rs
@@ -31,7 +31,7 @@ use arrow::{
GenericStringArray, Int16Array, Int32Array, Int64Array, Int8Array,
OffsetSizeTrait,
PrimitiveArray,
},
- compute::{cast_with_options, unary, CastOptions},
+ compute::{cast_with_options, take, unary, CastOptions},
datatypes::{
ArrowPrimitiveType, Decimal128Type, DecimalType, Float32Type,
Float64Type, Int64Type,
TimestampMicrosecondType,
@@ -40,6 +40,7 @@ use arrow::{
record_batch::RecordBatch,
util::display::FormatOptions,
};
+use arrow_array::DictionaryArray;
use arrow_schema::{DataType, Schema};
use datafusion_common::{
@@ -98,7 +99,6 @@ macro_rules! cast_utf8_to_int {
result
}};
}
-
macro_rules! cast_utf8_to_timestamp {
($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{
let len = $array.len();
@@ -507,19 +507,27 @@ impl Cast {
let to_type = &self.data_type;
let array = array_with_timezone(array, self.timezone.clone(),
Some(to_type))?;
let from_type = array.data_type().clone();
-
- // unpack dictionary string arrays first
- // TODO: we are unpacking a dictionary-encoded array and then
performing
- // the cast. We could potentially improve performance here by casting
the
- // dictionary values directly without unpacking the array first,
although this
- // would add more complexity to the code
let array = match &from_type {
DataType::Dictionary(key_type, value_type)
if key_type.as_ref() == &DataType::Int32
&& (value_type.as_ref() == &DataType::Utf8
|| value_type.as_ref() == &DataType::LargeUtf8) =>
{
- cast_with_options(&array, value_type.as_ref(), &CAST_OPTIONS)?
+ let dict_array = array
+ .as_any()
+ .downcast_ref::<DictionaryArray<Int32Type>>()
+ .expect("Expected a dictionary array");
+
+ let casted_dictionary = DictionaryArray::<Int32Type>::new(
+ dict_array.keys().clone(),
+ self.cast_array(dict_array.values().clone())?,
+ );
+
+ let casted_result = match to_type {
+ DataType::Dictionary(_, _) =>
Arc::new(casted_dictionary.clone()),
+ _ => take(casted_dictionary.values().as_ref(),
dict_array.keys(), None)?,
+ };
+ return Ok(spark_cast(casted_result, &from_type, to_type));
}
_ => array,
};
@@ -724,26 +732,31 @@ impl Cast {
.downcast_ref::<GenericStringArray<i32>>()
.expect("Expected a string array");
- let cast_array: ArrayRef = match to_type {
- DataType::Date32 => {
- let len = string_array.len();
- let mut cast_array =
PrimitiveArray::<Date32Type>::builder(len);
- for i in 0..len {
- if !string_array.is_null(i) {
- match date_parser(string_array.value(i), eval_mode) {
- Ok(Some(cast_value)) =>
cast_array.append_value(cast_value),
- Ok(None) => cast_array.append_null(),
- Err(e) => return Err(e),
- }
- } else {
- cast_array.append_null()
- }
+ if to_type != &DataType::Date32 {
+ unreachable!("Invalid data type {:?} in cast from string",
to_type);
+ }
+
+ let len = string_array.len();
+ let mut cast_array = PrimitiveArray::<Date32Type>::builder(len);
+
+ for i in 0..len {
+ let value = if string_array.is_null(i) {
+ None
+ } else {
+ match date_parser(string_array.value(i), eval_mode) {
+ Ok(Some(cast_value)) => Some(cast_value),
+ Ok(None) => None,
+ Err(e) => return Err(e),
}
- Arc::new(cast_array.finish()) as ArrayRef
+ };
+
+ match value {
+ Some(cast_value) => cast_array.append_value(cast_value),
+ None => cast_array.append_null(),
}
- _ => unreachable!("Invalid data type {:?} in cast from string",
to_type),
- };
- Ok(cast_array)
+ }
+
+ Ok(Arc::new(cast_array.finish()) as ArrayRef)
}
fn cast_string_to_timestamp(
@@ -1796,6 +1809,37 @@ mod tests {
assert_eq!(result.len(), 2);
}
+ #[test]
+ fn test_cast_dict_string_to_timestamp() -> DataFusionResult<()> {
+ // prepare input data
+ let keys = Int32Array::from(vec![0, 1]);
+ let values: ArrayRef = Arc::new(StringArray::from(vec![
+ Some("2020-01-01T12:34:56.123456"),
+ Some("T2"),
+ ]));
+ let dict_array = Arc::new(DictionaryArray::new(keys, values));
+
+ // prepare cast expression
+ let timezone = "UTC".to_string();
+ let expr = Arc::new(Column::new("a", 0)); // this is not used by the
test
+ let cast = Cast::new(
+ expr,
+ DataType::Timestamp(TimeUnit::Microsecond,
Some(timezone.clone().into())),
+ EvalMode::Legacy,
+ timezone.clone(),
+ );
+
+ // test casting string dictionary array to timestamp array
+ let result = cast.cast_array(dict_array)?;
+ assert_eq!(
+ *result.data_type(),
+ DataType::Timestamp(TimeUnit::Microsecond, Some(timezone.into()))
+ );
+ assert_eq!(result.len(), 2);
+
+ Ok(())
+ }
+
#[test]
fn date_parser_test() {
for date in &[
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]