This is an automated email from the ASF dual-hosted git repository.

comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 5be8dbe0e5 Minor: reduce code duplication in `date_bin_impl` (#8528)
5be8dbe0e5 is described below

commit 5be8dbe0e5f45984b5e6480d8766373f3bbff93d
Author: Alex Huang <[email protected]>
AuthorDate: Thu Dec 14 21:27:36 2023 +0100

    Minor: reduce code duplication in `date_bin_impl` (#8528)
    
    * reduce code duplication in date_bin_impl
---
 .../physical-expr/src/datetime_expressions.rs      | 143 +++++++++++----------
 1 file changed, 78 insertions(+), 65 deletions(-)

diff --git a/datafusion/physical-expr/src/datetime_expressions.rs 
b/datafusion/physical-expr/src/datetime_expressions.rs
index bbeb2b0dce..f6373d40d9 100644
--- a/datafusion/physical-expr/src/datetime_expressions.rs
+++ b/datafusion/physical-expr/src/datetime_expressions.rs
@@ -21,12 +21,6 @@ use crate::datetime_expressions;
 use crate::expressions::cast_column;
 use arrow::array::Float64Builder;
 use arrow::compute::cast;
-use arrow::{
-    array::TimestampNanosecondArray,
-    compute::kernels::temporal,
-    datatypes::TimeUnit,
-    temporal_conversions::{as_datetime_with_timezone, 
timestamp_ns_to_datetime},
-};
 use arrow::{
     array::{Array, ArrayRef, Float64Array, OffsetSizeTrait, PrimitiveArray},
     compute::kernels::cast_utils::string_to_timestamp_nanos,
@@ -36,11 +30,14 @@ use arrow::{
         TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
     },
 };
-use arrow_array::types::ArrowTimestampType;
-use arrow_array::{
-    timezone::Tz, TimestampMicrosecondArray, TimestampMillisecondArray,
-    TimestampSecondArray,
+use arrow::{
+    compute::kernels::temporal,
+    datatypes::TimeUnit,
+    temporal_conversions::{as_datetime_with_timezone, 
timestamp_ns_to_datetime},
 };
+use arrow_array::temporal_conversions::NANOSECONDS;
+use arrow_array::timezone::Tz;
+use arrow_array::types::ArrowTimestampType;
 use chrono::prelude::*;
 use chrono::{Duration, Months, NaiveDate};
 use datafusion_common::cast::{
@@ -647,89 +644,104 @@ fn date_bin_impl(
         return exec_err!("DATE_BIN stride must be non-zero");
     }
 
-    let f_nanos = |x: Option<i64>| x.map(|x| stride_fn(stride, x, origin));
-    let f_micros = |x: Option<i64>| {
-        let scale = 1_000;
-        x.map(|x| stride_fn(stride, x * scale, origin) / scale)
-    };
-    let f_millis = |x: Option<i64>| {
-        let scale = 1_000_000;
-        x.map(|x| stride_fn(stride, x * scale, origin) / scale)
-    };
-    let f_secs = |x: Option<i64>| {
-        let scale = 1_000_000_000;
-        x.map(|x| stride_fn(stride, x * scale, origin) / scale)
-    };
+    fn stride_map_fn<T: ArrowTimestampType>(
+        origin: i64,
+        stride: i64,
+        stride_fn: fn(i64, i64, i64) -> i64,
+    ) -> impl Fn(Option<i64>) -> Option<i64> {
+        let scale = match T::UNIT {
+            TimeUnit::Nanosecond => 1,
+            TimeUnit::Microsecond => NANOSECONDS / 1_000_000,
+            TimeUnit::Millisecond => NANOSECONDS / 1_000,
+            TimeUnit::Second => NANOSECONDS,
+        };
+        move |x: Option<i64>| x.map(|x| stride_fn(stride, x * scale, origin) / 
scale)
+    }
 
     Ok(match array {
         ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(v, tz_opt)) => {
+            let apply_stride_fn =
+                stride_map_fn::<TimestampNanosecondType>(origin, stride, 
stride_fn);
             ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(
-                f_nanos(*v),
+                apply_stride_fn(*v),
                 tz_opt.clone(),
             ))
         }
         ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(v, tz_opt)) => 
{
+            let apply_stride_fn =
+                stride_map_fn::<TimestampMicrosecondType>(origin, stride, 
stride_fn);
             ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(
-                f_micros(*v),
+                apply_stride_fn(*v),
                 tz_opt.clone(),
             ))
         }
         ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(v, tz_opt)) => 
{
+            let apply_stride_fn =
+                stride_map_fn::<TimestampMillisecondType>(origin, stride, 
stride_fn);
             ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(
-                f_millis(*v),
+                apply_stride_fn(*v),
                 tz_opt.clone(),
             ))
         }
         ColumnarValue::Scalar(ScalarValue::TimestampSecond(v, tz_opt)) => {
+            let apply_stride_fn =
+                stride_map_fn::<TimestampSecondType>(origin, stride, 
stride_fn);
             ColumnarValue::Scalar(ScalarValue::TimestampSecond(
-                f_secs(*v),
+                apply_stride_fn(*v),
                 tz_opt.clone(),
             ))
         }
