Jefffrey commented on code in PR #19384:
URL: https://github.com/apache/datafusion/pull/19384#discussion_r2636774108
##########
datafusion/functions/src/math/round.rs:
##########
@@ -123,119 +154,246 @@ impl ScalarUDFImpl for RoundFunc {
}
}
-/// Round SQL function
-fn round(args: &[ArrayRef]) -> Result<ArrayRef> {
- if args.len() != 1 && args.len() != 2 {
- return exec_err!(
- "round function requires one or two arguments, got {}",
- args.len()
- );
+fn round_columnar(
+ value: &ColumnarValue,
+ decimal_places: &ColumnarValue,
+ number_rows: usize,
+) -> Result<ColumnarValue> {
+ let value_array = value.to_array(number_rows)?;
+ let decimal_places_is_scalar = matches!(decimal_places,
ColumnarValue::Scalar(_));
+ let value_is_scalar = matches!(value, ColumnarValue::Scalar(_));
+
+ let arr: ArrayRef = match value_array.data_type() {
+ Float64 => {
+ let result = calculate_binary_math::<Float64Type, Int64Type,
Float64Type, _>(
+ value_array.as_ref(),
+ decimal_places,
+ round_float::<f64>,
+ )?;
+ result as _
+ }
+ Float32 => {
+ let result = calculate_binary_math::<Float32Type, Int64Type,
Float32Type, _>(
+ value_array.as_ref(),
+ decimal_places,
+ round_float::<f32>,
+ )?;
+ result as _
+ }
+ Decimal32(precision, scale) => {
+ let result = calculate_binary_decimal_math::<
+ Decimal32Type,
+ Int64Type,
+ Decimal32Type,
+ _,
+ >(
+ value_array.as_ref(),
+ decimal_places,
+ |v, dp| round_decimal32(v, *scale, dp),
+ *precision,
+ *scale,
+ )?;
+ result as _
+ }
+ Decimal64(precision, scale) => {
+ let result = calculate_binary_decimal_math::<
+ Decimal64Type,
+ Int64Type,
+ Decimal64Type,
+ _,
+ >(
+ value_array.as_ref(),
+ decimal_places,
+ |v, dp| round_decimal64(v, *scale, dp),
+ *precision,
+ *scale,
+ )?;
+ result as _
+ }
+ Decimal128(precision, scale) => {
+ let result = calculate_binary_decimal_math::<
+ Decimal128Type,
+ Int64Type,
+ Decimal128Type,
+ _,
+ >(
+ value_array.as_ref(),
+ decimal_places,
+ |v, dp| round_decimal128(v, *scale, dp),
+ *precision,
+ *scale,
+ )?;
+ result as _
+ }
+ Decimal256(precision, scale) => {
+ let result = calculate_binary_decimal_math::<
+ Decimal256Type,
+ Int64Type,
+ Decimal256Type,
+ _,
+ >(
+ value_array.as_ref(),
+ decimal_places,
+ |v, dp| round_decimal256(v, *scale, dp),
+ *precision,
+ *scale,
+ )?;
+ result as _
+ }
+ other => exec_err!("Unsupported data type {other:?} for function
round")?,
+ };
+
+ if value_is_scalar && decimal_places_is_scalar {
+ ScalarValue::try_from_array(&arr, 0).map(ColumnarValue::Scalar)
+ } else {
+ Ok(ColumnarValue::Array(arr))
}
+}
- let mut decimal_places =
ColumnarValue::Scalar(ScalarValue::Int64(Some(0)));
+fn round_float<T>(value: T, decimal_places: i64) -> Result<T, ArrowError>
+where
+ T: num_traits::Float,
+{
+ let places: i32 = decimal_places.try_into().map_err(|e| {
+ ArrowError::ComputeError(format!(
+ "Invalid value for decimal places: {decimal_places}: {e}"
+ ))
+ })?;
+
+ let factor = T::from(10_f64.powi(places)).ok_or_else(|| {
+ ArrowError::ComputeError(format!(
+ "Invalid value for decimal places: {decimal_places}"
+ ))
+ })?;
+ Ok((value * factor).round() / factor)
+}
+
+fn round_decimal32(
+ value: i32,
+ scale: i8,
+ decimal_places: i64,
+) -> Result<i32, ArrowError> {
+ let rounded =
+ round_decimal_i256(i256::from_i128(i128::from(value)), scale,
decimal_places)?;
+ rounded
+ .to_i128()
Review Comment:
Is there a better way to handle this than upcasting to i256 to perform the
round before downcasting back to i32?
##########
datafusion/functions/src/math/round.rs:
##########
@@ -93,14 +104,34 @@ impl ScalarUDFImpl for RoundFunc {
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
- match arg_types[0] {
- Float32 => Ok(Float32),
- _ => Ok(Float64),
- }
+ Ok(match arg_types[0].clone() {
+ Float32 => Float32,
+ dt @ Decimal128(_, _)
+ | dt @ Decimal256(_, _)
+ | dt @ Decimal32(_, _)
+ | dt @ Decimal64(_, _) => dt,
+ _ => Float64,
+ })
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) ->
Result<ColumnarValue> {
- make_scalar_function(round, vec![])(&args.args)
+ if args.args.len() != 1 && args.args.len() != 2 {
Review Comment:
Honestly we could just omit this check since signature guards us
##########
datafusion/functions/src/math/round.rs:
##########
@@ -123,119 +154,246 @@ impl ScalarUDFImpl for RoundFunc {
}
}
-/// Round SQL function
-fn round(args: &[ArrayRef]) -> Result<ArrayRef> {
- if args.len() != 1 && args.len() != 2 {
- return exec_err!(
- "round function requires one or two arguments, got {}",
- args.len()
- );
+fn round_columnar(
+ value: &ColumnarValue,
+ decimal_places: &ColumnarValue,
+ number_rows: usize,
+) -> Result<ColumnarValue> {
+ let value_array = value.to_array(number_rows)?;
+ let decimal_places_is_scalar = matches!(decimal_places,
ColumnarValue::Scalar(_));
+ let value_is_scalar = matches!(value, ColumnarValue::Scalar(_));
Review Comment:
nit: no need for these to be separate booleans, can just combine them into a
single `both_scalars`
##########
datafusion/functions/src/math/round.rs:
##########
@@ -310,9 +472,61 @@ mod test {
Arc::new(Int64Array::from(vec![2147483648])), // decimal_places
];
- let result = round(&args);
+ let result = round_arrays(Arc::clone(&args[0]),
Some(Arc::clone(&args[1])));
assert!(result.is_err());
- assert!(matches!(result, Err(DataFusionError::Execution(_))));
+ assert!(matches!(
+ result,
+ Err(DataFusionError::ArrowError(_, _)) |
Err(DataFusionError::Execution(_))
+ ));
+ }
+
+ #[test]
+ fn test_round_decimal128_scalar_places() {
Review Comment:
Is it possible to have these tests in SLTs instead?
##########
datafusion/functions/src/math/round.rs:
##########
@@ -123,119 +154,246 @@ impl ScalarUDFImpl for RoundFunc {
}
}
-/// Round SQL function
-fn round(args: &[ArrayRef]) -> Result<ArrayRef> {
- if args.len() != 1 && args.len() != 2 {
- return exec_err!(
- "round function requires one or two arguments, got {}",
- args.len()
- );
+fn round_columnar(
+ value: &ColumnarValue,
+ decimal_places: &ColumnarValue,
+ number_rows: usize,
+) -> Result<ColumnarValue> {
+ let value_array = value.to_array(number_rows)?;
+ let decimal_places_is_scalar = matches!(decimal_places,
ColumnarValue::Scalar(_));
+ let value_is_scalar = matches!(value, ColumnarValue::Scalar(_));
+
+ let arr: ArrayRef = match value_array.data_type() {
+ Float64 => {
+ let result = calculate_binary_math::<Float64Type, Int64Type,
Float64Type, _>(
+ value_array.as_ref(),
+ decimal_places,
+ round_float::<f64>,
+ )?;
+ result as _
+ }
+ Float32 => {
+ let result = calculate_binary_math::<Float32Type, Int64Type,
Float32Type, _>(
+ value_array.as_ref(),
+ decimal_places,
+ round_float::<f32>,
+ )?;
+ result as _
+ }
+ Decimal32(precision, scale) => {
+ let result = calculate_binary_decimal_math::<
+ Decimal32Type,
+ Int64Type,
+ Decimal32Type,
+ _,
+ >(
+ value_array.as_ref(),
+ decimal_places,
+ |v, dp| round_decimal32(v, *scale, dp),
+ *precision,
+ *scale,
+ )?;
+ result as _
+ }
+ Decimal64(precision, scale) => {
+ let result = calculate_binary_decimal_math::<
+ Decimal64Type,
+ Int64Type,
+ Decimal64Type,
+ _,
+ >(
+ value_array.as_ref(),
+ decimal_places,
+ |v, dp| round_decimal64(v, *scale, dp),
+ *precision,
+ *scale,
+ )?;
+ result as _
+ }
+ Decimal128(precision, scale) => {
+ let result = calculate_binary_decimal_math::<
+ Decimal128Type,
+ Int64Type,
+ Decimal128Type,
+ _,
+ >(
+ value_array.as_ref(),
+ decimal_places,
+ |v, dp| round_decimal128(v, *scale, dp),
+ *precision,
+ *scale,
+ )?;
+ result as _
+ }
+ Decimal256(precision, scale) => {
+ let result = calculate_binary_decimal_math::<
+ Decimal256Type,
+ Int64Type,
+ Decimal256Type,
+ _,
+ >(
+ value_array.as_ref(),
+ decimal_places,
+ |v, dp| round_decimal256(v, *scale, dp),
+ *precision,
+ *scale,
+ )?;
+ result as _
+ }
+ other => exec_err!("Unsupported data type {other:?} for function
round")?,
+ };
+
+ if value_is_scalar && decimal_places_is_scalar {
+ ScalarValue::try_from_array(&arr, 0).map(ColumnarValue::Scalar)
+ } else {
+ Ok(ColumnarValue::Array(arr))
}
+}
- let mut decimal_places =
ColumnarValue::Scalar(ScalarValue::Int64(Some(0)));
+fn round_float<T>(value: T, decimal_places: i64) -> Result<T, ArrowError>
+where
+ T: num_traits::Float,
+{
+ let places: i32 = decimal_places.try_into().map_err(|e| {
+ ArrowError::ComputeError(format!(
+ "Invalid value for decimal places: {decimal_places}: {e}"
+ ))
+ })?;
Review Comment:
I know it was the existing behaviour, but I do wonder if we're not better
off encoding this in the signature; that is, instead of accepting i64 arrays
then casting to i32 internally, explicitly only allow i32 input arrays at the
signature level 🤔
##########
datafusion/functions/src/math/round.rs:
##########
@@ -64,14 +69,20 @@ impl Default for RoundFunc {
impl RoundFunc {
pub fn new() -> Self {
- use DataType::*;
+ let decimal = Coercion::new_exact(TypeSignatureClass::Decimal);
+ let integer = Coercion::new_exact(TypeSignatureClass::Integer);
+ let float = Coercion::new_implicit(
+ TypeSignatureClass::Float,
Review Comment:
This also lets in f16 types now, so need to consider that
##########
datafusion/functions/src/math/round.rs:
##########
@@ -64,14 +69,20 @@ impl Default for RoundFunc {
impl RoundFunc {
pub fn new() -> Self {
- use DataType::*;
+ let decimal = Coercion::new_exact(TypeSignatureClass::Decimal);
+ let integer = Coercion::new_exact(TypeSignatureClass::Integer);
Review Comment:
Be careful with this, as it accepts any integer input; i16, i32, u32, etc.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]