This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 50517f6a perf: Optimize decimal precision check in decimal aggregates 
(sum and avg) (#952)
50517f6a is described below

commit 50517f6a8064005191e12b3ef34bceacf0650967
Author: Andy Grove <agr...@apache.org>
AuthorDate: Tue Sep 24 08:35:09 2024 -0600

    perf: Optimize decimal precision check in decimal aggregates (sum and avg) 
(#952)
    
    * agg bench
    
    * fix
    
    * fix
    
    * refactor
    
    * avg
    
    * optimized decimal aggregates with more efficient version of 
validate_decimal_precision
    
    * simplify function to remove branch
    
    * address feedback
    
    * format
    
    * Revert a change
    
    * add rust unit test
    
    * format
    
    * code cleanup
    
    * fix
    
    * fix
    
    * fix
    
    * fmt
    
    * update bench
    
    * clippy
---
 native/core/benches/aggregate.rs                   |   9 +-
 .../datafusion/expressions/avg_decimal.rs          |  27 ++--
 .../datafusion/expressions/checkoverflow.rs        |  15 +-
 .../datafusion/expressions/sum_decimal.rs          | 151 +++++++++++++++++----
 native/core/src/execution/datafusion/planner.rs    |  53 +++-----
 5 files changed, 171 insertions(+), 84 deletions(-)

diff --git a/native/core/benches/aggregate.rs b/native/core/benches/aggregate.rs
index e6b3e315..14425f76 100644
--- a/native/core/benches/aggregate.rs
+++ b/native/core/benches/aggregate.rs
@@ -67,7 +67,6 @@ fn criterion_benchmark(c: &mut Criterion) {
     group.bench_function("avg_decimal_comet", |b| {
         let comet_avg_decimal = 
Arc::new(AggregateUDF::new_from_impl(AvgDecimal::new(
             Arc::clone(&c1),
-            "avg",
             DataType::Decimal128(38, 10),
             DataType::Decimal128(38, 10),
         )));
@@ -96,11 +95,9 @@ fn criterion_benchmark(c: &mut Criterion) {
     });
 
     group.bench_function("sum_decimal_comet", |b| {
-        let comet_sum_decimal = 
Arc::new(AggregateUDF::new_from_impl(SumDecimal::new(
-            "sum",
-            Arc::clone(&c1),
-            DataType::Decimal128(38, 10),
-        )));
+        let comet_sum_decimal = Arc::new(AggregateUDF::new_from_impl(
+            SumDecimal::try_new(Arc::clone(&c1), DataType::Decimal128(38, 
10)).unwrap(),
+        ));
         b.to_async(&rt).iter(|| {
             black_box(agg_test(
                 partitions,
diff --git a/native/core/src/execution/datafusion/expressions/avg_decimal.rs 
b/native/core/src/execution/datafusion/expressions/avg_decimal.rs
index 0462f2d3..a265fdc2 100644
--- a/native/core/src/execution/datafusion/expressions/avg_decimal.rs
+++ b/native/core/src/execution/datafusion/expressions/avg_decimal.rs
@@ -28,10 +28,9 @@ use datafusion_common::{not_impl_err, Result, ScalarValue};
 use datafusion_physical_expr::{expressions::format_state_name, PhysicalExpr};
 use std::{any::Any, sync::Arc};
 
+use 
crate::execution::datafusion::expressions::checkoverflow::is_valid_decimal_precision;
 use arrow_array::ArrowNativeTypeOp;
-use arrow_data::decimal::{
-    validate_decimal_precision, MAX_DECIMAL_FOR_EACH_PRECISION, 
MIN_DECIMAL_FOR_EACH_PRECISION,
-};
+use arrow_data::decimal::{MAX_DECIMAL_FOR_EACH_PRECISION, 
MIN_DECIMAL_FOR_EACH_PRECISION};
 use datafusion::logical_expr::Volatility::Immutable;
 use datafusion::physical_expr_common::physical_expr::down_cast_any_ref;
 use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
@@ -43,7 +42,6 @@ use DataType::*;
 /// AVG aggregate expression
 #[derive(Debug, Clone)]
 pub struct AvgDecimal {
-    name: String,
     signature: Signature,
     expr: Arc<dyn PhysicalExpr>,
     sum_data_type: DataType,
@@ -52,14 +50,8 @@ pub struct AvgDecimal {
 
 impl AvgDecimal {
     /// Create a new AVG aggregate function
-    pub fn new(
-        expr: Arc<dyn PhysicalExpr>,
-        name: impl Into<String>,
-        result_type: DataType,
-        sum_type: DataType,
-    ) -> Self {
+    pub fn new(expr: Arc<dyn PhysicalExpr>, result_type: DataType, sum_type: 
DataType) -> Self {
         Self {
-            name: name.into(),
             signature: Signature::user_defined(Immutable),
             expr,
             result_data_type: result_type,
@@ -95,12 +87,12 @@ impl AggregateUDFImpl for AvgDecimal {
     fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
         Ok(vec![
             Field::new(
-                format_state_name(&self.name, "sum"),
+                format_state_name(self.name(), "sum"),
                 self.sum_data_type.clone(),
                 true,
             ),
             Field::new(
-                format_state_name(&self.name, "count"),
+                format_state_name(self.name(), "count"),
                 DataType::Int64,
                 true,
             ),
@@ -108,7 +100,7 @@ impl AggregateUDFImpl for AvgDecimal {
     }
 
     fn name(&self) -> &str {
-        &self.name
+        "avg"
     }
 
     fn reverse_expr(&self) -> ReversedUDAF {
@@ -169,8 +161,7 @@ impl PartialEq<dyn Any> for AvgDecimal {
         down_cast_any_ref(other)
             .downcast_ref::<Self>()
             .map(|x| {
-                self.name == x.name
-                    && self.sum_data_type == x.sum_data_type
+                self.sum_data_type == x.sum_data_type
                     && self.result_data_type == x.result_data_type
                     && self.expr.eq(&x.expr)
             })
@@ -212,7 +203,7 @@ impl AvgDecimalAccumulator {
             None => (v, false),
         };
 
-        if is_overflow || validate_decimal_precision(new_sum, 
self.sum_precision).is_err() {
+        if is_overflow || !is_valid_decimal_precision(new_sum, 
self.sum_precision) {
             // Overflow: set buffer accumulator to null
             self.is_not_null = false;
             return;
@@ -380,7 +371,7 @@ impl AvgDecimalGroupsAccumulator {
         let (new_sum, is_overflow) = 
self.sums[group_index].overflowing_add(value);
         self.counts[group_index] += 1;
 
-        if is_overflow || validate_decimal_precision(new_sum, 
self.sum_precision).is_err() {
+        if is_overflow || !is_valid_decimal_precision(new_sum, 
self.sum_precision) {
             // Overflow: set buffer accumulator to null
             self.is_not_null.set_bit(group_index, false);
             return;
diff --git a/native/core/src/execution/datafusion/expressions/checkoverflow.rs 
b/native/core/src/execution/datafusion/expressions/checkoverflow.rs
index e922171b..ed03ab66 100644
--- a/native/core/src/execution/datafusion/expressions/checkoverflow.rs
+++ b/native/core/src/execution/datafusion/expressions/checkoverflow.rs
@@ -27,7 +27,8 @@ use arrow::{
     datatypes::{Decimal128Type, DecimalType},
     record_batch::RecordBatch,
 };
-use arrow_schema::{DataType, Schema};
+use arrow_data::decimal::{MAX_DECIMAL_FOR_EACH_PRECISION, 
MIN_DECIMAL_FOR_EACH_PRECISION};
+use arrow_schema::{DataType, Schema, DECIMAL128_MAX_PRECISION};
 use datafusion::logical_expr::ColumnarValue;
 use datafusion::physical_expr_common::physical_expr::down_cast_any_ref;
 use datafusion_common::{DataFusionError, ScalarValue};
@@ -171,3 +172,15 @@ impl PhysicalExpr for CheckOverflow {
         self.hash(&mut s);
     }
 }
+
+/// Adapted from arrow-rs `validate_decimal_precision` but returns bool
+/// instead of Err to avoid the cost of formatting the error strings and is
+/// optimized to remove a memcpy that exists in the original function
+/// we can remove this code once we upgrade to a version of arrow-rs that
+/// includes https://github.com/apache/arrow-rs/pull/6419
+#[inline]
+pub fn is_valid_decimal_precision(value: i128, precision: u8) -> bool {
+    precision <= DECIMAL128_MAX_PRECISION
+        && value >= MIN_DECIMAL_FOR_EACH_PRECISION[precision as usize - 1]
+        && value <= MAX_DECIMAL_FOR_EACH_PRECISION[precision as usize - 1]
+}
diff --git a/native/core/src/execution/datafusion/expressions/sum_decimal.rs 
b/native/core/src/execution/datafusion/expressions/sum_decimal.rs
index e957bd25..a3ce96b6 100644
--- a/native/core/src/execution/datafusion/expressions/sum_decimal.rs
+++ b/native/core/src/execution/datafusion/expressions/sum_decimal.rs
@@ -15,6 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use 
crate::execution::datafusion::expressions::checkoverflow::is_valid_decimal_precision;
 use crate::unlikely;
 use arrow::{
     array::BooleanBufferBuilder,
@@ -23,11 +24,10 @@ use arrow::{
 use arrow_array::{
     cast::AsArray, types::Decimal128Type, Array, ArrayRef, BooleanArray, 
Decimal128Array,
 };
-use arrow_data::decimal::validate_decimal_precision;
 use arrow_schema::{DataType, Field};
 use datafusion::logical_expr::{Accumulator, EmitTo, GroupsAccumulator};
 use datafusion::physical_expr_common::physical_expr::down_cast_any_ref;
-use datafusion_common::{Result as DFResult, ScalarValue};
+use datafusion_common::{DataFusionError, Result as DFResult, ScalarValue};
 use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
 use datafusion_expr::Volatility::Immutable;
 use datafusion_expr::{AggregateUDFImpl, ReversedUDAF, Signature};
@@ -36,37 +36,37 @@ use std::{any::Any, ops::BitAnd, sync::Arc};
 
 #[derive(Debug)]
 pub struct SumDecimal {
-    name: String,
+    /// Aggregate function signature
     signature: Signature,
+    /// The expression that provides the input decimal values to be summed
     expr: Arc<dyn PhysicalExpr>,
-
-    /// The data type of the SUM result
+    /// The data type of the SUM result. This will always be a decimal type
+    /// with the same precision and scale as specified in this struct
     result_type: DataType,
-
-    /// Decimal precision and scale
+    /// Decimal precision
     precision: u8,
+    /// Decimal scale
     scale: i8,
-
-    /// Whether the result is nullable
-    nullable: bool,
 }
 
 impl SumDecimal {
-    pub fn new(name: impl Into<String>, expr: Arc<dyn PhysicalExpr>, 
data_type: DataType) -> Self {
+    pub fn try_new(expr: Arc<dyn PhysicalExpr>, data_type: DataType) -> 
DFResult<Self> {
         // The `data_type` is the SUM result type passed from Spark side
         let (precision, scale) = match data_type {
             DataType::Decimal128(p, s) => (p, s),
-            _ => unreachable!(),
+            _ => {
+                return Err(DataFusionError::Internal(
+                    "Invalid data type for SumDecimal".into(),
+                ))
+            }
         };
-        Self {
-            name: name.into(),
+        Ok(Self {
             signature: Signature::user_defined(Immutable),
             expr,
             result_type: data_type,
             precision,
             scale,
-            nullable: true,
-        }
+        })
     }
 }
 
@@ -84,14 +84,14 @@ impl AggregateUDFImpl for SumDecimal {
 
     fn state_fields(&self, _args: StateFieldsArgs) -> DFResult<Vec<Field>> {
         let fields = vec![
-            Field::new(&self.name, self.result_type.clone(), self.nullable),
+            Field::new(self.name(), self.result_type.clone(), 
self.is_nullable()),
             Field::new("is_empty", DataType::Boolean, false),
         ];
         Ok(fields)
     }
 
     fn name(&self) -> &str {
-        &self.name
+        "sum"
     }
 
     fn signature(&self) -> &Signature {
@@ -127,6 +127,11 @@ impl AggregateUDFImpl for SumDecimal {
     fn reverse_expr(&self) -> ReversedUDAF {
         ReversedUDAF::Identical
     }
+
+    fn is_nullable(&self) -> bool {
+        // SumDecimal is always nullable because overflows can cause null 
values
+        true
+    }
 }
 
 impl PartialEq<dyn Any> for SumDecimal {
@@ -134,12 +139,10 @@ impl PartialEq<dyn Any> for SumDecimal {
         down_cast_any_ref(other)
             .downcast_ref::<Self>()
             .map(|x| {
-                self.name == x.name
-                    && self.precision == x.precision
-                    && self.scale == x.scale
-                    && self.nullable == x.nullable
-                    && self.result_type == x.result_type
-                    && self.expr.eq(&x.expr)
+                // note that we do not compare result_type because this
+                // is guaranteed to match if the precision and scale
+                // match
+                self.precision == x.precision && self.scale == x.scale && 
self.expr.eq(&x.expr)
             })
             .unwrap_or(false)
     }
@@ -170,7 +173,7 @@ impl SumDecimalAccumulator {
         let v = unsafe { values.value_unchecked(idx) };
         let (new_sum, is_overflow) = self.sum.overflowing_add(v);
 
-        if is_overflow || validate_decimal_precision(new_sum, 
self.precision).is_err() {
+        if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) 
{
             // Overflow: set buffer accumulator to null
             self.is_not_null = false;
             return;
@@ -312,7 +315,7 @@ impl SumDecimalGroupsAccumulator {
         self.is_empty.set_bit(group_index, false);
         let (new_sum, is_overflow) = 
self.sum[group_index].overflowing_add(value);
 
-        if is_overflow || validate_decimal_precision(new_sum, 
self.precision).is_err() {
+        if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) 
{
             // Overflow: set buffer accumulator to null
             self.is_not_null.set_bit(group_index, false);
             return;
@@ -478,3 +481,99 @@ impl GroupsAccumulator for SumDecimalGroupsAccumulator {
             + self.is_not_null.capacity() / 8
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use arrow::datatypes::*;
+    use arrow_array::builder::{Decimal128Builder, StringBuilder};
+    use arrow_array::RecordBatch;
+    use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, 
PhysicalGroupBy};
+    use datafusion::physical_plan::memory::MemoryExec;
+    use datafusion::physical_plan::ExecutionPlan;
+    use datafusion_common::Result;
+    use datafusion_execution::TaskContext;
+    use datafusion_expr::AggregateUDF;
+    use datafusion_physical_expr::aggregate::AggregateExprBuilder;
+    use datafusion_physical_expr::expressions::{Column, Literal};
+    use futures::StreamExt;
+
+    #[test]
+    fn invalid_data_type() {
+        let expr = Arc::new(Literal::new(ScalarValue::Int32(Some(1))));
+        assert!(SumDecimal::try_new(expr, DataType::Int32).is_err());
+    }
+
+    #[tokio::test]
+    async fn sum_no_overflow() -> Result<()> {
+        let num_rows = 8192;
+        let batch = create_record_batch(num_rows);
+        let mut batches = Vec::new();
+        for _ in 0..10 {
+            batches.push(batch.clone());
+        }
+        let partitions = &[batches];
+        let c0: Arc<dyn PhysicalExpr> = Arc::new(Column::new("c0", 0));
+        let c1: Arc<dyn PhysicalExpr> = Arc::new(Column::new("c1", 1));
+
+        let data_type = DataType::Decimal128(8, 2);
+        let schema = Arc::clone(&partitions[0][0].schema());
+        let scan: Arc<dyn ExecutionPlan> =
+            Arc::new(MemoryExec::try_new(partitions, Arc::clone(&schema), 
None).unwrap());
+
+        let aggregate_udf = 
Arc::new(AggregateUDF::new_from_impl(SumDecimal::try_new(
+            Arc::clone(&c1),
+            data_type.clone(),
+        )?));
+
+        let aggr_expr = AggregateExprBuilder::new(aggregate_udf, vec![c1])
+            .schema(Arc::clone(&schema))
+            .alias("sum")
+            .with_ignore_nulls(false)
+            .with_distinct(false)
+            .build()?;
+
+        let aggregate = Arc::new(AggregateExec::try_new(
+            AggregateMode::Partial,
+            PhysicalGroupBy::new_single(vec![(c0, "c0".to_string())]),
+            vec![aggr_expr],
+            vec![None], // no filter expressions
+            scan,
+            Arc::clone(&schema),
+        )?);
+
+        let mut stream = aggregate
+            .execute(0, Arc::new(TaskContext::default()))
+            .unwrap();
+        while let Some(batch) = stream.next().await {
+            let _batch = batch?;
+        }
+
+        Ok(())
+    }
+
+    fn create_record_batch(num_rows: usize) -> RecordBatch {
+        let mut decimal_builder = Decimal128Builder::with_capacity(num_rows);
+        let mut string_builder = StringBuilder::with_capacity(num_rows, 
num_rows * 32);
+        for i in 0..num_rows {
+            decimal_builder.append_value(i as i128);
+            string_builder.append_value(format!("this is string #{}", i % 
1024));
+        }
+        let decimal_array = Arc::new(decimal_builder.finish());
+        let string_array = Arc::new(string_builder.finish());
+
+        let mut fields = vec![];
+        let mut columns: Vec<ArrayRef> = vec![];
+
+        // string column
+        fields.push(Field::new("c0", DataType::Utf8, false));
+        columns.push(string_array);
+
+        // decimal column
+        fields.push(Field::new("c1", DataType::Decimal128(38, 10), false));
+        columns.push(decimal_array);
+
+        let schema = Schema::new(fields);
+        RecordBatch::try_new(Arc::new(schema), columns).unwrap()
+    }
+}
diff --git a/native/core/src/execution/datafusion/planner.rs 
b/native/core/src/execution/datafusion/planner.rs
index 663db0d1..9000db61 100644
--- a/native/core/src/execution/datafusion/planner.rs
+++ b/native/core/src/execution/datafusion/planner.rs
@@ -1365,56 +1365,42 @@ impl PhysicalPlanner {
                 let child = self.create_expr(expr.child.as_ref().unwrap(), 
Arc::clone(&schema))?;
                 let datatype = 
to_arrow_datatype(expr.datatype.as_ref().unwrap());
 
-                match datatype {
+                let builder = match datatype {
                     DataType::Decimal128(_, _) => {
-                        let func = AggregateUDF::new_from_impl(SumDecimal::new(
-                            "sum",
+                        let func = 
AggregateUDF::new_from_impl(SumDecimal::try_new(
                             Arc::clone(&child),
                             datatype,
-                        ));
+                        )?);
                         AggregateExprBuilder::new(Arc::new(func), vec![child])
-                            .schema(schema)
-                            .alias("sum")
-                            .with_ignore_nulls(false)
-                            .with_distinct(false)
-                            .build()
-                            .map_err(|e| 
ExecutionError::DataFusionError(e.to_string()))
                     }
                     _ => {
                         // cast to the result data type of SUM if necessary, 
we should not expect
                         // a cast failure since it should have already been 
checked at Spark side
                         let child =
                             Arc::new(CastExpr::new(Arc::clone(&child), 
datatype.clone(), None));
-
                         AggregateExprBuilder::new(sum_udaf(), vec![child])
-                            .schema(schema)
-                            .alias("sum")
-                            .with_ignore_nulls(false)
-                            .with_distinct(false)
-                            .build()
-                            .map_err(|e| 
ExecutionError::DataFusionError(e.to_string()))
                     }
-                }
+                };
+                builder
+                    .schema(schema)
+                    .alias("sum")
+                    .with_ignore_nulls(false)
+                    .with_distinct(false)
+                    .build()
+                    .map_err(|e| e.into())
             }
             AggExprStruct::Avg(expr) => {
                 let child = self.create_expr(expr.child.as_ref().unwrap(), 
Arc::clone(&schema))?;
                 let datatype = 
to_arrow_datatype(expr.datatype.as_ref().unwrap());
                 let input_datatype = 
to_arrow_datatype(expr.sum_datatype.as_ref().unwrap());
-                match datatype {
+                let builder = match datatype {
                     DataType::Decimal128(_, _) => {
                         let func = AggregateUDF::new_from_impl(AvgDecimal::new(
                             Arc::clone(&child),
-                            "avg",
                             datatype,
                             input_datatype,
                         ));
                         AggregateExprBuilder::new(Arc::new(func), vec![child])
-                            .schema(schema)
-                            .alias("avg")
-                            .with_ignore_nulls(false)
-                            .with_distinct(false)
-                            .build()
-                            .map_err(|e| 
ExecutionError::DataFusionError(e.to_string()))
                     }
                     _ => {
                         // cast to the result data type of AVG if the result 
data type is different
@@ -1428,14 +1414,15 @@ impl PhysicalPlanner {
                             datatype,
                         ));
                         AggregateExprBuilder::new(Arc::new(func), vec![child])
-                            .schema(schema)
-                            .alias("avg")
-                            .with_ignore_nulls(false)
-                            .with_distinct(false)
-                            .build()
-                            .map_err(|e| 
ExecutionError::DataFusionError(e.to_string()))
                     }
-                }
+                };
+                builder
+                    .schema(schema)
+                    .alias("avg")
+                    .with_ignore_nulls(false)
+                    .with_distinct(false)
+                    .build()
+                    .map_err(|e| e.into())
             }
             AggExprStruct::First(expr) => {
                 let child = self.create_expr(expr.child.as_ref().unwrap(), 
Arc::clone(&schema))?;


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to