-        ColumnarValue::Array(array) => match array.data_type() {
-            DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => {
-                let array = as_timestamp_nanosecond_array(array)?
-                    .iter()
-                    .map(f_nanos)
-                    .collect::<TimestampNanosecondArray>()
-                    .with_timezone_opt(tz_opt.clone());
-
-                ColumnarValue::Array(Arc::new(array))
-            }
-            DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => {
-                let array = as_timestamp_microsecond_array(array)?
-                    .iter()
-                    .map(f_micros)
-                    .collect::<TimestampMicrosecondArray>()
-                    .with_timezone_opt(tz_opt.clone());
-
-                ColumnarValue::Array(Arc::new(array))
-            }
-            DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => {
-                let array = as_timestamp_millisecond_array(array)?
-                    .iter()
-                    .map(f_millis)
-                    .collect::<TimestampMillisecondArray>()
-                    .with_timezone_opt(tz_opt.clone());
 
-                ColumnarValue::Array(Arc::new(array))
-            }
-            DataType::Timestamp(TimeUnit::Second, tz_opt) => {
-                let array = as_timestamp_second_array(array)?
+        ColumnarValue::Array(array) => {
+            fn transform_array_with_stride<T>(
+                origin: i64,
+                stride: i64,
+                stride_fn: fn(i64, i64, i64) -> i64,
+                array: &ArrayRef,
+                tz_opt: &Option<Arc<str>>,
+            ) -> Result<ColumnarValue>
+            where
+                T: ArrowTimestampType,
+            {
+                let array = as_primitive_array::<T>(array)?;
+                let apply_stride_fn = stride_map_fn::<T>(origin, stride, 
stride_fn);
+                let array = array
                     .iter()
-                    .map(f_secs)
-                    .collect::<TimestampSecondArray>()
+                    .map(apply_stride_fn)
+                    .collect::<PrimitiveArray<T>>()
                     .with_timezone_opt(tz_opt.clone());
 
-                ColumnarValue::Array(Arc::new(array))
+                Ok(ColumnarValue::Array(Arc::new(array)))
             }
-            _ => {
-                return exec_err!(
-                    "DATE_BIN expects source argument to be a TIMESTAMP but 
got {}",
-                    array.data_type()
-                )
+            match array.data_type() {
+                DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => {
+                    transform_array_with_stride::<TimestampNanosecondType>(
+                        origin, stride, stride_fn, array, tz_opt,
+                    )?
+                }
+                DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => {
+                    transform_array_with_stride::<TimestampMicrosecondType>(
+                        origin, stride, stride_fn, array, tz_opt,
+                    )?
+                }
+                DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => {
+                    transform_array_with_stride::<TimestampMillisecondType>(
+                        origin, stride, stride_fn, array, tz_opt,
+                    )?
+                }
+                DataType::Timestamp(TimeUnit::Second, tz_opt) => {
+                    transform_array_with_stride::<TimestampSecondType>(
+                        origin, stride, stride_fn, array, tz_opt,
+                    )?
+                }
+                _ => {
+                    return exec_err!(
+                        "DATE_BIN expects source argument to be a TIMESTAMP 
but got {}",
+                        array.data_type()
+                    )
+                }
             }
-        },
+        }
         _ => {
             return exec_err!(
                 "DATE_BIN expects source argument to be a TIMESTAMP scalar or 
array"
@@ -1061,6 +1073,7 @@ mod tests {
     use arrow::array::{
         as_primitive_array, ArrayRef, Int64Array, IntervalDayTimeArray, 
StringBuilder,
     };
+    use arrow_array::TimestampNanosecondArray;
 
     use super::*;
 

Reply via email to