This is an automated email from the ASF dual-hosted git repository. jeffreyvo pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push: new 572c204394 feat: Support log for Decimal128 and Decimal256 (#17023) 572c204394 is described below commit 572c204394786f013b5c49648e3fdb153b7f9f73 Author: theirix <thei...@gmail.com> AuthorDate: Sun Sep 14 03:14:09 2025 +0100 feat: Support log for Decimal128 and Decimal256 (#17023) * Enable env_logger for datafusion-functions crate * Fixup ScalarValue decimal constructors * Support decimals in log UDF * Add sqllogic test for log on Decimals * Fix test for scalar new_ten * Loosen requirements on a return type * Remove extra logging * Improve handling scale and mix of base/value types Signed-off-by: theirix <thei...@gmail.com> * Format * Adjust ScalarFunctionArgs construction * Apply suggestions from code review Refactoring decimal128 conversions Co-authored-by: Jeffrey Vo <jeffrey.vo.austra...@gmail.com> * Update tests * Update tests and SLT * Improve test for decimal128_to_i128 * Improve type signature for log UDF * Fix clippy --------- Signed-off-by: theirix <thei...@gmail.com> --- Cargo.lock | 2 + datafusion/common/src/scalar/mod.rs | 10 +- datafusion/functions/Cargo.toml | 2 + datafusion/functions/src/lib.rs | 7 + datafusion/functions/src/math/log.rs | 511 +++++++++++++++++++++---- datafusion/functions/src/utils.rs | 105 ++++- datafusion/sqllogictest/test_files/array.slt | 2 +- datafusion/sqllogictest/test_files/decimal.slt | 124 ++++++ datafusion/sqllogictest/test_files/scalar.slt | 2 +- 9 files changed, 683 insertions(+), 82 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 272d9d4550..8a4b94f514 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2313,12 +2313,14 @@ dependencies = [ "blake3", "chrono", "criterion", + "ctor", "datafusion-common", "datafusion-doc", "datafusion-execution", "datafusion-expr", "datafusion-expr-common", "datafusion-macros", + "env_logger", "hex", "itertools 0.14.0", "log", diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 4d88f5a667..c5e764272b 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -1628,7 +1628,7 @@ impl ScalarValue { ) { return _internal_err!("Invalid precision and scale {err}"); } - if *scale <= 0 { + if *scale < 0 { return _internal_err!("Negative scale is not supported"); } match i128::from(10).checked_pow((*scale + 1) as u32) { @@ -1644,7 +1644,7 @@ impl ScalarValue { ) { return _internal_err!("Invalid precision and scale {err}"); } - if *scale <= 0 { + if *scale < 0 { return _internal_err!("Negative scale is not supported"); } match i256::from(10).checked_pow((*scale + 1) as u32) { @@ -5429,8 +5429,7 @@ mod tests { ScalarValue::new_ten(&DataType::Decimal128(7, 2)).unwrap(), ScalarValue::Decimal128(Some(1000), 7, 2) ); - // No negative or zero scale - assert!(ScalarValue::new_ten(&DataType::Decimal128(5, 0)).is_err()); + // No negative scale assert!(ScalarValue::new_ten(&DataType::Decimal128(5, -1)).is_err()); // Invalid combination assert!(ScalarValue::new_ten(&DataType::Decimal128(0, 2)).is_err()); @@ -5452,8 +5451,7 @@ mod tests { ScalarValue::new_ten(&DataType::Decimal256(7, 2)).unwrap(), ScalarValue::Decimal256(Some(1000.into()), 7, 2) ); - // No negative or zero scale - assert!(ScalarValue::new_ten(&DataType::Decimal256(5, 0)).is_err()); + // No negative scale assert!(ScalarValue::new_ten(&DataType::Decimal256(5, -1)).is_err()); // Invalid combination assert!(ScalarValue::new_ten(&DataType::Decimal256(0, 2)).is_err()); diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index b557fe832e..90331fbcca 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -87,6 +87,8 @@ uuid = { version = "1.18", features = ["v4"], optional = true } [dev-dependencies] arrow = { workspace = true, features = ["test_utils"] } criterion = { workspace = true } +ctor = { workspace = true } +env_logger = { workspace = true } rand = { workspace = true } tokio = { workspace = true, features = ["macros", "rt", "sync"] } diff --git a/datafusion/functions/src/lib.rs b/datafusion/functions/src/lib.rs index 51cd5df806..e28003606b 100644 --- a/datafusion/functions/src/lib.rs +++ b/datafusion/functions/src/lib.rs @@ -191,6 +191,13 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { Ok(()) } +#[cfg(test)] +#[ctor::ctor] +fn init() { + // Enable RUST_LOG logging configuration for test + let _ = env_logger::try_init(); +} + #[cfg(test)] mod tests { use crate::all_default_functions; diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index 6342501a86..6604f9ee22 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -22,8 +22,14 @@ use std::sync::Arc; use super::power::PowerFunc; -use arrow::array::{ArrayRef, AsArray}; -use arrow::datatypes::{DataType, Float32Type, Float64Type}; +use crate::utils::{calculate_binary_math, decimal128_to_i128}; +use arrow::array::{Array, ArrayRef}; +use arrow::datatypes::{ + DataType, Decimal128Type, Decimal256Type, Float32Type, Float64Type, Int32Type, + Int64Type, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, +}; +use arrow::error::ArrowError; +use arrow_buffer::i256; use datafusion_common::{ exec_err, internal_err, plan_datafusion_err, plan_err, Result, ScalarValue, }; @@ -58,14 +64,37 @@ impl Default for LogFunc { impl LogFunc { pub fn new() -> Self { - use DataType::*; Self { signature: Signature::one_of( vec![ - Exact(vec![Float32]), - Exact(vec![Float64]), - Exact(vec![Float32, Float32]), - Exact(vec![Float64, Float64]), + Numeric(1), + Numeric(2), + Exact(vec![DataType::Float32, DataType::Float32]), + Exact(vec![DataType::Float64, DataType::Float64]), + Exact(vec![ + DataType::Int64, + DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0), + ]), + Exact(vec![ + DataType::Float32, + DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0), + ]), + Exact(vec![ + DataType::Float64, + DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0), + ]), + Exact(vec![ + DataType::Int64, + DataType::Decimal256(DECIMAL256_MAX_PRECISION, 0), + ]), + Exact(vec![ + DataType::Float32, + DataType::Decimal256(DECIMAL256_MAX_PRECISION, 0), + ]), + Exact(vec![ + DataType::Float64, + DataType::Decimal256(DECIMAL256_MAX_PRECISION, 0), + ]), ], Volatility::Immutable, ), @@ -73,6 +102,41 @@ impl LogFunc { } } +/// Binary function to calculate an integer logarithm of Decimal128 `value` using `base` base +/// Returns error if base is invalid +fn log_decimal128(value: i128, scale: i8, base: f64) -> Result<f64, ArrowError> { + if !base.is_finite() || base.trunc() != base { + return Err(ArrowError::ComputeError(format!( + "Log cannot use non-integer base: {base}" + ))); + } + if (base as u32) < 2 { + return Err(ArrowError::ComputeError(format!( + "Log base must be greater than 1: {base}" + ))); + } + + let unscaled_value = decimal128_to_i128(value, scale)?; + if unscaled_value > 0 { + let log_value: u32 = unscaled_value.ilog(base as i128); + Ok(log_value as f64) + } else { + // Reflect f64::log behaviour + Ok(f64::NAN) + } +} + +/// Binary function to calculate an integer logarithm of Decimal128 `value` using `base` base +/// Returns error if base is invalid or if value is out of bounds of Decimal128 +fn log_decimal256(value: i256, scale: i8, base: f64) -> Result<f64, ArrowError> { + match value.to_i128() { + Some(value) => log_decimal128(value, scale, base), + None => Err(ArrowError::NotYetImplemented(format!( + "Log of Decimal256 larger than Decimal128 is not yet supported: {value}" + ))), + } +} + impl ScalarUDFImpl for LogFunc { fn as_any(&self) -> &dyn Any { self @@ -86,7 +150,8 @@ impl ScalarUDFImpl for LogFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { - match &arg_types[0] { + // Check last argument (value) + match &arg_types.last().ok_or(plan_datafusion_err!("No args"))? { DataType::Float32 => Ok(DataType::Float32), _ => Ok(DataType::Float64), } @@ -121,55 +186,68 @@ impl ScalarUDFImpl for LogFunc { fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { let args = ColumnarValue::values_to_arrays(&args.args)?; - let mut base = ColumnarValue::Scalar(ScalarValue::Float32(Some(10.0))); - - let mut x = &args[0]; - if args.len() == 2 { - x = &args[1]; - base = ColumnarValue::Array(Arc::clone(&args[0])); - } - // note in f64::log params order is different than in sql. e.g in sql log(base, x) == f64::log(x, base) - let arr: ArrayRef = match args[0].data_type() { - DataType::Float64 => match base { - ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => { - Arc::new(x.as_primitive::<Float64Type>().unary::<_, Float64Type>( - |value: f64| f64::log(value, base as f64), - )) - } - ColumnarValue::Array(base) => { - let x = x.as_primitive::<Float64Type>(); - let base = base.as_primitive::<Float64Type>(); - let result = arrow::compute::binary::<_, _, _, Float64Type>( - x, - base, - f64::log, - )?; - Arc::new(result) as _ - } - _ => { - return exec_err!("log function requires a scalar or array for base") - } - }, - - DataType::Float32 => match base { - ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => Arc::new( - x.as_primitive::<Float32Type>() - .unary::<_, Float32Type>(|value: f32| f32::log(value, base)), + let (base, value) = if args.len() == 2 { + // note in f64::log params order is different than in sql. e.g in sql log(base, x) == f64::log(x, base) + (ColumnarValue::Array(Arc::clone(&args[0])), &args[1]) + } else { + // log(num) - assume base is 10 + let ret_type = if args[0].data_type().is_null() { + &DataType::Float64 + } else { + args[0].data_type() + }; + ( + ColumnarValue::Array( + ScalarValue::new_ten(ret_type)?.to_array_of_size(args[0].len())?, ), - ColumnarValue::Array(base) => { - let x = x.as_primitive::<Float32Type>(); - let base = base.as_primitive::<Float32Type>(); - let result = arrow::compute::binary::<_, _, _, Float32Type>( - x, - base, - f32::log, - )?; - Arc::new(result) as _ - } - _ => { - return exec_err!("log function requires a scalar or array for base") - } - }, + &args[0], + ) + }; + + // All log functors have format 'log(value, base)' + // Therefore, for `calculate_binary_math` the first type means a type of main array + // The second type is the type of the base array (even if derived from main) + let arr: ArrayRef = match value.data_type() { + DataType::Float32 => calculate_binary_math::< + Float32Type, + Float32Type, + Float32Type, + _, + >(value, &base, |x, b| Ok(f32::log(x, b)))?, + DataType::Float64 => calculate_binary_math::< + Float64Type, + Float64Type, + Float64Type, + _, + >(value, &base, |x, b| Ok(f64::log(x, b)))?, + DataType::Int32 => { + calculate_binary_math::<Int32Type, Float64Type, Float64Type, _>( + value, + &base, + |x, b| Ok(f64::log(x as f64, b)), + )? + } + DataType::Int64 => { + calculate_binary_math::<Int64Type, Float64Type, Float64Type, _>( + value, + &base, + |x, b| Ok(f64::log(x as f64, b)), + )? + } + DataType::Decimal128(_precision, scale) => { + calculate_binary_math::<Decimal128Type, Float64Type, Float64Type, _>( + value, + &base, + |x, b| log_decimal128(x, *scale, b), + )? + } + DataType::Decimal256(_precision, scale) => { + calculate_binary_math::<Decimal256Type, Float64Type, Float64Type, _>( + value, + &base, + |x, b| log_decimal256(x, *scale, b), + )? + } other => { return exec_err!("Unsupported data type {other:?} for function log") } @@ -256,9 +334,11 @@ mod tests { use super::*; - use arrow::array::{Float32Array, Float64Array, Int64Array}; + use arrow::array::{ + Date32Array, Decimal128Array, Decimal256Array, Float32Array, Float64Array, + }; use arrow::compute::SortOptions; - use arrow::datatypes::Field; + use arrow::datatypes::{Field, DECIMAL256_MAX_PRECISION}; use datafusion_common::cast::{as_float32_array, as_float64_array}; use datafusion_common::config::ConfigOptions; use datafusion_common::DFSchema; @@ -266,33 +346,37 @@ mod tests { use datafusion_expr::simplify::SimplifyContext; #[test] - #[should_panic] fn test_log_invalid_base_type() { let arg_fields = vec![ - Field::new("a", DataType::Float64, false).into(), - Field::new("a", DataType::Int64, false).into(), + Field::new("b", DataType::Date32, false).into(), + Field::new("n", DataType::Float64, false).into(), ]; let args = ScalarFunctionArgs { args: vec![ + ColumnarValue::Array(Arc::new(Date32Array::from(vec![5, 10, 15, 20]))), // base ColumnarValue::Array(Arc::new(Float64Array::from(vec![ 10.0, 100.0, 1000.0, 10000.0, ]))), // num - ColumnarValue::Array(Arc::new(Int64Array::from(vec![5, 10, 15, 20]))), ], arg_fields, number_rows: 4, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), }; - let _ = LogFunc::new().invoke_with_args(args); + let result = LogFunc::new().invoke_with_args(args); + assert!(result.is_err()); + assert_eq!( + result.unwrap_err().to_string().lines().next().unwrap(), + "Arrow error: Cast error: Casting from Date32 to Float64 not supported" + ); } #[test] fn test_log_invalid_value() { - let arg_field = Field::new("a", DataType::Int64, false).into(); + let arg_field = Field::new("a", DataType::Date32, false).into(); let args = ScalarFunctionArgs { args: vec![ - ColumnarValue::Array(Arc::new(Int64Array::from(vec![10]))), // num + ColumnarValue::Array(Arc::new(Date32Array::from(vec![10]))), // num ], arg_fields: vec![arg_field], number_rows: 1, @@ -372,7 +456,7 @@ mod tests { ]; let args = ScalarFunctionArgs { args: vec![ - ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))), // num + ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))), // base ColumnarValue::Scalar(ScalarValue::Float32(Some(32.0))), // num ], arg_fields, @@ -406,7 +490,7 @@ mod tests { ]; let args = ScalarFunctionArgs { args: vec![ - ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), // num + ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), // base ColumnarValue::Scalar(ScalarValue::Float64(Some(64.0))), // num ], arg_fields, @@ -511,14 +595,14 @@ mod tests { let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Float64Array::from(vec![ - 2.0, 2.0, 3.0, 5.0, + 2.0, 2.0, 3.0, 5.0, 5.0, ]))), // base ColumnarValue::Array(Arc::new(Float64Array::from(vec![ - 8.0, 4.0, 81.0, 625.0, + 8.0, 4.0, 81.0, 625.0, -123.0, ]))), // num ], arg_fields, - number_rows: 4, + number_rows: 5, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), }; @@ -531,11 +615,12 @@ mod tests { let floats = as_float64_array(&arr) .expect("failed to convert result to a Float64Array"); - assert_eq!(floats.len(), 4); + assert_eq!(floats.len(), 5); assert!((floats.value(0) - 3.0).abs() < 1e-10); assert!((floats.value(1) - 2.0).abs() < 1e-10); assert!((floats.value(2) - 4.0).abs() < 1e-10); assert!((floats.value(3) - 4.0).abs() < 1e-10); + assert!(floats.value(4).is_nan()); } ColumnarValue::Scalar(_) => { panic!("Expected an array value") @@ -731,4 +816,288 @@ mod tests { SortProperties::Unordered ); } + + #[test] + fn test_log_scalar_decimal128_unary() { + let arg_field = Field::new("a", DataType::Decimal128(38, 0), false).into(); + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Decimal128(Some(10), 38, 0)), // num + ], + arg_fields: vec![arg_field], + number_rows: 1, + return_field: Field::new("f", DataType::Decimal128(38, 0), true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = LogFunc::new() + .invoke_with_args(args) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Decimal128Array"); + assert_eq!(floats.len(), 1); + assert!((floats.value(0) - 1.0).abs() < 1e-10); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_log_scalar_decimal128() { + let arg_fields = vec![ + Field::new("b", DataType::Float64, false).into(), + Field::new("x", DataType::Decimal128(38, 0), false).into(), + ]; + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), // base + ColumnarValue::Scalar(ScalarValue::Decimal128(Some(64), 38, 0)), // num + ], + arg_fields, + number_rows: 1, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = LogFunc::new() + .invoke_with_args(args) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float64Array"); + + assert_eq!(floats.len(), 1); + assert!((floats.value(0) - 6.0).abs() < 1e-10); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_log_decimal128_unary() { + let arg_field = Field::new("a", DataType::Decimal128(38, 0), false).into(); + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new( + Decimal128Array::from(vec![10, 100, 1000, 10000, 12600, -123]) + .with_precision_and_scale(38, 0) + .unwrap(), + )), // num + ], + arg_fields: vec![arg_field], + number_rows: 6, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = LogFunc::new() + .invoke_with_args(args) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float64Array"); + + assert_eq!(floats.len(), 6); + assert!((floats.value(0) - 1.0).abs() < 1e-10); + assert!((floats.value(1) - 2.0).abs() < 1e-10); + assert!((floats.value(2) - 3.0).abs() < 1e-10); + assert!((floats.value(3) - 4.0).abs() < 1e-10); + assert!((floats.value(4) - 4.0).abs() < 1e-10); // Integer rounding + assert!(floats.value(5).is_nan()); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_log_decimal128_base_decimal() { + // Base stays 2 despite scaling + for base in [ + ScalarValue::Decimal128(Some(i128::from(2)), 38, 0), + ScalarValue::Decimal128(Some(i128::from(2000)), 38, 3), + ] { + let arg_fields = vec![ + Field::new("b", DataType::Decimal128(38, 0), false).into(), + Field::new("x", DataType::Decimal128(38, 0), false).into(), + ]; + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(base), // base + ColumnarValue::Scalar(ScalarValue::Decimal128(Some(64), 38, 0)), // num + ], + arg_fields, + number_rows: 1, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = LogFunc::new() + .invoke_with_args(args) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float64Array"); + + assert_eq!(floats.len(), 1); + assert!((floats.value(0) - 6.0).abs() < 1e-10); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + } + + #[test] + fn test_log_decimal128_value_scale() { + // Value stays 1000 despite scaling + for value in [ + ScalarValue::Decimal128(Some(i128::from(1000)), 38, 0), + ScalarValue::Decimal128(Some(i128::from(10000)), 38, 1), + ScalarValue::Decimal128(Some(i128::from(1000000)), 38, 3), + ] { + let arg_fields = vec![ + Field::new("b", DataType::Decimal128(38, 0), false).into(), + Field::new("x", DataType::Decimal128(38, 0), false).into(), + ]; + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(value), // base + ], + arg_fields, + number_rows: 1, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = LogFunc::new() + .invoke_with_args(args) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float64Array"); + + assert_eq!(floats.len(), 1); + assert!((floats.value(0) - 3.0).abs() < 1e-10); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + } + + #[test] + fn test_log_decimal256_unary() { + let arg_field = Field::new( + "a", + DataType::Decimal256(DECIMAL256_MAX_PRECISION, 0), + false, + ) + .into(); + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new( + Decimal256Array::from(vec![ + Some(i256::from(10)), + Some(i256::from(100)), + Some(i256::from(1000)), + Some(i256::from(10000)), + Some(i256::from(12600)), + // Slightly lower than i128 max - can calculate + Some(i256::from_i128(i128::MAX) - i256::from(1000)), + // Give NaN for incorrect inputs, as in f64::log + Some(i256::from(-123)), + ]) + .with_precision_and_scale(DECIMAL256_MAX_PRECISION, 0) + .unwrap(), + )), // num + ], + arg_fields: vec![arg_field], + number_rows: 7, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = LogFunc::new() + .invoke_with_args(args) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float64Array"); + + assert_eq!(floats.len(), 7); + eprintln!("floats {:?}", &floats); + assert!((floats.value(0) - 1.0).abs() < 1e-10); + assert!((floats.value(1) - 2.0).abs() < 1e-10); + assert!((floats.value(2) - 3.0).abs() < 1e-10); + assert!((floats.value(3) - 4.0).abs() < 1e-10); + assert!((floats.value(4) - 4.0).abs() < 1e-10); // Integer rounding for float log + assert!((floats.value(5) - 38.0).abs() < 1e-10); + assert!(floats.value(6).is_nan()); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_log_decimal128_wrong_base() { + let arg_fields = vec![ + Field::new("b", DataType::Float64, false).into(), + Field::new("x", DataType::Decimal128(38, 0), false).into(), + ]; + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Float64(Some(-2.0))), // base + ColumnarValue::Scalar(ScalarValue::Decimal128(Some(64), 38, 0)), // num + ], + arg_fields, + number_rows: 1, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = LogFunc::new().invoke_with_args(args); + assert!(result.is_err()); + assert_eq!( + "Arrow error: Compute error: Log base must be greater than 1: -2", + result.unwrap_err().to_string().lines().next().unwrap() + ); + } + + #[test] + fn test_log_decimal256_error() { + let arg_field = Field::new("a", DataType::Decimal256(38, 0), false).into(); + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(Decimal256Array::from(vec![ + // Slightly larger than i128 + Some(i256::from_i128(i128::MAX) + i256::from(1000)), + ]))), // num + ], + arg_fields: vec![arg_field], + number_rows: 1, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = LogFunc::new().invoke_with_args(args); + assert!(result.is_err()); + assert_eq!(result.unwrap_err().to_string().lines().next().unwrap(), + "Arrow error: Not yet implemented: Log of Decimal256 larger than Decimal128 is not yet supported: 170141183460469231731687303715884106727" + ); + } } diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index 5294d071a4..932d61e800 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -15,12 +15,14 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::ArrayRef; +use arrow::array::{Array, ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray}; +use arrow::compute::try_binary; use arrow::datatypes::DataType; - -use datafusion_common::{Result, ScalarValue}; +use arrow::error::ArrowError; +use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::function::Hint; use datafusion_expr::ColumnarValue; +use std::sync::Arc; /// Creates a function to identify the optimal return type of a string function given /// the type of its first argument. @@ -120,6 +122,76 @@ where } } +/// Computes a binary math function for input arrays using a specified function. +/// Generic types: +/// - `L`: Left array primitive type +/// - `R`: Right array primitive type +/// - `O`: Output array primitive type +/// - `F`: Functor computing `fun(l: L, r: R) -> Result<OutputType>` +pub fn calculate_binary_math<L, R, O, F>( + left: &dyn Array, + right: &ColumnarValue, + fun: F, +) -> Result<Arc<PrimitiveArray<O>>> +where + R: ArrowPrimitiveType, + L: ArrowPrimitiveType, + O: ArrowPrimitiveType, + F: Fn(L::Native, R::Native) -> Result<O::Native, ArrowError>, + R::Native: TryFrom<ScalarValue>, +{ + Ok(match right { + ColumnarValue::Scalar(scalar) => { + let right_value: R::Native = + R::Native::try_from(scalar.clone()).map_err(|_| { + DataFusionError::NotImplemented(format!( + "Cannot convert scalar value {} to {}", + &scalar, + R::DATA_TYPE + )) + })?; + let left_array = left.as_primitive::<L>(); + // Bind right value + let result = + left_array.try_unary::<_, O, _>(|lvalue| fun(lvalue, right_value))?; + Arc::new(result) as _ + } + ColumnarValue::Array(right) => { + let right_casted = arrow::compute::cast(&right, &R::DATA_TYPE)?; + let right_array = right_casted.as_primitive::<R>(); + + // Types are compatible even they are decimals with different scale or precision + let result = if PrimitiveArray::<L>::is_compatible(&L::DATA_TYPE) { + let left_array = left.as_primitive::<L>(); + try_binary::<_, _, _, O>(left_array, right_array, &fun)? + } else { + let left_casted = arrow::compute::cast(left, &L::DATA_TYPE)?; + let left_array = left_casted.as_primitive::<L>(); + try_binary::<_, _, _, O>(left_array, right_array, &fun)? + }; + Arc::new(result) as _ + } + }) +} + +/// Converts Decimal128 components (value and scale) to an unscaled i128 +pub fn decimal128_to_i128(value: i128, scale: i8) -> Result<i128, ArrowError> { + if scale < 0 { + Err(ArrowError::ComputeError( + "Negative scale is not supported".into(), + )) + } else if scale == 0 { + Ok(value) + } else { + match i128::from(10).checked_pow(scale as u32) { + Some(divisor) => Ok(value / divisor), + None => Err(ArrowError::ComputeError(format!( + "Cannot get a power of {scale}" + ))), + } + } +} + #[cfg(test)] pub mod test { /// $FUNC ScalarUDFImpl to test @@ -251,4 +323,31 @@ pub mod test { let v = utf8_to_int_type(&DataType::LargeUtf8, "test").unwrap(); assert_eq!(v, DataType::Int64); } + + #[test] + fn test_decimal128_to_i128() { + let cases = [ + (123, 0, Some(123)), + (1230, 1, Some(123)), + (123000, 3, Some(123)), + (1, 0, Some(1)), + (123, -3, None), + (123, i8::MAX, None), + (i128::MAX, 0, Some(i128::MAX)), + (i128::MAX, 3, Some(i128::MAX / 1000)), + ]; + + for (value, scale, expected) in cases { + match decimal128_to_i128(value, scale) { + Ok(actual) => { + assert_eq!( + actual, + expected.expect("Got value but expected none"), + "{value} and {scale} vs {expected:?}" + ); + } + Err(_) => assert!(expected.is_none()), + } + } + } } diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 96ab84ab90..e720491712 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -695,7 +695,7 @@ SELECT array_length([now()]) query ? select [abs(-1.2), sin(-1), log(2), ceil(3.141)] ---- -[1.2, -0.8414709848078965, 0.3010299801826477, 4.0] +[1.2, -0.8414709848078965, 0.30102999566398114, 4.0] ## array literal with nested types query ??? diff --git a/datafusion/sqllogictest/test_files/decimal.slt b/datafusion/sqllogictest/test_files/decimal.slt index bd19991ec3..502821fcc3 100644 --- a/datafusion/sqllogictest/test_files/decimal.slt +++ b/datafusion/sqllogictest/test_files/decimal.slt @@ -783,3 +783,127 @@ with tt as ( ---- Float64 133333333333333330000000000000000000000000000 +# Following tests only make sense if numbers are parsed as decimals +# Remove when `parse_float_as_decimal` is true by default (#14612) +statement ok +set datafusion.sql_parser.parse_float_as_decimal = true; + +# smoke test for decimal parsing +query RT +select 100000000000000000000000000000000000::decimal(38,0), arrow_typeof(100000000000000000000000000000000000::decimal(38,0)); +---- +100000000000000000000000000000000000 Decimal128(38, 0) + +# log for small decimal128 +query R +select log(100::decimal(38,0)); +---- +2 + +# log for small decimal256 +query R +select log(100::decimal(76,0)); +---- +2 + +# log(10^21) for large decimal128 +query R +select log(10, 1000000000000000000000::decimal(38,0)); +---- +21 + +# log(10^35) for large decimal128 +# Must be 35 if parsed as decimal; 34 for floats +query R +select log(100000000000000000000000000000000000::decimal(38,0)) +---- +35 + +# Decimal overflow for 10^38 +query error Arrow error: Invalid argument error: .* is too large to store in a Decimal128 of precision 38. Max is +select log(100000000000000000000000000000000000000::decimal(38,0)) + +# log(10^35) for decimal256 for a value able to fit i128 +query R +select log(100000000000000000000000000000000000::decimal(76,0)); +---- +35 + +# log(10^50) for decimal256 for a value larger than i128 +query error Arrow error: Not yet implemented: Log of Decimal256 larger than Decimal128 is not yet supported +select log(100000000000000000000000000000000000000000000000000::decimal(76,0)); + +# log(10^35) for decimal128 with explicit base +query R +select log(10, 100000000000000000000000000000000000::decimal(38,0)); +---- +35 + +# log(10^35) for decimal256 with explicit base - only float as a base +query R +select log(10.0, 100000000000000000000000000000000000::decimal(76,0)); +---- +35 + +# log(10^35) for decimal128 with explicit decimal base +query R +select log(10::decimal(38, 0), 100000000000000000000000000000000000::decimal(38,0)); +---- +35 + +# log(10^35) for decimal128 with another base +query R +select log(2, 100000000000000000000000000000000000::decimal(38,0)); +---- +116 + +# log(10^35) for decimal128 with another base +query R +select log(2.0, 100000000000000000000000000000000000::decimal(38,0)); +---- +116.267483321058 + +# null cases +query R +select log(null, 100); +---- +NULL + +query R +select log(null, 100000000000000000000000000000000000::decimal(38,0)); +---- +NULL + +query R +select log(null); +---- +NULL + +query R +select log(2.0, null); +---- +NULL + +# Set parse_float_as_decimal to false to test float parsing +statement ok +set datafusion.sql_parser.parse_float_as_decimal = false; + +# smoke test for decimal parsing +query R +select 100000000000000000000000000000000000::decimal(38,0) +---- +99999999999999996863366107917975552 + +# log(10^35) for decimal128 with explicit decimal base +# Float parsing is rounding down +query R +select log(10, 100000000000000000000000000000000000::decimal(38,0)); +---- +34 + +# log(10^35) for large decimal128 if parsed as float +# Float parsing is rounding down +query R +select log(100000000000000000000000000000000000::decimal(38,0)) +---- +34 diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index ca0b472de9..b0e200015d 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -523,7 +523,7 @@ query RRR rowsort select log(a, 64) a, log(b), log(10, b) from unsigned_integers; ---- 3 NULL NULL -3.7855785 4 4 +3.785578521429 4 4 6 3 3 Infinity 2 2 --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org For additional commands, e-mail: commits-h...@datafusion.apache.org