This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 919c9ad7c perf: Optimize some decimal expressions (#3619)
919c9ad7c is described below

commit 919c9ad7c54a93f2d6b2980b2b672760aacc12a8
Author: Andy Grove <[email protected]>
AuthorDate: Thu Mar 12 07:56:58 2026 -0600

    perf: Optimize some decimal expressions (#3619)
    
    * feat: fused WideDecimalBinaryExpr for Decimal128 add/sub/mul
    
    Replace the 4-node expression tree (Cast→BinaryExpr→Cast→Cast) used for
    Decimal128 arithmetic that may overflow with a single fused expression
    that performs i256 register arithmetic directly. This reduces per-batch
    allocation from 4 intermediate arrays (112 bytes/elem) to 1 output array
    (16 bytes/elem).
    
    The new WideDecimalBinaryExpr evaluates children, performs add/sub/mul
    using i256 intermediates via try_binary, applies scale adjustment with
    HALF_UP rounding, checks precision bounds, and outputs a single
    Decimal128 array. Follows the same pattern as decimal_div.
    
    * feat: add criterion benchmark for wide decimal binary expr
    
    Add benchmark comparing old Cast->BinaryExpr->Cast chain vs fused
    WideDecimalBinaryExpr for Decimal128 add/sub/mul. Covers four cases:
    add with same scale, add with different scales, multiply, and subtract.
    
    * feat: fuse CheckOverflow with Cast and WideDecimalBinaryExpr
    
    Eliminate redundant CheckOverflow when wrapping WideDecimalBinaryExpr
    (which already handles overflow). Fuse Cast(Decimal128→Decimal128) +
    CheckOverflow into a single DecimalRescaleCheckOverflow expression that
    rescales and validates precision in one pass.
    
    * fix: address PR review feedback for decimal optimizations
    
    - Handle scale-up when s_out > max(s1, s2) in add/subtract
    - Propagate errors in scalar path when fail_on_error=true
    - Guard against large scale delta (>38) overflow in rescale
    - Assert precision <= 38 in precision_bound
    - Assert exp <= 76 in i256_pow10
    - Remove unnecessary _ prefix on used variables in planner
    - Use value.signum() instead of manual sign check
    - Verify Cast target type matches before fusing with CheckOverflow
    - Validate children count in with_new_children for both expressions
    - Add tests for scale-up, scalar error propagation, and large delta
    
    * style: apply cargo fmt
    
    * fix: add defensive checks for CheckOverflow bypass and multiply scale-up
    
    - Validate WideDecimalBinaryExpr output type matches CheckOverflow
      data_type before bypassing the overflow check
    - Handle s_out > natural_scale (scale-up) in multiply path for
      consistency with add/subtract
---
 native/core/src/execution/planner.rs               |  81 ++-
 native/spark-expr/Cargo.toml                       |   4 +
 native/spark-expr/benches/wide_decimal.rs          | 166 ++++++
 native/spark-expr/src/lib.rs                       |   3 +-
 .../math_funcs/internal/decimal_rescale_check.rs   | 482 ++++++++++++++++++
 native/spark-expr/src/math_funcs/internal/mod.rs   |   2 +
 native/spark-expr/src/math_funcs/mod.rs            |   2 +
 .../src/math_funcs/wide_decimal_binary_expr.rs     | 560 +++++++++++++++++++++
 8 files changed, 1273 insertions(+), 27 deletions(-)

diff --git a/native/core/src/execution/planner.rs 
b/native/core/src/execution/planner.rs
index b79b43f6c..15bbabe88 100644
--- a/native/core/src/execution/planner.rs
+++ b/native/core/src/execution/planner.rs
@@ -126,8 +126,9 @@ use datafusion_comet_proto::{
 use 
datafusion_comet_spark_expr::monotonically_increasing_id::MonotonicallyIncreasingId;
 use datafusion_comet_spark_expr::{
     ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, 
Covariance, CreateNamedStruct,
-    GetArrayStructFields, GetStructField, IfExpr, ListExtract, 
NormalizeNaNAndZero, RandExpr,
-    RandnExpr, SparkCastOptions, Stddev, SumDecimal, ToJson, UnboundColumn, 
Variance,
+    DecimalRescaleCheckOverflow, GetArrayStructFields, GetStructField, IfExpr, 
ListExtract,
+    NormalizeNaNAndZero, RandExpr, RandnExpr, SparkCastOptions, Stddev, 
SumDecimal, ToJson,
+    UnboundColumn, Variance, WideDecimalBinaryExpr, WideDecimalOp,
 };
 use itertools::Itertools;
 use jni::objects::GlobalRef;
@@ -408,10 +409,45 @@ impl PhysicalPlanner {
                 )))
             }
             ExprStruct::CheckOverflow(expr) => {
-                let child = self.create_expr(expr.child.as_ref().unwrap(), 
input_schema)?;
+                let child =
+                    self.create_expr(expr.child.as_ref().unwrap(), 
Arc::clone(&input_schema))?;
                 let data_type = 
to_arrow_datatype(expr.datatype.as_ref().unwrap());
                 let fail_on_error = expr.fail_on_error;
 
+                // WideDecimalBinaryExpr already handles overflow — skip 
redundant check
+                // but only if its output type matches CheckOverflow's 
declared type
+                if child
+                    .as_any()
+                    .downcast_ref::<WideDecimalBinaryExpr>()
+                    .is_some()
+                {
+                    let child_type = child.data_type(&input_schema)?;
+                    if child_type == data_type {
+                        return Ok(child);
+                    }
+                }
+
+                // Fuse Cast(Decimal128→Decimal128) + CheckOverflow into 
single rescale+check
+                // Only fuse when the Cast target type matches the 
CheckOverflow output type
+                if let Some(cast) = child.as_any().downcast_ref::<Cast>() {
+                    if let (
+                        DataType::Decimal128(p_out, s_out),
+                        Ok(DataType::Decimal128(_p_in, s_in)),
+                    ) = (&data_type, cast.child.data_type(&input_schema))
+                    {
+                        let cast_target = cast.data_type(&input_schema)?;
+                        if cast_target == data_type {
+                            return 
Ok(Arc::new(DecimalRescaleCheckOverflow::new(
+                                Arc::clone(&cast.child),
+                                s_in,
+                                *p_out,
+                                *s_out,
+                                fail_on_error,
+                            )));
+                        }
+                    }
+                }
+
                 // Look up query context from registry if expr_id is present
                 let query_context = spark_expr.expr_id.and_then(|expr_id| {
                     let registry = &self.query_context_registry;
@@ -740,29 +776,22 @@ impl PhysicalPlanner {
                 || (op == DataFusionOperator::Multiply && p1 + p2 >= 
DECIMAL128_MAX_PRECISION) =>
             {
                 let data_type = return_type.map(to_arrow_datatype).unwrap();
-                // For some Decimal128 operations, we need wider internal 
digits.
-                // Cast left and right to Decimal256 and cast the result back 
to Decimal128
-                let left = Arc::new(Cast::new(
-                    left,
-                    DataType::Decimal256(p1, s1),
-                    SparkCastOptions::new_without_timezone(EvalMode::Legacy, 
false),
-                    None,
-                    None,
-                ));
-                let right = Arc::new(Cast::new(
-                    right,
-                    DataType::Decimal256(p2, s2),
-                    SparkCastOptions::new_without_timezone(EvalMode::Legacy, 
false),
-                    None,
-                    None,
-                ));
-                let child = Arc::new(BinaryExpr::new(left, op, right));
-                Ok(Arc::new(Cast::new(
-                    child,
-                    data_type,
-                    SparkCastOptions::new_without_timezone(EvalMode::Legacy, 
false),
-                    None,
-                    None,
+                let (p_out, s_out) = match &data_type {
+                    DataType::Decimal128(p, s) => (*p, *s),
+                    dt => {
+                        return Err(ExecutionError::GeneralError(format!(
+                            "Expected Decimal128 return type, got {dt:?}"
+                        )))
+                    }
+                };
+                let wide_op = match op {
+                    DataFusionOperator::Plus => WideDecimalOp::Add,
+                    DataFusionOperator::Minus => WideDecimalOp::Subtract,
+                    DataFusionOperator::Multiply => WideDecimalOp::Multiply,
+                    _ => unreachable!(),
+                };
+                Ok(Arc::new(WideDecimalBinaryExpr::new(
+                    left, right, wide_op, p_out, s_out, eval_mode,
                 )))
             }
             (
diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml
index 9f08e480f..d4639c86e 100644
--- a/native/spark-expr/Cargo.toml
+++ b/native/spark-expr/Cargo.toml
@@ -105,6 +105,10 @@ path = "tests/spark_expr_reg.rs"
 name = "cast_from_boolean"
 harness = false
 
+[[bench]]
+name = "wide_decimal"
+harness = false
+
 [[bench]]
 name = "cast_non_int_numeric_timestamp"
 harness = false
diff --git a/native/spark-expr/benches/wide_decimal.rs 
b/native/spark-expr/benches/wide_decimal.rs
new file mode 100644
index 000000000..ec932ae68
--- /dev/null
+++ b/native/spark-expr/benches/wide_decimal.rs
@@ -0,0 +1,166 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Benchmarks comparing the old Cast->BinaryExpr->Cast chain vs the fused 
WideDecimalBinaryExpr
+//! for Decimal128 arithmetic that requires wider intermediate precision.
+
+use arrow::array::builder::Decimal128Builder;
+use arrow::array::RecordBatch;
+use arrow::datatypes::{DataType, Field, Schema};
+use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
+use datafusion::logical_expr::Operator;
+use datafusion::physical_expr::expressions::{BinaryExpr, Column};
+use datafusion::physical_expr::PhysicalExpr;
+use datafusion_comet_spark_expr::{
+    Cast, EvalMode, SparkCastOptions, WideDecimalBinaryExpr, WideDecimalOp,
+};
+use std::sync::Arc;
+
+const BATCH_SIZE: usize = 8192;
+
+/// Build a RecordBatch with two Decimal128 columns.
+fn make_decimal_batch(p1: u8, s1: i8, p2: u8, s2: i8) -> RecordBatch {
+    let mut left = Decimal128Builder::new();
+    let mut right = Decimal128Builder::new();
+    for i in 0..BATCH_SIZE as i128 {
+        left.append_value(123456789012345_i128 + i * 1000);
+        right.append_value(987654321098765_i128 - i * 1000);
+    }
+    let left = left.finish().with_data_type(DataType::Decimal128(p1, s1));
+    let right = right.finish().with_data_type(DataType::Decimal128(p2, s2));
+    let schema = Schema::new(vec![
+        Field::new("left", DataType::Decimal128(p1, s1), false),
+        Field::new("right", DataType::Decimal128(p2, s2), false),
+    ]);
+    RecordBatch::try_new(Arc::new(schema), vec![Arc::new(left), 
Arc::new(right)]).unwrap()
+}
+
+/// Old approach: Cast(Decimal128->Decimal256) both sides, BinaryExpr, 
Cast(Decimal256->Decimal128).
+fn build_old_expr(
+    p1: u8,
+    s1: i8,
+    p2: u8,
+    s2: i8,
+    op: Operator,
+    out_type: DataType,
+) -> Arc<dyn PhysicalExpr> {
+    let left_col: Arc<dyn PhysicalExpr> = Arc::new(Column::new("left", 0));
+    let right_col: Arc<dyn PhysicalExpr> = Arc::new(Column::new("right", 1));
+    let cast_opts = SparkCastOptions::new_without_timezone(EvalMode::Legacy, 
false);
+    let left_cast = Arc::new(Cast::new(
+        left_col,
+        DataType::Decimal256(p1, s1),
+        cast_opts.clone(),
+        None,
+        None,
+    ));
+    let right_cast = Arc::new(Cast::new(
+        right_col,
+        DataType::Decimal256(p2, s2),
+        cast_opts.clone(),
+        None,
+        None,
+    ));
+    let binary = Arc::new(BinaryExpr::new(left_cast, op, right_cast));
+    Arc::new(Cast::new(binary, out_type, cast_opts, None, None))
+}
+
+/// New approach: single fused WideDecimalBinaryExpr.
+fn build_new_expr(op: WideDecimalOp, p_out: u8, s_out: i8) -> Arc<dyn 
PhysicalExpr> {
+    let left_col: Arc<dyn PhysicalExpr> = Arc::new(Column::new("left", 0));
+    let right_col: Arc<dyn PhysicalExpr> = Arc::new(Column::new("right", 1));
+    Arc::new(WideDecimalBinaryExpr::new(
+        left_col,
+        right_col,
+        op,
+        p_out,
+        s_out,
+        EvalMode::Legacy,
+    ))
+}
+
+fn bench_case(
+    group: &mut criterion::BenchmarkGroup<criterion::measurement::WallTime>,
+    name: &str,
+    batch: &RecordBatch,
+    old_expr: &Arc<dyn PhysicalExpr>,
+    new_expr: &Arc<dyn PhysicalExpr>,
+) {
+    group.bench_with_input(BenchmarkId::new("old", name), batch, |b, batch| {
+        b.iter(|| old_expr.evaluate(batch).unwrap());
+    });
+    group.bench_with_input(BenchmarkId::new("fused", name), batch, |b, batch| {
+        b.iter(|| new_expr.evaluate(batch).unwrap());
+    });
+}
+
+fn criterion_benchmark(c: &mut Criterion) {
+    let mut group = c.benchmark_group("wide_decimal");
+
+    // Case 1: Add with same scale - Decimal128(38,10) + Decimal128(38,10) -> 
Decimal128(38,10)
+    // Triggers wide path because max(s1,s2) + max(p1-s1, p2-s2) = 10 + 28 = 
38 >= 38
+    {
+        let batch = make_decimal_batch(38, 10, 38, 10);
+        let old = build_old_expr(38, 10, 38, 10, Operator::Plus, 
DataType::Decimal128(38, 10));
+        let new = build_new_expr(WideDecimalOp::Add, 38, 10);
+        bench_case(&mut group, "add_same_scale", &batch, &old, &new);
+    }
+
+    // Case 2: Add with different scales - Decimal128(38,6) + Decimal128(38,4) 
-> Decimal128(38,6)
+    {
+        let batch = make_decimal_batch(38, 6, 38, 4);
+        let old = build_old_expr(38, 6, 38, 4, Operator::Plus, 
DataType::Decimal128(38, 6));
+        let new = build_new_expr(WideDecimalOp::Add, 38, 6);
+        bench_case(&mut group, "add_diff_scale", &batch, &old, &new);
+    }
+
+    // Case 3: Multiply - Decimal128(20,10) * Decimal128(20,10) -> 
Decimal128(38,6)
+    // Triggers wide path because p1 + p2 = 40 >= 38
+    {
+        let batch = make_decimal_batch(20, 10, 20, 10);
+        let old = build_old_expr(
+            20,
+            10,
+            20,
+            10,
+            Operator::Multiply,
+            DataType::Decimal128(38, 6),
+        );
+        let new = build_new_expr(WideDecimalOp::Multiply, 38, 6);
+        bench_case(&mut group, "multiply", &batch, &old, &new);
+    }
+
+    // Case 4: Subtract with same scale - Decimal128(38,18) - 
Decimal128(38,18) -> Decimal128(38,18)
+    {
+        let batch = make_decimal_batch(38, 18, 38, 18);
+        let old = build_old_expr(
+            38,
+            18,
+            38,
+            18,
+            Operator::Minus,
+            DataType::Decimal128(38, 18),
+        );
+        let new = build_new_expr(WideDecimalOp::Subtract, 38, 18);
+        bench_case(&mut group, "subtract", &batch, &old, &new);
+    }
+
+    group.finish();
+}
+
+criterion_group!(benches, criterion_benchmark);
+criterion_main!(benches);
diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs
index 072fa1fad..ba19d6a9b 100644
--- a/native/spark-expr/src/lib.rs
+++ b/native/spark-expr/src/lib.rs
@@ -80,7 +80,8 @@ pub use json_funcs::{FromJson, ToJson};
 pub use math_funcs::{
     create_modulo_expr, create_negate_expr, spark_ceil, spark_decimal_div,
     spark_decimal_integral_div, spark_floor, spark_make_decimal, spark_round, 
spark_unhex,
-    spark_unscaled_value, CheckOverflow, NegativeExpr, NormalizeNaNAndZero,
+    spark_unscaled_value, CheckOverflow, DecimalRescaleCheckOverflow, 
NegativeExpr,
+    NormalizeNaNAndZero, WideDecimalBinaryExpr, WideDecimalOp,
 };
 pub use query_context::{create_query_context_map, QueryContext, 
QueryContextMap};
 pub use string_funcs::*;
diff --git a/native/spark-expr/src/math_funcs/internal/decimal_rescale_check.rs 
b/native/spark-expr/src/math_funcs/internal/decimal_rescale_check.rs
new file mode 100644
index 000000000..132240495
--- /dev/null
+++ b/native/spark-expr/src/math_funcs/internal/decimal_rescale_check.rs
@@ -0,0 +1,482 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Fused decimal rescale + overflow check expression.
+//!
+//! Replaces the pattern `CheckOverflow(Cast(expr, Decimal128(p2,s2)), 
Decimal128(p2,s2))`
+//! with a single expression that rescales and validates precision in one pass.
+
+use arrow::array::{as_primitive_array, Array, ArrayRef, Decimal128Array};
+use arrow::datatypes::{DataType, Decimal128Type, Schema};
+use arrow::error::ArrowError;
+use arrow::record_batch::RecordBatch;
+use datafusion::common::{DataFusionError, ScalarValue};
+use datafusion::logical_expr::ColumnarValue;
+use datafusion::physical_expr::PhysicalExpr;
+use std::hash::Hash;
+use std::{
+    any::Any,
+    fmt::{Display, Formatter},
+    sync::Arc,
+};
+
+/// A fused expression that rescales a Decimal128 value (changing scale) and 
checks
+/// for precision overflow in a single pass. Replaces the two-step
+/// `CheckOverflow(Cast(expr, Decimal128(p,s)))` pattern.
+#[derive(Debug, Eq)]
+pub struct DecimalRescaleCheckOverflow {
+    child: Arc<dyn PhysicalExpr>,
+    input_scale: i8,
+    output_precision: u8,
+    output_scale: i8,
+    fail_on_error: bool,
+}
+
+impl Hash for DecimalRescaleCheckOverflow {
+    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+        self.child.hash(state);
+        self.input_scale.hash(state);
+        self.output_precision.hash(state);
+        self.output_scale.hash(state);
+        self.fail_on_error.hash(state);
+    }
+}
+
+impl PartialEq for DecimalRescaleCheckOverflow {
+    fn eq(&self, other: &Self) -> bool {
+        self.child.eq(&other.child)
+            && self.input_scale == other.input_scale
+            && self.output_precision == other.output_precision
+            && self.output_scale == other.output_scale
+            && self.fail_on_error == other.fail_on_error
+    }
+}
+
+impl DecimalRescaleCheckOverflow {
+    pub fn new(
+        child: Arc<dyn PhysicalExpr>,
+        input_scale: i8,
+        output_precision: u8,
+        output_scale: i8,
+        fail_on_error: bool,
+    ) -> Self {
+        Self {
+            child,
+            input_scale,
+            output_precision,
+            output_scale,
+            fail_on_error,
+        }
+    }
+}
+
+impl Display for DecimalRescaleCheckOverflow {
+    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+        write!(
+            f,
+            "DecimalRescaleCheckOverflow [child: {}, input_scale: {}, output: 
Decimal128({}, {}), fail_on_error: {}]",
+            self.child, self.input_scale, self.output_precision, 
self.output_scale, self.fail_on_error
+        )
+    }
+}
+
+/// Maximum absolute value for a given decimal precision: 10^p - 1.
+/// Precision must be <= 38 (max for Decimal128).
+#[inline]
+fn precision_bound(precision: u8) -> i128 {
+    assert!(
+        precision <= 38,
+        "precision_bound: precision {precision} exceeds maximum 38"
+    );
+    10i128.pow(precision as u32) - 1
+}
+
+/// Rescale a single i128 value by the given delta (output_scale - input_scale)
+/// and check precision bounds. Returns `Ok(value)` or `Ok(i128::MAX)` as 
sentinel
+/// for overflow in legacy mode, or `Err` in ANSI mode.
+#[inline]
+fn rescale_and_check(
+    value: i128,
+    delta: i8,
+    scale_factor: i128,
+    bound: i128,
+    fail_on_error: bool,
+) -> Result<i128, ArrowError> {
+    let rescaled = if delta > 0 {
+        // Scale up: multiply. Check for overflow.
+        match value.checked_mul(scale_factor) {
+            Some(v) => v,
+            None => {
+                if fail_on_error {
+                    return Err(ArrowError::ComputeError(
+                        "Decimal overflow during rescale".to_string(),
+                    ));
+                }
+                return Ok(i128::MAX); // sentinel
+            }
+        }
+    } else if delta < 0 {
+        // Scale down with HALF_UP rounding
+        // divisor = 10^(-delta), half = divisor / 2
+        let divisor = scale_factor; // already 10^abs(delta)
+        let half = divisor / 2;
+        let sign = value.signum();
+        (value + sign * half) / divisor
+    } else {
+        value
+    };
+
+    // Precision check
+    if rescaled.abs() > bound {
+        if fail_on_error {
+            return Err(ArrowError::ComputeError(
+                "Decimal overflow: value does not fit in 
precision".to_string(),
+            ));
+        }
+        Ok(i128::MAX) // sentinel for null_if_overflow_precision
+    } else {
+        Ok(rescaled)
+    }
+}
+
+impl PhysicalExpr for DecimalRescaleCheckOverflow {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+        Display::fmt(self, f)
+    }
+
+    fn data_type(&self, _: &Schema) -> datafusion::common::Result<DataType> {
+        Ok(DataType::Decimal128(
+            self.output_precision,
+            self.output_scale,
+        ))
+    }
+
+    fn nullable(&self, _: &Schema) -> datafusion::common::Result<bool> {
+        Ok(true)
+    }
+
+    fn evaluate(&self, batch: &RecordBatch) -> 
datafusion::common::Result<ColumnarValue> {
+        let arg = self.child.evaluate(batch)?;
+        let delta = self.output_scale - self.input_scale;
+        let abs_delta = delta.unsigned_abs();
+        // If abs_delta > 38, the scale factor overflows i128. In that case,
+        // any non-zero value will overflow the output precision, so we treat
+        // it as an immediate overflow condition.
+        if abs_delta > 38 {
+            return Err(DataFusionError::Execution(format!(
+                "DecimalRescaleCheckOverflow: scale delta {delta} exceeds 
maximum supported range"
+            )));
+        }
+        let scale_factor = 10i128.pow(abs_delta as u32);
+        let bound = precision_bound(self.output_precision);
+        let fail_on_error = self.fail_on_error;
+        let p_out = self.output_precision;
+        let s_out = self.output_scale;
+
+        match arg {
+            ColumnarValue::Array(array)
+                if matches!(array.data_type(), DataType::Decimal128(_, _)) =>
+            {
+                let decimal_array = 
as_primitive_array::<Decimal128Type>(&array);
+
+                let result: Decimal128Array =
+                    arrow::compute::kernels::arity::try_unary(decimal_array, 
|value| {
+                        rescale_and_check(value, delta, scale_factor, bound, 
fail_on_error)
+                    })?;
+
+                let result = if !fail_on_error {
+                    result.null_if_overflow_precision(p_out)
+                } else {
+                    result
+                };
+
+                let result = result
+                    .with_precision_and_scale(p_out, s_out)
+                    .map(|a| Arc::new(a) as ArrayRef)?;
+
+                Ok(ColumnarValue::Array(result))
+            }
+            ColumnarValue::Scalar(ScalarValue::Decimal128(v, _precision, 
_scale)) => {
+                let new_v = match v {
+                    Some(val) => {
+                        let r = rescale_and_check(val, delta, scale_factor, 
bound, fail_on_error)
+                            .map_err(|e| 
DataFusionError::ArrowError(Box::new(e), None))?;
+                        if r == i128::MAX {
+                            None
+                        } else {
+                            Some(r)
+                        }
+                    }
+                    None => None,
+                };
+                Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
+                    new_v, p_out, s_out,
+                )))
+            }
+            v => Err(DataFusionError::Execution(format!(
+                "DecimalRescaleCheckOverflow expects Decimal128, but found 
{v:?}"
+            ))),
+        }
+    }
+
+    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
+        vec![&self.child]
+    }
+
+    fn with_new_children(
+        self: Arc<Self>,
+        children: Vec<Arc<dyn PhysicalExpr>>,
+    ) -> datafusion::common::Result<Arc<dyn PhysicalExpr>> {
+        if children.len() != 1 {
+            return Err(DataFusionError::Internal(format!(
+                "DecimalRescaleCheckOverflow expects 1 child, got {}",
+                children.len()
+            )));
+        }
+        Ok(Arc::new(DecimalRescaleCheckOverflow::new(
+            Arc::clone(&children[0]),
+            self.input_scale,
+            self.output_precision,
+            self.output_scale,
+            self.fail_on_error,
+        )))
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use arrow::array::{AsArray, Decimal128Array};
+    use arrow::datatypes::{Field, Schema};
+    use arrow::record_batch::RecordBatch;
+    use datafusion::physical_expr::expressions::Column;
+
+    fn make_batch(values: Vec<Option<i128>>, precision: u8, scale: i8) -> 
RecordBatch {
+        let arr =
+            
Decimal128Array::from(values).with_data_type(DataType::Decimal128(precision, 
scale));
+        let schema = Schema::new(vec![Field::new("col", 
arr.data_type().clone(), true)]);
+        RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arr)]).unwrap()
+    }
+
+    fn eval_expr(
+        batch: &RecordBatch,
+        input_scale: i8,
+        output_precision: u8,
+        output_scale: i8,
+        fail_on_error: bool,
+    ) -> datafusion::common::Result<ArrayRef> {
+        let child: Arc<dyn PhysicalExpr> = Arc::new(Column::new("col", 0));
+        let expr = DecimalRescaleCheckOverflow::new(
+            child,
+            input_scale,
+            output_precision,
+            output_scale,
+            fail_on_error,
+        );
+        match expr.evaluate(batch)? {
+            ColumnarValue::Array(arr) => Ok(arr),
+            _ => panic!("expected array"),
+        }
+    }
+
+    #[test]
+    fn test_scale_up() {
+        // Decimal128(10,2) -> Decimal128(10,4): 1.50 (150) -> 1.5000 (15000)
+        let batch = make_batch(vec![Some(150), Some(-300)], 10, 2);
+        let result = eval_expr(&batch, 2, 10, 4, false).unwrap();
+        let arr = result.as_primitive::<Decimal128Type>();
+        assert_eq!(arr.value(0), 15000); // 1.5000
+        assert_eq!(arr.value(1), -30000); // -3.0000
+    }
+
+    #[test]
+    fn test_scale_down_with_half_up_rounding() {
+        // Decimal128(10,4) -> Decimal128(10,2)
+        // 1.2350 (12350) -> round to 1.24 (124) with HALF_UP
+        // 1.2349 (12349) -> round to 1.23 (123) with HALF_UP
+        // -1.2350 (-12350) -> round to -1.24 (-124) with HALF_UP
+        let batch = make_batch(vec![Some(12350), Some(12349), Some(-12350)], 
10, 4);
+        let result = eval_expr(&batch, 4, 10, 2, false).unwrap();
+        let arr = result.as_primitive::<Decimal128Type>();
+        assert_eq!(arr.value(0), 124); // 1.24
+        assert_eq!(arr.value(1), 123); // 1.23
+        assert_eq!(arr.value(2), -124); // -1.24
+    }
+
+    #[test]
+    fn test_same_scale_precision_check_only() {
+        // Same scale, just check precision. Value 999 fits in precision 3, 
1000 does not.
+        let batch = make_batch(vec![Some(999), Some(1000)], 38, 0);
+        let result = eval_expr(&batch, 0, 3, 0, false).unwrap();
+        let arr = result.as_primitive::<Decimal128Type>();
+        assert_eq!(arr.value(0), 999);
+        assert!(arr.is_null(1)); // overflow -> null in legacy mode
+    }
+
+    #[test]
+    fn test_overflow_null_in_legacy_mode() {
+        // Scale up causes overflow: 10^37 * 100 > i128::MAX range for 
precision 38
+        // Use precision 3, value 10 (which is 10 at scale 0), scale up to 
scale 2 -> 1000, which overflows precision 3
+        let batch = make_batch(vec![Some(10)], 38, 0);
+        let result = eval_expr(&batch, 0, 3, 2, false).unwrap();
+        let arr = result.as_primitive::<Decimal128Type>();
+        assert!(arr.is_null(0)); // 10 * 100 = 1000 > 999 (max for precision 3)
+    }
+
+    #[test]
+    fn test_overflow_error_in_ansi_mode() {
+        let batch = make_batch(vec![Some(10)], 38, 0);
+        let result = eval_expr(&batch, 0, 3, 2, true);
+        assert!(result.is_err());
+    }
+
+    #[test]
+    fn test_null_propagation() {
+        let batch = make_batch(vec![Some(100), None, Some(200)], 10, 2);
+        let result = eval_expr(&batch, 2, 10, 4, false).unwrap();
+        let arr = result.as_primitive::<Decimal128Type>();
+        assert!(!arr.is_null(0));
+        assert!(arr.is_null(1));
+        assert!(!arr.is_null(2));
+    }
+
+    #[test]
+    fn test_scalar_path() {
+        let schema = Schema::new(vec![Field::new("col", 
DataType::Decimal128(10, 2), true)]);
+        let batch = RecordBatch::new_empty(Arc::new(schema));
+
+        let scalar_expr = DecimalRescaleCheckOverflow::new(
+            Arc::new(ScalarChild(Some(150), 10, 2)),
+            2,
+            10,
+            4,
+            false,
+        );
+        let result = scalar_expr.evaluate(&batch).unwrap();
+        match result {
+            ColumnarValue::Scalar(ScalarValue::Decimal128(v, p, s)) => {
+                assert_eq!(v, Some(15000));
+                assert_eq!(p, 10);
+                assert_eq!(s, 4);
+            }
+            _ => panic!("expected decimal scalar"),
+        }
+    }
+
+    /// Helper expression that always returns a Decimal128 scalar.
+    #[derive(Debug, Eq, PartialEq, Hash)]
+    struct ScalarChild(Option<i128>, u8, i8);
+
+    impl Display for ScalarChild {
+        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+            write!(f, "ScalarChild({:?})", self.0)
+        }
+    }
+
+    impl PhysicalExpr for ScalarChild {
+        fn as_any(&self) -> &dyn Any {
+            self
+        }
+        fn data_type(&self, _: &Schema) -> 
datafusion::common::Result<DataType> {
+            Ok(DataType::Decimal128(self.1, self.2))
+        }
+        fn nullable(&self, _: &Schema) -> datafusion::common::Result<bool> {
+            Ok(true)
+        }
+        fn evaluate(&self, _batch: &RecordBatch) -> 
datafusion::common::Result<ColumnarValue> {
+            Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
+                self.0, self.1, self.2,
+            )))
+        }
+        fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
+            vec![]
+        }
+        fn with_new_children(
+            self: Arc<Self>,
+            _children: Vec<Arc<dyn PhysicalExpr>>,
+        ) -> datafusion::common::Result<Arc<dyn PhysicalExpr>> {
+            Ok(self)
+        }
+        fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+            Display::fmt(self, f)
+        }
+    }
+
+    #[test]
+    fn test_scalar_null() {
+        let schema = Schema::new(vec![Field::new("col", 
DataType::Decimal128(10, 2), true)]);
+        let batch = RecordBatch::new_empty(Arc::new(schema));
+        let expr =
+            DecimalRescaleCheckOverflow::new(Arc::new(ScalarChild(None, 10, 
2)), 2, 10, 4, false);
+        let result = expr.evaluate(&batch).unwrap();
+        match result {
+            ColumnarValue::Scalar(ScalarValue::Decimal128(v, _, _)) => {
+                assert_eq!(v, None);
+            }
+            _ => panic!("expected decimal scalar"),
+        }
+    }
+
+    #[test]
+    fn test_scalar_overflow_legacy() {
+        let schema = Schema::new(vec![Field::new("col", 
DataType::Decimal128(38, 0), true)]);
+        let batch = RecordBatch::new_empty(Arc::new(schema));
+        let expr = DecimalRescaleCheckOverflow::new(
+            Arc::new(ScalarChild(Some(10), 38, 0)),
+            0,
+            3,
+            2,
+            false,
+        );
+        let result = expr.evaluate(&batch).unwrap();
+        match result {
+            ColumnarValue::Scalar(ScalarValue::Decimal128(v, _, _)) => {
+                assert_eq!(v, None); // 10 * 100 = 1000 > 999
+            }
+            _ => panic!("expected decimal scalar"),
+        }
+    }
+
+    #[test]
+    fn test_scalar_overflow_ansi_returns_error() {
+        // fail_on_error=true must propagate the error, not silently return 
None
+        let schema = Schema::new(vec![Field::new("col", 
DataType::Decimal128(38, 0), true)]);
+        let batch = RecordBatch::new_empty(Arc::new(schema));
+        let expr = DecimalRescaleCheckOverflow::new(
+            Arc::new(ScalarChild(Some(10), 38, 0)),
+            0,
+            3,
+            2,
+            true, // fail_on_error = true
+        );
+        let result = expr.evaluate(&batch);
+        assert!(result.is_err()); // must be error, not Ok(None)
+    }
+
+    #[test]
+    fn test_large_scale_delta_returns_error() {
+        // delta = output_scale - input_scale = 38 - (-1) = 39
+        // 10i128.pow(39) would overflow, so we must reject gracefully
+        let batch = make_batch(vec![Some(1)], 38, -1);
+        let result = eval_expr(&batch, -1, 38, 38, false);
+        assert!(result.is_err());
+    }
+}
diff --git a/native/spark-expr/src/math_funcs/internal/mod.rs 
b/native/spark-expr/src/math_funcs/internal/mod.rs
index 29295f0d5..dff26146e 100644
--- a/native/spark-expr/src/math_funcs/internal/mod.rs
+++ b/native/spark-expr/src/math_funcs/internal/mod.rs
@@ -16,11 +16,13 @@
 // under the License.
 
 mod checkoverflow;
+mod decimal_rescale_check;
 mod make_decimal;
 mod normalize_nan;
 mod unscaled_value;
 
 pub use checkoverflow::CheckOverflow;
+pub use decimal_rescale_check::DecimalRescaleCheckOverflow;
 pub use make_decimal::spark_make_decimal;
 pub use normalize_nan::NormalizeNaNAndZero;
 pub use unscaled_value::spark_unscaled_value;
diff --git a/native/spark-expr/src/math_funcs/mod.rs 
b/native/spark-expr/src/math_funcs/mod.rs
index 35c1dc650..1219bc720 100644
--- a/native/spark-expr/src/math_funcs/mod.rs
+++ b/native/spark-expr/src/math_funcs/mod.rs
@@ -26,6 +26,7 @@ mod negative;
 mod round;
 pub(crate) mod unhex;
 mod utils;
+mod wide_decimal_binary_expr;
 
 pub use ceil::spark_ceil;
 pub use div::spark_decimal_div;
@@ -36,3 +37,4 @@ pub use modulo_expr::create_modulo_expr;
 pub use negative::{create_negate_expr, NegativeExpr};
 pub use round::spark_round;
 pub use unhex::spark_unhex;
+pub use wide_decimal_binary_expr::{WideDecimalBinaryExpr, WideDecimalOp};
diff --git a/native/spark-expr/src/math_funcs/wide_decimal_binary_expr.rs 
b/native/spark-expr/src/math_funcs/wide_decimal_binary_expr.rs
new file mode 100644
index 000000000..644252b46
--- /dev/null
+++ b/native/spark-expr/src/math_funcs/wide_decimal_binary_expr.rs
@@ -0,0 +1,560 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Fused wide-decimal binary expression for Decimal128 add/sub/mul that may 
overflow.
+//!
+//! Instead of building a 4-node expression tree (Cast→BinaryExpr→Cast→Cast), 
this performs
+//! i256 intermediate arithmetic in a single expression, producing only one 
output array.
+
+use crate::math_funcs::utils::get_precision_scale;
+use crate::EvalMode;
+use arrow::array::{Array, ArrayRef, AsArray, Decimal128Array};
+use arrow::datatypes::{i256, DataType, Decimal128Type, Schema};
+use arrow::error::ArrowError;
+use arrow::record_batch::RecordBatch;
+use datafusion::common::Result;
+use datafusion::logical_expr::ColumnarValue;
+use datafusion::physical_expr::PhysicalExpr;
+use std::fmt::{Display, Formatter};
+use std::hash::Hash;
+use std::{any::Any, sync::Arc};
+
+/// The arithmetic operation to perform.
+#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
+pub enum WideDecimalOp {
+    Add,
+    Subtract,
+    Multiply,
+}
+
+impl Display for WideDecimalOp {
+    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+        match self {
+            WideDecimalOp::Add => write!(f, "+"),
+            WideDecimalOp::Subtract => write!(f, "-"),
+            WideDecimalOp::Multiply => write!(f, "*"),
+        }
+    }
+}
+
+/// A fused expression that evaluates Decimal128 add/sub/mul using i256 
intermediate arithmetic,
+/// applies scale adjustment with HALF_UP rounding, checks precision bounds, 
and outputs
+/// a single Decimal128 array.
+#[derive(Debug, Eq)]
+pub struct WideDecimalBinaryExpr {
+    left: Arc<dyn PhysicalExpr>,
+    right: Arc<dyn PhysicalExpr>,
+    op: WideDecimalOp,
+    output_precision: u8,
+    output_scale: i8,
+    eval_mode: EvalMode,
+}
+
+impl Hash for WideDecimalBinaryExpr {
+    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+        self.left.hash(state);
+        self.right.hash(state);
+        self.op.hash(state);
+        self.output_precision.hash(state);
+        self.output_scale.hash(state);
+        self.eval_mode.hash(state);
+    }
+}
+
+impl PartialEq for WideDecimalBinaryExpr {
+    fn eq(&self, other: &Self) -> bool {
+        self.left.eq(&other.left)
+            && self.right.eq(&other.right)
+            && self.op == other.op
+            && self.output_precision == other.output_precision
+            && self.output_scale == other.output_scale
+            && self.eval_mode == other.eval_mode
+    }
+}
+
+impl WideDecimalBinaryExpr {
+    pub fn new(
+        left: Arc<dyn PhysicalExpr>,
+        right: Arc<dyn PhysicalExpr>,
+        op: WideDecimalOp,
+        output_precision: u8,
+        output_scale: i8,
+        eval_mode: EvalMode,
+    ) -> Self {
+        Self {
+            left,
+            right,
+            op,
+            output_precision,
+            output_scale,
+            eval_mode,
+        }
+    }
+}
+
+impl Display for WideDecimalBinaryExpr {
+    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+        write!(
+            f,
+            "WideDecimalBinaryExpr [{} {} {}, output: Decimal128({}, {})]",
+            self.left, self.op, self.right, self.output_precision, 
self.output_scale
+        )
+    }
+}
+
+/// Compute `value / divisor` with HALF_UP rounding.
+#[inline]
+fn div_round_half_up(value: i256, divisor: i256) -> i256 {
+    let (quot, rem) = (value / divisor, value % divisor);
+    // HALF_UP: if |remainder| * 2 >= |divisor|, round away from zero
+    let abs_rem_x2 = if rem < i256::ZERO {
+        rem.wrapping_neg()
+    } else {
+        rem
+    }
+    .wrapping_mul(i256::from_i128(2));
+    let abs_divisor = if divisor < i256::ZERO {
+        divisor.wrapping_neg()
+    } else {
+        divisor
+    };
+    if abs_rem_x2 >= abs_divisor {
+        if (value < i256::ZERO) != (divisor < i256::ZERO) {
+            quot.wrapping_sub(i256::ONE)
+        } else {
+            quot.wrapping_add(i256::ONE)
+        }
+    } else {
+        quot
+    }
+}
+
+/// i256 constant for 10.
+const I256_TEN: i256 = i256::from_i128(10);
+
+/// Compute 10^exp as i256. Panics if exp > 76 (max representable power of 10 
in i256).
+#[inline]
+fn i256_pow10(exp: u32) -> i256 {
+    assert!(exp <= 76, "i256_pow10: exponent {exp} exceeds maximum 76");
+    let mut result = i256::ONE;
+    for _ in 0..exp {
+        result = result.wrapping_mul(I256_TEN);
+    }
+    result
+}
+
+/// Maximum i128 value for a given decimal precision (1-indexed).
+/// precision p allows values in [-10^p + 1, 10^p - 1].
+#[inline]
+fn max_for_precision(precision: u8) -> i256 {
+    i256_pow10(precision as u32).wrapping_sub(i256::ONE)
+}
+
+impl PhysicalExpr for WideDecimalBinaryExpr {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
+        Ok(DataType::Decimal128(
+            self.output_precision,
+            self.output_scale,
+        ))
+    }
+
+    fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
+        Ok(true)
+    }
+
+    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
+        let left_val = self.left.evaluate(batch)?;
+        let right_val = self.right.evaluate(batch)?;
+
+        let (left_arr, right_arr): (ArrayRef, ArrayRef) = match (&left_val, 
&right_val) {
+            (ColumnarValue::Array(l), ColumnarValue::Array(r)) => 
(Arc::clone(l), Arc::clone(r)),
+            (ColumnarValue::Scalar(l), ColumnarValue::Array(r)) => {
+                (l.to_array_of_size(r.len())?, Arc::clone(r))
+            }
+            (ColumnarValue::Array(l), ColumnarValue::Scalar(r)) => {
+                (Arc::clone(l), r.to_array_of_size(l.len())?)
+            }
+            (ColumnarValue::Scalar(l), ColumnarValue::Scalar(r)) => 
(l.to_array()?, r.to_array()?),
+        };
+
+        let left = left_arr.as_primitive::<Decimal128Type>();
+        let right = right_arr.as_primitive::<Decimal128Type>();
+        let (_p1, s1) = get_precision_scale(left.data_type());
+        let (_p2, s2) = get_precision_scale(right.data_type());
+
+        let p_out = self.output_precision;
+        let s_out = self.output_scale;
+        let op = self.op;
+        let eval_mode = self.eval_mode;
+
+        let bound = max_for_precision(p_out);
+        let neg_bound = i256::ZERO.wrapping_sub(bound);
+
+        let result: Decimal128Array = match op {
+            WideDecimalOp::Add | WideDecimalOp::Subtract => {
+                let max_scale = std::cmp::max(s1, s2);
+                let l_scale_up = i256_pow10((max_scale - s1) as u32);
+                let r_scale_up = i256_pow10((max_scale - s2) as u32);
+                // After add/sub at max_scale, we may need to rescale to s_out
+                let scale_diff = max_scale as i16 - s_out as i16;
+                let (need_scale_down, need_scale_up) = (scale_diff > 0, 
scale_diff < 0);
+                let rescale_divisor = if need_scale_down {
+                    i256_pow10(scale_diff as u32)
+                } else {
+                    i256::ONE
+                };
+                let scale_up_factor = if need_scale_up {
+                    i256_pow10((-scale_diff) as u32)
+                } else {
+                    i256::ONE
+                };
+
+                arrow::compute::kernels::arity::try_binary(left, right, |l, r| 
{
+                    let l256 = i256::from_i128(l).wrapping_mul(l_scale_up);
+                    let r256 = i256::from_i128(r).wrapping_mul(r_scale_up);
+                    let raw = match op {
+                        WideDecimalOp::Add => l256.wrapping_add(r256),
+                        WideDecimalOp::Subtract => l256.wrapping_sub(r256),
+                        _ => unreachable!(),
+                    };
+                    let result = if need_scale_down {
+                        div_round_half_up(raw, rescale_divisor)
+                    } else if need_scale_up {
+                        raw.wrapping_mul(scale_up_factor)
+                    } else {
+                        raw
+                    };
+                    check_overflow_and_convert(result, bound, neg_bound, 
eval_mode)
+                })?
+            }
+            WideDecimalOp::Multiply => {
+                let natural_scale = s1 + s2;
+                let scale_diff = natural_scale as i16 - s_out as i16;
+                let (need_scale_down, need_scale_up) = (scale_diff > 0, 
scale_diff < 0);
+                let rescale_divisor = if need_scale_down {
+                    i256_pow10(scale_diff as u32)
+                } else {
+                    i256::ONE
+                };
+                let scale_up_factor = if need_scale_up {
+                    i256_pow10((-scale_diff) as u32)
+                } else {
+                    i256::ONE
+                };
+
+                arrow::compute::kernels::arity::try_binary(left, right, |l, r| 
{
+                    let raw = 
i256::from_i128(l).wrapping_mul(i256::from_i128(r));
+                    let result = if need_scale_down {
+                        div_round_half_up(raw, rescale_divisor)
+                    } else if need_scale_up {
+                        raw.wrapping_mul(scale_up_factor)
+                    } else {
+                        raw
+                    };
+                    check_overflow_and_convert(result, bound, neg_bound, 
eval_mode)
+                })?
+            }
+        };
+
+        let result = if eval_mode != EvalMode::Ansi {
+            result.null_if_overflow_precision(p_out)
+        } else {
+            result
+        };
+        let result = result.with_data_type(DataType::Decimal128(p_out, s_out));
+        Ok(ColumnarValue::Array(Arc::new(result)))
+    }
+
+    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
+        vec![&self.left, &self.right]
+    }
+
+    fn with_new_children(
+        self: Arc<Self>,
+        children: Vec<Arc<dyn PhysicalExpr>>,
+    ) -> Result<Arc<dyn PhysicalExpr>> {
+        if children.len() != 2 {
+            return Err(datafusion::common::DataFusionError::Internal(format!(
+                "WideDecimalBinaryExpr expects 2 children, got {}",
+                children.len()
+            )));
+        }
+        Ok(Arc::new(WideDecimalBinaryExpr::new(
+            Arc::clone(&children[0]),
+            Arc::clone(&children[1]),
+            self.op,
+            self.output_precision,
+            self.output_scale,
+            self.eval_mode,
+        )))
+    }
+
+    fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+        Display::fmt(self, f)
+    }
+}
+
+/// Check if the i256 result fits in the output precision. In Ansi mode, 
return an error
+/// on overflow. In Legacy/Try mode, return i128::MAX as a sentinel value that 
will be
+/// nullified by `null_if_overflow_precision`.
+#[inline]
+fn check_overflow_and_convert(
+    result: i256,
+    bound: i256,
+    neg_bound: i256,
+    eval_mode: EvalMode,
+) -> Result<i128, ArrowError> {
+    if result > bound || result < neg_bound {
+        if eval_mode == EvalMode::Ansi {
+            return Err(ArrowError::ComputeError("Arithmetic 
overflow".to_string()));
+        }
+        // Sentinel value — will be nullified by null_if_overflow_precision
+        Ok(i128::MAX)
+    } else {
+        Ok(result.to_i128().unwrap())
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use arrow::array::Decimal128Array;
+    use arrow::datatypes::{Field, Schema};
+    use arrow::record_batch::RecordBatch;
+    use datafusion::physical_expr::expressions::Column;
+
+    fn make_batch(
+        left_values: Vec<Option<i128>>,
+        left_precision: u8,
+        left_scale: i8,
+        right_values: Vec<Option<i128>>,
+        right_precision: u8,
+        right_scale: i8,
+    ) -> RecordBatch {
+        let left_arr = Decimal128Array::from(left_values)
+            .with_data_type(DataType::Decimal128(left_precision, left_scale));
+        let right_arr = Decimal128Array::from(right_values)
+            .with_data_type(DataType::Decimal128(right_precision, 
right_scale));
+        let schema = Schema::new(vec![
+            Field::new("left", left_arr.data_type().clone(), true),
+            Field::new("right", right_arr.data_type().clone(), true),
+        ]);
+        RecordBatch::try_new(
+            Arc::new(schema),
+            vec![Arc::new(left_arr), Arc::new(right_arr)],
+        )
+        .unwrap()
+    }
+
+    fn eval_expr(
+        batch: &RecordBatch,
+        op: WideDecimalOp,
+        output_precision: u8,
+        output_scale: i8,
+        eval_mode: EvalMode,
+    ) -> Result<ArrayRef> {
+        let left: Arc<dyn PhysicalExpr> = Arc::new(Column::new("left", 0));
+        let right: Arc<dyn PhysicalExpr> = Arc::new(Column::new("right", 1));
+        let expr =
+            WideDecimalBinaryExpr::new(left, right, op, output_precision, 
output_scale, eval_mode);
+        match expr.evaluate(batch)? {
+            ColumnarValue::Array(arr) => Ok(arr),
+            _ => panic!("expected array"),
+        }
+    }
+
+    #[test]
+    fn test_add_same_scale() {
+        // Decimal128(38, 10) + Decimal128(38, 10) -> Decimal128(38, 10)
+        let batch = make_batch(
+            vec![Some(1000000000), Some(2500000000)], // 0.1, 0.25 (scale 10 → 
divide by 10^10 mentally)
+            38,
+            10,
+            vec![Some(2000000000), Some(7500000000)],
+            38,
+            10,
+        );
+        let result = eval_expr(&batch, WideDecimalOp::Add, 38, 10, 
EvalMode::Legacy).unwrap();
+        let arr = result.as_primitive::<Decimal128Type>();
+        assert_eq!(arr.value(0), 3000000000); // 0.1 + 0.2
+        assert_eq!(arr.value(1), 10000000000); // 0.25 + 0.75
+    }
+
+    #[test]
+    fn test_subtract_same_scale() {
+        let batch = make_batch(
+            vec![Some(5000), Some(1000)],
+            38,
+            2,
+            vec![Some(3000), Some(2000)],
+            38,
+            2,
+        );
+        let result = eval_expr(&batch, WideDecimalOp::Subtract, 38, 2, 
EvalMode::Legacy).unwrap();
+        let arr = result.as_primitive::<Decimal128Type>();
+        assert_eq!(arr.value(0), 2000); // 50.00 - 30.00
+        assert_eq!(arr.value(1), -1000); // 10.00 - 20.00
+    }
+
+    #[test]
+    fn test_add_different_scales() {
+        // Decimal128(10, 2) + Decimal128(10, 4) -> output scale 4
+        let batch = make_batch(
+            vec![Some(150)], // 1.50
+            10,
+            2,
+            vec![Some(2500)], // 0.2500
+            10,
+            4,
+        );
+        let result = eval_expr(&batch, WideDecimalOp::Add, 38, 4, 
EvalMode::Legacy).unwrap();
+        let arr = result.as_primitive::<Decimal128Type>();
+        assert_eq!(arr.value(0), 17500); // 1.5000 + 0.2500 = 1.7500
+    }
+
+    #[test]
+    fn test_multiply_with_scale_reduction() {
+        // Decimal128(20, 5) * Decimal128(20, 5) -> natural scale 10, output 
scale 6
+        // 1.00000 * 2.00000 = 2.000000
+        let batch = make_batch(
+            vec![Some(100000)], // 1.00000
+            20,
+            5,
+            vec![Some(200000)], // 2.00000
+            20,
+            5,
+        );
+        let result = eval_expr(&batch, WideDecimalOp::Multiply, 38, 6, 
EvalMode::Legacy).unwrap();
+        let arr = result.as_primitive::<Decimal128Type>();
+        assert_eq!(arr.value(0), 2000000); // 2.000000
+    }
+
+    #[test]
+    fn test_multiply_half_up_rounding() {
+        // Test HALF_UP rounding: 1.5 * 1.5 = 2.25, but if output scale=1, 
should round to 2.3
+        // Input: scale 1, values 15 (1.5) * 15 (1.5) = natural scale 2, raw = 
225
+        // Output scale 1: 225 / 10 = 22 remainder 5 -> HALF_UP rounds to 23
+        let batch = make_batch(
+            vec![Some(15)], // 1.5
+            10,
+            1,
+            vec![Some(15)], // 1.5
+            10,
+            1,
+        );
+        let result = eval_expr(&batch, WideDecimalOp::Multiply, 38, 1, 
EvalMode::Legacy).unwrap();
+        let arr = result.as_primitive::<Decimal128Type>();
+        assert_eq!(arr.value(0), 23); // 2.3
+    }
+
+    #[test]
+    fn test_multiply_half_up_rounding_negative() {
+        // -1.5 * 1.5 = -2.25, output scale 1: -225/10 => -22 rem -5 -> 
HALF_UP rounds to -23
+        let batch = make_batch(
+            vec![Some(-15)], // -1.5
+            10,
+            1,
+            vec![Some(15)], // 1.5
+            10,
+            1,
+        );
+        let result = eval_expr(&batch, WideDecimalOp::Multiply, 38, 1, 
EvalMode::Legacy).unwrap();
+        let arr = result.as_primitive::<Decimal128Type>();
+        assert_eq!(arr.value(0), -23); // -2.3
+    }
+
+    #[test]
+    fn test_overflow_legacy_mode_returns_null() {
+        // Use precision 1 (max value 9), so 5 + 5 = 10 overflows
+        let batch = make_batch(vec![Some(5)], 38, 0, vec![Some(5)], 38, 0);
+        let result = eval_expr(&batch, WideDecimalOp::Add, 1, 0, 
EvalMode::Legacy).unwrap();
+        let arr = result.as_primitive::<Decimal128Type>();
+        assert!(arr.is_null(0));
+    }
+
+    #[test]
+    fn test_overflow_ansi_mode_returns_error() {
+        let batch = make_batch(vec![Some(5)], 38, 0, vec![Some(5)], 38, 0);
+        let result = eval_expr(&batch, WideDecimalOp::Add, 1, 0, 
EvalMode::Ansi);
+        assert!(result.is_err());
+    }
+
+    #[test]
+    fn test_null_propagation() {
+        let batch = make_batch(vec![Some(100), None], 10, 2, vec![None, 
Some(200)], 10, 2);
+        let result = eval_expr(&batch, WideDecimalOp::Add, 38, 2, 
EvalMode::Legacy).unwrap();
+        let arr = result.as_primitive::<Decimal128Type>();
+        assert!(arr.is_null(0));
+        assert!(arr.is_null(1));
+    }
+
+    #[test]
+    fn test_zeros() {
+        let batch = make_batch(vec![Some(0)], 38, 10, vec![Some(0)], 38, 10);
+        let result = eval_expr(&batch, WideDecimalOp::Multiply, 38, 10, 
EvalMode::Legacy).unwrap();
+        let arr = result.as_primitive::<Decimal128Type>();
+        assert_eq!(arr.value(0), 0);
+    }
+
+    #[test]
+    fn test_max_precision_values() {
+        // Max Decimal128(38,0) value: 10^38 - 1
+        let max_val = 10i128.pow(38) - 1;
+        let batch = make_batch(vec![Some(max_val)], 38, 0, vec![Some(0)], 38, 
0);
+        let result = eval_expr(&batch, WideDecimalOp::Add, 38, 0, 
EvalMode::Legacy).unwrap();
+        let arr = result.as_primitive::<Decimal128Type>();
+        assert_eq!(arr.value(0), max_val);
+    }
+
+    #[test]
+    fn test_add_scale_up_to_output() {
+        // When s_out > max(s1, s2), result must be scaled UP
+        // Decimal128(10, 2) + Decimal128(10, 2) with output scale 4
+        // 1.50 + 0.25 = 1.75, at scale 4 = 17500
+        let batch = make_batch(
+            vec![Some(150)], // 1.50
+            10,
+            2,
+            vec![Some(25)], // 0.25
+            10,
+            2,
+        );
+        let result = eval_expr(&batch, WideDecimalOp::Add, 38, 4, 
EvalMode::Legacy).unwrap();
+        let arr = result.as_primitive::<Decimal128Type>();
+        assert_eq!(arr.value(0), 17500); // 1.7500
+    }
+
+    #[test]
+    fn test_subtract_scale_up_to_output() {
+        // s_out (4) > max(s1, s2) (2) — verify scale-up path for subtract
+        let batch = make_batch(
+            vec![Some(300)], // 3.00
+            10,
+            2,
+            vec![Some(100)], // 1.00
+            10,
+            2,
+        );
+        let result = eval_expr(&batch, WideDecimalOp::Subtract, 38, 4, 
EvalMode::Legacy).unwrap();
+        let arr = result.as_primitive::<Decimal128Type>();
+        assert_eq!(arr.value(0), 20000); // 2.0000
+    }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to