Jefffrey commented on code in PR #19788:
URL: https://github.com/apache/datafusion/pull/19788#discussion_r2690951185


##########
datafusion/functions/src/math/trunc.rs:
##########
@@ -110,7 +110,57 @@ impl ScalarUDFImpl for TruncFunc {
     }
 
     fn invoke_with_args(&self, args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
-        make_scalar_function(trunc, vec![])(&args.args)
+        // Extract precision from second argument (default 0)
+        let precision = match args.args.get(1) {
+            Some(ColumnarValue::Scalar(Int64(Some(p)))) => Some(*p),
+            Some(ColumnarValue::Scalar(Int64(None))) => None, // null precision
+            Some(ColumnarValue::Array(_)) => {
+                // Precision is an array - use array path
+                return make_scalar_function(trunc, vec![])(&args.args);
+            }
+            None => Some(0), // default precision
+            _ => Some(0),

Review Comment:
   This catch all arm should return an internal error, unless theres a case I'm 
missing?



##########
datafusion/functions/src/math/trunc.rs:
##########
@@ -110,7 +110,57 @@ impl ScalarUDFImpl for TruncFunc {
     }
 
     fn invoke_with_args(&self, args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
-        make_scalar_function(trunc, vec![])(&args.args)
+        // Extract precision from second argument (default 0)
+        let precision = match args.args.get(1) {
+            Some(ColumnarValue::Scalar(Int64(Some(p)))) => Some(*p),
+            Some(ColumnarValue::Scalar(Int64(None))) => None, // null precision
+            Some(ColumnarValue::Array(_)) => {
+                // Precision is an array - use array path
+                return make_scalar_function(trunc, vec![])(&args.args);
+            }
+            None => Some(0), // default precision
+            _ => Some(0),
+        };
+
+        // Scalar fast path using tuple matching for (value, precision)
+        match (&args.args[0], precision) {
+            // Null precision returns null with same type as input
+            (ColumnarValue::Scalar(ScalarValue::Float32(_)), None) => {
+                Ok(ColumnarValue::Scalar(ScalarValue::Float32(None)))
+            }
+            (ColumnarValue::Scalar(ScalarValue::Float64(_)), None)
+            | (ColumnarValue::Scalar(ScalarValue::Null), None) => {
+                Ok(ColumnarValue::Scalar(ScalarValue::Float64(None)))
+            }
+            // Float64 scalar with precision
+            (ColumnarValue::Scalar(ScalarValue::Float64(v)), Some(p)) => {
+                let result = v.map(|x| {
+                    if p == 0 {
+                        x.trunc()
+                    } else {
+                        compute_truncate64(x, p)
+                    }
+                });
+                Ok(ColumnarValue::Scalar(ScalarValue::Float64(result)))
+            }
+            // Float32 scalar with precision
+            (ColumnarValue::Scalar(ScalarValue::Float32(v)), Some(p)) => {
+                let result = v.map(|x| {
+                    if p == 0 {
+                        x.trunc()
+                    } else {
+                        compute_truncate32(x, p)
+                    }
+                });
+                Ok(ColumnarValue::Scalar(ScalarValue::Float32(result)))
+            }
+            // Null scalar
+            (ColumnarValue::Scalar(ScalarValue::Null), Some(_)) => {
+                Ok(ColumnarValue::Scalar(ScalarValue::Float64(None)))
+            }
+            // Array path for everything else
+            _ => make_scalar_function(trunc, vec![])(&args.args),

Review Comment:
   ```suggestion
               // Null cases
               (ColumnarValue::Scalar(sv), _) if sv.is_null() => {
                   
ColumnarValue::Scalar(ScalarValue::Null).cast_to(args.return_type(), None)
               }
               (_, None) => {
                   
ColumnarValue::Scalar(ScalarValue::Null).cast_to(args.return_type(), None)
               }
               // Scalar cases
               (ColumnarValue::Scalar(ScalarValue::Float64(Some(v))), Some(p)) 
=> Ok(
                   ColumnarValue::Scalar(ScalarValue::Float64(Some(if p == 0 {
                       v.trunc()
                   } else {
                       compute_truncate64(*v, p)
                   }))),
               ),
               (ColumnarValue::Scalar(ScalarValue::Float32(Some(v))), Some(p)) 
=> Ok(
                   ColumnarValue::Scalar(ScalarValue::Float32(Some(if p == 0 {
                       v.trunc()
                   } else {
                       compute_truncate32(*v, p)
                   }))),
               ),
               // Array path for everything else
               _ => make_scalar_function(trunc, vec![])(&args.args),
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to