This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new ed045d9895 Add Decimal256 to `ScalarValue` (#7048)
ed045d9895 is described below

commit ed045d989501946d9a73d8e1c3b884f279a0a00d
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Tue Jul 25 05:22:57 2023 -0700

    Add Decimal256 to `ScalarValue` (#7048)
    
    * Initial Support Decimal256 ScalarValue
    
    * Add Decimal256 to proto
    
    * Update protobuf code
    
    * Add Decimal256 to from_proto
    
    * Update datafusion/expr/src/type_coercion/aggregates.rs
    
    Co-authored-by: Daniël Heres <[email protected]>
    
    ---------
    
    Co-authored-by: Daniël Heres <[email protected]>
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/common/src/cast.rs                      |   6 +
 datafusion/common/src/scalar.rs                    | 161 +++++++++++++++++++--
 .../sqllogictests/test_files/arrow_typeof.slt      |  22 ++-
 .../tests/sqllogictests/test_files/decimal.slt     |   9 ++
 datafusion/expr/src/type_coercion/aggregates.rs    |  19 +++
 datafusion/physical-expr/src/aggregate/average.rs  |   4 +-
 datafusion/physical-expr/src/aggregate/sum.rs      |  22 ++-
 datafusion/proto/proto/datafusion.proto            |   8 +
 datafusion/proto/src/generated/pbjson.rs           | 145 +++++++++++++++++++
 datafusion/proto/src/generated/prost.rs            |  14 +-
 datafusion/proto/src/logical_plan/from_proto.rs    |  10 +-
 datafusion/proto/src/logical_plan/to_proto.rs      |  18 +++
 12 files changed, 411 insertions(+), 27 deletions(-)

diff --git a/datafusion/common/src/cast.rs b/datafusion/common/src/cast.rs
index 04ae32ec35..4356f36b18 100644
--- a/datafusion/common/src/cast.rs
+++ b/datafusion/common/src/cast.rs
@@ -34,6 +34,7 @@ use arrow::{
     },
     datatypes::{ArrowDictionaryKeyType, ArrowPrimitiveType},
 };
+use arrow_array::Decimal256Array;
 
 // Downcast ArrayRef to Date32Array
 pub fn as_date32_array(array: &dyn Array) -> Result<&Date32Array> {
@@ -65,6 +66,11 @@ pub fn as_decimal128_array(array: &dyn Array) -> 
Result<&Decimal128Array> {
     Ok(downcast_value!(array, Decimal128Array))
 }
 
+// Downcast ArrayRef to Decimal256Array
+pub fn as_decimal256_array(array: &dyn Array) -> Result<&Decimal256Array> {
+    Ok(downcast_value!(array, Decimal256Array))
+}
+
 // Downcast ArrayRef to Float32Array
 pub fn as_float32_array(array: &dyn Array) -> Result<&Float32Array> {
     Ok(downcast_value!(array, Float32Array))
diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs
index 99ff5f3384..4a7767023f 100644
--- a/datafusion/common/src/scalar.rs
+++ b/datafusion/common/src/scalar.rs
@@ -26,14 +26,14 @@ use std::str::FromStr;
 use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc};
 
 use crate::cast::{
-    as_decimal128_array, as_dictionary_array, as_fixed_size_binary_array,
-    as_fixed_size_list_array, as_list_array, as_struct_array,
+    as_decimal128_array, as_decimal256_array, as_dictionary_array,
+    as_fixed_size_binary_array, as_fixed_size_list_array, as_list_array, 
as_struct_array,
 };
 use crate::delta::shift_months;
 use crate::error::{DataFusionError, Result};
 use arrow::buffer::NullBuffer;
 use arrow::compute::nullif;
-use arrow::datatypes::{FieldRef, Fields, SchemaBuilder};
+use arrow::datatypes::{i256, FieldRef, Fields, SchemaBuilder};
 use arrow::{
     array::*,
     compute::kernels::cast::{cast_with_options, CastOptions},
@@ -47,6 +47,7 @@ use arrow::{
     },
 };
 use arrow_array::timezone::Tz;
+use arrow_array::ArrowNativeTypeOp;
 use chrono::{Datelike, Duration, NaiveDate, NaiveDateTime};
 
 // Constants we use throughout this file:
@@ -75,6 +76,8 @@ pub enum ScalarValue {
     Float64(Option<f64>),
     /// 128bit decimal, using the i128 to represent the decimal, precision 
scale
     Decimal128(Option<i128>, u8, i8),
+    /// 256bit decimal, using the i256 to represent the decimal, precision 
scale
+    Decimal256(Option<i256>, u8, i8),
     /// signed 8bit int
     Int8(Option<i8>),
     /// signed 16bit int
@@ -160,6 +163,10 @@ impl PartialEq for ScalarValue {
                 v1.eq(v2) && p1.eq(p2) && s1.eq(s2)
             }
             (Decimal128(_, _, _), _) => false,
+            (Decimal256(v1, p1, s1), Decimal256(v2, p2, s2)) => {
+                v1.eq(v2) && p1.eq(p2) && s1.eq(s2)
+            }
+            (Decimal256(_, _, _), _) => false,
             (Boolean(v1), Boolean(v2)) => v1.eq(v2),
             (Boolean(_), _) => false,
             (Float32(v1), Float32(v2)) => match (v1, v2) {
@@ -283,6 +290,15 @@ impl PartialOrd for ScalarValue {
                 }
             }
             (Decimal128(_, _, _), _) => None,
+            (Decimal256(v1, p1, s1), Decimal256(v2, p2, s2)) => {
+                if p1.eq(p2) && s1.eq(s2) {
+                    v1.partial_cmp(v2)
+                } else {
+                    // Two decimal values can be compared if they have the 
same precision and scale.
+                    None
+                }
+            }
+            (Decimal256(_, _, _), _) => None,
             (Boolean(v1), Boolean(v2)) => v1.partial_cmp(v2),
             (Boolean(_), _) => None,
             (Float32(v1), Float32(v2)) => match (v1, v2) {
@@ -1038,6 +1054,7 @@ macro_rules! impl_op_arithmetic {
                 get_sign!($OPERATION),
                 true,
             )))),
+            // todo: Add Decimal256 support
             _ => Err(DataFusionError::Internal(format!(
                 "Operator {} is not implemented for types {:?} and {:?}",
                 stringify!($OPERATION),
@@ -1516,6 +1533,11 @@ impl std::hash::Hash for ScalarValue {
                 p.hash(state);
                 s.hash(state)
             }
+            Decimal256(v, p, s) => {
+                v.hash(state);
+                p.hash(state);
+                s.hash(state)
+            }
             Boolean(v) => v.hash(state),
             Float32(v) => v.map(Fl).hash(state),
             Float64(v) => v.map(Fl).hash(state),
@@ -1994,6 +2016,9 @@ impl ScalarValue {
             ScalarValue::Decimal128(_, precision, scale) => {
                 DataType::Decimal128(*precision, *scale)
             }
+            ScalarValue::Decimal256(_, precision, scale) => {
+                DataType::Decimal256(*precision, *scale)
+            }
             ScalarValue::TimestampSecond(_, tz_opt) => {
                 DataType::Timestamp(TimeUnit::Second, tz_opt.clone())
             }
@@ -2083,6 +2108,9 @@ impl ScalarValue {
             ScalarValue::Decimal128(Some(v), precision, scale) => {
                 Ok(ScalarValue::Decimal128(Some(-v), *precision, *scale))
             }
+            ScalarValue::Decimal256(Some(v), precision, scale) => Ok(
+                ScalarValue::Decimal256(Some(v.neg_wrapping()), *precision, 
*scale),
+            ),
             value => Err(DataFusionError::Internal(format!(
                 "Can not run arithmetic negative on scalar value {value:?}"
             ))),
@@ -2154,6 +2182,7 @@ impl ScalarValue {
             ScalarValue::Float32(v) => v.is_none(),
             ScalarValue::Float64(v) => v.is_none(),
             ScalarValue::Decimal128(v, _, _) => v.is_none(),
+            ScalarValue::Decimal256(v, _, _) => v.is_none(),
             ScalarValue::Int8(v) => v.is_none(),
             ScalarValue::Int16(v) => v.is_none(),
             ScalarValue::Int32(v) => v.is_none(),
@@ -2415,10 +2444,10 @@ impl ScalarValue {
                     ScalarValue::iter_to_decimal_array(scalars, *precision, 
*scale)?;
                 Arc::new(decimal_array)
             }
-            DataType::Decimal256(_, _) => {
-                return Err(DataFusionError::Internal(
-                    "Decimal256 is not supported for ScalarValue".to_string(),
-                ));
+            DataType::Decimal256(precision, scale) => {
+                let decimal_array =
+                    ScalarValue::iter_to_decimal256_array(scalars, *precision, 
*scale)?;
+                Arc::new(decimal_array)
             }
             DataType::Null => ScalarValue::iter_to_null_array(scalars),
             DataType::Boolean => build_array_primitive!(BooleanArray, Boolean),
@@ -2680,6 +2709,22 @@ impl ScalarValue {
         Ok(array)
     }
 
+    fn iter_to_decimal256_array(
+        scalars: impl IntoIterator<Item = ScalarValue>,
+        precision: u8,
+        scale: i8,
+    ) -> Result<Decimal256Array> {
+        let array = scalars
+            .into_iter()
+            .map(|element: ScalarValue| match element {
+                ScalarValue::Decimal256(v1, _, _) => v1,
+                _ => unreachable!(),
+            })
+            .collect::<Decimal256Array>()
+            .with_precision_and_scale(precision, scale)?;
+        Ok(array)
+    }
+
     fn iter_to_array_list(
         scalars: impl IntoIterator<Item = ScalarValue>,
         data_type: &DataType,
@@ -2764,12 +2809,28 @@ impl ScalarValue {
         }
     }
 
+    fn build_decimal256_array(
+        value: Option<i256>,
+        precision: u8,
+        scale: i8,
+        size: usize,
+    ) -> Decimal256Array {
+        std::iter::repeat(value)
+            .take(size)
+            .collect::<Decimal256Array>()
+            .with_precision_and_scale(precision, scale)
+            .unwrap()
+    }
+
     /// Converts a scalar value into an array of `size` rows.
     pub fn to_array_of_size(&self, size: usize) -> ArrayRef {
         match self {
             ScalarValue::Decimal128(e, precision, scale) => Arc::new(
                 ScalarValue::build_decimal_array(*e, *precision, *scale, size),
             ),
+            ScalarValue::Decimal256(e, precision, scale) => Arc::new(
+                ScalarValue::build_decimal256_array(*e, *precision, *scale, 
size),
+            ),
             ScalarValue::Boolean(e) => {
                 Arc::new(BooleanArray::from(vec![*e; size])) as ArrayRef
             }
@@ -3044,12 +3105,28 @@ impl ScalarValue {
         precision: u8,
         scale: i8,
     ) -> Result<ScalarValue> {
-        let array = as_decimal128_array(array)?;
-        if array.is_null(index) {
-            Ok(ScalarValue::Decimal128(None, precision, scale))
-        } else {
-            let value = array.value(index);
-            Ok(ScalarValue::Decimal128(Some(value), precision, scale))
+        match array.data_type() {
+            DataType::Decimal128(_, _) => {
+                let array = as_decimal128_array(array)?;
+                if array.is_null(index) {
+                    Ok(ScalarValue::Decimal128(None, precision, scale))
+                } else {
+                    let value = array.value(index);
+                    Ok(ScalarValue::Decimal128(Some(value), precision, scale))
+                }
+            }
+            DataType::Decimal256(_, _) => {
+                let array = as_decimal256_array(array)?;
+                if array.is_null(index) {
+                    Ok(ScalarValue::Decimal256(None, precision, scale))
+                } else {
+                    let value = array.value(index);
+                    Ok(ScalarValue::Decimal256(Some(value), precision, scale))
+                }
+            }
+            _ => Err(DataFusionError::Internal(
+                "Unsupported decimal type".to_string(),
+            )),
         }
     }
 
@@ -3067,6 +3144,11 @@ impl ScalarValue {
                     array, index, *precision, *scale,
                 )?
             }
+            DataType::Decimal256(precision, scale) => {
+                ScalarValue::get_decimal_value_from_array(
+                    array, index, *precision, *scale,
+                )?
+            }
             DataType::Boolean => typed_cast!(array, index, BooleanArray, 
Boolean),
             DataType::Float64 => typed_cast!(array, index, Float64Array, 
Float64),
             DataType::Float32 => typed_cast!(array, index, Float32Array, 
Float32),
@@ -3265,6 +3347,25 @@ impl ScalarValue {
         }
     }
 
+    fn eq_array_decimal256(
+        array: &ArrayRef,
+        index: usize,
+        value: Option<&i256>,
+        precision: u8,
+        scale: i8,
+    ) -> Result<bool> {
+        let array = as_decimal256_array(array)?;
+        if array.precision() != precision || array.scale() != scale {
+            return Ok(false);
+        }
+        let is_null = array.is_null(index);
+        if let Some(v) = value {
+            Ok(!array.is_null(index) && array.value(index) == *v)
+        } else {
+            Ok(is_null)
+        }
+    }
+
     /// Compares a single row of array @ index for equality with self,
     /// in an optimized fashion.
     ///
@@ -3294,6 +3395,16 @@ impl ScalarValue {
                 )
                 .unwrap()
             }
+            ScalarValue::Decimal256(v, precision, scale) => {
+                ScalarValue::eq_array_decimal256(
+                    array,
+                    index,
+                    v.as_ref(),
+                    *precision,
+                    *scale,
+                )
+                .unwrap()
+            }
             ScalarValue::Boolean(val) => {
                 eq_array_primitive!(array, index, BooleanArray, val)
             }
@@ -3416,6 +3527,7 @@ impl ScalarValue {
                 | ScalarValue::Float32(_)
                 | ScalarValue::Float64(_)
                 | ScalarValue::Decimal128(_, _, _)
+                | ScalarValue::Decimal256(_, _, _)
                 | ScalarValue::Int8(_)
                 | ScalarValue::Int16(_)
                 | ScalarValue::Int32(_)
@@ -3647,6 +3759,22 @@ impl TryFrom<ScalarValue> for i128 {
     }
 }
 
+// special implementation for i256 because of Decimal128
+impl TryFrom<ScalarValue> for i256 {
+    type Error = DataFusionError;
+
+    fn try_from(value: ScalarValue) -> Result<Self> {
+        match value {
+            ScalarValue::Decimal256(Some(inner_value), _, _) => 
Ok(inner_value),
+            _ => Err(DataFusionError::Internal(format!(
+                "Cannot convert {:?} to {}",
+                value,
+                std::any::type_name::<Self>()
+            ))),
+        }
+    }
+}
+
 impl_try_from!(UInt8, u8);
 impl_try_from!(UInt16, u16);
 impl_try_from!(UInt32, u32);
@@ -3684,6 +3812,9 @@ impl TryFrom<&DataType> for ScalarValue {
             DataType::Decimal128(precision, scale) => {
                 ScalarValue::Decimal128(None, *precision, *scale)
             }
+            DataType::Decimal256(precision, scale) => {
+                ScalarValue::Decimal256(None, *precision, *scale)
+            }
             DataType::Utf8 => ScalarValue::Utf8(None),
             DataType::LargeUtf8 => ScalarValue::LargeUtf8(None),
             DataType::Binary => ScalarValue::Binary(None),
@@ -3753,6 +3884,9 @@ impl fmt::Display for ScalarValue {
             ScalarValue::Decimal128(v, p, s) => {
                 write!(f, "{v:?},{p:?},{s:?}")?;
             }
+            ScalarValue::Decimal256(v, p, s) => {
+                write!(f, "{v:?},{p:?},{s:?}")?;
+            }
             ScalarValue::Boolean(e) => format_option!(f, e)?,
             ScalarValue::Float32(e) => format_option!(f, e)?,
             ScalarValue::Float64(e) => format_option!(f, e)?,
@@ -3830,6 +3964,7 @@ impl fmt::Debug for ScalarValue {
     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
         match self {
             ScalarValue::Decimal128(_, _, _) => write!(f, 
"Decimal128({self})"),
+            ScalarValue::Decimal256(_, _, _) => write!(f, 
"Decimal256({self})"),
             ScalarValue::Boolean(_) => write!(f, "Boolean({self})"),
             ScalarValue::Float32(_) => write!(f, "Float32({self})"),
             ScalarValue::Float64(_) => write!(f, "Float64({self})"),
diff --git a/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt 
b/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt
index 4a3d39bdeb..5c82c7e009 100644
--- a/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/arrow_typeof.slt
@@ -180,23 +180,29 @@ drop table foo
 
 statement ok
 create table foo as select
-  arrow_cast(100, 'Decimal128(5,2)') as col_d128
-  -- Can't make a decimal 156:
-  -- This feature is not implemented: Can't create a scalar from array of type 
"Decimal256(3, 2)"
-  --arrow_cast(100, 'Decimal256(5,2)') as col_d256
+  arrow_cast(100, 'Decimal128(5,2)') as col_d128,
+  arrow_cast(100, 'Decimal256(5,2)') as col_d256
 ;
 
 
 ## Ensure each column in the table has the expected type
 
-query T
+query TT
 SELECT
-  arrow_typeof(col_d128)
-  -- arrow_typeof(col_d256),
+  arrow_typeof(col_d128),
+  arrow_typeof(col_d256)
   FROM foo;
 ----
-Decimal128(5, 2)
+Decimal128(5, 2) Decimal256(5, 2)
+
 
+query RR
+SELECT
+  col_d128,
+  col_d256
+  FROM foo;
+----
+100 100.00
 
 statement ok
 drop table foo
diff --git a/datafusion/core/tests/sqllogictests/test_files/decimal.slt 
b/datafusion/core/tests/sqllogictests/test_files/decimal.slt
index f413517741..8fd08f87c8 100644
--- a/datafusion/core/tests/sqllogictests/test_files/decimal.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/decimal.slt
@@ -612,3 +612,12 @@ insert into foo VALUES (1, 5);
 
 query error DataFusion error: Arrow error: Compute error: Overflow happened 
on: 100000000000000000000 \* 100000000000000000000000000000000000000
 select a / b from foo;
+
+statement ok
+create table t as values (arrow_cast(123, 'Decimal256(5,2)'));
+
+query error DataFusion error: Internal error: Operator \+ is not implemented 
for types Decimal256\(None,15,2\) and Decimal256\(Some\(12300\),15,2\)\. This 
was likely caused by a bug in DataFusion's code and we would welcome that you 
file an bug report in our issue tracker
+select AVG(column1) from t;
+
+statement ok
+drop table t;
diff --git a/datafusion/expr/src/type_coercion/aggregates.rs 
b/datafusion/expr/src/type_coercion/aggregates.rs
index 1fccdcbd2c..dec2eb7f12 100644
--- a/datafusion/expr/src/type_coercion/aggregates.rs
+++ b/datafusion/expr/src/type_coercion/aggregates.rs
@@ -17,6 +17,7 @@
 
 use arrow::datatypes::{
     DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
+    DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
 };
 use datafusion_common::{DataFusionError, Result};
 use std::ops::Deref;
@@ -360,6 +361,12 @@ pub fn sum_return_type(arg_type: &DataType) -> 
Result<DataType> {
             let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
             Ok(DataType::Decimal128(new_precision, *scale))
         }
+        DataType::Decimal256(precision, scale) => {
+            // in the spark, the result type is DECIMAL(min(38,precision+10), 
s)
+            // ref: 
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
+            let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
+            Ok(DataType::Decimal256(new_precision, *scale))
+        }
         DataType::Dictionary(_, dict_value_type) => {
             sum_return_type(dict_value_type.as_ref())
         }
@@ -423,6 +430,13 @@ pub fn avg_return_type(arg_type: &DataType) -> 
Result<DataType> {
             let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4);
             Ok(DataType::Decimal128(new_precision, new_scale))
         }
+        DataType::Decimal256(precision, scale) => {
+            // in the spark, the result type is DECIMAL(min(38,precision+4), 
min(38,scale+4)).
+            // ref: 
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
+            let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 4);
+            let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4);
+            Ok(DataType::Decimal256(new_precision, new_scale))
+        }
         arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64),
         DataType::Dictionary(_, dict_value_type) => {
             avg_return_type(dict_value_type.as_ref())
@@ -441,6 +455,11 @@ pub fn avg_sum_type(arg_type: &DataType) -> 
Result<DataType> {
             let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
             Ok(DataType::Decimal128(new_precision, *scale))
         }
+        DataType::Decimal256(precision, scale) => {
+            // in Spark the sum type of avg is DECIMAL(min(38,precision+10), s)
+            let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
+            Ok(DataType::Decimal256(new_precision, *scale))
+        }
         arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64),
         DataType::Dictionary(_, dict_value_type) => {
             avg_sum_type(dict_value_type.as_ref())
diff --git a/datafusion/physical-expr/src/aggregate/average.rs 
b/datafusion/physical-expr/src/aggregate/average.rs
index a1d77a2d88..9c01093edf 100644
--- a/datafusion/physical-expr/src/aggregate/average.rs
+++ b/datafusion/physical-expr/src/aggregate/average.rs
@@ -77,12 +77,12 @@ impl Avg {
         // the internal sum data type of avg just support FLOAT64 and Decimal 
data type.
         assert!(matches!(
             sum_data_type,
-            DataType::Float64 | DataType::Decimal128(_, _)
+            DataType::Float64 | DataType::Decimal128(_, _) | 
DataType::Decimal256(_, _)
         ));
         // the result of avg just support FLOAT64 and Decimal data type.
         assert!(matches!(
             rt_data_type,
-            DataType::Float64 | DataType::Decimal128(_, _)
+            DataType::Float64 | DataType::Decimal128(_, _) | 
DataType::Decimal256(_, _)
         ));
         Self {
             name: name.into(),
diff --git a/datafusion/physical-expr/src/aggregate/sum.rs 
b/datafusion/physical-expr/src/aggregate/sum.rs
index 45e2be7fb4..9ac90cef4b 100644
--- a/datafusion/physical-expr/src/aggregate/sum.rs
+++ b/datafusion/physical-expr/src/aggregate/sum.rs
@@ -28,6 +28,7 @@ use crate::expressions::format_state_name;
 use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr};
 use arrow::array::Array;
 use arrow::array::Decimal128Array;
+use arrow::array::Decimal256Array;
 use arrow::compute;
 use arrow::compute::kernels::cast;
 use arrow::datatypes::DataType;
@@ -39,8 +40,8 @@ use arrow::{
     datatypes::Field,
 };
 use arrow_array::types::{
-    Decimal128Type, Float32Type, Float64Type, Int32Type, Int64Type, UInt32Type,
-    UInt64Type,
+    Decimal128Type, Decimal256Type, Float32Type, Float64Type, Int32Type, 
Int64Type,
+    UInt32Type, UInt64Type,
 };
 use datafusion_common::{downcast_value, DataFusionError, Result, ScalarValue};
 use datafusion_expr::Accumulator;
@@ -169,6 +170,10 @@ impl AggregateExpr for Sum {
                 instantiate_primitive_accumulator!(self, Decimal128Type, |x, 
y| x
                     .add_assign(y))
             }
+            DataType::Decimal256(_, _) => {
+                instantiate_primitive_accumulator!(self, Decimal256Type, |x, 
y| *x =
+                    *x + y)
+            }
             _ => Err(DataFusionError::NotImplemented(format!(
                 "GroupsAccumulator not supported for {}: {}",
                 self.name, self.data_type
@@ -250,6 +255,16 @@ fn sum_decimal_batch(values: &ArrayRef, precision: u8, 
scale: i8) -> Result<Scal
     Ok(ScalarValue::Decimal128(result, precision, scale))
 }
 
+fn sum_decimal256_batch(
+    values: &ArrayRef,
+    precision: u8,
+    scale: i8,
+) -> Result<ScalarValue> {
+    let array = downcast_value!(values, Decimal256Array);
+    let result = compute::sum(array);
+    Ok(ScalarValue::Decimal256(result, precision, scale))
+}
+
 // sums the array and returns a ScalarValue of its corresponding type.
 pub(crate) fn sum_batch(values: &ArrayRef, sum_type: &DataType) -> 
Result<ScalarValue> {
     // TODO refine the cast kernel in arrow-rs
@@ -263,6 +278,9 @@ pub(crate) fn sum_batch(values: &ArrayRef, sum_type: 
&DataType) -> Result<Scalar
         DataType::Decimal128(precision, scale) => {
             sum_decimal_batch(values, *precision, *scale)?
         }
+        DataType::Decimal256(precision, scale) => {
+            sum_decimal256_batch(values, *precision, *scale)?
+        }
         DataType::Float64 => typed_sum_delta_batch!(values, Float64Array, 
Float64),
         DataType::Float32 => typed_sum_delta_batch!(values, Float32Array, 
Float32),
         DataType::Int64 => typed_sum_delta_batch!(values, Int64Array, Int64),
diff --git a/datafusion/proto/proto/datafusion.proto 
b/datafusion/proto/proto/datafusion.proto
index 8192a403d3..f7247effdd 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -908,6 +908,8 @@ message ScalarValue{
     //WAS: ScalarType null_list_value = 18;
 
     Decimal128 decimal128_value = 20;
+    Decimal256 decimal256_value = 39;
+
     int64 date_64_value = 21;
     int32 interval_yearmonth_value = 24;
     int64 interval_daytime_value = 25;
@@ -934,6 +936,12 @@ message Decimal128{
   int64 s = 3;
 }
 
+message Decimal256{
+  bytes value = 1;
+  int64 p = 2;
+  int64 s = 3;
+}
+
 // Serialized data type
 message ArrowType{
   oneof arrow_type_enum {
diff --git a/datafusion/proto/src/generated/pbjson.rs 
b/datafusion/proto/src/generated/pbjson.rs
index 05bfbd089d..aaf6bb97bb 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -4983,6 +4983,137 @@ impl<'de> serde::Deserialize<'de> for Decimal128 {
         deserializer.deserialize_struct("datafusion.Decimal128", FIELDS, 
GeneratedVisitor)
     }
 }
+impl serde::Serialize for Decimal256 {
+    #[allow(deprecated)]
+    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, 
S::Error>
+    where
+        S: serde::Serializer,
+    {
+        use serde::ser::SerializeStruct;
+        let mut len = 0;
+        if !self.value.is_empty() {
+            len += 1;
+        }
+        if self.p != 0 {
+            len += 1;
+        }
+        if self.s != 0 {
+            len += 1;
+        }
+        let mut struct_ser = 
serializer.serialize_struct("datafusion.Decimal256", len)?;
+        if !self.value.is_empty() {
+            struct_ser.serialize_field("value", 
pbjson::private::base64::encode(&self.value).as_str())?;
+        }
+        if self.p != 0 {
+            struct_ser.serialize_field("p", 
ToString::to_string(&self.p).as_str())?;
+        }
+        if self.s != 0 {
+            struct_ser.serialize_field("s", 
ToString::to_string(&self.s).as_str())?;
+        }
+        struct_ser.end()
+    }
+}
+impl<'de> serde::Deserialize<'de> for Decimal256 {
+    #[allow(deprecated)]
+    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+    where
+        D: serde::Deserializer<'de>,
+    {
+        const FIELDS: &[&str] = &[
+            "value",
+            "p",
+            "s",
+        ];
+
+        #[allow(clippy::enum_variant_names)]
+        enum GeneratedField {
+            Value,
+            P,
+            S,
+        }
+        impl<'de> serde::Deserialize<'de> for GeneratedField {
+            fn deserialize<D>(deserializer: D) -> 
std::result::Result<GeneratedField, D::Error>
+            where
+                D: serde::Deserializer<'de>,
+            {
+                struct GeneratedVisitor;
+
+                impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+                    type Value = GeneratedField;
+
+                    fn expecting(&self, formatter: &mut 
std::fmt::Formatter<'_>) -> std::fmt::Result {
+                        write!(formatter, "expected one of: {:?}", &FIELDS)
+                    }
+
+                    #[allow(unused_variables)]
+                    fn visit_str<E>(self, value: &str) -> 
std::result::Result<GeneratedField, E>
+                    where
+                        E: serde::de::Error,
+                    {
+                        match value {
+                            "value" => Ok(GeneratedField::Value),
+                            "p" => Ok(GeneratedField::P),
+                            "s" => Ok(GeneratedField::S),
+                            _ => Err(serde::de::Error::unknown_field(value, 
FIELDS)),
+                        }
+                    }
+                }
+                deserializer.deserialize_identifier(GeneratedVisitor)
+            }
+        }
+        struct GeneratedVisitor;
+        impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+            type Value = Decimal256;
+
+            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> 
std::fmt::Result {
+                formatter.write_str("struct datafusion.Decimal256")
+            }
+
+            fn visit_map<V>(self, mut map: V) -> 
std::result::Result<Decimal256, V::Error>
+                where
+                    V: serde::de::MapAccess<'de>,
+            {
+                let mut value__ = None;
+                let mut p__ = None;
+                let mut s__ = None;
+                while let Some(k) = map.next_key()? {
+                    match k {
+                        GeneratedField::Value => {
+                            if value__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("value"));
+                            }
+                            value__ = 
+                                
Some(map.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0)
+                            ;
+                        }
+                        GeneratedField::P => {
+                            if p__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("p"));
+                            }
+                            p__ = 
+                                
Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0)
+                            ;
+                        }
+                        GeneratedField::S => {
+                            if s__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("s"));
+                            }
+                            s__ = 
+                                
Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0)
+                            ;
+                        }
+                    }
+                }
+                Ok(Decimal256 {
+                    value: value__.unwrap_or_default(),
+                    p: p__.unwrap_or_default(),
+                    s: s__.unwrap_or_default(),
+                })
+            }
+        }
+        deserializer.deserialize_struct("datafusion.Decimal256", FIELDS, 
GeneratedVisitor)
+    }
+}
 impl serde::Serialize for DfField {
     #[allow(deprecated)]
     fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, 
S::Error>
@@ -19125,6 +19256,9 @@ impl serde::Serialize for ScalarValue {
                 scalar_value::Value::Decimal128Value(v) => {
                     struct_ser.serialize_field("decimal128Value", v)?;
                 }
+                scalar_value::Value::Decimal256Value(v) => {
+                    struct_ser.serialize_field("decimal256Value", v)?;
+                }
                 scalar_value::Value::Date64Value(v) => {
                     struct_ser.serialize_field("date64Value", 
ToString::to_string(&v).as_str())?;
                 }
@@ -19218,6 +19352,8 @@ impl<'de> serde::Deserialize<'de> for ScalarValue {
             "listValue",
             "decimal128_value",
             "decimal128Value",
+            "decimal256_value",
+            "decimal256Value",
             "date_64_value",
             "date64Value",
             "interval_yearmonth_value",
@@ -19270,6 +19406,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue {
             Time32Value,
             ListValue,
             Decimal128Value,
+            Decimal256Value,
             Date64Value,
             IntervalYearmonthValue,
             IntervalDaytimeValue,
@@ -19324,6 +19461,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue {
                             "time32Value" | "time32_value" => 
Ok(GeneratedField::Time32Value),
                             "listValue" | "list_value" => 
Ok(GeneratedField::ListValue),
                             "decimal128Value" | "decimal128_value" => 
Ok(GeneratedField::Decimal128Value),
+                            "decimal256Value" | "decimal256_value" => 
Ok(GeneratedField::Decimal256Value),
                             "date64Value" | "date_64_value" => 
Ok(GeneratedField::Date64Value),
                             "intervalYearmonthValue" | 
"interval_yearmonth_value" => Ok(GeneratedField::IntervalYearmonthValue),
                             "intervalDaytimeValue" | "interval_daytime_value" 
=> Ok(GeneratedField::IntervalDaytimeValue),
@@ -19471,6 +19609,13 @@ impl<'de> serde::Deserialize<'de> for ScalarValue {
                                 return 
Err(serde::de::Error::duplicate_field("decimal128Value"));
                             }
                             value__ = 
map.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::Decimal128Value)
+;
+                        }
+                        GeneratedField::Decimal256Value => {
+                            if value__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("decimal256Value"));
+                            }
+                            value__ = 
map.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::Decimal256Value)
 ;
                         }
                         GeneratedField::Date64Value => {
diff --git a/datafusion/proto/src/generated/prost.rs 
b/datafusion/proto/src/generated/prost.rs
index f50754494d..e1ad6acec8 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1097,7 +1097,7 @@ pub struct ScalarFixedSizeBinary {
 pub struct ScalarValue {
     #[prost(
         oneof = "scalar_value::Value",
-        tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 20, 
21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 32, 34"
+        tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 20, 
39, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 32, 34"
     )]
     pub value: ::core::option::Option<scalar_value::Value>,
 }
@@ -1146,6 +1146,8 @@ pub mod scalar_value {
         ListValue(super::ScalarListValue),
         #[prost(message, tag = "20")]
         Decimal128Value(super::Decimal128),
+        #[prost(message, tag = "39")]
+        Decimal256Value(super::Decimal256),
         #[prost(int64, tag = "21")]
         Date64Value(i64),
         #[prost(int32, tag = "24")]
@@ -1188,6 +1190,16 @@ pub struct Decimal128 {
     #[prost(int64, tag = "3")]
     pub s: i64,
 }
+#[allow(clippy::derive_partial_eq_without_eq)]
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct Decimal256 {
+    #[prost(bytes = "vec", tag = "1")]
+    pub value: ::prost::alloc::vec::Vec<u8>,
+    #[prost(int64, tag = "2")]
+    pub p: i64,
+    #[prost(int64, tag = "3")]
+    pub s: i64,
+}
 /// Serialized data type
 #[allow(clippy::derive_partial_eq_without_eq)]
 #[derive(Clone, PartialEq, ::prost::Message)]
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs 
b/datafusion/proto/src/logical_plan/from_proto.rs
index 674588692d..71a1bf87db 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -26,7 +26,7 @@ use crate::protobuf::{
     OptimizedPhysicalPlanType, PlaceholderNode, RollupNode,
 };
 use arrow::datatypes::{
-    DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, TimeUnit,
+    i256, DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, 
TimeUnit,
     UnionFields, UnionMode,
 };
 use datafusion::execution::registry::FunctionRegistry;
@@ -648,6 +648,14 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
                     val.s as i8,
                 )
             }
+            Value::Decimal256Value(val) => {
+                let array = vec_to_array(val.value.clone());
+                Self::Decimal256(
+                    Some(i256::from_be_bytes(array)),
+                    val.p as u8,
+                    val.s as i8,
+                )
+            }
             Value::Date64Value(v) => Self::Date64(Some(*v)),
             Value::Time32Value(v) => {
                 let time_value =
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs 
b/datafusion/proto/src/logical_plan/to_proto.rs
index 072bc84d54..f1a9615761 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -1148,6 +1148,24 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue {
                     )),
                 }),
             },
+            ScalarValue::Decimal256(val, p, s) => match *val {
+                Some(v) => {
+                    let array = v.to_be_bytes();
+                    let vec_val: Vec<u8> = array.to_vec();
+                    Ok(protobuf::ScalarValue {
+                        value: 
Some(Value::Decimal256Value(protobuf::Decimal256 {
+                            value: vec_val,
+                            p: *p as i64,
+                            s: *s as i64,
+                        })),
+                    })
+                }
+                None => Ok(protobuf::ScalarValue {
+                    value: Some(protobuf::scalar_value::Value::NullValue(
+                        (&data_type).try_into()?,
+                    )),
+                }),
+            },
             ScalarValue::Date64(val) => {
                 create_proto_scalar(val.as_ref(), &data_type, |s| 
Value::Date64Value(*s))
             }


Reply via email to