Jefffrey commented on code in PR #18707:
URL: https://github.com/apache/datafusion/pull/18707#discussion_r2540636866


##########
datafusion/functions/src/datetime/to_local_time.rs:
##########
@@ -111,133 +112,152 @@ impl Default for ToLocalTimeFunc {
 impl ToLocalTimeFunc {
     pub fn new() -> Self {
         Self {
-            signature: Signature::user_defined(Volatility::Immutable),
+            signature: Signature::coercible(
+                vec![Coercion::new_exact(TypeSignatureClass::Timestamp)],
+                Volatility::Immutable,
+            ),
         }
     }
+}
 
-    fn to_local_time(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
-        let [time_value] = take_function_args(self.name(), args)?;
+impl ScalarUDFImpl for ToLocalTimeFunc {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
 
-        let arg_type = time_value.data_type();
-        match arg_type {
-            Timestamp(_, None) => {
-                // if no timezone specified, just return the input
-                Ok(time_value.clone())
-            }
-            // If has timezone, adjust the underlying time value. The current 
time value
-            // is stored as i64 in UTC, even though the timezone may not be in 
UTC. Therefore,
-            // we need to adjust the time value to the local time. See 
[`adjust_to_local_time`]
-            // for more details.
-            //
-            // Then remove the timezone in return type, i.e. return None
-            Timestamp(_, Some(timezone)) => {
-                let tz: Tz = timezone.parse()?;
-
-                match time_value {
-                    ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(
-                        Some(ts),
-                        Some(_),
-                    )) => {
-                        let adjusted_ts =
-                            
adjust_to_local_time::<TimestampNanosecondType>(*ts, tz)?;
-                        
Ok(ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(
-                            Some(adjusted_ts),
-                            None,
-                        )))
-                    }
-                    ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(
-                        Some(ts),
-                        Some(_),
-                    )) => {
-                        let adjusted_ts =
-                            
adjust_to_local_time::<TimestampMicrosecondType>(*ts, tz)?;
-                        
Ok(ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(
-                            Some(adjusted_ts),
-                            None,
-                        )))
-                    }
-                    ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(
-                        Some(ts),
-                        Some(_),
-                    )) => {
-                        let adjusted_ts =
-                            
adjust_to_local_time::<TimestampMillisecondType>(*ts, tz)?;
-                        
Ok(ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(
-                            Some(adjusted_ts),
-                            None,
-                        )))
-                    }
-                    ColumnarValue::Scalar(ScalarValue::TimestampSecond(
-                        Some(ts),
-                        Some(_),
-                    )) => {
-                        let adjusted_ts =
-                            adjust_to_local_time::<TimestampSecondType>(*ts, 
tz)?;
-                        Ok(ColumnarValue::Scalar(ScalarValue::TimestampSecond(
-                            Some(adjusted_ts),
-                            None,
-                        )))
-                    }
-                    ColumnarValue::Array(array) => {
-                        fn transform_array<T: ArrowTimestampType>(
-                            array: &ArrayRef,
-                            tz: Tz,
-                        ) -> Result<ColumnarValue> {
-                            let mut builder = PrimitiveBuilder::<T>::new();
-
-                            let primitive_array = 
as_primitive_array::<T>(array)?;
-                            for ts_opt in primitive_array.iter() {
-                                match ts_opt {
-                                    None => builder.append_null(),
-                                    Some(ts) => {
-                                        let adjusted_ts: i64 =
-                                            adjust_to_local_time::<T>(ts, tz)?;
-                                        builder.append_value(adjusted_ts)
-                                    }
-                                }
-                            }
-
-                            
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
-                        }
-
-                        match array.data_type() {
-                            Timestamp(_, None) => {
-                                // if no timezone specified, just return the 
input
-                                Ok(time_value.clone())
-                            }
-                            Timestamp(Nanosecond, Some(_)) => {
-                                
transform_array::<TimestampNanosecondType>(array, tz)
-                            }
-                            Timestamp(Microsecond, Some(_)) => {
-                                
transform_array::<TimestampMicrosecondType>(array, tz)
-                            }
-                            Timestamp(Millisecond, Some(_)) => {
-                                
transform_array::<TimestampMillisecondType>(array, tz)
-                            }
-                            Timestamp(Second, Some(_)) => {
-                                transform_array::<TimestampSecondType>(array, 
tz)
-                            }
-                            _ => {
-                                exec_err!("to_local_time function requires 
timestamp argument in array, got {:?}", array.data_type())
-                            }
-                        }
-                    }
-                    _ => {
-                        exec_err!(
-                        "to_local_time function requires timestamp argument, 
got {:?}",
-                        time_value.data_type()
-                    )
-                    }
-                }
-            }
-            _ => {
-                exec_err!(
-                    "to_local_time function requires timestamp argument, got 
{:?}",
-                    arg_type
-                )
+    fn name(&self) -> &str {
+        "to_local_time"
+    }
+
+    fn signature(&self) -> &Signature {
+        &self.signature
+    }
+
+    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
+        match &arg_types[0] {
+            DataType::Null => Ok(Timestamp(Nanosecond, None)),
+            Timestamp(timeunit, _) => Ok(Timestamp(*timeunit, None)),
+            dt => internal_err!(
+                "The to_local_time function can only accept timestamp as the 
arg, got {dt}"
+            ),
+        }
+    }
+
+    fn invoke_with_args(
+        &self,
+        args: datafusion_expr::ScalarFunctionArgs,
+    ) -> Result<ColumnarValue> {
+        let [time_value] = take_function_args(self.name(), &args.args)?;
+        to_local_time(time_value)
+    }
+
+    fn documentation(&self) -> Option<&Documentation> {
+        self.doc()
+    }
+}
+
+fn transform_array<T: ArrowTimestampType>(
+    array: &ArrayRef,
+    tz: Tz,
+) -> Result<ColumnarValue> {
+    let mut builder = PrimitiveBuilder::<T>::new();
+
+    let primitive_array = as_primitive_array::<T>(array)?;
+    for ts_opt in primitive_array.iter() {
+        match ts_opt {
+            None => builder.append_null(),
+            Some(ts) => {
+                let adjusted_ts: i64 = adjust_to_local_time::<T>(ts, tz)?;
+                builder.append_value(adjusted_ts)
             }
         }
     }
+
+    Ok(ColumnarValue::Array(Arc::new(builder.finish())))
+}
+
+fn to_local_time(time_value: &ColumnarValue) -> Result<ColumnarValue> {
+    let arg_type = time_value.data_type();
+
+    let tz: Tz = match &arg_type {
+        Timestamp(_, Some(timezone)) => timezone.parse()?,
+        Timestamp(_, None) => {
+            // if no timezone specified, just return the input
+            return Ok(time_value.clone());
+        }
+        DataType::Null => {

Review Comment:
   Yeah this is a small quirk of the coercion API, where `Null` types are 
passed through as is; I plan to look into it later to see if we can handle it 
better (need to check if there is an existing GitHub issue for this)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to