theirix commented on code in PR #22655:
URL: https://github.com/apache/datafusion/pull/22655#discussion_r3397462237
##########
datafusion/sqllogictest/test_files/math.slt:
##########
@@ -686,6 +686,22 @@ select gcd(-9223372036854775808, 0);
query error DataFusion error: Arrow error: Compute error: Signed integer
overflow in GCD\(0, \-9223372036854775808\)
select gcd(0, -9223372036854775808);
+# gcd decimal
Review Comment:
Done
##########
datafusion/functions/src/utils.rs:
##########
@@ -133,6 +134,69 @@ pub fn calculate_binary_math<L, R, O, F>(
right: &ColumnarValue,
fun: F,
) -> Result<Arc<PrimitiveArray<O>>>
+where
+ L: ArrowPrimitiveType,
+ R: ArrowPrimitiveType,
+ O: ArrowPrimitiveType,
+ F: Fn(L::Native, R::Native) -> Result<O::Native, ArrowError>,
+ R::Native: TryFrom<ScalarValue>,
+{
+ calculate_binary_math_cast::<L, R, O, F>(left, right, fun, &R::DATA_TYPE)
+}
+
+/// Computes a binary math function for input arrays using a specified function
+/// and applies rescaling to given precision and scale.
+/// Deprecated, use [`calculate_binary_decimal_math_cast`] instead.
Review Comment:
Since the original function `calculate_binary_math_cast` is public, we
cannot track its usage outside, and removal is only possible with deprecation.
It shouldn't have been declared as public in the first place - so the new
function, too.
What I can see as a plan:
- Deprecate it explicitly with the macro, as you suggest
- Port some usages in the `datafusion` repo in the next PR
- Mark the new function as `pub(crate)` as it is intended to be used locally
only
##########
datafusion/functions/src/math/gcd.rs:
##########
@@ -76,37 +76,123 @@ impl ScalarUDFImpl for GcdFunc {
&self.signature
}
- fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
- Ok(DataType::Int64)
+ fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
+ Ok(arg_types[0].clone())
+ }
+
+ fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
+ let [arg1, arg2] = take_function_args(self.name(), arg_types)?;
+
+ let coerced_type = match (arg1, arg2) {
+ (DataType::Null, _) | (_, DataType::Null) => Ok(DataType::Int64),
+ (lhs, rhs) if lhs.is_integer() && rhs.is_integer() =>
Ok(DataType::Int64),
+ (lhs, rhs) if lhs.is_decimal() || rhs.is_decimal() => {
+ decimal_coercion(lhs, rhs).map(Ok).unwrap_or_else(|| {
+ exec_err!(
Review Comment:
Done
##########
datafusion/functions/src/math/lcm.rs:
##########
@@ -15,25 +15,22 @@
// specific language governing permissions and limitations
Review Comment:
Fixed
##########
datafusion/functions/src/math/gcd.rs:
##########
@@ -141,44 +227,271 @@ fn compute_gcd_with_scalar(arr: &ArrayRef, scalar:
Option<i64>) -> Result<Column
}
Some(scalar_value) => {
let result: PrimitiveArray<Int64Type> =
- prim.try_unary(|val| compute_gcd(val, scalar_value))?;
+ prim.try_unary(|val| gcd_signed_int(val, scalar_value))?;
Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef))
}
None => Ok(ColumnarValue::Scalar(ScalarValue::Int64(None))),
}
}
-/// Computes gcd of two unsigned integers using Binary GCD algorithm.
-pub(super) fn unsigned_gcd(mut a: u64, mut b: u64) -> u64 {
- if a == 0 {
- return b;
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::math::common::gcd_signed;
+ use arrow::array::{Array, Decimal128Array, Int64Array};
+ use arrow::datatypes::{DECIMAL128_MAX_PRECISION, Field};
+ use arrow_buffer::i256;
+ use datafusion_common::ScalarValue;
+ use datafusion_common::cast::{as_decimal128_array, as_int64_array};
+ use datafusion_common::config::ConfigOptions;
+ use std::sync::Arc;
+
+ #[test]
+ fn test_i64_array() {
+ let arg_fields = vec![
+ Field::new("a", DataType::Int64, true).into(),
+ Field::new("b", DataType::Int64, true).into(),
+ ];
+ let args = ScalarFunctionArgs {
+ args: vec![
+ ColumnarValue::Array(Arc::new(Int64Array::from(vec![
+ 0, 2, 0, 2, 15, 20,
+ ]))),
+ ColumnarValue::Array(Arc::new(Int64Array::from(vec![
+ 0, 0, 2, 3, 10, 1000,
+ ]))),
+ ],
+ arg_fields,
+ number_rows: 6,
+ return_field: Field::new("f", DataType::Int64, true).into(),
+ config_options: Arc::new(ConfigOptions::default()),
+ };
+ let result = GcdFunc::new()
+ .invoke_with_args(args)
+ .expect("failed to initialize function");
+
+ match result {
+ ColumnarValue::Array(arr) => {
+ let values =
+ as_int64_array(&arr).expect("failed to convert result to
an array");
+ assert_eq!(values.len(), 6);
+ assert_eq!(values.value(0), 0);
+ assert_eq!(values.value(1), 2);
+ assert_eq!(values.value(2), 2);
+ assert_eq!(values.value(3), 1);
+ assert_eq!(values.value(4), 5);
+ assert_eq!(values.value(5), 20);
+ }
+ ColumnarValue::Scalar(_) => {
+ panic!("Expected an array value")
+ }
+ }
}
- if b == 0 {
- return a;
+
+ #[test]
+ fn test_decimal_scalar() {
+ let arg_fields = vec![
+ Field::new("a", DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0),
true)
+ .into(),
+ Field::new("b", DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0),
true)
+ .into(),
+ ];
+ let args = ScalarFunctionArgs {
+ args: vec![
+ ColumnarValue::Scalar(ScalarValue::Decimal128(
+ Some(i128::from(2)),
+ DECIMAL128_MAX_PRECISION,
+ 0,
+ )),
+ ColumnarValue::Scalar(ScalarValue::Decimal128(
+ Some(i128::from(3)),
+ DECIMAL128_MAX_PRECISION,
+ 0,
+ )),
+ ],
+ arg_fields,
+ number_rows: 1,
+ return_field: Field::new(
+ "f",
+ DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0),
+ true,
+ )
+ .into(),
+ config_options: Arc::new(ConfigOptions::default()),
+ };
+ let result = GcdFunc::new()
+ .invoke_with_args(args)
+ .expect("failed to initialize function power");
+
+ match result {
+ ColumnarValue::Array(arr) => {
+ let ints = as_decimal128_array(&arr)
+ .expect("failed to convert result to an array");
+
+ assert_eq!(ints.len(), 1);
+ assert_eq!(ints.value(0), i128::from(1));
+ // Signature stays the same as input
+ assert_eq!(
+ *arr.data_type(),
+ DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0)
+ );
+ }
+ ColumnarValue::Scalar(_) => {
+ panic!("Expected an array value")
+ }
+ }
}
- let shift = (a | b).trailing_zeros();
- a >>= a.trailing_zeros();
- loop {
- b >>= b.trailing_zeros();
- if a > b {
- swap(&mut a, &mut b);
+ #[test]
+ fn test_decimal_array_scalar() {
+ let arg_fields = vec![
+ Field::new("a", DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0),
true)
+ .into(),
+ Field::new("b", DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0),
true)
+ .into(),
+ ];
+ let args = ScalarFunctionArgs {
+ args: vec![
+ ColumnarValue::Array(Arc::new(
+ Decimal128Array::from(vec![2, 15])
+ .with_precision_and_scale(DECIMAL128_MAX_PRECISION, 0)
+ .unwrap(),
+ )),
+ ColumnarValue::Scalar(ScalarValue::Decimal128(
+ Some(i128::from(3)),
+ DECIMAL128_MAX_PRECISION,
+ 0,
+ )),
+ ],
+ arg_fields,
+ number_rows: 2,
+ return_field: Field::new(
+ "f",
+ DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0),
+ true,
+ )
+ .into(),
+ config_options: Arc::new(ConfigOptions::default()),
+ };
+ let result = GcdFunc::new()
+ .invoke_with_args(args)
+ .expect("failed to initialize function power");
+
+ match result {
+ ColumnarValue::Array(arr) => {
+ let ints = as_decimal128_array(&arr)
+ .expect("failed to convert result to an array");
+
+ assert_eq!(ints.len(), 2);
+ assert_eq!(ints.value(0), i128::from(1));
+ assert_eq!(ints.value(1), i128::from(3));
+ // Signature stays the same as input
+ assert_eq!(
+ *arr.data_type(),
+ DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0)
+ );
+ }
+ ColumnarValue::Scalar(_) => {
+ panic!("Expected an array value")
+ }
}
- b -= a;
- if b == 0 {
- return a << shift;
+ }
+
+ #[test]
+ fn test_coercion() {
+ let mut coerced = GcdFunc::new()
+ .coerce_types(&[DataType::Int64, DataType::Int32])
+ .expect("coercion should succeed");
+ assert_eq!(coerced, vec![DataType::Int64, DataType::Int64]);
+
+ coerced = GcdFunc::new()
+ .coerce_types(&[DataType::Decimal128(10, 2), DataType::Int32])
+ .expect("coercion should succeed");
+
+ assert_eq!(
+ coerced,
+ vec![DataType::Decimal128(12, 2), DataType::Decimal128(12, 2)]
+ );
+
+ coerced = GcdFunc::new()
+ .coerce_types(&[DataType::Decimal128(10, 2), DataType::Null])
+ .expect("coercion should succeed");
+
+ assert_eq!(coerced, vec![DataType::Int64, DataType::Int64]);
+ }
+
+ const GCD_COMMON_TEST_CASES: [(i64, i64, i64); 18] = [
+ // Basic cases
+ (48, 18, 6),
+ (54, 24, 6),
+ (100, 50, 50),
+ (17, 19, 1),
+ (21, 14, 7),
+ // Edge cases with 0
+ (0, 0, 0),
+ (0, 5, 5),
+ (10, 0, 10),
+ // Same numbers
+ (7, 7, 7),
+ (100, 100, 100),
+ // One is 1
+ (1, 1, 1),
+ (1, 100, 1),
+ (999, 1, 1),
+ // Large numbers
+ (1000000, 500000, 500000),
+ (123456, 789012, 12),
+ (999999, 111111, 111111),
+ // Powers of 2
+ (64, 128, 64),
+ (1024, 2048, 1024),
+ ];
+
+ #[test]
+ fn test_gcd_i64() {
+ let test_cases: Vec<(i64, i64, i64)> = [
+ GCD_COMMON_TEST_CASES.into(),
+ vec![
+ // Max value cases
+ (1, i64::MAX, 1),
+ (i64::MAX, 1, 1),
+ (i64::MAX, i64::MAX, i64::MAX),
+ ],
+ ]
+ .concat();
+
+ // Success cases
+ for (a, b, expected) in test_cases {
+ let actual = gcd_signed(a, b).expect("should succeed");
+ assert_eq!(
+ actual, expected,
+ "euclid_gcd({a}, {b}) expected {expected}, actual {actual}"
+ );
}
}
-}
-/// Computes greatest common divisor using Binary GCD algorithm.
-pub fn compute_gcd(x: i64, y: i64) -> Result<i64, ArrowError> {
- let a = x.unsigned_abs();
- let b = y.unsigned_abs();
- let r = unsigned_gcd(a, b);
- // The result can be up to 2^63 (e.g. gcd(i64::MIN, 0) or
- // gcd(i64::MIN, i64::MIN)), which does not fit into i64.
- r.try_into().map_err(|_| {
- ArrowError::ComputeError(format!("Signed integer overflow in GCD({x},
{y})"))
- })
+ #[test]
+ fn test_gcd_decimal128() {
Review Comment:
I forgot to remove, cleaned now
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]