This is an automated email from the ASF dual-hosted git repository.
github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 032117adfc More decimal 32/64 support - type coercsion and misc gaps
(#17808)
032117adfc is described below
commit 032117adfc03d65a1245087ed03ef7b1998129ed
Author: Adam Gutglick <[email protected]>
AuthorDate: Tue Sep 30 15:00:37 2025 +0100
More decimal 32/64 support - type coercsion and misc gaps (#17808)
* More small decimal support
* CR comments
Co-authored-by: Jeffrey Vo <[email protected]>
* Add tests and cleanup some code
---------
Co-authored-by: Jeffrey Vo <[email protected]>
---
datafusion/common/src/scalar/mod.rs | 6 +
datafusion/expr-common/src/type_coercion/binary.rs | 122 +++++++++++++++++--
.../src/type_coercion/binary/tests/arithmetic.rs | 130 +++++++++++++++++++++
.../src/type_coercion/binary/tests/comparison.rs | 88 ++++++++++++++
datafusion/functions/src/math/abs.rs | 12 +-
.../src/joins/sort_merge_join/stream.rs | 5 +
datafusion/spark/src/function/math/width_bucket.rs | 12 +-
7 files changed, 353 insertions(+), 22 deletions(-)
diff --git a/datafusion/common/src/scalar/mod.rs
b/datafusion/common/src/scalar/mod.rs
index 8c079056e2..bba994dd11 100644
--- a/datafusion/common/src/scalar/mod.rs
+++ b/datafusion/common/src/scalar/mod.rs
@@ -1362,6 +1362,12 @@ impl ScalarValue {
DataType::Float16 =>
ScalarValue::Float16(Some(f16::from_f32(0.0))),
DataType::Float32 => ScalarValue::Float32(Some(0.0)),
DataType::Float64 => ScalarValue::Float64(Some(0.0)),
+ DataType::Decimal32(precision, scale) => {
+ ScalarValue::Decimal32(Some(0), *precision, *scale)
+ }
+ DataType::Decimal64(precision, scale) => {
+ ScalarValue::Decimal64(Some(0), *precision, *scale)
+ }
DataType::Decimal128(precision, scale) => {
ScalarValue::Decimal128(Some(0), *precision, *scale)
}
diff --git a/datafusion/expr-common/src/type_coercion/binary.rs
b/datafusion/expr-common/src/type_coercion/binary.rs
index 1c99f49d26..52bb211d9b 100644
--- a/datafusion/expr-common/src/type_coercion/binary.rs
+++ b/datafusion/expr-common/src/type_coercion/binary.rs
@@ -327,6 +327,16 @@ impl<'a> BinaryTypeCoercer<'a> {
// TODO Move the rest inside of BinaryTypeCoercer
+fn is_decimal(data_type: &DataType) -> bool {
+ matches!(
+ data_type,
+ DataType::Decimal32(..)
+ | DataType::Decimal64(..)
+ | DataType::Decimal128(..)
+ | DataType::Decimal256(..)
+ )
+}
+
/// Coercion rules for mathematics operators between decimal and non-decimal
types.
fn math_decimal_coercion(
lhs_type: &DataType,
@@ -357,6 +367,15 @@ fn math_decimal_coercion(
| (Decimal256(_, _), Decimal256(_, _)) => {
Some((lhs_type.clone(), rhs_type.clone()))
}
+ // Cross-variant decimal coercion - choose larger variant with
appropriate precision/scale
+ (lhs, rhs)
+ if is_decimal(lhs)
+ && is_decimal(rhs)
+ && std::mem::discriminant(lhs) != std::mem::discriminant(rhs)
=>
+ {
+ let coerced_type = get_wider_decimal_type_cross_variant(lhs_type,
rhs_type)?;
+ Some((coerced_type.clone(), coerced_type))
+ }
// Unlike with comparison we don't coerce to a decimal in the case of
floating point
// numbers, instead falling back to floating point arithmetic instead
(
@@ -953,21 +972,92 @@ pub fn binary_numeric_coercion(
pub fn decimal_coercion(lhs_type: &DataType, rhs_type: &DataType) ->
Option<DataType> {
use arrow::datatypes::DataType::*;
+ // Prefer decimal data type over floating point for comparison operation
match (lhs_type, rhs_type) {
- // Prefer decimal data type over floating point for comparison
operation
- (Decimal128(_, _), Decimal128(_, _)) => {
+ // Same decimal types
+ (lhs_type, rhs_type)
+ if is_decimal(lhs_type)
+ && is_decimal(rhs_type)
+ && std::mem::discriminant(lhs_type)
+ == std::mem::discriminant(rhs_type) =>
+ {
get_wider_decimal_type(lhs_type, rhs_type)
}
- (Decimal128(_, _), _) => get_common_decimal_type(lhs_type, rhs_type),
- (_, Decimal128(_, _)) => get_common_decimal_type(rhs_type, lhs_type),
- (Decimal256(_, _), Decimal256(_, _)) => {
- get_wider_decimal_type(lhs_type, rhs_type)
+ // Mismatched decimal types
+ (lhs_type, rhs_type)
+ if is_decimal(lhs_type)
+ && is_decimal(rhs_type)
+ && std::mem::discriminant(lhs_type)
+ != std::mem::discriminant(rhs_type) =>
+ {
+ get_wider_decimal_type_cross_variant(lhs_type, rhs_type)
+ }
+ // Decimal + non-decimal types
+ (Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_,
_), _) => {
+ get_common_decimal_type(lhs_type, rhs_type)
+ }
+ (_, Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) |
Decimal256(_, _)) => {
+ get_common_decimal_type(rhs_type, lhs_type)
}
- (Decimal256(_, _), _) => get_common_decimal_type(lhs_type, rhs_type),
- (_, Decimal256(_, _)) => get_common_decimal_type(rhs_type, lhs_type),
(_, _) => None,
}
}
+/// Handle cross-variant decimal widening by choosing the larger variant
+fn get_wider_decimal_type_cross_variant(
+ lhs_type: &DataType,
+ rhs_type: &DataType,
+) -> Option<DataType> {
+ use arrow::datatypes::DataType::*;
+
+ let (p1, s1) = match lhs_type {
+ Decimal32(p, s) => (*p, *s),
+ Decimal64(p, s) => (*p, *s),
+ Decimal128(p, s) => (*p, *s),
+ Decimal256(p, s) => (*p, *s),
+ _ => return None,
+ };
+
+ let (p2, s2) = match rhs_type {
+ Decimal32(p, s) => (*p, *s),
+ Decimal64(p, s) => (*p, *s),
+ Decimal128(p, s) => (*p, *s),
+ Decimal256(p, s) => (*p, *s),
+ _ => return None,
+ };
+
+ // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)
+ let s = s1.max(s2);
+ let range = (p1 as i8 - s1).max(p2 as i8 - s2);
+ let required_precision = (range + s) as u8;
+
+ // Choose the larger variant between the two input types, while making
sure we don't overflow the precision.
+ match (lhs_type, rhs_type) {
+ (Decimal32(_, _), Decimal64(_, _)) | (Decimal64(_, _), Decimal32(_, _))
+ if required_precision <= DECIMAL64_MAX_PRECISION =>
+ {
+ Some(Decimal64(required_precision, s))
+ }
+ (Decimal32(_, _), Decimal128(_, _))
+ | (Decimal128(_, _), Decimal32(_, _))
+ | (Decimal64(_, _), Decimal128(_, _))
+ | (Decimal128(_, _), Decimal64(_, _))
+ if required_precision <= DECIMAL128_MAX_PRECISION =>
+ {
+ Some(Decimal128(required_precision, s))
+ }
+ (Decimal32(_, _), Decimal256(_, _))
+ | (Decimal256(_, _), Decimal32(_, _))
+ | (Decimal64(_, _), Decimal256(_, _))
+ | (Decimal256(_, _), Decimal64(_, _))
+ | (Decimal128(_, _), Decimal256(_, _))
+ | (Decimal256(_, _), Decimal128(_, _))
+ if required_precision <= DECIMAL256_MAX_PRECISION =>
+ {
+ Some(Decimal256(required_precision, s))
+ }
+ _ => None,
+ }
+}
/// Coerce `lhs_type` and `rhs_type` to a common type.
fn get_common_decimal_type(
@@ -976,7 +1066,15 @@ fn get_common_decimal_type(
) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match decimal_type {
- Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) => {
+ Decimal32(_, _) => {
+ let other_decimal_type =
coerce_numeric_type_to_decimal32(other_type)?;
+ get_wider_decimal_type(decimal_type, &other_decimal_type)
+ }
+ Decimal64(_, _) => {
+ let other_decimal_type =
coerce_numeric_type_to_decimal64(other_type)?;
+ get_wider_decimal_type(decimal_type, &other_decimal_type)
+ }
+ Decimal128(_, _) => {
let other_decimal_type =
coerce_numeric_type_to_decimal128(other_type)?;
get_wider_decimal_type(decimal_type, &other_decimal_type)
}
@@ -988,7 +1086,7 @@ fn get_common_decimal_type(
}
}
-/// Returns a `DataType::Decimal128` that can store any value from either
+/// Returns a decimal [`DataType`] variant that can store any value from either
/// `lhs_decimal_type` and `rhs_decimal_type`
///
/// The result decimal type is `(max(s1, s2) + max(p1-s1, p2-s2), max(s1,
s2))`.
@@ -1209,14 +1307,14 @@ fn numerical_coercion(lhs_type: &DataType, rhs_type:
&DataType) -> Option<DataTy
}
fn create_decimal32_type(precision: u8, scale: i8) -> DataType {
- DataType::Decimal128(
+ DataType::Decimal32(
DECIMAL32_MAX_PRECISION.min(precision),
DECIMAL32_MAX_SCALE.min(scale),
)
}
fn create_decimal64_type(precision: u8, scale: i8) -> DataType {
- DataType::Decimal128(
+ DataType::Decimal64(
DECIMAL64_MAX_PRECISION.min(precision),
DECIMAL64_MAX_SCALE.min(scale),
)
diff --git
a/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs
b/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs
index e6238ba007..bfedcf0713 100644
--- a/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs
+++ b/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs
@@ -291,3 +291,133 @@ fn test_coercion_arithmetic_decimal() -> Result<()> {
Ok(())
}
+
+#[test]
+fn test_coercion_arithmetic_decimal_cross_variant() -> Result<()> {
+ let test_cases = [
+ (
+ DataType::Decimal32(5, 2),
+ DataType::Decimal64(10, 3),
+ DataType::Decimal64(10, 3),
+ DataType::Decimal64(10, 3),
+ ),
+ (
+ DataType::Decimal32(7, 1),
+ DataType::Decimal128(15, 4),
+ DataType::Decimal128(15, 4),
+ DataType::Decimal128(15, 4),
+ ),
+ (
+ DataType::Decimal32(9, 0),
+ DataType::Decimal256(20, 5),
+ DataType::Decimal256(20, 5),
+ DataType::Decimal256(20, 5),
+ ),
+ (
+ DataType::Decimal64(12, 3),
+ DataType::Decimal128(18, 2),
+ DataType::Decimal128(19, 3),
+ DataType::Decimal128(19, 3),
+ ),
+ (
+ DataType::Decimal64(15, 4),
+ DataType::Decimal256(25, 6),
+ DataType::Decimal256(25, 6),
+ DataType::Decimal256(25, 6),
+ ),
+ (
+ DataType::Decimal128(20, 5),
+ DataType::Decimal256(30, 8),
+ DataType::Decimal256(30, 8),
+ DataType::Decimal256(30, 8),
+ ),
+ // Reverse order cases
+ (
+ DataType::Decimal64(10, 3),
+ DataType::Decimal32(5, 2),
+ DataType::Decimal64(10, 3),
+ DataType::Decimal64(10, 3),
+ ),
+ (
+ DataType::Decimal128(15, 4),
+ DataType::Decimal32(7, 1),
+ DataType::Decimal128(15, 4),
+ DataType::Decimal128(15, 4),
+ ),
+ (
+ DataType::Decimal256(20, 5),
+ DataType::Decimal32(9, 0),
+ DataType::Decimal256(20, 5),
+ DataType::Decimal256(20, 5),
+ ),
+ (
+ DataType::Decimal128(18, 2),
+ DataType::Decimal64(12, 3),
+ DataType::Decimal128(19, 3),
+ DataType::Decimal128(19, 3),
+ ),
+ (
+ DataType::Decimal256(25, 6),
+ DataType::Decimal64(15, 4),
+ DataType::Decimal256(25, 6),
+ DataType::Decimal256(25, 6),
+ ),
+ (
+ DataType::Decimal256(30, 8),
+ DataType::Decimal128(20, 5),
+ DataType::Decimal256(30, 8),
+ DataType::Decimal256(30, 8),
+ ),
+ ];
+
+ for (lhs_type, rhs_type, expected_lhs_type, expected_rhs_type) in
test_cases {
+ test_math_decimal_coercion_rule(
+ lhs_type,
+ rhs_type,
+ expected_lhs_type,
+ expected_rhs_type,
+ );
+ }
+
+ Ok(())
+}
+
+#[test]
+fn test_decimal_precision_overflow_cross_variant() -> Result<()> {
+ // s = max(0, 1) = 1, range = max(76-0, 38-1) = 76, required_precision =
76 + 1 = 77 (overflow)
+ let result = get_wider_decimal_type_cross_variant(
+ &DataType::Decimal256(76, 0),
+ &DataType::Decimal128(38, 1),
+ );
+ assert!(result.is_none());
+
+ // s = max(0, 10) = 10, range = max(9-0, 18-10) = 9, required_precision =
9 + 10 = 19 (overflow > 18)
+ let result = get_wider_decimal_type_cross_variant(
+ &DataType::Decimal32(9, 0),
+ &DataType::Decimal64(18, 10),
+ );
+ assert!(result.is_none());
+
+ // s = max(5, 26) = 26, range = max(18-5, 38-26) = 13, required_precision
= 13 + 26 = 39 (overflow > 38)
+ let result = get_wider_decimal_type_cross_variant(
+ &DataType::Decimal64(18, 5),
+ &DataType::Decimal128(38, 26),
+ );
+ assert!(result.is_none());
+
+ // s = max(10, 49) = 49, range = max(38-10, 76-49) = 28,
required_precision = 28 + 49 = 77 (overflow > 76)
+ let result = get_wider_decimal_type_cross_variant(
+ &DataType::Decimal128(38, 10),
+ &DataType::Decimal256(76, 49),
+ );
+ assert!(result.is_none());
+
+ // s = max(2, 3) = 3, range = max(5-2, 10-3) = 7, required_precision = 7 +
3 = 10 (valid <= 18)
+ let result = get_wider_decimal_type_cross_variant(
+ &DataType::Decimal32(5, 2),
+ &DataType::Decimal64(10, 3),
+ );
+ assert!(result.is_some());
+
+ Ok(())
+}
diff --git
a/datafusion/expr-common/src/type_coercion/binary/tests/comparison.rs
b/datafusion/expr-common/src/type_coercion/binary/tests/comparison.rs
index 208edae4ff..5401264e43 100644
--- a/datafusion/expr-common/src/type_coercion/binary/tests/comparison.rs
+++ b/datafusion/expr-common/src/type_coercion/binary/tests/comparison.rs
@@ -697,3 +697,91 @@ fn test_map_coercion() -> Result<()> {
);
Ok(())
}
+
+#[test]
+fn test_decimal_cross_variant_comparison_coercion() -> Result<()> {
+ let test_cases = [
+ // (lhs, rhs, expected_result)
+ (
+ DataType::Decimal32(5, 2),
+ DataType::Decimal64(10, 3),
+ DataType::Decimal64(10, 3),
+ ),
+ (
+ DataType::Decimal32(7, 1),
+ DataType::Decimal128(15, 4),
+ DataType::Decimal128(15, 4),
+ ),
+ (
+ DataType::Decimal32(9, 0),
+ DataType::Decimal256(20, 5),
+ DataType::Decimal256(20, 5),
+ ),
+ (
+ DataType::Decimal64(12, 3),
+ DataType::Decimal128(18, 2),
+ DataType::Decimal128(19, 3),
+ ),
+ (
+ DataType::Decimal64(15, 4),
+ DataType::Decimal256(25, 6),
+ DataType::Decimal256(25, 6),
+ ),
+ (
+ DataType::Decimal128(20, 5),
+ DataType::Decimal256(30, 8),
+ DataType::Decimal256(30, 8),
+ ),
+ // Reverse order cases
+ (
+ DataType::Decimal64(10, 3),
+ DataType::Decimal32(5, 2),
+ DataType::Decimal64(10, 3),
+ ),
+ (
+ DataType::Decimal128(15, 4),
+ DataType::Decimal32(7, 1),
+ DataType::Decimal128(15, 4),
+ ),
+ (
+ DataType::Decimal256(20, 5),
+ DataType::Decimal32(9, 0),
+ DataType::Decimal256(20, 5),
+ ),
+ (
+ DataType::Decimal128(18, 2),
+ DataType::Decimal64(12, 3),
+ DataType::Decimal128(19, 3),
+ ),
+ (
+ DataType::Decimal256(25, 6),
+ DataType::Decimal64(15, 4),
+ DataType::Decimal256(25, 6),
+ ),
+ (
+ DataType::Decimal256(30, 8),
+ DataType::Decimal128(20, 5),
+ DataType::Decimal256(30, 8),
+ ),
+ ];
+
+ let comparison_op_types = [
+ Operator::NotEq,
+ Operator::Eq,
+ Operator::Gt,
+ Operator::GtEq,
+ Operator::Lt,
+ Operator::LtEq,
+ ];
+
+ for (lhs_type, rhs_type, expected_type) in test_cases {
+ for op in comparison_op_types {
+ let (lhs, rhs) =
+ BinaryTypeCoercer::new(&lhs_type, &op,
&rhs_type).get_input_types()?;
+ assert_eq!(expected_type, lhs, "Coercion of type {lhs_type:?} with
{rhs_type:?} resulted in unexpected type: {lhs:?}");
+ assert_eq!(expected_type, rhs, "Coercion of type {rhs_type:?} with
{lhs_type:?} resulted in unexpected type: {rhs:?}");
+ }
+ }
+
+ Ok(())
+}
diff --git a/datafusion/functions/src/math/abs.rs
b/datafusion/functions/src/math/abs.rs
index 8af8e4c2c8..040f13c014 100644
--- a/datafusion/functions/src/math/abs.rs
+++ b/datafusion/functions/src/math/abs.rs
@@ -21,8 +21,8 @@ use std::any::Any;
use std::sync::Arc;
use arrow::array::{
- ArrayRef, Decimal128Array, Decimal256Array, Float32Array, Float64Array,
Int16Array,
- Int32Array, Int64Array, Int8Array,
+ ArrayRef, Decimal128Array, Decimal256Array, Decimal32Array, Decimal64Array,
+ Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array,
};
use arrow::datatypes::DataType;
use arrow::error::ArrowError;
@@ -98,6 +98,8 @@ fn create_abs_function(input_data_type: &DataType) ->
Result<MathArrayFunction>
| DataType::UInt64 => Ok(|input: &ArrayRef| Ok(Arc::clone(input))),
// Decimal types
+ DataType::Decimal32(_, _) =>
Ok(make_decimal_abs_function!(Decimal32Array)),
+ DataType::Decimal64(_, _) =>
Ok(make_decimal_abs_function!(Decimal64Array)),
DataType::Decimal128(_, _) =>
Ok(make_decimal_abs_function!(Decimal128Array)),
DataType::Decimal256(_, _) =>
Ok(make_decimal_abs_function!(Decimal256Array)),
@@ -162,6 +164,12 @@ impl ScalarUDFImpl for AbsFunc {
DataType::UInt16 => Ok(DataType::UInt16),
DataType::UInt32 => Ok(DataType::UInt32),
DataType::UInt64 => Ok(DataType::UInt64),
+ DataType::Decimal32(precision, scale) => {
+ Ok(DataType::Decimal32(precision, scale))
+ }
+ DataType::Decimal64(precision, scale) => {
+ Ok(DataType::Decimal64(precision, scale))
+ }
DataType::Decimal128(precision, scale) => {
Ok(DataType::Decimal128(precision, scale))
}
diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs
b/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs
index d28a9bad17..f16ef24fd1 100644
--- a/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs
+++ b/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs
@@ -1922,6 +1922,8 @@ fn compare_join_arrays(
DataType::BinaryView => compare_value!(BinaryViewArray),
DataType::FixedSizeBinary(_) =>
compare_value!(FixedSizeBinaryArray),
DataType::LargeBinary => compare_value!(LargeBinaryArray),
+ DataType::Decimal32(..) => compare_value!(Decimal32Array),
+ DataType::Decimal64(..) => compare_value!(Decimal64Array),
DataType::Decimal128(..) => compare_value!(Decimal128Array),
DataType::Timestamp(time_unit, None) => match time_unit {
TimeUnit::Second => compare_value!(TimestampSecondArray),
@@ -1994,7 +1996,10 @@ fn is_join_arrays_equal(
DataType::BinaryView => compare_value!(BinaryViewArray),
DataType::FixedSizeBinary(_) =>
compare_value!(FixedSizeBinaryArray),
DataType::LargeBinary => compare_value!(LargeBinaryArray),
+ DataType::Decimal32(..) => compare_value!(Decimal32Array),
+ DataType::Decimal64(..) => compare_value!(Decimal64Array),
DataType::Decimal128(..) => compare_value!(Decimal128Array),
+ DataType::Decimal256(..) => compare_value!(Decimal256Array),
DataType::Timestamp(time_unit, None) => match time_unit {
TimeUnit::Second => compare_value!(TimestampSecondArray),
TimeUnit::Millisecond =>
compare_value!(TimestampMillisecondArray),
diff --git a/datafusion/spark/src/function/math/width_bucket.rs
b/datafusion/spark/src/function/math/width_bucket.rs
index 24f8fe6b24..45a0d843b7 100644
--- a/datafusion/spark/src/function/math/width_bucket.rs
+++ b/datafusion/spark/src/function/math/width_bucket.rs
@@ -32,6 +32,7 @@ use datafusion_common::cast::{
};
use datafusion_common::{exec_err, Result};
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
+use datafusion_expr::type_coercion::is_signed_numeric;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl,
Signature};
use datafusion_functions::utils::make_scalar_function;
@@ -93,16 +94,11 @@ impl ScalarUDFImpl for SparkWidthBucket {
let (v, lo, hi, n) = (&types[0], &types[1], &types[2], &types[3]);
- let is_num = |t: &DataType| {
- matches!(
- t,
- Int8 | Int16 | Int32 | Int64 | Float32 | Float64 |
Decimal128(_, _)
- )
- };
-
match (v, lo, hi, n) {
(a, b, c, &(Int8 | Int16 | Int32 | Int64))
- if is_num(a) && is_num(b) && is_num(c) =>
+ if is_signed_numeric(a)
+ && is_signed_numeric(b)
+ && is_signed_numeric(c) =>
{
Ok(vec![Float64, Float64, Float64, Int32])
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]