This is an automated email from the ASF dual-hosted git repository.
github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 44cd972546 Support Decimal32/64 types (#17501)
44cd972546 is described below
commit 44cd972546eaf8e946fa9d4576abade7daf7efab
Author: Adam Gutglick <[email protected]>
AuthorDate: Fri Sep 19 20:06:45 2025 +0100
Support Decimal32/64 types (#17501)
* Support Decimal32/64 types
* Fix bugs, tests, handle more aggregate functions and schema
* Fill out more parts in expr,common and expr-common
* Some stragglers and overlooked corners
* Actually commit the avg_distinct support
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion/common/src/cast.rs | 17 +-
datafusion/common/src/dfschema.rs | 39 ++
datafusion/common/src/scalar/mod.rs | 412 +++++++++++++++++++--
datafusion/common/src/types/native.rs | 11 +-
.../tests/fuzz_cases/record_batch_generator.rs | 55 ++-
datafusion/expr-common/src/casts.rs | 68 +++-
.../expr-common/src/type_coercion/aggregates.rs | 45 ++-
datafusion/expr-common/src/type_coercion/binary.rs | 124 ++++++-
.../src/type_coercion/binary/tests/arithmetic.rs | 57 ++-
datafusion/expr/src/logical_plan/builder.rs | 17 +-
datafusion/expr/src/test/function_stub.rs | 20 +-
datafusion/expr/src/type_coercion/functions.rs | 5 +-
datafusion/expr/src/type_coercion/mod.rs | 10 +-
.../src/aggregate/avg_distinct/decimal.rs | 96 ++++-
.../functions-aggregate-common/src/min_max.rs | 64 +++-
datafusion/functions-aggregate/src/average.rs | 98 ++++-
datafusion/functions-aggregate/src/first_last.rs | 20 +-
datafusion/functions-aggregate/src/median.rs | 8 +-
datafusion/functions-aggregate/src/min_max.rs | 24 +-
datafusion/functions-aggregate/src/sum.rs | 31 +-
datafusion/proto-common/src/from_proto/mod.rs | 13 +-
datafusion/proto-common/src/to_proto/mod.rs | 36 ++
datafusion/sql/src/unparser/expr.rs | 42 ++-
test-utils/src/array_gen/random_data.rs | 14 +-
24 files changed, 1188 insertions(+), 138 deletions(-)
diff --git a/datafusion/common/src/cast.rs b/datafusion/common/src/cast.rs
index 68b753a667..791c2d16ae 100644
--- a/datafusion/common/src/cast.rs
+++ b/datafusion/common/src/cast.rs
@@ -22,9 +22,10 @@
use crate::{downcast_value, Result};
use arrow::array::{
- BinaryViewArray, DurationMicrosecondArray, DurationMillisecondArray,
- DurationNanosecondArray, DurationSecondArray, Float16Array, Int16Array,
Int8Array,
- LargeBinaryArray, LargeStringArray, StringViewArray, UInt16Array,
+ BinaryViewArray, Decimal32Array, Decimal64Array, DurationMicrosecondArray,
+ DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray,
Float16Array,
+ Int16Array, Int8Array, LargeBinaryArray, LargeStringArray, StringViewArray,
+ UInt16Array,
};
use arrow::{
array::{
@@ -97,6 +98,16 @@ pub fn as_uint64_array(array: &dyn Array) ->
Result<&UInt64Array> {
Ok(downcast_value!(array, UInt64Array))
}
+// Downcast Array to Decimal32Array
+pub fn as_decimal32_array(array: &dyn Array) -> Result<&Decimal32Array> {
+ Ok(downcast_value!(array, Decimal32Array))
+}
+
+// Downcast Array to Decimal64Array
+pub fn as_decimal64_array(array: &dyn Array) -> Result<&Decimal64Array> {
+ Ok(downcast_value!(array, Decimal64Array))
+}
+
// Downcast Array to Decimal128Array
pub fn as_decimal128_array(array: &dyn Array) -> Result<&Decimal128Array> {
Ok(downcast_value!(array, Decimal128Array))
diff --git a/datafusion/common/src/dfschema.rs
b/datafusion/common/src/dfschema.rs
index d69810db19..bc4a9658e9 100644
--- a/datafusion/common/src/dfschema.rs
+++ b/datafusion/common/src/dfschema.rs
@@ -798,6 +798,14 @@ impl DFSchema {
.zip(iter2)
.all(|((t1, f1), (t2, f2))| t1 == t2 &&
Self::field_is_semantically_equal(f1, f2))
}
+ (
+ DataType::Decimal32(_l_precision, _l_scale),
+ DataType::Decimal32(_r_precision, _r_scale),
+ ) => true,
+ (
+ DataType::Decimal64(_l_precision, _l_scale),
+ DataType::Decimal64(_r_precision, _r_scale),
+ ) => true,
(
DataType::Decimal128(_l_precision, _l_scale),
DataType::Decimal128(_r_precision, _r_scale),
@@ -1056,6 +1064,12 @@ fn format_simple_data_type(data_type: &DataType) ->
String {
DataType::Dictionary(_, value_type) => {
format_simple_data_type(value_type.as_ref())
}
+ DataType::Decimal32(precision, scale) => {
+ format!("decimal32({precision}, {scale})")
+ }
+ DataType::Decimal64(precision, scale) => {
+ format!("decimal64({precision}, {scale})")
+ }
DataType::Decimal128(precision, scale) => {
format!("decimal128({precision}, {scale})")
}
@@ -1794,6 +1808,27 @@ mod tests {
&DataType::Int16
));
+ // Succeeds if decimal precision and scale are different
+ assert!(DFSchema::datatype_is_semantically_equal(
+ &DataType::Decimal32(1, 2),
+ &DataType::Decimal32(2, 1),
+ ));
+
+ assert!(DFSchema::datatype_is_semantically_equal(
+ &DataType::Decimal64(1, 2),
+ &DataType::Decimal64(2, 1),
+ ));
+
+ assert!(DFSchema::datatype_is_semantically_equal(
+ &DataType::Decimal128(1, 2),
+ &DataType::Decimal128(2, 1),
+ ));
+
+ assert!(DFSchema::datatype_is_semantically_equal(
+ &DataType::Decimal256(1, 2),
+ &DataType::Decimal256(2, 1),
+ ));
+
// Test lists
// Succeeds if both have the same element type, disregards names and
nullability
@@ -2377,6 +2412,8 @@ mod tests {
),
false,
),
+ Field::new("decimal32", DataType::Decimal32(9, 4), true),
+ Field::new("decimal64", DataType::Decimal64(9, 4), true),
Field::new("decimal128", DataType::Decimal128(18, 4), true),
Field::new("decimal256", DataType::Decimal256(38, 10), false),
Field::new("date32", DataType::Date32, true),
@@ -2408,6 +2445,8 @@ mod tests {
|-- fixed_size_binary: fixed_size_binary (nullable = true)
|-- fixed_size_list: fixed size list (nullable = false)
| |-- item: int32 (nullable = true)
+ |-- decimal32: decimal32(9, 4) (nullable = true)
+ |-- decimal64: decimal64(9, 4) (nullable = true)
|-- decimal128: decimal128(18, 4) (nullable = true)
|-- decimal256: decimal256(38, 10) (nullable = false)
|-- date32: date32 (nullable = true)
diff --git a/datafusion/common/src/scalar/mod.rs
b/datafusion/common/src/scalar/mod.rs
index 67287546f4..2a0b8b6ec9 100644
--- a/datafusion/common/src/scalar/mod.rs
+++ b/datafusion/common/src/scalar/mod.rs
@@ -35,13 +35,14 @@ use std::sync::Arc;
use crate::cast::{
as_binary_array, as_binary_view_array, as_boolean_array, as_date32_array,
- as_date64_array, as_decimal128_array, as_decimal256_array,
as_dictionary_array,
- as_duration_microsecond_array, as_duration_millisecond_array,
- as_duration_nanosecond_array, as_duration_second_array,
as_fixed_size_binary_array,
- as_fixed_size_list_array, as_float16_array, as_float32_array,
as_float64_array,
- as_int16_array, as_int32_array, as_int64_array, as_int8_array,
as_interval_dt_array,
- as_interval_mdn_array, as_interval_ym_array, as_large_binary_array,
- as_large_list_array, as_large_string_array, as_string_array,
as_string_view_array,
+ as_date64_array, as_decimal128_array, as_decimal256_array,
as_decimal32_array,
+ as_decimal64_array, as_dictionary_array, as_duration_microsecond_array,
+ as_duration_millisecond_array, as_duration_nanosecond_array,
+ as_duration_second_array, as_fixed_size_binary_array,
as_fixed_size_list_array,
+ as_float16_array, as_float32_array, as_float64_array, as_int16_array,
as_int32_array,
+ as_int64_array, as_int8_array, as_interval_dt_array, as_interval_mdn_array,
+ as_interval_ym_array, as_large_binary_array, as_large_list_array,
+ as_large_string_array, as_string_array, as_string_view_array,
as_time32_millisecond_array, as_time32_second_array,
as_time64_microsecond_array,
as_time64_nanosecond_array, as_timestamp_microsecond_array,
as_timestamp_millisecond_array, as_timestamp_nanosecond_array,
@@ -56,17 +57,17 @@ use crate::{_internal_datafusion_err, arrow_datafusion_err};
use arrow::array::{
new_empty_array, new_null_array, Array, ArrayData, ArrayRef,
ArrowNativeTypeOp,
ArrowPrimitiveType, AsArray, BinaryArray, BinaryViewArray, BooleanArray,
Date32Array,
- Date64Array, Decimal128Array, Decimal256Array, DictionaryArray,
- DurationMicrosecondArray, DurationMillisecondArray,
DurationNanosecondArray,
- DurationSecondArray, FixedSizeBinaryArray, FixedSizeListArray,
Float16Array,
- Float32Array, Float64Array, GenericListArray, Int16Array, Int32Array,
Int64Array,
- Int8Array, IntervalDayTimeArray, IntervalMonthDayNanoArray,
IntervalYearMonthArray,
- LargeBinaryArray, LargeListArray, LargeStringArray, ListArray, MapArray,
- MutableArrayData, PrimitiveArray, Scalar, StringArray, StringViewArray,
StructArray,
- Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray,
- Time64NanosecondArray, TimestampMicrosecondArray,
TimestampMillisecondArray,
- TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array,
- UInt64Array, UInt8Array, UnionArray,
+ Date64Array, Decimal128Array, Decimal256Array, Decimal32Array,
Decimal64Array,
+ DictionaryArray, DurationMicrosecondArray, DurationMillisecondArray,
+ DurationNanosecondArray, DurationSecondArray, FixedSizeBinaryArray,
+ FixedSizeListArray, Float16Array, Float32Array, Float64Array,
GenericListArray,
+ Int16Array, Int32Array, Int64Array, Int8Array, IntervalDayTimeArray,
+ IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray,
LargeListArray,
+ LargeStringArray, ListArray, MapArray, MutableArrayData, PrimitiveArray,
Scalar,
+ StringArray, StringViewArray, StructArray, Time32MillisecondArray,
Time32SecondArray,
+ Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray,
+ TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray,
+ UInt16Array, UInt32Array, UInt64Array, UInt8Array, UnionArray,
};
use arrow::buffer::ScalarBuffer;
use arrow::compute::kernels::cast::{cast_with_options, CastOptions};
@@ -75,12 +76,13 @@ use arrow::compute::kernels::numeric::{
};
use arrow::datatypes::{
i256, validate_decimal_precision_and_scale, ArrowDictionaryKeyType,
ArrowNativeType,
- ArrowTimestampType, DataType, Date32Type, Decimal128Type, Decimal256Type,
Field,
- Float32Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTime,
- IntervalDayTimeType, IntervalMonthDayNano, IntervalMonthDayNanoType,
IntervalUnit,
- IntervalYearMonthType, TimeUnit, TimestampMicrosecondType,
TimestampMillisecondType,
- TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type,
UInt64Type,
- UInt8Type, UnionFields, UnionMode, DECIMAL128_MAX_PRECISION,
+ ArrowTimestampType, DataType, Date32Type, Decimal128Type, Decimal256Type,
+ Decimal32Type, Decimal64Type, Field, Float32Type, Int16Type, Int32Type,
Int64Type,
+ Int8Type, IntervalDayTime, IntervalDayTimeType, IntervalMonthDayNano,
+ IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, TimeUnit,
+ TimestampMicrosecondType, TimestampMillisecondType,
TimestampNanosecondType,
+ TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
UnionFields,
+ UnionMode, DECIMAL128_MAX_PRECISION,
};
use arrow::util::display::{array_value_to_string, ArrayFormatter,
FormatOptions};
use cache::{get_or_create_cached_key_array, get_or_create_cached_null_array};
@@ -231,6 +233,10 @@ pub enum ScalarValue {
Float32(Option<f32>),
/// 64bit float
Float64(Option<f64>),
+ /// 32bit decimal, using the i32 to represent the decimal, precision scale
+ Decimal32(Option<i32>, u8, i8),
+ /// 64bit decimal, using the i64 to represent the decimal, precision scale
+ Decimal64(Option<i64>, u8, i8),
/// 128bit decimal, using the i128 to represent the decimal, precision
scale
Decimal128(Option<i128>, u8, i8),
/// 256bit decimal, using the i256 to represent the decimal, precision
scale
@@ -340,6 +346,14 @@ impl PartialEq for ScalarValue {
// any newly added enum variant will require editing this list
// or else face a compile error
match (self, other) {
+ (Decimal32(v1, p1, s1), Decimal32(v2, p2, s2)) => {
+ v1.eq(v2) && p1.eq(p2) && s1.eq(s2)
+ }
+ (Decimal32(_, _, _), _) => false,
+ (Decimal64(v1, p1, s1), Decimal64(v2, p2, s2)) => {
+ v1.eq(v2) && p1.eq(p2) && s1.eq(s2)
+ }
+ (Decimal64(_, _, _), _) => false,
(Decimal128(v1, p1, s1), Decimal128(v2, p2, s2)) => {
v1.eq(v2) && p1.eq(p2) && s1.eq(s2)
}
@@ -459,6 +473,24 @@ impl PartialOrd for ScalarValue {
// any newly added enum variant will require editing this list
// or else face a compile error
match (self, other) {
+ (Decimal32(v1, p1, s1), Decimal32(v2, p2, s2)) => {
+ if p1.eq(p2) && s1.eq(s2) {
+ v1.partial_cmp(v2)
+ } else {
+ // Two decimal values can be compared if they have the
same precision and scale.
+ None
+ }
+ }
+ (Decimal32(_, _, _), _) => None,
+ (Decimal64(v1, p1, s1), Decimal64(v2, p2, s2)) => {
+ if p1.eq(p2) && s1.eq(s2) {
+ v1.partial_cmp(v2)
+ } else {
+ // Two decimal values can be compared if they have the
same precision and scale.
+ None
+ }
+ }
+ (Decimal64(_, _, _), _) => None,
(Decimal128(v1, p1, s1), Decimal128(v2, p2, s2)) => {
if p1.eq(p2) && s1.eq(s2) {
v1.partial_cmp(v2)
@@ -760,6 +792,16 @@ impl Hash for ScalarValue {
fn hash<H: Hasher>(&self, state: &mut H) {
use ScalarValue::*;
match self {
+ Decimal32(v, p, s) => {
+ v.hash(state);
+ p.hash(state);
+ s.hash(state)
+ }
+ Decimal64(v, p, s) => {
+ v.hash(state);
+ p.hash(state);
+ s.hash(state)
+ }
Decimal128(v, p, s) => {
v.hash(state);
p.hash(state);
@@ -1045,6 +1087,12 @@ impl ScalarValue {
DataType::UInt16 => ScalarValue::UInt16(None),
DataType::UInt32 => ScalarValue::UInt32(None),
DataType::UInt64 => ScalarValue::UInt64(None),
+ DataType::Decimal32(precision, scale) => {
+ ScalarValue::Decimal32(None, *precision, *scale)
+ }
+ DataType::Decimal64(precision, scale) => {
+ ScalarValue::Decimal64(None, *precision, *scale)
+ }
DataType::Decimal128(precision, scale) => {
ScalarValue::Decimal128(None, *precision, *scale)
}
@@ -1521,6 +1569,34 @@ impl ScalarValue {
DataType::Float16 =>
ScalarValue::Float16(Some(f16::from_f32(1.0))),
DataType::Float32 => ScalarValue::Float32(Some(1.0)),
DataType::Float64 => ScalarValue::Float64(Some(1.0)),
+ DataType::Decimal32(precision, scale) => {
+ validate_decimal_precision_and_scale::<Decimal32Type>(
+ *precision, *scale,
+ )?;
+ if *scale < 0 {
+ return _internal_err!("Negative scale is not supported");
+ }
+ match 10_i32.checked_pow(*scale as u32) {
+ Some(value) => {
+ ScalarValue::Decimal32(Some(value), *precision, *scale)
+ }
+ None => return _internal_err!("Unsupported scale {scale}"),
+ }
+ }
+ DataType::Decimal64(precision, scale) => {
+ validate_decimal_precision_and_scale::<Decimal64Type>(
+ *precision, *scale,
+ )?;
+ if *scale < 0 {
+ return _internal_err!("Negative scale is not supported");
+ }
+ match i64::from(10).checked_pow(*scale as u32) {
+ Some(value) => {
+ ScalarValue::Decimal64(Some(value), *precision, *scale)
+ }
+ None => return _internal_err!("Unsupported scale {scale}"),
+ }
+ }
DataType::Decimal128(precision, scale) => {
validate_decimal_precision_and_scale::<Decimal128Type>(
*precision, *scale,
@@ -1567,6 +1643,34 @@ impl ScalarValue {
DataType::Float16 =>
ScalarValue::Float16(Some(f16::from_f32(-1.0))),
DataType::Float32 => ScalarValue::Float32(Some(-1.0)),
DataType::Float64 => ScalarValue::Float64(Some(-1.0)),
+ DataType::Decimal32(precision, scale) => {
+ validate_decimal_precision_and_scale::<Decimal32Type>(
+ *precision, *scale,
+ )?;
+ if *scale < 0 {
+ return _internal_err!("Negative scale is not supported");
+ }
+ match 10_i32.checked_pow(*scale as u32) {
+ Some(value) => {
+ ScalarValue::Decimal32(Some(-value), *precision,
*scale)
+ }
+ None => return _internal_err!("Unsupported scale {scale}"),
+ }
+ }
+ DataType::Decimal64(precision, scale) => {
+ validate_decimal_precision_and_scale::<Decimal64Type>(
+ *precision, *scale,
+ )?;
+ if *scale < 0 {
+ return _internal_err!("Negative scale is not supported");
+ }
+ match i64::from(10).checked_pow(*scale as u32) {
+ Some(value) => {
+ ScalarValue::Decimal64(Some(-value), *precision,
*scale)
+ }
+ None => return _internal_err!("Unsupported scale {scale}"),
+ }
+ }
DataType::Decimal128(precision, scale) => {
validate_decimal_precision_and_scale::<Decimal128Type>(
*precision, *scale,
@@ -1616,6 +1720,38 @@ impl ScalarValue {
DataType::Float16 =>
ScalarValue::Float16(Some(f16::from_f32(10.0))),
DataType::Float32 => ScalarValue::Float32(Some(10.0)),
DataType::Float64 => ScalarValue::Float64(Some(10.0)),
+ DataType::Decimal32(precision, scale) => {
+ if let Err(err) =
validate_decimal_precision_and_scale::<Decimal32Type>(
+ *precision, *scale,
+ ) {
+ return _internal_err!("Invalid precision and scale {err}");
+ }
+ if *scale <= 0 {
+ return _internal_err!("Negative scale is not supported");
+ }
+ match 10_i32.checked_pow((*scale + 1) as u32) {
+ Some(value) => {
+ ScalarValue::Decimal32(Some(value), *precision, *scale)
+ }
+ None => return _internal_err!("Unsupported scale {scale}"),
+ }
+ }
+ DataType::Decimal64(precision, scale) => {
+ if let Err(err) =
validate_decimal_precision_and_scale::<Decimal64Type>(
+ *precision, *scale,
+ ) {
+ return _internal_err!("Invalid precision and scale {err}");
+ }
+ if *scale <= 0 {
+ return _internal_err!("Negative scale is not supported");
+ }
+ match i64::from(10).checked_pow((*scale + 1) as u32) {
+ Some(value) => {
+ ScalarValue::Decimal64(Some(value), *precision, *scale)
+ }
+ None => return _internal_err!("Unsupported scale {scale}"),
+ }
+ }
DataType::Decimal128(precision, scale) => {
if let Err(err) =
validate_decimal_precision_and_scale::<Decimal128Type>(
*precision, *scale,
@@ -1668,6 +1804,12 @@ impl ScalarValue {
ScalarValue::Int16(_) => DataType::Int16,
ScalarValue::Int32(_) => DataType::Int32,
ScalarValue::Int64(_) => DataType::Int64,
+ ScalarValue::Decimal32(_, precision, scale) => {
+ DataType::Decimal32(*precision, *scale)
+ }
+ ScalarValue::Decimal64(_, precision, scale) => {
+ DataType::Decimal64(*precision, *scale)
+ }
ScalarValue::Decimal128(_, precision, scale) => {
DataType::Decimal128(*precision, *scale)
}
@@ -1790,6 +1932,24 @@ impl ScalarValue {
);
Ok(ScalarValue::IntervalMonthDayNano(Some(val)))
}
+ ScalarValue::Decimal32(Some(v), precision, scale) => {
+ Ok(ScalarValue::Decimal32(
+ Some(neg_checked_with_ctx(*v, || {
+ format!("In negation of Decimal32({v}, {precision},
{scale})")
+ })?),
+ *precision,
+ *scale,
+ ))
+ }
+ ScalarValue::Decimal64(Some(v), precision, scale) => {
+ Ok(ScalarValue::Decimal64(
+ Some(neg_checked_with_ctx(*v, || {
+ format!("In negation of Decimal64({v}, {precision},
{scale})")
+ })?),
+ *precision,
+ *scale,
+ ))
+ }
ScalarValue::Decimal128(Some(v), precision, scale) => {
Ok(ScalarValue::Decimal128(
Some(neg_checked_with_ctx(*v, || {
@@ -1941,6 +2101,8 @@ impl ScalarValue {
ScalarValue::Float16(v) => v.is_none(),
ScalarValue::Float32(v) => v.is_none(),
ScalarValue::Float64(v) => v.is_none(),
+ ScalarValue::Decimal32(v, _, _) => v.is_none(),
+ ScalarValue::Decimal64(v, _, _) => v.is_none(),
ScalarValue::Decimal128(v, _, _) => v.is_none(),
ScalarValue::Decimal256(v, _, _) => v.is_none(),
ScalarValue::Int8(v) => v.is_none(),
@@ -2196,19 +2358,19 @@ impl ScalarValue {
}
let array: ArrayRef = match &data_type {
- DataType::Decimal32(_precision, _scale) => {
- return _not_impl_err!(
- "Decimal32 not supported in ScalarValue::iter_to_array"
- );
+ DataType::Decimal32(precision, scale) => {
+ let decimal_array =
+ ScalarValue::iter_to_decimal32_array(scalars, *precision,
*scale)?;
+ Arc::new(decimal_array)
}
- DataType::Decimal64(_precision, _scale) => {
- return _not_impl_err!(
- "Decimal64 not supported in ScalarValue::iter_to_array"
- );
+ DataType::Decimal64(precision, scale) => {
+ let decimal_array =
+ ScalarValue::iter_to_decimal64_array(scalars, *precision,
*scale)?;
+ Arc::new(decimal_array)
}
DataType::Decimal128(precision, scale) => {
let decimal_array =
- ScalarValue::iter_to_decimal_array(scalars, *precision,
*scale)?;
+ ScalarValue::iter_to_decimal128_array(scalars, *precision,
*scale)?;
Arc::new(decimal_array)
}
DataType::Decimal256(precision, scale) => {
@@ -2417,7 +2579,43 @@ impl ScalarValue {
Ok(new_null_array(&DataType::Null, length))
}
- fn iter_to_decimal_array(
+ fn iter_to_decimal32_array(
+ scalars: impl IntoIterator<Item = ScalarValue>,
+ precision: u8,
+ scale: i8,
+ ) -> Result<Decimal32Array> {
+ let array = scalars
+ .into_iter()
+ .map(|element: ScalarValue| match element {
+ ScalarValue::Decimal32(v1, _, _) => Ok(v1),
+ s => {
+ _internal_err!("Expected ScalarValue::Null element.
Received {s:?}")
+ }
+ })
+ .collect::<Result<Decimal32Array>>()?
+ .with_precision_and_scale(precision, scale)?;
+ Ok(array)
+ }
+
+ fn iter_to_decimal64_array(
+ scalars: impl IntoIterator<Item = ScalarValue>,
+ precision: u8,
+ scale: i8,
+ ) -> Result<Decimal64Array> {
+ let array = scalars
+ .into_iter()
+ .map(|element: ScalarValue| match element {
+ ScalarValue::Decimal64(v1, _, _) => Ok(v1),
+ s => {
+ _internal_err!("Expected ScalarValue::Null element.
Received {s:?}")
+ }
+ })
+ .collect::<Result<Decimal64Array>>()?
+ .with_precision_and_scale(precision, scale)?;
+ Ok(array)
+ }
+
+ fn iter_to_decimal128_array(
scalars: impl IntoIterator<Item = ScalarValue>,
precision: u8,
scale: i8,
@@ -2455,7 +2653,43 @@ impl ScalarValue {
Ok(array)
}
- fn build_decimal_array(
+ fn build_decimal32_array(
+ value: Option<i32>,
+ precision: u8,
+ scale: i8,
+ size: usize,
+ ) -> Result<Decimal32Array> {
+ Ok(match value {
+ Some(val) => Decimal32Array::from(vec![val; size])
+ .with_precision_and_scale(precision, scale)?,
+ None => {
+ let mut builder = Decimal32Array::builder(size)
+ .with_precision_and_scale(precision, scale)?;
+ builder.append_nulls(size);
+ builder.finish()
+ }
+ })
+ }
+
+ fn build_decimal64_array(
+ value: Option<i64>,
+ precision: u8,
+ scale: i8,
+ size: usize,
+ ) -> Result<Decimal64Array> {
+ Ok(match value {
+ Some(val) => Decimal64Array::from(vec![val; size])
+ .with_precision_and_scale(precision, scale)?,
+ None => {
+ let mut builder = Decimal64Array::builder(size)
+ .with_precision_and_scale(precision, scale)?;
+ builder.append_nulls(size);
+ builder.finish()
+ }
+ })
+ }
+
+ fn build_decimal128_array(
value: Option<i128>,
precision: u8,
scale: i8,
@@ -2634,8 +2868,14 @@ impl ScalarValue {
/// - a `Dictionary` that fails be converted to a dictionary array of size
pub fn to_array_of_size(&self, size: usize) -> Result<ArrayRef> {
Ok(match self {
+ ScalarValue::Decimal32(e, precision, scale) => Arc::new(
+ ScalarValue::build_decimal32_array(*e, *precision, *scale,
size)?,
+ ),
+ ScalarValue::Decimal64(e, precision, scale) => Arc::new(
+ ScalarValue::build_decimal64_array(*e, *precision, *scale,
size)?,
+ ),
ScalarValue::Decimal128(e, precision, scale) => Arc::new(
- ScalarValue::build_decimal_array(*e, *precision, *scale,
size)?,
+ ScalarValue::build_decimal128_array(*e, *precision, *scale,
size)?,
),
ScalarValue::Decimal256(e, precision, scale) => Arc::new(
ScalarValue::build_decimal256_array(*e, *precision, *scale,
size)?,
@@ -2945,6 +3185,24 @@ impl ScalarValue {
scale: i8,
) -> Result<ScalarValue> {
match array.data_type() {
+ DataType::Decimal32(_, _) => {
+ let array = as_decimal32_array(array)?;
+ if array.is_null(index) {
+ Ok(ScalarValue::Decimal32(None, precision, scale))
+ } else {
+ let value = array.value(index);
+ Ok(ScalarValue::Decimal32(Some(value), precision, scale))
+ }
+ }
+ DataType::Decimal64(_, _) => {
+ let array = as_decimal64_array(array)?;
+ if array.is_null(index) {
+ Ok(ScalarValue::Decimal64(None, precision, scale))
+ } else {
+ let value = array.value(index);
+ Ok(ScalarValue::Decimal64(Some(value), precision, scale))
+ }
+ }
DataType::Decimal128(_, _) => {
let array = as_decimal128_array(array)?;
if array.is_null(index) {
@@ -2963,7 +3221,9 @@ impl ScalarValue {
Ok(ScalarValue::Decimal256(Some(value), precision, scale))
}
}
- _ => _internal_err!("Unsupported decimal type"),
+ other => {
+ unreachable!("Invalid type isn't decimal: {other:?}")
+ }
}
}
@@ -3077,6 +3337,16 @@ impl ScalarValue {
Ok(match array.data_type() {
DataType::Null => ScalarValue::Null,
+ DataType::Decimal32(precision, scale) => {
+ ScalarValue::get_decimal_value_from_array(
+ array, index, *precision, *scale,
+ )?
+ }
+ DataType::Decimal64(precision, scale) => {
+ ScalarValue::get_decimal_value_from_array(
+ array, index, *precision, *scale,
+ )?
+ }
DataType::Decimal128(precision, scale) => {
ScalarValue::get_decimal_value_from_array(
array, index, *precision, *scale,
@@ -3337,6 +3607,44 @@ impl ScalarValue {
ScalarValue::try_from_array(&cast_arr, 0)
}
+ fn eq_array_decimal32(
+ array: &ArrayRef,
+ index: usize,
+ value: Option<&i32>,
+ precision: u8,
+ scale: i8,
+ ) -> Result<bool> {
+ let array = as_decimal32_array(array)?;
+ if array.precision() != precision || array.scale() != scale {
+ return Ok(false);
+ }
+ let is_null = array.is_null(index);
+ if let Some(v) = value {
+ Ok(!array.is_null(index) && array.value(index) == *v)
+ } else {
+ Ok(is_null)
+ }
+ }
+
+ fn eq_array_decimal64(
+ array: &ArrayRef,
+ index: usize,
+ value: Option<&i64>,
+ precision: u8,
+ scale: i8,
+ ) -> Result<bool> {
+ let array = as_decimal64_array(array)?;
+ if array.precision() != precision || array.scale() != scale {
+ return Ok(false);
+ }
+ let is_null = array.is_null(index);
+ if let Some(v) = value {
+ Ok(!array.is_null(index) && array.value(index) == *v)
+ } else {
+ Ok(is_null)
+ }
+ }
+
fn eq_array_decimal(
array: &ArrayRef,
index: usize,
@@ -3404,6 +3712,24 @@ impl ScalarValue {
#[inline]
pub fn eq_array(&self, array: &ArrayRef, index: usize) -> Result<bool> {
Ok(match self {
+ ScalarValue::Decimal32(v, precision, scale) => {
+ ScalarValue::eq_array_decimal32(
+ array,
+ index,
+ v.as_ref(),
+ *precision,
+ *scale,
+ )?
+ }
+ ScalarValue::Decimal64(v, precision, scale) => {
+ ScalarValue::eq_array_decimal64(
+ array,
+ index,
+ v.as_ref(),
+ *precision,
+ *scale,
+ )?
+ }
ScalarValue::Decimal128(v, precision, scale) => {
ScalarValue::eq_array_decimal(
array,
@@ -3602,6 +3928,8 @@ impl ScalarValue {
| ScalarValue::Float16(_)
| ScalarValue::Float32(_)
| ScalarValue::Float64(_)
+ | ScalarValue::Decimal32(_, _, _)
+ | ScalarValue::Decimal64(_, _, _)
| ScalarValue::Decimal128(_, _, _)
| ScalarValue::Decimal256(_, _, _)
| ScalarValue::Int8(_)
@@ -3711,6 +4039,8 @@ impl ScalarValue {
| ScalarValue::Float16(_)
| ScalarValue::Float32(_)
| ScalarValue::Float64(_)
+ | ScalarValue::Decimal32(_, _, _)
+ | ScalarValue::Decimal64(_, _, _)
| ScalarValue::Decimal128(_, _, _)
| ScalarValue::Decimal256(_, _, _)
| ScalarValue::Int8(_)
@@ -4224,6 +4554,12 @@ macro_rules! format_option {
impl fmt::Display for ScalarValue {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
+ ScalarValue::Decimal32(v, p, s) => {
+ write!(f, "{v:?},{p:?},{s:?}")?;
+ }
+ ScalarValue::Decimal64(v, p, s) => {
+ write!(f, "{v:?},{p:?},{s:?}")?;
+ }
ScalarValue::Decimal128(v, p, s) => {
write!(f, "{v:?},{p:?},{s:?}")?;
}
@@ -4413,6 +4749,8 @@ fn fmt_binary(data: &[u8], f: &mut fmt::Formatter) ->
fmt::Result {
impl fmt::Debug for ScalarValue {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
+ ScalarValue::Decimal32(_, _, _) => write!(f, "Decimal32({self})"),
+ ScalarValue::Decimal64(_, _, _) => write!(f, "Decimal64({self})"),
ScalarValue::Decimal128(_, _, _) => write!(f,
"Decimal128({self})"),
ScalarValue::Decimal256(_, _, _) => write!(f,
"Decimal256({self})"),
ScalarValue::Boolean(_) => write!(f, "Boolean({self})"),
diff --git a/datafusion/common/src/types/native.rs
b/datafusion/common/src/types/native.rs
index 76ff9bd095..5cef0adfbd 100644
--- a/datafusion/common/src/types/native.rs
+++ b/datafusion/common/src/types/native.rs
@@ -23,6 +23,7 @@ use crate::error::{Result, _internal_err};
use arrow::compute::can_cast_types;
use arrow::datatypes::{
DataType, Field, FieldRef, Fields, IntervalUnit, TimeUnit, UnionFields,
+ DECIMAL128_MAX_PRECISION, DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION,
};
use std::{fmt::Display, sync::Arc};
@@ -228,7 +229,15 @@ impl LogicalType for NativeType {
(Self::Float16, _) => Float16,
(Self::Float32, _) => Float32,
(Self::Float64, _) => Float64,
- (Self::Decimal(p, s), _) if p <= &38 => Decimal128(*p, *s),
+ (Self::Decimal(p, s), _) if *p <= DECIMAL32_MAX_PRECISION => {
+ Decimal32(*p, *s)
+ }
+ (Self::Decimal(p, s), _) if *p <= DECIMAL64_MAX_PRECISION => {
+ Decimal64(*p, *s)
+ }
+ (Self::Decimal(p, s), _) if *p <= DECIMAL128_MAX_PRECISION => {
+ Decimal128(*p, *s)
+ }
(Self::Decimal(p, s), _) => Decimal256(*p, *s),
(Self::Timestamp(tu, tz), _) => Timestamp(*tu, tz.clone()),
// If given type is Date, return the same type
diff --git a/datafusion/core/tests/fuzz_cases/record_batch_generator.rs
b/datafusion/core/tests/fuzz_cases/record_batch_generator.rs
index e7f63b5351..45dba5f786 100644
--- a/datafusion/core/tests/fuzz_cases/record_batch_generator.rs
+++ b/datafusion/core/tests/fuzz_cases/record_batch_generator.rs
@@ -20,18 +20,19 @@ use std::sync::Arc;
use arrow::array::{ArrayRef, DictionaryArray, PrimitiveArray, RecordBatch};
use arrow::datatypes::{
ArrowPrimitiveType, BooleanType, DataType, Date32Type, Date64Type,
Decimal128Type,
- Decimal256Type, DurationMicrosecondType, DurationMillisecondType,
- DurationNanosecondType, DurationSecondType, Field, Float32Type,
Float64Type,
- Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTimeType,
- IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, Schema,
- Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
Time64NanosecondType,
- TimeUnit, TimestampMicrosecondType, TimestampMillisecondType,
+ Decimal256Type, Decimal32Type, Decimal64Type, DurationMicrosecondType,
+ DurationMillisecondType, DurationNanosecondType, DurationSecondType, Field,
+ Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
+ IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit,
IntervalYearMonthType,
+ Schema, Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
+ Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
TimestampMillisecondType,
TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type,
UInt64Type,
UInt8Type,
};
use arrow_schema::{
DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION,
- DECIMAL256_MAX_SCALE,
+ DECIMAL256_MAX_SCALE, DECIMAL32_MAX_PRECISION, DECIMAL32_MAX_SCALE,
+ DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE,
};
use datafusion_common::{arrow_datafusion_err, DataFusionError, Result};
use rand::{rng, rngs::StdRng, Rng, SeedableRng};
@@ -104,6 +105,20 @@ pub fn get_supported_types_columns(rng_seed: u64) ->
Vec<ColumnDescr> {
"duration_nanosecond",
DataType::Duration(TimeUnit::Nanosecond),
),
+ ColumnDescr::new("decimal32", {
+ let precision: u8 = rng.random_range(1..=DECIMAL32_MAX_PRECISION);
+ let scale: i8 = rng.random_range(
+ i8::MIN..=std::cmp::min(precision as i8, DECIMAL32_MAX_SCALE),
+ );
+ DataType::Decimal32(precision, scale)
+ }),
+ ColumnDescr::new("decimal64", {
+ let precision: u8 = rng.random_range(1..=DECIMAL64_MAX_PRECISION);
+ let scale: i8 = rng.random_range(
+ i8::MIN..=std::cmp::min(precision as i8, DECIMAL64_MAX_SCALE),
+ );
+ DataType::Decimal64(precision, scale)
+ }),
ColumnDescr::new("decimal128", {
let precision: u8 = rng.random_range(1..=DECIMAL128_MAX_PRECISION);
let scale: i8 = rng.random_range(
@@ -682,6 +697,32 @@ impl RecordBatchGenerator {
_ => unreachable!(),
}
}
+ DataType::Decimal32(precision, scale) => {
+ generate_decimal_array!(
+ self,
+ num_rows,
+ max_num_distinct,
+ null_pct,
+ batch_gen_rng,
+ array_gen_rng,
+ precision,
+ scale,
+ Decimal32Type
+ )
+ }
+ DataType::Decimal64(precision, scale) => {
+ generate_decimal_array!(
+ self,
+ num_rows,
+ max_num_distinct,
+ null_pct,
+ batch_gen_rng,
+ array_gen_rng,
+ precision,
+ scale,
+ Decimal64Type
+ )
+ }
DataType::Decimal128(precision, scale) => {
generate_decimal_array!(
self,
diff --git a/datafusion/expr-common/src/casts.rs
b/datafusion/expr-common/src/casts.rs
index 684452d538..8939ff1371 100644
--- a/datafusion/expr-common/src/casts.rs
+++ b/datafusion/expr-common/src/casts.rs
@@ -25,7 +25,9 @@ use std::cmp::Ordering;
use arrow::datatypes::{
DataType, TimeUnit, MAX_DECIMAL128_FOR_EACH_PRECISION,
- MIN_DECIMAL128_FOR_EACH_PRECISION,
+ MAX_DECIMAL32_FOR_EACH_PRECISION, MAX_DECIMAL64_FOR_EACH_PRECISION,
+ MIN_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL32_FOR_EACH_PRECISION,
+ MIN_DECIMAL64_FOR_EACH_PRECISION,
};
use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS};
use datafusion_common::ScalarValue;
@@ -69,6 +71,8 @@ fn is_supported_numeric_type(data_type: &DataType) -> bool {
| DataType::Int16
| DataType::Int32
| DataType::Int64
+ | DataType::Decimal32(_, _)
+ | DataType::Decimal64(_, _)
| DataType::Decimal128(_, _)
| DataType::Timestamp(_, _)
)
@@ -114,6 +118,8 @@ fn try_cast_numeric_literal(
| DataType::Int32
| DataType::Int64 => 1_i128,
DataType::Timestamp(_, _) => 1_i128,
+ DataType::Decimal32(_, scale) => 10_i128.pow(*scale as u32),
+ DataType::Decimal64(_, scale) => 10_i128.pow(*scale as u32),
DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32),
_ => return None,
};
@@ -127,6 +133,20 @@ fn try_cast_numeric_literal(
DataType::Int32 => (i32::MIN as i128, i32::MAX as i128),
DataType::Int64 => (i64::MIN as i128, i64::MAX as i128),
DataType::Timestamp(_, _) => (i64::MIN as i128, i64::MAX as i128),
+ DataType::Decimal32(precision, _) => (
+ // Different precision for decimal32 can store different range of
value.
+ // For example, the precision is 3, the max of value is `999` and
the min
+ // value is `-999`
+ MIN_DECIMAL32_FOR_EACH_PRECISION[*precision as usize] as i128,
+ MAX_DECIMAL32_FOR_EACH_PRECISION[*precision as usize] as i128,
+ ),
+ DataType::Decimal64(precision, _) => (
+ // Different precision for decimal64 can store different range of
value.
+ // For example, the precision is 3, the max of value is `999` and
the min
+ // value is `-999`
+ MIN_DECIMAL64_FOR_EACH_PRECISION[*precision as usize] as i128,
+ MAX_DECIMAL64_FOR_EACH_PRECISION[*precision as usize] as i128,
+ ),
DataType::Decimal128(precision, _) => (
// Different precision for decimal128 can store different range of
value.
// For example, the precision is 3, the max of value is `999` and
the min
@@ -149,6 +169,46 @@ fn try_cast_numeric_literal(
ScalarValue::TimestampMillisecond(Some(v), _) => (*v as
i128).checked_mul(mul),
ScalarValue::TimestampMicrosecond(Some(v), _) => (*v as
i128).checked_mul(mul),
ScalarValue::TimestampNanosecond(Some(v), _) => (*v as
i128).checked_mul(mul),
+ ScalarValue::Decimal32(Some(v), _, scale) => {
+ let v = *v as i128;
+ let lit_scale_mul = 10_i128.pow(*scale as u32);
+ if mul >= lit_scale_mul {
+ // Example:
+ // lit is decimal(123,3,2)
+ // target type is decimal(5,3)
+ // the lit can be converted to the decimal(1230,5,3)
+ v.checked_mul(mul / lit_scale_mul)
+ } else if v % (lit_scale_mul / mul) == 0 {
+ // Example:
+ // lit is decimal(123000,10,3)
+ // target type is int32: the lit can be converted to INT32(123)
+ // target type is decimal(10,2): the lit can be converted to
decimal(12300,10,2)
+ Some(v / (lit_scale_mul / mul))
+ } else {
+ // can't convert the lit decimal to the target data type
+ None
+ }
+ }
+ ScalarValue::Decimal64(Some(v), _, scale) => {
+ let v = *v as i128;
+ let lit_scale_mul = 10_i128.pow(*scale as u32);
+ if mul >= lit_scale_mul {
+ // Example:
+ // lit is decimal(123,3,2)
+ // target type is decimal(5,3)
+ // the lit can be converted to the decimal(1230,5,3)
+ v.checked_mul(mul / lit_scale_mul)
+ } else if v % (lit_scale_mul / mul) == 0 {
+ // Example:
+ // lit is decimal(123000,10,3)
+ // target type is int32: the lit can be converted to INT32(123)
+ // target type is decimal(10,2): the lit can be converted to
decimal(12300,10,2)
+ Some(v / (lit_scale_mul / mul))
+ } else {
+ // can't convert the lit decimal to the target data type
+ None
+ }
+ }
ScalarValue::Decimal128(Some(v), _, scale) => {
let lit_scale_mul = 10_i128.pow(*scale as u32);
if mul >= lit_scale_mul {
@@ -218,6 +278,12 @@ fn try_cast_numeric_literal(
);
ScalarValue::TimestampNanosecond(value, tz.clone())
}
+ DataType::Decimal32(p, s) => {
+ ScalarValue::Decimal32(Some(value as i32), *p, *s)
+ }
+ DataType::Decimal64(p, s) => {
+ ScalarValue::Decimal64(Some(value as i64), *p, *s)
+ }
DataType::Decimal128(p, s) => {
ScalarValue::Decimal128(Some(value), *p, *s)
}
diff --git a/datafusion/expr-common/src/type_coercion/aggregates.rs
b/datafusion/expr-common/src/type_coercion/aggregates.rs
index b4f393eceb..e77a072a84 100644
--- a/datafusion/expr-common/src/type_coercion/aggregates.rs
+++ b/datafusion/expr-common/src/type_coercion/aggregates.rs
@@ -18,7 +18,8 @@
use crate::signature::TypeSignature;
use arrow::datatypes::{
DataType, FieldRef, TimeUnit, DECIMAL128_MAX_PRECISION,
DECIMAL128_MAX_SCALE,
- DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
+ DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DECIMAL32_MAX_PRECISION,
+ DECIMAL32_MAX_SCALE, DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE,
};
use datafusion_common::{internal_err, plan_err, Result};
@@ -150,6 +151,18 @@ pub fn sum_return_type(arg_type: &DataType) ->
Result<DataType> {
DataType::Int64 => Ok(DataType::Int64),
DataType::UInt64 => Ok(DataType::UInt64),
DataType::Float64 => Ok(DataType::Float64),
+ DataType::Decimal32(precision, scale) => {
+ // in the spark, the result type is DECIMAL(min(38,precision+10),
s)
+ // ref:
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
+ let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10);
+ Ok(DataType::Decimal32(new_precision, *scale))
+ }
+ DataType::Decimal64(precision, scale) => {
+ // in the spark, the result type is DECIMAL(min(38,precision+10),
s)
+ // ref:
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
+ let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10);
+ Ok(DataType::Decimal64(new_precision, *scale))
+ }
DataType::Decimal128(precision, scale) => {
// In the spark, the result type is DECIMAL(min(38,precision+10),
s)
// Ref:
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
@@ -196,6 +209,20 @@ pub fn correlation_return_type(arg_type: &DataType) ->
Result<DataType> {
/// Function return type of an average
pub fn avg_return_type(func_name: &str, arg_type: &DataType) ->
Result<DataType> {
match arg_type {
+ DataType::Decimal32(precision, scale) => {
+ // In the spark, the result type is DECIMAL(min(38,precision+4),
min(38,scale+4)).
+ // Ref:
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
+ let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 4);
+ let new_scale = DECIMAL32_MAX_SCALE.min(*scale + 4);
+ Ok(DataType::Decimal32(new_precision, new_scale))
+ }
+ DataType::Decimal64(precision, scale) => {
+ // In the spark, the result type is DECIMAL(min(38,precision+4),
min(38,scale+4)).
+ // Ref:
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
+ let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 4);
+ let new_scale = DECIMAL64_MAX_SCALE.min(*scale + 4);
+ Ok(DataType::Decimal64(new_precision, new_scale))
+ }
DataType::Decimal128(precision, scale) => {
// In the spark, the result type is DECIMAL(min(38,precision+4),
min(38,scale+4)).
// Ref:
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
@@ -222,6 +249,16 @@ pub fn avg_return_type(func_name: &str, arg_type:
&DataType) -> Result<DataType>
/// Internal sum type of an average
pub fn avg_sum_type(arg_type: &DataType) -> Result<DataType> {
match arg_type {
+ DataType::Decimal32(precision, scale) => {
+ // In the spark, the sum type of avg is
DECIMAL(min(38,precision+10), s)
+ let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10);
+ Ok(DataType::Decimal32(new_precision, *scale))
+ }
+ DataType::Decimal64(precision, scale) => {
+ // In the spark, the sum type of avg is
DECIMAL(min(38,precision+10), s)
+ let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10);
+ Ok(DataType::Decimal64(new_precision, *scale))
+ }
DataType::Decimal128(precision, scale) => {
// In the spark, the sum type of avg is
DECIMAL(min(38,precision+10), s)
let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
@@ -249,7 +286,7 @@ pub fn is_sum_support_arg_type(arg_type: &DataType) -> bool
{
_ => matches!(
arg_type,
arg_type if NUMERICS.contains(arg_type)
- || matches!(arg_type, DataType::Decimal128(_, _) |
DataType::Decimal256(_, _))
+ || matches!(arg_type, DataType::Decimal32(_, _) |
DataType::Decimal64(_, _) |DataType::Decimal128(_, _) | DataType::Decimal256(_,
_))
),
}
}
@@ -262,7 +299,7 @@ pub fn is_avg_support_arg_type(arg_type: &DataType) -> bool
{
_ => matches!(
arg_type,
arg_type if NUMERICS.contains(arg_type)
- || matches!(arg_type, DataType::Decimal128(_, _)|
DataType::Decimal256(_, _))
+ || matches!(arg_type, DataType::Decimal32(_, _) |
DataType::Decimal64(_, _) |DataType::Decimal128(_, _) | DataType::Decimal256(_,
_))
),
}
}
@@ -297,6 +334,8 @@ pub fn coerce_avg_type(func_name: &str, arg_types:
&[DataType]) -> Result<Vec<Da
// Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html
doc
fn coerced_type(func_name: &str, data_type: &DataType) -> Result<DataType>
{
match &data_type {
+ DataType::Decimal32(p, s) => Ok(DataType::Decimal32(*p, *s)),
+ DataType::Decimal64(p, s) => Ok(DataType::Decimal64(*p, *s)),
DataType::Decimal128(p, s) => Ok(DataType::Decimal128(*p, *s)),
DataType::Decimal256(p, s) => Ok(DataType::Decimal256(*p, *s)),
d if d.is_numeric() => Ok(DataType::Float64),
diff --git a/datafusion/expr-common/src/type_coercion/binary.rs
b/datafusion/expr-common/src/type_coercion/binary.rs
index a102523b21..1c99f49d26 100644
--- a/datafusion/expr-common/src/type_coercion/binary.rs
+++ b/datafusion/expr-common/src/type_coercion/binary.rs
@@ -27,6 +27,8 @@ use arrow::compute::can_cast_types;
use arrow::datatypes::{
DataType, Field, FieldRef, Fields, TimeUnit, DECIMAL128_MAX_PRECISION,
DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
+ DECIMAL32_MAX_PRECISION, DECIMAL32_MAX_SCALE, DECIMAL64_MAX_PRECISION,
+ DECIMAL64_MAX_SCALE,
};
use datafusion_common::types::NativeType;
use datafusion_common::{
@@ -341,22 +343,64 @@ fn math_decimal_coercion(
let (lhs_type, value_type) = math_decimal_coercion(lhs_type,
value_type)?;
Some((lhs_type, value_type))
}
- (Null, dec_type @ Decimal128(_, _)) | (dec_type @ Decimal128(_, _),
Null) => {
- Some((dec_type.clone(), dec_type.clone()))
- }
- (Decimal128(_, _), Decimal128(_, _)) | (Decimal256(_, _),
Decimal256(_, _)) => {
+ (
+ Null,
+ Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) |
Decimal256(_, _),
+ ) => Some((rhs_type.clone(), rhs_type.clone())),
+ (
+ Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) |
Decimal256(_, _),
+ Null,
+ ) => Some((lhs_type.clone(), lhs_type.clone())),
+ (Decimal32(_, _), Decimal32(_, _))
+ | (Decimal64(_, _), Decimal64(_, _))
+ | (Decimal128(_, _), Decimal128(_, _))
+ | (Decimal256(_, _), Decimal256(_, _)) => {
Some((lhs_type.clone(), rhs_type.clone()))
}
// Unlike with comparison we don't coerce to a decimal in the case of
floating point
// numbers, instead falling back to floating point arithmetic instead
+ (
+ Decimal32(_, _),
+ Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64,
+ ) => Some((
+ lhs_type.clone(),
+ coerce_numeric_type_to_decimal32(rhs_type)?,
+ )),
+ (
+ Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64,
+ Decimal32(_, _),
+ ) => Some((
+ coerce_numeric_type_to_decimal32(lhs_type)?,
+ rhs_type.clone(),
+ )),
+ (
+ Decimal64(_, _),
+ Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64,
+ ) => Some((
+ lhs_type.clone(),
+ coerce_numeric_type_to_decimal64(rhs_type)?,
+ )),
+ (
+ Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64,
+ Decimal64(_, _),
+ ) => Some((
+ coerce_numeric_type_to_decimal64(lhs_type)?,
+ rhs_type.clone(),
+ )),
(
Decimal128(_, _),
Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64,
- ) => Some((lhs_type.clone(),
coerce_numeric_type_to_decimal(rhs_type)?)),
+ ) => Some((
+ lhs_type.clone(),
+ coerce_numeric_type_to_decimal128(rhs_type)?,
+ )),
(
Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64,
Decimal128(_, _),
- ) => Some((coerce_numeric_type_to_decimal(lhs_type)?,
rhs_type.clone())),
+ ) => Some((
+ coerce_numeric_type_to_decimal128(lhs_type)?,
+ rhs_type.clone(),
+ )),
(
Decimal256(_, _),
Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64,
@@ -932,8 +976,8 @@ fn get_common_decimal_type(
) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match decimal_type {
- Decimal128(_, _) => {
- let other_decimal_type =
coerce_numeric_type_to_decimal(other_type)?;
+ Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) => {
+ let other_decimal_type =
coerce_numeric_type_to_decimal128(other_type)?;
get_wider_decimal_type(decimal_type, &other_decimal_type)
}
Decimal256(_, _) => {
@@ -953,11 +997,23 @@ fn get_wider_decimal_type(
rhs_type: &DataType,
) -> Option<DataType> {
match (lhs_decimal_type, rhs_type) {
+ (DataType::Decimal32(p1, s1), DataType::Decimal32(p2, s2)) => {
+ // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)
+ let s = *s1.max(s2);
+ let range = (*p1 as i8 - s1).max(*p2 as i8 - s2);
+ Some(create_decimal32_type((range + s) as u8, s))
+ }
+ (DataType::Decimal64(p1, s1), DataType::Decimal64(p2, s2)) => {
+ // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)
+ let s = *s1.max(s2);
+ let range = (*p1 as i8 - s1).max(*p2 as i8 - s2);
+ Some(create_decimal64_type((range + s) as u8, s))
+ }
(DataType::Decimal128(p1, s1), DataType::Decimal128(p2, s2)) => {
// max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)
let s = *s1.max(s2);
let range = (*p1 as i8 - s1).max(*p2 as i8 - s2);
- Some(create_decimal_type((range + s) as u8, s))
+ Some(create_decimal128_type((range + s) as u8, s))
}
(DataType::Decimal256(p1, s1), DataType::Decimal256(p2, s2)) => {
// max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)
@@ -971,7 +1027,39 @@ fn get_wider_decimal_type(
/// Convert the numeric data type to the decimal data type.
/// We support signed and unsigned integer types and floating-point type.
-fn coerce_numeric_type_to_decimal(numeric_type: &DataType) -> Option<DataType>
{
+fn coerce_numeric_type_to_decimal32(numeric_type: &DataType) ->
Option<DataType> {
+ use arrow::datatypes::DataType::*;
+ // This conversion rule is from spark
+ //
https://github.com/apache/spark/blob/1c81ad20296d34f137238dadd67cc6ae405944eb/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala#L127
+ match numeric_type {
+ Int8 | UInt8 => Some(Decimal32(3, 0)),
+ Int16 | UInt16 => Some(Decimal32(5, 0)),
+ // TODO if we convert the floating-point data to the decimal type, it
maybe overflow.
+ Float16 => Some(Decimal32(6, 3)),
+ _ => None,
+ }
+}
+
+/// Convert the numeric data type to the decimal data type.
+/// We support signed and unsigned integer types and floating-point type.
+fn coerce_numeric_type_to_decimal64(numeric_type: &DataType) ->
Option<DataType> {
+ use arrow::datatypes::DataType::*;
+ // This conversion rule is from spark
+ //
https://github.com/apache/spark/blob/1c81ad20296d34f137238dadd67cc6ae405944eb/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala#L127
+ match numeric_type {
+ Int8 | UInt8 => Some(Decimal64(3, 0)),
+ Int16 | UInt16 => Some(Decimal64(5, 0)),
+ Int32 | UInt32 => Some(Decimal64(10, 0)),
+ // TODO if we convert the floating-point data to the decimal type, it
maybe overflow.
+ Float16 => Some(Decimal64(6, 3)),
+ Float32 => Some(Decimal64(14, 7)),
+ _ => None,
+ }
+}
+
+/// Convert the numeric data type to the decimal data type.
+/// We support signed and unsigned integer types and floating-point type.
+fn coerce_numeric_type_to_decimal128(numeric_type: &DataType) ->
Option<DataType> {
use arrow::datatypes::DataType::*;
// This conversion rule is from spark
//
https://github.com/apache/spark/blob/1c81ad20296d34f137238dadd67cc6ae405944eb/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala#L127
@@ -1120,7 +1208,21 @@ fn numerical_coercion(lhs_type: &DataType, rhs_type:
&DataType) -> Option<DataTy
}
}
-fn create_decimal_type(precision: u8, scale: i8) -> DataType {
+fn create_decimal32_type(precision: u8, scale: i8) -> DataType {
+ DataType::Decimal128(
+ DECIMAL32_MAX_PRECISION.min(precision),
+ DECIMAL32_MAX_SCALE.min(scale),
+ )
+}
+
+fn create_decimal64_type(precision: u8, scale: i8) -> DataType {
+ DataType::Decimal128(
+ DECIMAL64_MAX_PRECISION.min(precision),
+ DECIMAL64_MAX_SCALE.min(scale),
+ )
+}
+
+fn create_decimal128_type(precision: u8, scale: i8) -> DataType {
DataType::Decimal128(
DECIMAL128_MAX_PRECISION.min(precision),
DECIMAL128_MAX_SCALE.min(scale),
diff --git
a/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs
b/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs
index fdd41ae2bb..e6238ba007 100644
--- a/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs
+++ b/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs
@@ -56,32 +56,75 @@ fn test_date_timestamp_arithmetic_error() -> Result<()> {
#[test]
fn test_decimal_mathematics_op_type() {
+ // Decimal32
assert_eq!(
- coerce_numeric_type_to_decimal(&DataType::Int8).unwrap(),
+ coerce_numeric_type_to_decimal32(&DataType::Int8).unwrap(),
+ DataType::Decimal32(3, 0)
+ );
+ assert_eq!(
+ coerce_numeric_type_to_decimal32(&DataType::Int16).unwrap(),
+ DataType::Decimal32(5, 0)
+ );
+ assert!(coerce_numeric_type_to_decimal32(&DataType::Int32).is_none());
+ assert!(coerce_numeric_type_to_decimal32(&DataType::Int64).is_none(),);
+ assert_eq!(
+ coerce_numeric_type_to_decimal32(&DataType::Float16).unwrap(),
+ DataType::Decimal32(6, 3)
+ );
+ assert!(coerce_numeric_type_to_decimal32(&DataType::Float32).is_none(),);
+ assert!(coerce_numeric_type_to_decimal32(&DataType::Float64).is_none());
+
+ // Decimal64
+ assert_eq!(
+ coerce_numeric_type_to_decimal64(&DataType::Int8).unwrap(),
+ DataType::Decimal64(3, 0)
+ );
+ assert_eq!(
+ coerce_numeric_type_to_decimal64(&DataType::Int16).unwrap(),
+ DataType::Decimal64(5, 0)
+ );
+ assert_eq!(
+ coerce_numeric_type_to_decimal64(&DataType::Int32).unwrap(),
+ DataType::Decimal64(10, 0)
+ );
+ assert!(coerce_numeric_type_to_decimal64(&DataType::Int64).is_none(),);
+ assert_eq!(
+ coerce_numeric_type_to_decimal64(&DataType::Float16).unwrap(),
+ DataType::Decimal64(6, 3)
+ );
+ assert_eq!(
+ coerce_numeric_type_to_decimal64(&DataType::Float32).unwrap(),
+ DataType::Decimal64(14, 7)
+ );
+ assert!(coerce_numeric_type_to_decimal64(&DataType::Float64).is_none());
+
+ // Decimal128
+ assert_eq!(
+ coerce_numeric_type_to_decimal128(&DataType::Int8).unwrap(),
DataType::Decimal128(3, 0)
);
assert_eq!(
- coerce_numeric_type_to_decimal(&DataType::Int16).unwrap(),
+ coerce_numeric_type_to_decimal128(&DataType::Int16).unwrap(),
DataType::Decimal128(5, 0)
);
assert_eq!(
- coerce_numeric_type_to_decimal(&DataType::Int32).unwrap(),
+ coerce_numeric_type_to_decimal128(&DataType::Int32).unwrap(),
DataType::Decimal128(10, 0)
);
assert_eq!(
- coerce_numeric_type_to_decimal(&DataType::Int64).unwrap(),
+ coerce_numeric_type_to_decimal128(&DataType::Int64).unwrap(),
DataType::Decimal128(20, 0)
);
assert_eq!(
- coerce_numeric_type_to_decimal(&DataType::Float16).unwrap(),
+ coerce_numeric_type_to_decimal128(&DataType::Float16).unwrap(),
DataType::Decimal128(6, 3)
);
assert_eq!(
- coerce_numeric_type_to_decimal(&DataType::Float32).unwrap(),
+ coerce_numeric_type_to_decimal128(&DataType::Float32).unwrap(),
DataType::Decimal128(14, 7)
);
assert_eq!(
- coerce_numeric_type_to_decimal(&DataType::Float64).unwrap(),
+ coerce_numeric_type_to_decimal128(&DataType::Float64).unwrap(),
DataType::Decimal128(30, 15)
);
}
diff --git a/datafusion/expr/src/logical_plan/builder.rs
b/datafusion/expr/src/logical_plan/builder.rs
index 3b4fd1aff9..709bb0d71f 100644
--- a/datafusion/expr/src/logical_plan/builder.rs
+++ b/datafusion/expr/src/logical_plan/builder.rs
@@ -282,15 +282,14 @@ impl LogicalPlanBuilder {
let value = &row[j];
let data_type = value.get_type(schema)?;
- if !data_type.equals_datatype(field_type) {
- if can_cast_types(&data_type, field_type) {
- } else {
- return exec_err!(
- "type mismatch and can't cast to got {} and {}",
- data_type,
- field_type
- );
- }
+ if !data_type.equals_datatype(field_type)
+ && !can_cast_types(&data_type, field_type)
+ {
+ return exec_err!(
+ "type mismatch and can't cast to got {} and {}",
+ data_type,
+ field_type
+ );
}
}
fields.push(field_type.to_owned(), field_nullable);
diff --git a/datafusion/expr/src/test/function_stub.rs
b/datafusion/expr/src/test/function_stub.rs
index c5b7c751b9..41bc645058 100644
--- a/datafusion/expr/src/test/function_stub.rs
+++ b/datafusion/expr/src/test/function_stub.rs
@@ -23,6 +23,7 @@ use std::any::Any;
use arrow::datatypes::{
DataType, FieldRef, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
+ DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION,
};
use datafusion_common::{exec_err, not_impl_err, utils::take_function_args,
Result};
@@ -135,9 +136,10 @@ impl AggregateUDFImpl for Sum {
DataType::Dictionary(_, v) => coerced_type(v),
// in the spark, the result type is
DECIMAL(min(38,precision+10), s)
// ref:
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
- DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => {
- Ok(data_type.clone())
- }
+ DataType::Decimal32(_, _)
+ | DataType::Decimal64(_, _)
+ | DataType::Decimal128(_, _)
+ | DataType::Decimal256(_, _) => Ok(data_type.clone()),
dt if dt.is_signed_integer() => Ok(DataType::Int64),
dt if dt.is_unsigned_integer() => Ok(DataType::UInt64),
dt if dt.is_floating() => Ok(DataType::Float64),
@@ -153,6 +155,18 @@ impl AggregateUDFImpl for Sum {
DataType::Int64 => Ok(DataType::Int64),
DataType::UInt64 => Ok(DataType::UInt64),
DataType::Float64 => Ok(DataType::Float64),
+ DataType::Decimal32(precision, scale) => {
+ // in the spark, the result type is
DECIMAL(min(38,precision+10), s)
+ // ref:
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
+ let new_precision = DECIMAL32_MAX_PRECISION.min(*precision +
10);
+ Ok(DataType::Decimal32(new_precision, *scale))
+ }
+ DataType::Decimal64(precision, scale) => {
+ // in the spark, the result type is
DECIMAL(min(38,precision+10), s)
+ // ref:
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
+ let new_precision = DECIMAL64_MAX_PRECISION.min(*precision +
10);
+ Ok(DataType::Decimal64(new_precision, *scale))
+ }
DataType::Decimal128(precision, scale) => {
// in the spark, the result type is
DECIMAL(min(38,precision+10), s)
// ref:
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
diff --git a/datafusion/expr/src/type_coercion/functions.rs
b/datafusion/expr/src/type_coercion/functions.rs
index 9d15360ca8..bcaff11bcd 100644
--- a/datafusion/expr/src/type_coercion/functions.rs
+++ b/datafusion/expr/src/type_coercion/functions.rs
@@ -879,7 +879,10 @@ fn coerced_from<'a>(
| UInt64
| Float32
| Float64
- | Decimal128(_, _),
+ | Decimal32(_, _)
+ | Decimal64(_, _)
+ | Decimal128(_, _)
+ | Decimal256(_, _),
) => Some(type_into.clone()),
(
Timestamp(TimeUnit::Nanosecond, None),
diff --git a/datafusion/expr/src/type_coercion/mod.rs
b/datafusion/expr/src/type_coercion/mod.rs
index 4fc150ef29..bd1acd3f3a 100644
--- a/datafusion/expr/src/type_coercion/mod.rs
+++ b/datafusion/expr/src/type_coercion/mod.rs
@@ -51,6 +51,8 @@ pub fn is_signed_numeric(dt: &DataType) -> bool {
| DataType::Float16
| DataType::Float32
| DataType::Float64
+ | DataType::Decimal32(_, _)
+ | DataType::Decimal64(_, _)
| DataType::Decimal128(_, _)
| DataType::Decimal256(_, _),
)
@@ -89,5 +91,11 @@ pub fn is_utf8_or_utf8view_or_large_utf8(dt: &DataType) ->
bool {
/// Determine whether the given data type `dt` is a `Decimal`.
pub fn is_decimal(dt: &DataType) -> bool {
- matches!(dt, DataType::Decimal128(_, _) | DataType::Decimal256(_, _))
+ matches!(
+ dt,
+ DataType::Decimal32(_, _)
+ | DataType::Decimal64(_, _)
+ | DataType::Decimal128(_, _)
+ | DataType::Decimal256(_, _)
+ )
}
diff --git
a/datafusion/functions-aggregate-common/src/aggregate/avg_distinct/decimal.rs
b/datafusion/functions-aggregate-common/src/aggregate/avg_distinct/decimal.rs
index a71871b9b4..9920bf5bf4 100644
---
a/datafusion/functions-aggregate-common/src/aggregate/avg_distinct/decimal.rs
+++
b/datafusion/functions-aggregate-common/src/aggregate/avg_distinct/decimal.rs
@@ -17,7 +17,9 @@
use arrow::{
array::{ArrayRef, ArrowNumericType},
- datatypes::{i256, Decimal128Type, Decimal256Type, DecimalType},
+ datatypes::{
+ i256, Decimal128Type, Decimal256Type, Decimal32Type, Decimal64Type,
DecimalType,
+ },
};
use datafusion_common::{Result, ScalarValue};
use datafusion_expr_common::accumulator::Accumulator;
@@ -28,7 +30,7 @@ use crate::aggregate::sum_distinct::DistinctSumAccumulator;
use crate::utils::DecimalAverager;
/// Generic implementation of `AVG DISTINCT` for Decimal types.
-/// Handles both Decimal128Type and Decimal256Type.
+/// Handles both all Arrow decimal types (32, 64, 128 and 256 bits).
#[derive(Debug)]
pub struct DecimalDistinctAvgAccumulator<T: DecimalType + Debug> {
sum_accumulator: DistinctSumAccumulator<T>,
@@ -80,6 +82,34 @@ impl<T: DecimalType + ArrowNumericType + Debug> Accumulator
let sum_scalar = self.sum_accumulator.evaluate()?;
match sum_scalar {
+ ScalarValue::Decimal32(Some(sum), _, _) => {
+ let decimal_averager =
DecimalAverager::<Decimal32Type>::try_new(
+ self.sum_scale,
+ self.target_precision,
+ self.target_scale,
+ )?;
+ let avg = decimal_averager
+ .avg(sum, self.sum_accumulator.distinct_count() as i32)?;
+ Ok(ScalarValue::Decimal32(
+ Some(avg),
+ self.target_precision,
+ self.target_scale,
+ ))
+ }
+ ScalarValue::Decimal64(Some(sum), _, _) => {
+ let decimal_averager =
DecimalAverager::<Decimal64Type>::try_new(
+ self.sum_scale,
+ self.target_precision,
+ self.target_scale,
+ )?;
+ let avg = decimal_averager
+ .avg(sum, self.sum_accumulator.distinct_count() as i64)?;
+ Ok(ScalarValue::Decimal64(
+ Some(avg),
+ self.target_precision,
+ self.target_scale,
+ ))
+ }
ScalarValue::Decimal128(Some(sum), _, _) => {
let decimal_averager =
DecimalAverager::<Decimal128Type>::try_new(
self.sum_scale,
@@ -127,9 +157,69 @@ impl<T: DecimalType + ArrowNumericType + Debug> Accumulator
#[cfg(test)]
mod tests {
use super::*;
- use arrow::array::{Decimal128Array, Decimal256Array};
+ use arrow::array::{
+ Decimal128Array, Decimal256Array, Decimal32Array, Decimal64Array,
+ };
use std::sync::Arc;
+ #[test]
+ fn test_decimal32_distinct_avg_accumulator() -> Result<()> {
+ let precision = 5_u8;
+ let scale = 2_i8;
+ let array = Decimal32Array::from(vec![
+ Some(10_00),
+ Some(12_50),
+ Some(17_50),
+ Some(20_00),
+ Some(20_00),
+ Some(30_00),
+ None,
+ None,
+ ])
+ .with_precision_and_scale(precision, scale)?;
+
+ let mut accumulator =
+
DecimalDistinctAvgAccumulator::<Decimal32Type>::with_decimal_params(
+ scale, 9, 6,
+ );
+ accumulator.update_batch(&[Arc::new(array)])?;
+
+ let result = accumulator.evaluate()?;
+ let expected_result = ScalarValue::Decimal32(Some(18000000), 9, 6);
+ assert_eq!(result, expected_result);
+
+ Ok(())
+ }
+
+ #[test]
+ fn test_decimal64_distinct_avg_accumulator() -> Result<()> {
+ let precision = 10_u8;
+ let scale = 4_i8;
+ let array = Decimal64Array::from(vec![
+ Some(100_0000),
+ Some(125_0000),
+ Some(175_0000),
+ Some(200_0000),
+ Some(200_0000),
+ Some(300_0000),
+ None,
+ None,
+ ])
+ .with_precision_and_scale(precision, scale)?;
+
+ let mut accumulator =
+
DecimalDistinctAvgAccumulator::<Decimal64Type>::with_decimal_params(
+ scale, 14, 8,
+ );
+ accumulator.update_batch(&[Arc::new(array)])?;
+
+ let result = accumulator.evaluate()?;
+ let expected_result = ScalarValue::Decimal64(Some(180_00000000), 14,
8);
+ assert_eq!(result, expected_result);
+
+ Ok(())
+ }
+
#[test]
fn test_decimal128_distinct_avg_accumulator() -> Result<()> {
let precision = 10_u8;
diff --git a/datafusion/functions-aggregate-common/src/min_max.rs
b/datafusion/functions-aggregate-common/src/min_max.rs
index 0aad9b356f..7dd60e1c0e 100644
--- a/datafusion/functions-aggregate-common/src/min_max.rs
+++ b/datafusion/functions-aggregate-common/src/min_max.rs
@@ -19,15 +19,15 @@
use arrow::array::{
ArrayRef, AsArray as _, BinaryArray, BinaryViewArray, BooleanArray,
Date32Array,
- Date64Array, Decimal128Array, Decimal256Array, DurationMicrosecondArray,
- DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray,
- FixedSizeBinaryArray, Float16Array, Float32Array, Float64Array, Int16Array,
- Int32Array, Int64Array, Int8Array, IntervalDayTimeArray,
IntervalMonthDayNanoArray,
- IntervalYearMonthArray, LargeBinaryArray, LargeStringArray, StringArray,
- StringViewArray, Time32MillisecondArray, Time32SecondArray,
Time64MicrosecondArray,
- Time64NanosecondArray, TimestampMicrosecondArray,
TimestampMillisecondArray,
- TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array,
- UInt64Array, UInt8Array,
+ Date64Array, Decimal128Array, Decimal256Array, Decimal32Array,
Decimal64Array,
+ DurationMicrosecondArray, DurationMillisecondArray,
DurationNanosecondArray,
+ DurationSecondArray, FixedSizeBinaryArray, Float16Array, Float32Array,
Float64Array,
+ Int16Array, Int32Array, Int64Array, Int8Array, IntervalDayTimeArray,
+ IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray,
+ LargeStringArray, StringArray, StringViewArray, Time32MillisecondArray,
+ Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray,
+ TimestampMicrosecondArray, TimestampMillisecondArray,
TimestampNanosecondArray,
+ TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
};
use arrow::compute;
use arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
@@ -144,6 +144,32 @@ macro_rules! min_max {
($VALUE:expr, $DELTA:expr, $OP:ident) => {{
Ok(match ($VALUE, $DELTA) {
(ScalarValue::Null, ScalarValue::Null) => ScalarValue::Null,
+ (
+ lhs @ ScalarValue::Decimal32(lhsv, lhsp, lhss),
+ rhs @ ScalarValue::Decimal32(rhsv, rhsp, rhss)
+ ) => {
+ if lhsp.eq(rhsp) && lhss.eq(rhss) {
+ typed_min_max!(lhsv, rhsv, Decimal32, $OP, lhsp, lhss)
+ } else {
+ return internal_err!(
+ "MIN/MAX is not expected to receive scalars of
incompatible types {:?}",
+ (lhs, rhs)
+ );
+ }
+ }
+ (
+ lhs @ ScalarValue::Decimal64(lhsv, lhsp, lhss),
+ rhs @ ScalarValue::Decimal64(rhsv, rhsp, rhss)
+ ) => {
+ if lhsp.eq(rhsp) && lhss.eq(rhss) {
+ typed_min_max!(lhsv, rhsv, Decimal64, $OP, lhsp, lhss)
+ } else {
+ return internal_err!(
+ "MIN/MAX is not expected to receive scalars of
incompatible types {:?}",
+ (lhs, rhs)
+ );
+ }
+ }
(
lhs @ ScalarValue::Decimal128(lhsv, lhsp, lhss),
rhs @ ScalarValue::Decimal128(rhsv, rhsp, rhss)
@@ -513,6 +539,26 @@ macro_rules! min_max_batch {
($VALUES:expr, $OP:ident) => {{
match $VALUES.data_type() {
DataType::Null => ScalarValue::Null,
+ DataType::Decimal32(precision, scale) => {
+ typed_min_max_batch!(
+ $VALUES,
+ Decimal32Array,
+ Decimal32,
+ $OP,
+ precision,
+ scale
+ )
+ }
+ DataType::Decimal64(precision, scale) => {
+ typed_min_max_batch!(
+ $VALUES,
+ Decimal64Array,
+ Decimal64,
+ $OP,
+ precision,
+ scale
+ )
+ }
DataType::Decimal128(precision, scale) => {
typed_min_max_batch!(
$VALUES,
diff --git a/datafusion/functions-aggregate/src/average.rs
b/datafusion/functions-aggregate/src/average.rs
index 200828dffe..a6a83c24b0 100644
--- a/datafusion/functions-aggregate/src/average.rs
+++ b/datafusion/functions-aggregate/src/average.rs
@@ -24,10 +24,11 @@ use arrow::array::{
use arrow::compute::sum;
use arrow::datatypes::{
- i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type,
DecimalType,
- DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType,
- DurationSecondType, Field, FieldRef, Float64Type, TimeUnit, UInt64Type,
- DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
+ i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type,
Decimal32Type,
+ Decimal64Type, DecimalType, DurationMicrosecondType,
DurationMillisecondType,
+ DurationNanosecondType, DurationSecondType, Field, FieldRef, Float64Type,
TimeUnit,
+ UInt64Type, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
+ DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION,
};
use datafusion_common::{
exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue,
@@ -127,6 +128,22 @@ impl AggregateUDFImpl for Avg {
// Numeric types are converted to Float64 via
`coerce_avg_type` during logical plan creation
(Float64, _) =>
Ok(Box::new(Float64DistinctAvgAccumulator::default())),
+ (
+ Decimal32(_, scale),
+ Decimal32(target_precision, target_scale),
+ ) =>
Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal32Type>::with_decimal_params(
+ *scale,
+ *target_precision,
+ *target_scale,
+ ))),
+ (
+ Decimal64(_, scale),
+ Decimal64(target_precision, target_scale),
+ ) =>
Ok(Box::new(DecimalDistinctAvgAccumulator::<Decimal64Type>::with_decimal_params(
+ *scale,
+ *target_precision,
+ *target_scale,
+ ))),
(
Decimal128(_, scale),
Decimal128(target_precision, target_scale),
@@ -154,6 +171,28 @@ impl AggregateUDFImpl for Avg {
} else {
match (&data_type, acc_args.return_type()) {
(Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
+ (
+ Decimal32(sum_precision, sum_scale),
+ Decimal32(target_precision, target_scale),
+ ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal32Type> {
+ sum: None,
+ count: 0,
+ sum_scale: *sum_scale,
+ sum_precision: *sum_precision,
+ target_precision: *target_precision,
+ target_scale: *target_scale,
+ })),
+ (
+ Decimal64(sum_precision, sum_scale),
+ Decimal64(target_precision, target_scale),
+ ) => Ok(Box::new(DecimalAvgAccumulator::<Decimal64Type> {
+ sum: None,
+ count: 0,
+ sum_scale: *sum_scale,
+ sum_precision: *sum_precision,
+ target_precision: *target_precision,
+ target_scale: *target_scale,
+ })),
(
Decimal128(sum_precision, sum_scale),
Decimal128(target_precision, target_scale),
@@ -199,6 +238,12 @@ impl AggregateUDFImpl for Avg {
// Decimal accumulator actually uses a different precision during
accumulation,
// see DecimalDistinctAvgAccumulator::with_decimal_params
let dt = match args.input_fields[0].data_type() {
+ DataType::Decimal32(_, scale) => {
+ DataType::Decimal32(DECIMAL32_MAX_PRECISION, *scale)
+ }
+ DataType::Decimal64(_, scale) => {
+ DataType::Decimal64(DECIMAL64_MAX_PRECISION, *scale)
+ }
DataType::Decimal128(_, scale) => {
DataType::Decimal128(DECIMAL128_MAX_PRECISION, *scale)
}
@@ -237,7 +282,12 @@ impl AggregateUDFImpl for Avg {
fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
matches!(
args.return_field.data_type(),
- DataType::Float64 | DataType::Decimal128(_, _) |
DataType::Duration(_)
+ DataType::Float64
+ | DataType::Decimal32(_, _)
+ | DataType::Decimal64(_, _)
+ | DataType::Decimal128(_, _)
+ | DataType::Decimal256(_, _)
+ | DataType::Duration(_)
) && !args.is_distinct
}
@@ -257,6 +307,44 @@ impl AggregateUDFImpl for Avg {
|sum: f64, count: u64| Ok(sum / count as f64),
)))
}
+ (
+ Decimal32(_sum_precision, sum_scale),
+ Decimal32(target_precision, target_scale),
+ ) => {
+ let decimal_averager =
DecimalAverager::<Decimal32Type>::try_new(
+ *sum_scale,
+ *target_precision,
+ *target_scale,
+ )?;
+
+ let avg_fn =
+ move |sum: i32, count: u64| decimal_averager.avg(sum,
count as i32);
+
+ Ok(Box::new(AvgGroupsAccumulator::<Decimal32Type, _>::new(
+ &data_type,
+ args.return_field.data_type(),
+ avg_fn,
+ )))
+ }
+ (
+ Decimal64(_sum_precision, sum_scale),
+ Decimal64(target_precision, target_scale),
+ ) => {
+ let decimal_averager =
DecimalAverager::<Decimal64Type>::try_new(
+ *sum_scale,
+ *target_precision,
+ *target_scale,
+ )?;
+
+ let avg_fn =
+ move |sum: i64, count: u64| decimal_averager.avg(sum,
count as i64);
+
+ Ok(Box::new(AvgGroupsAccumulator::<Decimal64Type, _>::new(
+ &data_type,
+ args.return_field.data_type(),
+ avg_fn,
+ )))
+ }
(
Decimal128(_sum_precision, sum_scale),
Decimal128(target_precision, target_scale),
diff --git a/datafusion/functions-aggregate/src/first_last.rs
b/datafusion/functions-aggregate/src/first_last.rs
index 6ef1332ba0..28755427c7 100644
--- a/datafusion/functions-aggregate/src/first_last.rs
+++ b/datafusion/functions-aggregate/src/first_last.rs
@@ -30,12 +30,12 @@ use arrow::array::{
use arrow::buffer::{BooleanBuffer, NullBuffer};
use arrow::compute::{self, LexicographicalComparator, SortColumn, SortOptions};
use arrow::datatypes::{
- DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field,
FieldRef,
- Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type,
Int8Type,
- Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
Time64NanosecondType,
- TimeUnit, TimestampMicrosecondType, TimestampMillisecondType,
- TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type,
UInt64Type,
- UInt8Type,
+ DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type,
Decimal32Type,
+ Decimal64Type, Field, FieldRef, Float16Type, Float32Type, Float64Type,
Int16Type,
+ Int32Type, Int64Type, Int8Type, Time32MillisecondType, Time32SecondType,
+ Time64MicrosecondType, Time64NanosecondType, TimeUnit,
TimestampMicrosecondType,
+ TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
UInt16Type,
+ UInt32Type, UInt64Type, UInt8Type,
};
use datafusion_common::cast::as_boolean_array;
use datafusion_common::utils::{compare_rows, extract_row_at_idx_to_buf,
get_row_at_idx};
@@ -185,6 +185,8 @@ impl AggregateUDFImpl for FirstValue {
| Float16
| Float32
| Float64
+ | Decimal32(_, _)
+ | Decimal64(_, _)
| Decimal128(_, _)
| Decimal256(_, _)
| Date32
@@ -234,6 +236,8 @@ impl AggregateUDFImpl for FirstValue {
DataType::Float32 => create_accumulator::<Float32Type>(args),
DataType::Float64 => create_accumulator::<Float64Type>(args),
+ DataType::Decimal32(_, _) =>
create_accumulator::<Decimal32Type>(args),
+ DataType::Decimal64(_, _) =>
create_accumulator::<Decimal64Type>(args),
DataType::Decimal128(_, _) =>
create_accumulator::<Decimal128Type>(args),
DataType::Decimal256(_, _) =>
create_accumulator::<Decimal256Type>(args),
@@ -1124,6 +1128,8 @@ impl AggregateUDFImpl for LastValue {
| Float16
| Float32
| Float64
+ | Decimal32(_, _)
+ | Decimal64(_, _)
| Decimal128(_, _)
| Decimal256(_, _)
| Date32
@@ -1175,6 +1181,8 @@ impl AggregateUDFImpl for LastValue {
DataType::Float32 => create_accumulator::<Float32Type>(args),
DataType::Float64 => create_accumulator::<Float64Type>(args),
+ DataType::Decimal32(_, _) =>
create_accumulator::<Decimal32Type>(args),
+ DataType::Decimal64(_, _) =>
create_accumulator::<Decimal64Type>(args),
DataType::Decimal128(_, _) =>
create_accumulator::<Decimal128Type>(args),
DataType::Decimal256(_, _) =>
create_accumulator::<Decimal256Type>(args),
diff --git a/datafusion/functions-aggregate/src/median.rs
b/datafusion/functions-aggregate/src/median.rs
index a73ccbd99b..a65759594e 100644
--- a/datafusion/functions-aggregate/src/median.rs
+++ b/datafusion/functions-aggregate/src/median.rs
@@ -35,7 +35,9 @@ use arrow::{
use arrow::array::Array;
use arrow::array::ArrowNativeTypeOp;
-use arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType, FieldRef};
+use arrow::datatypes::{
+ ArrowNativeType, ArrowPrimitiveType, Decimal32Type, Decimal64Type,
FieldRef,
+};
use datafusion_common::{
internal_datafusion_err, internal_err, DataFusionError, HashSet, Result,
ScalarValue,
@@ -166,6 +168,8 @@ impl AggregateUDFImpl for Median {
DataType::Float16 => helper!(Float16Type, dt),
DataType::Float32 => helper!(Float32Type, dt),
DataType::Float64 => helper!(Float64Type, dt),
+ DataType::Decimal32(_, _) => helper!(Decimal32Type, dt),
+ DataType::Decimal64(_, _) => helper!(Decimal64Type, dt),
DataType::Decimal128(_, _) => helper!(Decimal128Type, dt),
DataType::Decimal256(_, _) => helper!(Decimal256Type, dt),
_ => Err(DataFusionError::NotImplemented(format!(
@@ -205,6 +209,8 @@ impl AggregateUDFImpl for Median {
DataType::Float16 => helper!(Float16Type, dt),
DataType::Float32 => helper!(Float32Type, dt),
DataType::Float64 => helper!(Float64Type, dt),
+ DataType::Decimal32(_, _) => helper!(Decimal32Type, dt),
+ DataType::Decimal64(_, _) => helper!(Decimal64Type, dt),
DataType::Decimal128(_, _) => helper!(Decimal128Type, dt),
DataType::Decimal256(_, _) => helper!(Decimal256Type, dt),
_ => Err(DataFusionError::NotImplemented(format!(
diff --git a/datafusion/functions-aggregate/src/min_max.rs
b/datafusion/functions-aggregate/src/min_max.rs
index 639c08706b..1a46afefff 100644
--- a/datafusion/functions-aggregate/src/min_max.rs
+++ b/datafusion/functions-aggregate/src/min_max.rs
@@ -23,10 +23,10 @@ mod min_max_struct;
use arrow::array::ArrayRef;
use arrow::datatypes::{
- DataType, Decimal128Type, Decimal256Type, DurationMicrosecondType,
- DurationMillisecondType, DurationNanosecondType, DurationSecondType,
Float16Type,
- Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
UInt16Type,
- UInt32Type, UInt64Type, UInt8Type,
+ DataType, Decimal128Type, Decimal256Type, Decimal32Type, Decimal64Type,
+ DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType,
+ DurationSecondType, Float16Type, Float32Type, Float64Type, Int16Type,
Int32Type,
+ Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
};
use datafusion_common::stats::Precision;
use datafusion_common::{exec_err, internal_err, ColumnStatistics, Result};
@@ -239,6 +239,8 @@ impl AggregateUDFImpl for Max {
| Float16
| Float32
| Float64
+ | Decimal32(_, _)
+ | Decimal64(_, _)
| Decimal128(_, _)
| Decimal256(_, _)
| Date32
@@ -320,6 +322,12 @@ impl AggregateUDFImpl for Max {
Duration(Nanosecond) => {
primitive_max_accumulator!(data_type, i64,
DurationNanosecondType)
}
+ Decimal32(_, _) => {
+ primitive_max_accumulator!(data_type, i32, Decimal32Type)
+ }
+ Decimal64(_, _) => {
+ primitive_max_accumulator!(data_type, i64, Decimal64Type)
+ }
Decimal128(_, _) => {
primitive_max_accumulator!(data_type, i128, Decimal128Type)
}
@@ -518,6 +526,8 @@ impl AggregateUDFImpl for Min {
| Float16
| Float32
| Float64
+ | Decimal32(_, _)
+ | Decimal64(_, _)
| Decimal128(_, _)
| Decimal256(_, _)
| Date32
@@ -599,6 +609,12 @@ impl AggregateUDFImpl for Min {
Duration(Nanosecond) => {
primitive_min_accumulator!(data_type, i64,
DurationNanosecondType)
}
+ Decimal32(_, _) => {
+ primitive_min_accumulator!(data_type, i32, Decimal32Type)
+ }
+ Decimal64(_, _) => {
+ primitive_min_accumulator!(data_type, i64, Decimal64Type)
+ }
Decimal128(_, _) => {
primitive_min_accumulator!(data_type, i128, Decimal128Type)
}
diff --git a/datafusion/functions-aggregate/src/sum.rs
b/datafusion/functions-aggregate/src/sum.rs
index d974ca22b1..04339fc645 100644
--- a/datafusion/functions-aggregate/src/sum.rs
+++ b/datafusion/functions-aggregate/src/sum.rs
@@ -18,6 +18,8 @@
//! Defines `SUM` and `SUM DISTINCT` aggregate accumulators
use ahash::RandomState;
+use arrow::datatypes::DECIMAL32_MAX_PRECISION;
+use arrow::datatypes::DECIMAL64_MAX_PRECISION;
use datafusion_expr::utils::AggregateOrderSensitivity;
use std::any::Any;
use std::mem::size_of_val;
@@ -27,8 +29,8 @@ use arrow::array::ArrowNativeTypeOp;
use arrow::array::{ArrowNumericType, AsArray};
use arrow::datatypes::{ArrowNativeType, FieldRef};
use arrow::datatypes::{
- DataType, Decimal128Type, Decimal256Type, Float64Type, Int64Type,
UInt64Type,
- DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
+ DataType, Decimal128Type, Decimal256Type, Decimal32Type, Decimal64Type,
Float64Type,
+ Int64Type, UInt64Type, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
};
use arrow::{array::ArrayRef, datatypes::Field};
use datafusion_common::{
@@ -71,6 +73,12 @@ macro_rules! downcast_sum {
DataType::Float64 => {
$helper!(Float64Type, $args.return_field.data_type().clone())
}
+ DataType::Decimal32(_, _) => {
+ $helper!(Decimal32Type, $args.return_field.data_type().clone())
+ }
+ DataType::Decimal64(_, _) => {
+ $helper!(Decimal64Type, $args.return_field.data_type().clone())
+ }
DataType::Decimal128(_, _) => {
$helper!(Decimal128Type,
$args.return_field.data_type().clone())
}
@@ -145,9 +153,10 @@ impl AggregateUDFImpl for Sum {
DataType::Dictionary(_, v) => coerced_type(v),
// in the spark, the result type is
DECIMAL(min(38,precision+10), s)
// ref:
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
- DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => {
- Ok(data_type.clone())
- }
+ DataType::Decimal32(_, _)
+ | DataType::Decimal64(_, _)
+ | DataType::Decimal128(_, _)
+ | DataType::Decimal256(_, _) => Ok(data_type.clone()),
dt if dt.is_signed_integer() => Ok(DataType::Int64),
dt if dt.is_unsigned_integer() => Ok(DataType::UInt64),
dt if dt.is_floating() => Ok(DataType::Float64),
@@ -163,6 +172,18 @@ impl AggregateUDFImpl for Sum {
DataType::Int64 => Ok(DataType::Int64),
DataType::UInt64 => Ok(DataType::UInt64),
DataType::Float64 => Ok(DataType::Float64),
+ DataType::Decimal32(precision, scale) => {
+ // in the spark, the result type is
DECIMAL(min(38,precision+10), s)
+ // ref:
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
+ let new_precision = DECIMAL32_MAX_PRECISION.min(*precision +
10);
+ Ok(DataType::Decimal32(new_precision, *scale))
+ }
+ DataType::Decimal64(precision, scale) => {
+ // in the spark, the result type is
DECIMAL(min(38,precision+10), s)
+ // ref:
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
+ let new_precision = DECIMAL64_MAX_PRECISION.min(*precision +
10);
+ Ok(DataType::Decimal64(new_precision, *scale))
+ }
DataType::Decimal128(precision, scale) => {
// in the spark, the result type is
DECIMAL(min(38,precision+10), s)
// ref:
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
diff --git a/datafusion/proto-common/src/from_proto/mod.rs
b/datafusion/proto-common/src/from_proto/mod.rs
index ea43de5c03..2d07fb8410 100644
--- a/datafusion/proto-common/src/from_proto/mod.rs
+++ b/datafusion/proto-common/src/from_proto/mod.rs
@@ -37,7 +37,6 @@ use datafusion_common::{
TableParquetOptions,
},
file_options::{csv_writer::CsvWriterOptions,
json_writer::JsonWriterOptions},
- not_impl_err,
parsers::CompressionTypeVariant,
plan_datafusion_err,
stats::Precision,
@@ -478,13 +477,13 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
let null_type: DataType = v.try_into()?;
null_type.try_into().map_err(Error::DataFusionError)?
}
- Value::Decimal32Value(_val) => {
- return not_impl_err!("Decimal32 protobuf deserialization")
- .map_err(Error::DataFusionError)
+ Value::Decimal32Value(val) => {
+ let array = vec_to_array(val.value.clone());
+ Self::Decimal32(Some(i32::from_be_bytes(array)), val.p as u8,
val.s as i8)
}
- Value::Decimal64Value(_val) => {
- return not_impl_err!("Decimal64 protobuf deserialization")
- .map_err(Error::DataFusionError)
+ Value::Decimal64Value(val) => {
+ let array = vec_to_array(val.value.clone());
+ Self::Decimal64(Some(i64::from_be_bytes(array)), val.p as u8,
val.s as i8)
}
Value::Decimal128Value(val) => {
let array = vec_to_array(val.value.clone());
diff --git a/datafusion/proto-common/src/to_proto/mod.rs
b/datafusion/proto-common/src/to_proto/mod.rs
index fae1e2b1c6..8e4131479e 100644
--- a/datafusion/proto-common/src/to_proto/mod.rs
+++ b/datafusion/proto-common/src/to_proto/mod.rs
@@ -405,6 +405,42 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue {
})
})
}
+ ScalarValue::Decimal32(val, p, s) => match *val {
+ Some(v) => {
+ let array = v.to_be_bytes();
+ let vec_val: Vec<u8> = array.to_vec();
+ Ok(protobuf::ScalarValue {
+ value: Some(Value::Decimal32Value(protobuf::Decimal32 {
+ value: vec_val,
+ p: *p as i64,
+ s: *s as i64,
+ })),
+ })
+ }
+ None => Ok(protobuf::ScalarValue {
+ value: Some(protobuf::scalar_value::Value::NullValue(
+ (&data_type).try_into()?,
+ )),
+ }),
+ },
+ ScalarValue::Decimal64(val, p, s) => match *val {
+ Some(v) => {
+ let array = v.to_be_bytes();
+ let vec_val: Vec<u8> = array.to_vec();
+ Ok(protobuf::ScalarValue {
+ value: Some(Value::Decimal64Value(protobuf::Decimal64 {
+ value: vec_val,
+ p: *p as i64,
+ s: *s as i64,
+ })),
+ })
+ }
+ None => Ok(protobuf::ScalarValue {
+ value: Some(protobuf::scalar_value::Value::NullValue(
+ (&data_type).try_into()?,
+ )),
+ }),
+ },
ScalarValue::Decimal128(val, p, s) => match *val {
Some(v) => {
let array = v.to_be_bytes();
diff --git a/datafusion/sql/src/unparser/expr.rs
b/datafusion/sql/src/unparser/expr.rs
index 05d232120d..493a9f2ae6 100644
--- a/datafusion/sql/src/unparser/expr.rs
+++ b/datafusion/sql/src/unparser/expr.rs
@@ -35,7 +35,9 @@ use arrow::array::{
},
ArrayRef, Date32Array, Date64Array, PrimitiveArray,
};
-use arrow::datatypes::{DataType, Decimal128Type, Decimal256Type, DecimalType};
+use arrow::datatypes::{
+ DataType, Decimal128Type, Decimal256Type, Decimal32Type, Decimal64Type,
DecimalType,
+};
use arrow::util::display::array_value_to_string;
use datafusion_common::{
internal_datafusion_err, internal_err, not_impl_err, plan_err, Column,
Result,
@@ -1182,6 +1184,20 @@ impl Unparser<'_> {
Ok(ast::Expr::value(ast::Value::Number(f_val, false)))
}
ScalarValue::Float64(None) =>
Ok(ast::Expr::value(ast::Value::Null)),
+ ScalarValue::Decimal32(Some(value), precision, scale) => {
+ Ok(ast::Expr::value(ast::Value::Number(
+ Decimal32Type::format_decimal(*value, *precision, *scale),
+ false,
+ )))
+ }
+ ScalarValue::Decimal32(None, ..) =>
Ok(ast::Expr::value(ast::Value::Null)),
+ ScalarValue::Decimal64(Some(value), precision, scale) => {
+ Ok(ast::Expr::value(ast::Value::Number(
+ Decimal64Type::format_decimal(*value, *precision, *scale),
+ false,
+ )))
+ }
+ ScalarValue::Decimal64(None, ..) =>
Ok(ast::Expr::value(ast::Value::Null)),
ScalarValue::Decimal128(Some(value), precision, scale) => {
Ok(ast::Expr::value(ast::Value::Number(
Decimal128Type::format_decimal(*value, *precision, *scale),
@@ -1726,13 +1742,9 @@ impl Unparser<'_> {
not_impl_err!("Unsupported DataType: conversion: {data_type}")
}
DataType::Dictionary(_, val) => self.arrow_dtype_to_ast_dtype(val),
- DataType::Decimal32(_precision, _scale) => {
- not_impl_err!("Unsupported DataType: conversion: {data_type}")
- }
- DataType::Decimal64(_precision, _scale) => {
- not_impl_err!("Unsupported DataType: conversion: {data_type}")
- }
- DataType::Decimal128(precision, scale)
+ DataType::Decimal32(precision, scale)
+ | DataType::Decimal64(precision, scale)
+ | DataType::Decimal128(precision, scale)
| DataType::Decimal256(precision, scale) => {
let mut new_precision = *precision as u64;
let mut new_scale = *scale as u64;
@@ -2179,6 +2191,20 @@ mod tests {
(col("need-quoted").eq(lit(1)), r#"("need-quoted" = 1)"#),
(col("need quoted").eq(lit(1)), r#"("need quoted" = 1)"#),
// See test_interval_scalar_to_expr for interval literals
+ (
+ (col("a") + col("b")).gt(Expr::Literal(
+ ScalarValue::Decimal32(Some(1123), 4, 3),
+ None,
+ )),
+ r#"((a + b) > 1.123)"#,
+ ),
+ (
+ (col("a") + col("b")).gt(Expr::Literal(
+ ScalarValue::Decimal64(Some(1123), 4, 3),
+ None,
+ )),
+ r#"((a + b) > 1.123)"#,
+ ),
(
(col("a") + col("b")).gt(Expr::Literal(
ScalarValue::Decimal128(Some(100123), 28, 3),
diff --git a/test-utils/src/array_gen/random_data.rs
b/test-utils/src/array_gen/random_data.rs
index 78518b7bf9..ea2b872f7d 100644
--- a/test-utils/src/array_gen/random_data.rs
+++ b/test-utils/src/array_gen/random_data.rs
@@ -17,12 +17,12 @@
use arrow::array::ArrowPrimitiveType;
use arrow::datatypes::{
- i256, Date32Type, Date64Type, Decimal128Type, Decimal256Type,
- DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType,
- DurationSecondType, Float32Type, Float64Type, Int16Type, Int32Type,
Int64Type,
- Int8Type, IntervalDayTime, IntervalDayTimeType, IntervalMonthDayNano,
- IntervalMonthDayNanoType, IntervalYearMonthType, Time32MillisecondType,
- Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
+ i256, Date32Type, Date64Type, Decimal128Type, Decimal256Type,
Decimal32Type,
+ Decimal64Type, DurationMicrosecondType, DurationMillisecondType,
+ DurationNanosecondType, DurationSecondType, Float32Type, Float64Type,
Int16Type,
+ Int32Type, Int64Type, Int8Type, IntervalDayTime, IntervalDayTimeType,
+ IntervalMonthDayNano, IntervalMonthDayNanoType, IntervalYearMonthType,
+ Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
Time64NanosecondType,
TimestampMicrosecondType, TimestampMillisecondType,
TimestampNanosecondType,
TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
};
@@ -67,6 +67,8 @@ basic_random_data!(Time32MillisecondType);
basic_random_data!(Time64MicrosecondType);
basic_random_data!(Time64NanosecondType);
basic_random_data!(IntervalYearMonthType);
+basic_random_data!(Decimal32Type);
+basic_random_data!(Decimal64Type);
basic_random_data!(Decimal128Type);
basic_random_data!(TimestampSecondType);
basic_random_data!(TimestampMillisecondType);
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]