This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 919c9ad7c perf: Optimize some decimal expressions (#3619)
919c9ad7c is described below
commit 919c9ad7c54a93f2d6b2980b2b672760aacc12a8
Author: Andy Grove <[email protected]>
AuthorDate: Thu Mar 12 07:56:58 2026 -0600
perf: Optimize some decimal expressions (#3619)
* feat: fused WideDecimalBinaryExpr for Decimal128 add/sub/mul
Replace the 4-node expression tree (Cast→BinaryExpr→Cast→Cast) used for
Decimal128 arithmetic that may overflow with a single fused expression
that performs i256 register arithmetic directly. This reduces per-batch
allocation from 4 intermediate arrays (112 bytes/elem) to 1 output array
(16 bytes/elem).
The new WideDecimalBinaryExpr evaluates children, performs add/sub/mul
using i256 intermediates via try_binary, applies scale adjustment with
HALF_UP rounding, checks precision bounds, and outputs a single
Decimal128 array. Follows the same pattern as decimal_div.
* feat: add criterion benchmark for wide decimal binary expr
Add benchmark comparing old Cast->BinaryExpr->Cast chain vs fused
WideDecimalBinaryExpr for Decimal128 add/sub/mul. Covers four cases:
add with same scale, add with different scales, multiply, and subtract.
* feat: fuse CheckOverflow with Cast and WideDecimalBinaryExpr
Eliminate redundant CheckOverflow when wrapping WideDecimalBinaryExpr
(which already handles overflow). Fuse Cast(Decimal128→Decimal128) +
CheckOverflow into a single DecimalRescaleCheckOverflow expression that
rescales and validates precision in one pass.
* fix: address PR review feedback for decimal optimizations
- Handle scale-up when s_out > max(s1, s2) in add/subtract
- Propagate errors in scalar path when fail_on_error=true
- Guard against large scale delta (>38) overflow in rescale
- Assert precision <= 38 in precision_bound
- Assert exp <= 76 in i256_pow10
- Remove unnecessary _ prefix on used variables in planner
- Use value.signum() instead of manual sign check
- Verify Cast target type matches before fusing with CheckOverflow
- Validate children count in with_new_children for both expressions
- Add tests for scale-up, scalar error propagation, and large delta
* style: apply cargo fmt
* fix: add defensive checks for CheckOverflow bypass and multiply scale-up
- Validate WideDecimalBinaryExpr output type matches CheckOverflow
data_type before bypassing the overflow check
- Handle s_out > natural_scale (scale-up) in multiply path for
consistency with add/subtract
---
native/core/src/execution/planner.rs | 81 ++-
native/spark-expr/Cargo.toml | 4 +
native/spark-expr/benches/wide_decimal.rs | 166 ++++++
native/spark-expr/src/lib.rs | 3 +-
.../math_funcs/internal/decimal_rescale_check.rs | 482 ++++++++++++++++++
native/spark-expr/src/math_funcs/internal/mod.rs | 2 +
native/spark-expr/src/math_funcs/mod.rs | 2 +
.../src/math_funcs/wide_decimal_binary_expr.rs | 560 +++++++++++++++++++++
8 files changed, 1273 insertions(+), 27 deletions(-)
diff --git a/native/core/src/execution/planner.rs
b/native/core/src/execution/planner.rs
index b79b43f6c..15bbabe88 100644
--- a/native/core/src/execution/planner.rs
+++ b/native/core/src/execution/planner.rs
@@ -126,8 +126,9 @@ use datafusion_comet_proto::{
use
datafusion_comet_spark_expr::monotonically_increasing_id::MonotonicallyIncreasingId;
use datafusion_comet_spark_expr::{
ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation,
Covariance, CreateNamedStruct,
- GetArrayStructFields, GetStructField, IfExpr, ListExtract,
NormalizeNaNAndZero, RandExpr,
- RandnExpr, SparkCastOptions, Stddev, SumDecimal, ToJson, UnboundColumn,
Variance,
+ DecimalRescaleCheckOverflow, GetArrayStructFields, GetStructField, IfExpr,
ListExtract,
+ NormalizeNaNAndZero, RandExpr, RandnExpr, SparkCastOptions, Stddev,
SumDecimal, ToJson,
+ UnboundColumn, Variance, WideDecimalBinaryExpr, WideDecimalOp,
};
use itertools::Itertools;
use jni::objects::GlobalRef;
@@ -408,10 +409,45 @@ impl PhysicalPlanner {
)))
}
ExprStruct::CheckOverflow(expr) => {
- let child = self.create_expr(expr.child.as_ref().unwrap(),
input_schema)?;
+ let child =
+ self.create_expr(expr.child.as_ref().unwrap(),
Arc::clone(&input_schema))?;
let data_type =
to_arrow_datatype(expr.datatype.as_ref().unwrap());
let fail_on_error = expr.fail_on_error;
+ // WideDecimalBinaryExpr already handles overflow — skip
redundant check
+ // but only if its output type matches CheckOverflow's
declared type
+ if child
+ .as_any()
+ .downcast_ref::<WideDecimalBinaryExpr>()
+ .is_some()
+ {
+ let child_type = child.data_type(&input_schema)?;
+ if child_type == data_type {
+ return Ok(child);
+ }
+ }
+
+ // Fuse Cast(Decimal128→Decimal128) + CheckOverflow into
single rescale+check
+ // Only fuse when the Cast target type matches the
CheckOverflow output type
+ if let Some(cast) = child.as_any().downcast_ref::<Cast>() {
+ if let (
+ DataType::Decimal128(p_out, s_out),
+ Ok(DataType::Decimal128(_p_in, s_in)),
+ ) = (&data_type, cast.child.data_type(&input_schema))
+ {
+ let cast_target = cast.data_type(&input_schema)?;
+ if cast_target == data_type {
+ return
Ok(Arc::new(DecimalRescaleCheckOverflow::new(
+ Arc::clone(&cast.child),
+ s_in,
+ *p_out,
+ *s_out,
+ fail_on_error,
+ )));
+ }
+ }
+ }
+
// Look up query context from registry if expr_id is present
let query_context = spark_expr.expr_id.and_then(|expr_id| {
let registry = &self.query_context_registry;
@@ -740,29 +776,22 @@ impl PhysicalPlanner {
|| (op == DataFusionOperator::Multiply && p1 + p2 >=
DECIMAL128_MAX_PRECISION) =>
{
let data_type = return_type.map(to_arrow_datatype).unwrap();
- // For some Decimal128 operations, we need wider internal
digits.
- // Cast left and right to Decimal256 and cast the result back
to Decimal128
- let left = Arc::new(Cast::new(
- left,
- DataType::Decimal256(p1, s1),
- SparkCastOptions::new_without_timezone(EvalMode::Legacy,
false),
- None,
- None,
- ));
- let right = Arc::new(Cast::new(
- right,
- DataType::Decimal256(p2, s2),
- SparkCastOptions::new_without_timezone(EvalMode::Legacy,
false),
- None,
- None,
- ));
- let child = Arc::new(BinaryExpr::new(left, op, right));
- Ok(Arc::new(Cast::new(
- child,
- data_type,
- SparkCastOptions::new_without_timezone(EvalMode::Legacy,
false),
- None,
- None,
+ let (p_out, s_out) = match &data_type {
+ DataType::Decimal128(p, s) => (*p, *s),
+ dt => {
+ return Err(ExecutionError::GeneralError(format!(
+ "Expected Decimal128 return type, got {dt:?}"
+ )))
+ }
+ };
+ let wide_op = match op {
+ DataFusionOperator::Plus => WideDecimalOp::Add,
+ DataFusionOperator::Minus => WideDecimalOp::Subtract,
+ DataFusionOperator::Multiply => WideDecimalOp::Multiply,
+ _ => unreachable!(),
+ };
+ Ok(Arc::new(WideDecimalBinaryExpr::new(
+ left, right, wide_op, p_out, s_out, eval_mode,
)))
}
(
diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml
index 9f08e480f..d4639c86e 100644
--- a/native/spark-expr/Cargo.toml
+++ b/native/spark-expr/Cargo.toml
@@ -105,6 +105,10 @@ path = "tests/spark_expr_reg.rs"
name = "cast_from_boolean"
harness = false
+[[bench]]
+name = "wide_decimal"
+harness = false
+
[[bench]]
name = "cast_non_int_numeric_timestamp"
harness = false
diff --git a/native/spark-expr/benches/wide_decimal.rs
b/native/spark-expr/benches/wide_decimal.rs
new file mode 100644
index 000000000..ec932ae68
--- /dev/null
+++ b/native/spark-expr/benches/wide_decimal.rs
@@ -0,0 +1,166 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Benchmarks comparing the old Cast->BinaryExpr->Cast chain vs the fused
WideDecimalBinaryExpr
+//! for Decimal128 arithmetic that requires wider intermediate precision.
+
+use arrow::array::builder::Decimal128Builder;
+use arrow::array::RecordBatch;
+use arrow::datatypes::{DataType, Field, Schema};
+use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
+use datafusion::logical_expr::Operator;
+use datafusion::physical_expr::expressions::{BinaryExpr, Column};
+use datafusion::physical_expr::PhysicalExpr;
+use datafusion_comet_spark_expr::{
+ Cast, EvalMode, SparkCastOptions, WideDecimalBinaryExpr, WideDecimalOp,
+};
+use std::sync::Arc;
+
+const BATCH_SIZE: usize = 8192;
+
+/// Build a RecordBatch with two Decimal128 columns.
+fn make_decimal_batch(p1: u8, s1: i8, p2: u8, s2: i8) -> RecordBatch {
+ let mut left = Decimal128Builder::new();
+ let mut right = Decimal128Builder::new();
+ for i in 0..BATCH_SIZE as i128 {
+ left.append_value(123456789012345_i128 + i * 1000);
+ right.append_value(987654321098765_i128 - i * 1000);
+ }
+ let left = left.finish().with_data_type(DataType::Decimal128(p1, s1));
+ let right = right.finish().with_data_type(DataType::Decimal128(p2, s2));
+ let schema = Schema::new(vec![
+ Field::new("left", DataType::Decimal128(p1, s1), false),
+ Field::new("right", DataType::Decimal128(p2, s2), false),
+ ]);
+ RecordBatch::try_new(Arc::new(schema), vec![Arc::new(left),
Arc::new(right)]).unwrap()
+}
+
+/// Old approach: Cast(Decimal128->Decimal256) both sides, BinaryExpr,
Cast(Decimal256->Decimal128).
+fn build_old_expr(
+ p1: u8,
+ s1: i8,
+ p2: u8,
+ s2: i8,
+ op: Operator,
+ out_type: DataType,
+) -> Arc<dyn PhysicalExpr> {
+ let left_col: Arc<dyn PhysicalExpr> = Arc::new(Column::new("left", 0));
+ let right_col: Arc<dyn PhysicalExpr> = Arc::new(Column::new("right", 1));
+ let cast_opts = SparkCastOptions::new_without_timezone(EvalMode::Legacy,
false);
+ let left_cast = Arc::new(Cast::new(
+ left_col,
+ DataType::Decimal256(p1, s1),
+ cast_opts.clone(),
+ None,
+ None,
+ ));
+ let right_cast = Arc::new(Cast::new(
+ right_col,
+ DataType::Decimal256(p2, s2),
+ cast_opts.clone(),
+ None,
+ None,
+ ));
+ let binary = Arc::new(BinaryExpr::new(left_cast, op, right_cast));
+ Arc::new(Cast::new(binary, out_type, cast_opts, None, None))
+}
+
+/// New approach: single fused WideDecimalBinaryExpr.
+fn build_new_expr(op: WideDecimalOp, p_out: u8, s_out: i8) -> Arc<dyn
PhysicalExpr> {
+ let left_col: Arc<dyn PhysicalExpr> = Arc::new(Column::new("left", 0));
+ let right_col: Arc<dyn PhysicalExpr> = Arc::new(Column::new("right", 1));
+ Arc::new(WideDecimalBinaryExpr::new(
+ left_col,
+ right_col,
+ op,
+ p_out,
+ s_out,
+ EvalMode::Legacy,
+ ))
+}
+
+fn bench_case(
+ group: &mut criterion::BenchmarkGroup<criterion::measurement::WallTime>,
+ name: &str,
+ batch: &RecordBatch,
+ old_expr: &Arc<dyn PhysicalExpr>,
+ new_expr: &Arc<dyn PhysicalExpr>,
+) {
+ group.bench_with_input(BenchmarkId::new("old", name), batch, |b, batch| {
+ b.iter(|| old_expr.evaluate(batch).unwrap());
+ });
+ group.bench_with_input(BenchmarkId::new("fused", name), batch, |b, batch| {
+ b.iter(|| new_expr.evaluate(batch).unwrap());
+ });
+}
+
+fn criterion_benchmark(c: &mut Criterion) {
+ let mut group = c.benchmark_group("wide_decimal");
+
+ // Case 1: Add with same scale - Decimal128(38,10) + Decimal128(38,10) ->
Decimal128(38,10)
+ // Triggers wide path because max(s1,s2) + max(p1-s1, p2-s2) = 10 + 28 =
38 >= 38
+ {
+ let batch = make_decimal_batch(38, 10, 38, 10);
+ let old = build_old_expr(38, 10, 38, 10, Operator::Plus,
DataType::Decimal128(38, 10));
+ let new = build_new_expr(WideDecimalOp::Add, 38, 10);
+ bench_case(&mut group, "add_same_scale", &batch, &old, &new);
+ }
+
+ // Case 2: Add with different scales - Decimal128(38,6) + Decimal128(38,4)
-> Decimal128(38,6)
+ {
+ let batch = make_decimal_batch(38, 6, 38, 4);
+ let old = build_old_expr(38, 6, 38, 4, Operator::Plus,
DataType::Decimal128(38, 6));
+ let new = build_new_expr(WideDecimalOp::Add, 38, 6);
+ bench_case(&mut group, "add_diff_scale", &batch, &old, &new);
+ }
+
+ // Case 3: Multiply - Decimal128(20,10) * Decimal128(20,10) ->
Decimal128(38,6)
+ // Triggers wide path because p1 + p2 = 40 >= 38
+ {
+ let batch = make_decimal_batch(20, 10, 20, 10);
+ let old = build_old_expr(
+ 20,
+ 10,
+ 20,
+ 10,
+ Operator::Multiply,
+ DataType::Decimal128(38, 6),
+ );
+ let new = build_new_expr(WideDecimalOp::Multiply, 38, 6);
+ bench_case(&mut group, "multiply", &batch, &old, &new);
+ }
+
+ // Case 4: Subtract with same scale - Decimal128(38,18) -
Decimal128(38,18) -> Decimal128(38,18)
+ {
+ let batch = make_decimal_batch(38, 18, 38, 18);
+ let old = build_old_expr(
+ 38,
+ 18,
+ 38,
+ 18,
+ Operator::Minus,
+ DataType::Decimal128(38, 18),
+ );
+ let new = build_new_expr(WideDecimalOp::Subtract, 38, 18);
+ bench_case(&mut group, "subtract", &batch, &old, &new);
+ }
+
+ group.finish();
+}
+
+criterion_group!(benches, criterion_benchmark);
+criterion_main!(benches);
diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs
index 072fa1fad..ba19d6a9b 100644
--- a/native/spark-expr/src/lib.rs
+++ b/native/spark-expr/src/lib.rs
@@ -80,7 +80,8 @@ pub use json_funcs::{FromJson, ToJson};
pub use math_funcs::{
create_modulo_expr, create_negate_expr, spark_ceil, spark_decimal_div,
spark_decimal_integral_div, spark_floor, spark_make_decimal, spark_round,
spark_unhex,
- spark_unscaled_value, CheckOverflow, NegativeExpr, NormalizeNaNAndZero,
+ spark_unscaled_value, CheckOverflow, DecimalRescaleCheckOverflow,
NegativeExpr,
+ NormalizeNaNAndZero, WideDecimalBinaryExpr, WideDecimalOp,
};
pub use query_context::{create_query_context_map, QueryContext,
QueryContextMap};
pub use string_funcs::*;
diff --git a/native/spark-expr/src/math_funcs/internal/decimal_rescale_check.rs
b/native/spark-expr/src/math_funcs/internal/decimal_rescale_check.rs
new file mode 100644
index 000000000..132240495
--- /dev/null
+++ b/native/spark-expr/src/math_funcs/internal/decimal_rescale_check.rs
@@ -0,0 +1,482 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Fused decimal rescale + overflow check expression.
+//!
+//! Replaces the pattern `CheckOverflow(Cast(expr, Decimal128(p2,s2)),
Decimal128(p2,s2))`
+//! with a single expression that rescales and validates precision in one pass.
+
+use arrow::array::{as_primitive_array, Array, ArrayRef, Decimal128Array};
+use arrow::datatypes::{DataType, Decimal128Type, Schema};
+use arrow::error::ArrowError;
+use arrow::record_batch::RecordBatch;
+use datafusion::common::{DataFusionError, ScalarValue};
+use datafusion::logical_expr::ColumnarValue;
+use datafusion::physical_expr::PhysicalExpr;
+use std::hash::Hash;
+use std::{
+ any::Any,
+ fmt::{Display, Formatter},
+ sync::Arc,
+};
+
+/// A fused expression that rescales a Decimal128 value (changing scale) and
checks
+/// for precision overflow in a single pass. Replaces the two-step
+/// `CheckOverflow(Cast(expr, Decimal128(p,s)))` pattern.
+#[derive(Debug, Eq)]
+pub struct DecimalRescaleCheckOverflow {
+ child: Arc<dyn PhysicalExpr>,
+ input_scale: i8,
+ output_precision: u8,
+ output_scale: i8,
+ fail_on_error: bool,
+}
+
+impl Hash for DecimalRescaleCheckOverflow {
+ fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+ self.child.hash(state);
+ self.input_scale.hash(state);
+ self.output_precision.hash(state);
+ self.output_scale.hash(state);
+ self.fail_on_error.hash(state);
+ }
+}
+
+impl PartialEq for DecimalRescaleCheckOverflow {
+ fn eq(&self, other: &Self) -> bool {
+ self.child.eq(&other.child)
+ && self.input_scale == other.input_scale
+ && self.output_precision == other.output_precision
+ && self.output_scale == other.output_scale
+ && self.fail_on_error == other.fail_on_error
+ }
+}
+
+impl DecimalRescaleCheckOverflow {
+ pub fn new(
+ child: Arc<dyn PhysicalExpr>,
+ input_scale: i8,
+ output_precision: u8,
+ output_scale: i8,
+ fail_on_error: bool,
+ ) -> Self {
+ Self {
+ child,
+ input_scale,
+ output_precision,
+ output_scale,
+ fail_on_error,
+ }
+ }
+}
+
+impl Display for DecimalRescaleCheckOverflow {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ write!(
+ f,
+ "DecimalRescaleCheckOverflow [child: {}, input_scale: {}, output:
Decimal128({}, {}), fail_on_error: {}]",
+ self.child, self.input_scale, self.output_precision,
self.output_scale, self.fail_on_error
+ )
+ }
+}
+
+/// Maximum absolute value for a given decimal precision: 10^p - 1.
+/// Precision must be <= 38 (max for Decimal128).
+#[inline]
+fn precision_bound(precision: u8) -> i128 {
+ assert!(
+ precision <= 38,
+ "precision_bound: precision {precision} exceeds maximum 38"
+ );
+ 10i128.pow(precision as u32) - 1
+}
+
+/// Rescale a single i128 value by the given delta (output_scale - input_scale)
+/// and check precision bounds. Returns `Ok(value)` or `Ok(i128::MAX)` as
sentinel
+/// for overflow in legacy mode, or `Err` in ANSI mode.
+#[inline]
+fn rescale_and_check(
+ value: i128,
+ delta: i8,
+ scale_factor: i128,
+ bound: i128,
+ fail_on_error: bool,
+) -> Result<i128, ArrowError> {
+ let rescaled = if delta > 0 {
+ // Scale up: multiply. Check for overflow.
+ match value.checked_mul(scale_factor) {
+ Some(v) => v,
+ None => {
+ if fail_on_error {
+ return Err(ArrowError::ComputeError(
+ "Decimal overflow during rescale".to_string(),
+ ));
+ }
+ return Ok(i128::MAX); // sentinel
+ }
+ }
+ } else if delta < 0 {
+ // Scale down with HALF_UP rounding
+ // divisor = 10^(-delta), half = divisor / 2
+ let divisor = scale_factor; // already 10^abs(delta)
+ let half = divisor / 2;
+ let sign = value.signum();
+ (value + sign * half) / divisor
+ } else {
+ value
+ };
+
+ // Precision check
+ if rescaled.abs() > bound {
+ if fail_on_error {
+ return Err(ArrowError::ComputeError(
+ "Decimal overflow: value does not fit in
precision".to_string(),
+ ));
+ }
+ Ok(i128::MAX) // sentinel for null_if_overflow_precision
+ } else {
+ Ok(rescaled)
+ }
+}
+
+impl PhysicalExpr for DecimalRescaleCheckOverflow {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ Display::fmt(self, f)
+ }
+
+ fn data_type(&self, _: &Schema) -> datafusion::common::Result<DataType> {
+ Ok(DataType::Decimal128(
+ self.output_precision,
+ self.output_scale,
+ ))
+ }
+
+ fn nullable(&self, _: &Schema) -> datafusion::common::Result<bool> {
+ Ok(true)
+ }
+
+ fn evaluate(&self, batch: &RecordBatch) ->
datafusion::common::Result<ColumnarValue> {
+ let arg = self.child.evaluate(batch)?;
+ let delta = self.output_scale - self.input_scale;
+ let abs_delta = delta.unsigned_abs();
+ // If abs_delta > 38, the scale factor overflows i128. In that case,
+ // any non-zero value will overflow the output precision, so we treat
+ // it as an immediate overflow condition.
+ if abs_delta > 38 {
+ return Err(DataFusionError::Execution(format!(
+ "DecimalRescaleCheckOverflow: scale delta {delta} exceeds
maximum supported range"
+ )));
+ }
+ let scale_factor = 10i128.pow(abs_delta as u32);
+ let bound = precision_bound(self.output_precision);
+ let fail_on_error = self.fail_on_error;
+ let p_out = self.output_precision;
+ let s_out = self.output_scale;
+
+ match arg {
+ ColumnarValue::Array(array)
+ if matches!(array.data_type(), DataType::Decimal128(_, _)) =>
+ {
+ let decimal_array =
as_primitive_array::<Decimal128Type>(&array);
+
+ let result: Decimal128Array =
+ arrow::compute::kernels::arity::try_unary(decimal_array,
|value| {
+ rescale_and_check(value, delta, scale_factor, bound,
fail_on_error)
+ })?;
+
+ let result = if !fail_on_error {
+ result.null_if_overflow_precision(p_out)
+ } else {
+ result
+ };
+
+ let result = result
+ .with_precision_and_scale(p_out, s_out)
+ .map(|a| Arc::new(a) as ArrayRef)?;
+
+ Ok(ColumnarValue::Array(result))
+ }
+ ColumnarValue::Scalar(ScalarValue::Decimal128(v, _precision,
_scale)) => {
+ let new_v = match v {
+ Some(val) => {
+ let r = rescale_and_check(val, delta, scale_factor,
bound, fail_on_error)
+ .map_err(|e|
DataFusionError::ArrowError(Box::new(e), None))?;
+ if r == i128::MAX {
+ None
+ } else {
+ Some(r)
+ }
+ }
+ None => None,
+ };
+ Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
+ new_v, p_out, s_out,
+ )))
+ }
+ v => Err(DataFusionError::Execution(format!(
+ "DecimalRescaleCheckOverflow expects Decimal128, but found
{v:?}"
+ ))),
+ }
+ }
+
+ fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
+ vec![&self.child]
+ }
+
+ fn with_new_children(
+ self: Arc<Self>,
+ children: Vec<Arc<dyn PhysicalExpr>>,
+ ) -> datafusion::common::Result<Arc<dyn PhysicalExpr>> {
+ if children.len() != 1 {
+ return Err(DataFusionError::Internal(format!(
+ "DecimalRescaleCheckOverflow expects 1 child, got {}",
+ children.len()
+ )));
+ }
+ Ok(Arc::new(DecimalRescaleCheckOverflow::new(
+ Arc::clone(&children[0]),
+ self.input_scale,
+ self.output_precision,
+ self.output_scale,
+ self.fail_on_error,
+ )))
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use arrow::array::{AsArray, Decimal128Array};
+ use arrow::datatypes::{Field, Schema};
+ use arrow::record_batch::RecordBatch;
+ use datafusion::physical_expr::expressions::Column;
+
+ fn make_batch(values: Vec<Option<i128>>, precision: u8, scale: i8) ->
RecordBatch {
+ let arr =
+
Decimal128Array::from(values).with_data_type(DataType::Decimal128(precision,
scale));
+ let schema = Schema::new(vec![Field::new("col",
arr.data_type().clone(), true)]);
+ RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arr)]).unwrap()
+ }
+
+ fn eval_expr(
+ batch: &RecordBatch,
+ input_scale: i8,
+ output_precision: u8,
+ output_scale: i8,
+ fail_on_error: bool,
+ ) -> datafusion::common::Result<ArrayRef> {
+ let child: Arc<dyn PhysicalExpr> = Arc::new(Column::new("col", 0));
+ let expr = DecimalRescaleCheckOverflow::new(
+ child,
+ input_scale,
+ output_precision,
+ output_scale,
+ fail_on_error,
+ );
+ match expr.evaluate(batch)? {
+ ColumnarValue::Array(arr) => Ok(arr),
+ _ => panic!("expected array"),
+ }
+ }
+
+ #[test]
+ fn test_scale_up() {
+ // Decimal128(10,2) -> Decimal128(10,4): 1.50 (150) -> 1.5000 (15000)
+ let batch = make_batch(vec![Some(150), Some(-300)], 10, 2);
+ let result = eval_expr(&batch, 2, 10, 4, false).unwrap();
+ let arr = result.as_primitive::<Decimal128Type>();
+ assert_eq!(arr.value(0), 15000); // 1.5000
+ assert_eq!(arr.value(1), -30000); // -3.0000
+ }
+
+ #[test]
+ fn test_scale_down_with_half_up_rounding() {
+ // Decimal128(10,4) -> Decimal128(10,2)
+ // 1.2350 (12350) -> round to 1.24 (124) with HALF_UP
+ // 1.2349 (12349) -> round to 1.23 (123) with HALF_UP
+ // -1.2350 (-12350) -> round to -1.24 (-124) with HALF_UP
+ let batch = make_batch(vec![Some(12350), Some(12349), Some(-12350)],
10, 4);
+ let result = eval_expr(&batch, 4, 10, 2, false).unwrap();
+ let arr = result.as_primitive::<Decimal128Type>();
+ assert_eq!(arr.value(0), 124); // 1.24
+ assert_eq!(arr.value(1), 123); // 1.23
+ assert_eq!(arr.value(2), -124); // -1.24
+ }
+
+ #[test]
+ fn test_same_scale_precision_check_only() {
+ // Same scale, just check precision. Value 999 fits in precision 3,
1000 does not.
+ let batch = make_batch(vec![Some(999), Some(1000)], 38, 0);
+ let result = eval_expr(&batch, 0, 3, 0, false).unwrap();
+ let arr = result.as_primitive::<Decimal128Type>();
+ assert_eq!(arr.value(0), 999);
+ assert!(arr.is_null(1)); // overflow -> null in legacy mode
+ }
+
+ #[test]
+ fn test_overflow_null_in_legacy_mode() {
+ // Scale up causes overflow: 10^37 * 100 > i128::MAX range for
precision 38
+ // Use precision 3, value 10 (which is 10 at scale 0), scale up to
scale 2 -> 1000, which overflows precision 3
+ let batch = make_batch(vec![Some(10)], 38, 0);
+ let result = eval_expr(&batch, 0, 3, 2, false).unwrap();
+ let arr = result.as_primitive::<Decimal128Type>();
+ assert!(arr.is_null(0)); // 10 * 100 = 1000 > 999 (max for precision 3)
+ }
+
+ #[test]
+ fn test_overflow_error_in_ansi_mode() {
+ let batch = make_batch(vec![Some(10)], 38, 0);
+ let result = eval_expr(&batch, 0, 3, 2, true);
+ assert!(result.is_err());
+ }
+
+ #[test]
+ fn test_null_propagation() {
+ let batch = make_batch(vec![Some(100), None, Some(200)], 10, 2);
+ let result = eval_expr(&batch, 2, 10, 4, false).unwrap();
+ let arr = result.as_primitive::<Decimal128Type>();
+ assert!(!arr.is_null(0));
+ assert!(arr.is_null(1));
+ assert!(!arr.is_null(2));
+ }
+
+ #[test]
+ fn test_scalar_path() {
+ let schema = Schema::new(vec![Field::new("col",
DataType::Decimal128(10, 2), true)]);
+ let batch = RecordBatch::new_empty(Arc::new(schema));
+
+ let scalar_expr = DecimalRescaleCheckOverflow::new(
+ Arc::new(ScalarChild(Some(150), 10, 2)),
+ 2,
+ 10,
+ 4,
+ false,
+ );
+ let result = scalar_expr.evaluate(&batch).unwrap();
+ match result {
+ ColumnarValue::Scalar(ScalarValue::Decimal128(v, p, s)) => {
+ assert_eq!(v, Some(15000));
+ assert_eq!(p, 10);
+ assert_eq!(s, 4);
+ }
+ _ => panic!("expected decimal scalar"),
+ }
+ }
+
+ /// Helper expression that always returns a Decimal128 scalar.
+ #[derive(Debug, Eq, PartialEq, Hash)]
+ struct ScalarChild(Option<i128>, u8, i8);
+
+ impl Display for ScalarChild {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ write!(f, "ScalarChild({:?})", self.0)
+ }
+ }
+
+ impl PhysicalExpr for ScalarChild {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+ fn data_type(&self, _: &Schema) ->
datafusion::common::Result<DataType> {
+ Ok(DataType::Decimal128(self.1, self.2))
+ }
+ fn nullable(&self, _: &Schema) -> datafusion::common::Result<bool> {
+ Ok(true)
+ }
+ fn evaluate(&self, _batch: &RecordBatch) ->
datafusion::common::Result<ColumnarValue> {
+ Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
+ self.0, self.1, self.2,
+ )))
+ }
+ fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
+ vec![]
+ }
+ fn with_new_children(
+ self: Arc<Self>,
+ _children: Vec<Arc<dyn PhysicalExpr>>,
+ ) -> datafusion::common::Result<Arc<dyn PhysicalExpr>> {
+ Ok(self)
+ }
+ fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ Display::fmt(self, f)
+ }
+ }
+
+ #[test]
+ fn test_scalar_null() {
+ let schema = Schema::new(vec![Field::new("col",
DataType::Decimal128(10, 2), true)]);
+ let batch = RecordBatch::new_empty(Arc::new(schema));
+ let expr =
+ DecimalRescaleCheckOverflow::new(Arc::new(ScalarChild(None, 10,
2)), 2, 10, 4, false);
+ let result = expr.evaluate(&batch).unwrap();
+ match result {
+ ColumnarValue::Scalar(ScalarValue::Decimal128(v, _, _)) => {
+ assert_eq!(v, None);
+ }
+ _ => panic!("expected decimal scalar"),
+ }
+ }
+
+ #[test]
+ fn test_scalar_overflow_legacy() {
+ let schema = Schema::new(vec![Field::new("col",
DataType::Decimal128(38, 0), true)]);
+ let batch = RecordBatch::new_empty(Arc::new(schema));
+ let expr = DecimalRescaleCheckOverflow::new(
+ Arc::new(ScalarChild(Some(10), 38, 0)),
+ 0,
+ 3,
+ 2,
+ false,
+ );
+ let result = expr.evaluate(&batch).unwrap();
+ match result {
+ ColumnarValue::Scalar(ScalarValue::Decimal128(v, _, _)) => {
+ assert_eq!(v, None); // 10 * 100 = 1000 > 999
+ }
+ _ => panic!("expected decimal scalar"),
+ }
+ }
+
+ #[test]
+ fn test_scalar_overflow_ansi_returns_error() {
+ // fail_on_error=true must propagate the error, not silently return
None
+ let schema = Schema::new(vec![Field::new("col",
DataType::Decimal128(38, 0), true)]);
+ let batch = RecordBatch::new_empty(Arc::new(schema));
+ let expr = DecimalRescaleCheckOverflow::new(
+ Arc::new(ScalarChild(Some(10), 38, 0)),
+ 0,
+ 3,
+ 2,
+ true, // fail_on_error = true
+ );
+ let result = expr.evaluate(&batch);
+ assert!(result.is_err()); // must be error, not Ok(None)
+ }
+
+ #[test]
+ fn test_large_scale_delta_returns_error() {
+ // delta = output_scale - input_scale = 38 - (-1) = 39
+ // 10i128.pow(39) would overflow, so we must reject gracefully
+ let batch = make_batch(vec![Some(1)], 38, -1);
+ let result = eval_expr(&batch, -1, 38, 38, false);
+ assert!(result.is_err());
+ }
+}
diff --git a/native/spark-expr/src/math_funcs/internal/mod.rs
b/native/spark-expr/src/math_funcs/internal/mod.rs
index 29295f0d5..dff26146e 100644
--- a/native/spark-expr/src/math_funcs/internal/mod.rs
+++ b/native/spark-expr/src/math_funcs/internal/mod.rs
@@ -16,11 +16,13 @@
// under the License.
mod checkoverflow;
+mod decimal_rescale_check;
mod make_decimal;
mod normalize_nan;
mod unscaled_value;
pub use checkoverflow::CheckOverflow;
+pub use decimal_rescale_check::DecimalRescaleCheckOverflow;
pub use make_decimal::spark_make_decimal;
pub use normalize_nan::NormalizeNaNAndZero;
pub use unscaled_value::spark_unscaled_value;
diff --git a/native/spark-expr/src/math_funcs/mod.rs
b/native/spark-expr/src/math_funcs/mod.rs
index 35c1dc650..1219bc720 100644
--- a/native/spark-expr/src/math_funcs/mod.rs
+++ b/native/spark-expr/src/math_funcs/mod.rs
@@ -26,6 +26,7 @@ mod negative;
mod round;
pub(crate) mod unhex;
mod utils;
+mod wide_decimal_binary_expr;
pub use ceil::spark_ceil;
pub use div::spark_decimal_div;
@@ -36,3 +37,4 @@ pub use modulo_expr::create_modulo_expr;
pub use negative::{create_negate_expr, NegativeExpr};
pub use round::spark_round;
pub use unhex::spark_unhex;
+pub use wide_decimal_binary_expr::{WideDecimalBinaryExpr, WideDecimalOp};
diff --git a/native/spark-expr/src/math_funcs/wide_decimal_binary_expr.rs
b/native/spark-expr/src/math_funcs/wide_decimal_binary_expr.rs
new file mode 100644
index 000000000..644252b46
--- /dev/null
+++ b/native/spark-expr/src/math_funcs/wide_decimal_binary_expr.rs
@@ -0,0 +1,560 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Fused wide-decimal binary expression for Decimal128 add/sub/mul that may
overflow.
+//!
+//! Instead of building a 4-node expression tree (Cast→BinaryExpr→Cast→Cast),
this performs
+//! i256 intermediate arithmetic in a single expression, producing only one
output array.
+
+use crate::math_funcs::utils::get_precision_scale;
+use crate::EvalMode;
+use arrow::array::{Array, ArrayRef, AsArray, Decimal128Array};
+use arrow::datatypes::{i256, DataType, Decimal128Type, Schema};
+use arrow::error::ArrowError;
+use arrow::record_batch::RecordBatch;
+use datafusion::common::Result;
+use datafusion::logical_expr::ColumnarValue;
+use datafusion::physical_expr::PhysicalExpr;
+use std::fmt::{Display, Formatter};
+use std::hash::Hash;
+use std::{any::Any, sync::Arc};
+
+/// The arithmetic operation to perform.
+#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
+pub enum WideDecimalOp {
+ Add,
+ Subtract,
+ Multiply,
+}
+
+impl Display for WideDecimalOp {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ match self {
+ WideDecimalOp::Add => write!(f, "+"),
+ WideDecimalOp::Subtract => write!(f, "-"),
+ WideDecimalOp::Multiply => write!(f, "*"),
+ }
+ }
+}
+
+/// A fused expression that evaluates Decimal128 add/sub/mul using i256
intermediate arithmetic,
+/// applies scale adjustment with HALF_UP rounding, checks precision bounds,
and outputs
+/// a single Decimal128 array.
+#[derive(Debug, Eq)]
+pub struct WideDecimalBinaryExpr {
+ left: Arc<dyn PhysicalExpr>,
+ right: Arc<dyn PhysicalExpr>,
+ op: WideDecimalOp,
+ output_precision: u8,
+ output_scale: i8,
+ eval_mode: EvalMode,
+}
+
+impl Hash for WideDecimalBinaryExpr {
+ fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+ self.left.hash(state);
+ self.right.hash(state);
+ self.op.hash(state);
+ self.output_precision.hash(state);
+ self.output_scale.hash(state);
+ self.eval_mode.hash(state);
+ }
+}
+
+impl PartialEq for WideDecimalBinaryExpr {
+ fn eq(&self, other: &Self) -> bool {
+ self.left.eq(&other.left)
+ && self.right.eq(&other.right)
+ && self.op == other.op
+ && self.output_precision == other.output_precision
+ && self.output_scale == other.output_scale
+ && self.eval_mode == other.eval_mode
+ }
+}
+
+impl WideDecimalBinaryExpr {
+ pub fn new(
+ left: Arc<dyn PhysicalExpr>,
+ right: Arc<dyn PhysicalExpr>,
+ op: WideDecimalOp,
+ output_precision: u8,
+ output_scale: i8,
+ eval_mode: EvalMode,
+ ) -> Self {
+ Self {
+ left,
+ right,
+ op,
+ output_precision,
+ output_scale,
+ eval_mode,
+ }
+ }
+}
+
+impl Display for WideDecimalBinaryExpr {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ write!(
+ f,
+ "WideDecimalBinaryExpr [{} {} {}, output: Decimal128({}, {})]",
+ self.left, self.op, self.right, self.output_precision,
self.output_scale
+ )
+ }
+}
+
+/// Compute `value / divisor` with HALF_UP rounding.
+#[inline]
+fn div_round_half_up(value: i256, divisor: i256) -> i256 {
+ let (quot, rem) = (value / divisor, value % divisor);
+ // HALF_UP: if |remainder| * 2 >= |divisor|, round away from zero
+ let abs_rem_x2 = if rem < i256::ZERO {
+ rem.wrapping_neg()
+ } else {
+ rem
+ }
+ .wrapping_mul(i256::from_i128(2));
+ let abs_divisor = if divisor < i256::ZERO {
+ divisor.wrapping_neg()
+ } else {
+ divisor
+ };
+ if abs_rem_x2 >= abs_divisor {
+ if (value < i256::ZERO) != (divisor < i256::ZERO) {
+ quot.wrapping_sub(i256::ONE)
+ } else {
+ quot.wrapping_add(i256::ONE)
+ }
+ } else {
+ quot
+ }
+}
+
+/// i256 constant for 10.
+const I256_TEN: i256 = i256::from_i128(10);
+
+/// Compute 10^exp as i256. Panics if exp > 76 (max representable power of 10
in i256).
+#[inline]
+fn i256_pow10(exp: u32) -> i256 {
+ assert!(exp <= 76, "i256_pow10: exponent {exp} exceeds maximum 76");
+ let mut result = i256::ONE;
+ for _ in 0..exp {
+ result = result.wrapping_mul(I256_TEN);
+ }
+ result
+}
+
+/// Maximum i128 value for a given decimal precision (1-indexed).
+/// precision p allows values in [-10^p + 1, 10^p - 1].
+#[inline]
+fn max_for_precision(precision: u8) -> i256 {
+ i256_pow10(precision as u32).wrapping_sub(i256::ONE)
+}
+
+impl PhysicalExpr for WideDecimalBinaryExpr {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
+ Ok(DataType::Decimal128(
+ self.output_precision,
+ self.output_scale,
+ ))
+ }
+
+ fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
+ Ok(true)
+ }
+
+ fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
+ let left_val = self.left.evaluate(batch)?;
+ let right_val = self.right.evaluate(batch)?;
+
+ let (left_arr, right_arr): (ArrayRef, ArrayRef) = match (&left_val,
&right_val) {
+ (ColumnarValue::Array(l), ColumnarValue::Array(r)) =>
(Arc::clone(l), Arc::clone(r)),
+ (ColumnarValue::Scalar(l), ColumnarValue::Array(r)) => {
+ (l.to_array_of_size(r.len())?, Arc::clone(r))
+ }
+ (ColumnarValue::Array(l), ColumnarValue::Scalar(r)) => {
+ (Arc::clone(l), r.to_array_of_size(l.len())?)
+ }
+ (ColumnarValue::Scalar(l), ColumnarValue::Scalar(r)) =>
(l.to_array()?, r.to_array()?),
+ };
+
+ let left = left_arr.as_primitive::<Decimal128Type>();
+ let right = right_arr.as_primitive::<Decimal128Type>();
+ let (_p1, s1) = get_precision_scale(left.data_type());
+ let (_p2, s2) = get_precision_scale(right.data_type());
+
+ let p_out = self.output_precision;
+ let s_out = self.output_scale;
+ let op = self.op;
+ let eval_mode = self.eval_mode;
+
+ let bound = max_for_precision(p_out);
+ let neg_bound = i256::ZERO.wrapping_sub(bound);
+
+ let result: Decimal128Array = match op {
+ WideDecimalOp::Add | WideDecimalOp::Subtract => {
+ let max_scale = std::cmp::max(s1, s2);
+ let l_scale_up = i256_pow10((max_scale - s1) as u32);
+ let r_scale_up = i256_pow10((max_scale - s2) as u32);
+ // After add/sub at max_scale, we may need to rescale to s_out
+ let scale_diff = max_scale as i16 - s_out as i16;
+ let (need_scale_down, need_scale_up) = (scale_diff > 0,
scale_diff < 0);
+ let rescale_divisor = if need_scale_down {
+ i256_pow10(scale_diff as u32)
+ } else {
+ i256::ONE
+ };
+ let scale_up_factor = if need_scale_up {
+ i256_pow10((-scale_diff) as u32)
+ } else {
+ i256::ONE
+ };
+
+ arrow::compute::kernels::arity::try_binary(left, right, |l, r|
{
+ let l256 = i256::from_i128(l).wrapping_mul(l_scale_up);
+ let r256 = i256::from_i128(r).wrapping_mul(r_scale_up);
+ let raw = match op {
+ WideDecimalOp::Add => l256.wrapping_add(r256),
+ WideDecimalOp::Subtract => l256.wrapping_sub(r256),
+ _ => unreachable!(),
+ };
+ let result = if need_scale_down {
+ div_round_half_up(raw, rescale_divisor)
+ } else if need_scale_up {
+ raw.wrapping_mul(scale_up_factor)
+ } else {
+ raw
+ };
+ check_overflow_and_convert(result, bound, neg_bound,
eval_mode)
+ })?
+ }
+ WideDecimalOp::Multiply => {
+ let natural_scale = s1 + s2;
+ let scale_diff = natural_scale as i16 - s_out as i16;
+ let (need_scale_down, need_scale_up) = (scale_diff > 0,
scale_diff < 0);
+ let rescale_divisor = if need_scale_down {
+ i256_pow10(scale_diff as u32)
+ } else {
+ i256::ONE
+ };
+ let scale_up_factor = if need_scale_up {
+ i256_pow10((-scale_diff) as u32)
+ } else {
+ i256::ONE
+ };
+
+ arrow::compute::kernels::arity::try_binary(left, right, |l, r|
{
+ let raw =
i256::from_i128(l).wrapping_mul(i256::from_i128(r));
+ let result = if need_scale_down {
+ div_round_half_up(raw, rescale_divisor)
+ } else if need_scale_up {
+ raw.wrapping_mul(scale_up_factor)
+ } else {
+ raw
+ };
+ check_overflow_and_convert(result, bound, neg_bound,
eval_mode)
+ })?
+ }
+ };
+
+ let result = if eval_mode != EvalMode::Ansi {
+ result.null_if_overflow_precision(p_out)
+ } else {
+ result
+ };
+ let result = result.with_data_type(DataType::Decimal128(p_out, s_out));
+ Ok(ColumnarValue::Array(Arc::new(result)))
+ }
+
+ fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
+ vec![&self.left, &self.right]
+ }
+
+ fn with_new_children(
+ self: Arc<Self>,
+ children: Vec<Arc<dyn PhysicalExpr>>,
+ ) -> Result<Arc<dyn PhysicalExpr>> {
+ if children.len() != 2 {
+ return Err(datafusion::common::DataFusionError::Internal(format!(
+ "WideDecimalBinaryExpr expects 2 children, got {}",
+ children.len()
+ )));
+ }
+ Ok(Arc::new(WideDecimalBinaryExpr::new(
+ Arc::clone(&children[0]),
+ Arc::clone(&children[1]),
+ self.op,
+ self.output_precision,
+ self.output_scale,
+ self.eval_mode,
+ )))
+ }
+
+ fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ Display::fmt(self, f)
+ }
+}
+
+/// Check if the i256 result fits in the output precision. In Ansi mode,
return an error
+/// on overflow. In Legacy/Try mode, return i128::MAX as a sentinel value that
will be
+/// nullified by `null_if_overflow_precision`.
+#[inline]
+fn check_overflow_and_convert(
+ result: i256,
+ bound: i256,
+ neg_bound: i256,
+ eval_mode: EvalMode,
+) -> Result<i128, ArrowError> {
+ if result > bound || result < neg_bound {
+ if eval_mode == EvalMode::Ansi {
+ return Err(ArrowError::ComputeError("Arithmetic
overflow".to_string()));
+ }
+ // Sentinel value — will be nullified by null_if_overflow_precision
+ Ok(i128::MAX)
+ } else {
+ Ok(result.to_i128().unwrap())
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use arrow::array::Decimal128Array;
+ use arrow::datatypes::{Field, Schema};
+ use arrow::record_batch::RecordBatch;
+ use datafusion::physical_expr::expressions::Column;
+
+ fn make_batch(
+ left_values: Vec<Option<i128>>,
+ left_precision: u8,
+ left_scale: i8,
+ right_values: Vec<Option<i128>>,
+ right_precision: u8,
+ right_scale: i8,
+ ) -> RecordBatch {
+ let left_arr = Decimal128Array::from(left_values)
+ .with_data_type(DataType::Decimal128(left_precision, left_scale));
+ let right_arr = Decimal128Array::from(right_values)
+ .with_data_type(DataType::Decimal128(right_precision,
right_scale));
+ let schema = Schema::new(vec![
+ Field::new("left", left_arr.data_type().clone(), true),
+ Field::new("right", right_arr.data_type().clone(), true),
+ ]);
+ RecordBatch::try_new(
+ Arc::new(schema),
+ vec![Arc::new(left_arr), Arc::new(right_arr)],
+ )
+ .unwrap()
+ }
+
+ fn eval_expr(
+ batch: &RecordBatch,
+ op: WideDecimalOp,
+ output_precision: u8,
+ output_scale: i8,
+ eval_mode: EvalMode,
+ ) -> Result<ArrayRef> {
+ let left: Arc<dyn PhysicalExpr> = Arc::new(Column::new("left", 0));
+ let right: Arc<dyn PhysicalExpr> = Arc::new(Column::new("right", 1));
+ let expr =
+ WideDecimalBinaryExpr::new(left, right, op, output_precision,
output_scale, eval_mode);
+ match expr.evaluate(batch)? {
+ ColumnarValue::Array(arr) => Ok(arr),
+ _ => panic!("expected array"),
+ }
+ }
+
+ #[test]
+ fn test_add_same_scale() {
+ // Decimal128(38, 10) + Decimal128(38, 10) -> Decimal128(38, 10)
+ let batch = make_batch(
+ vec![Some(1000000000), Some(2500000000)], // 0.1, 0.25 (scale 10 →
divide by 10^10 mentally)
+ 38,
+ 10,
+ vec![Some(2000000000), Some(7500000000)],
+ 38,
+ 10,
+ );
+ let result = eval_expr(&batch, WideDecimalOp::Add, 38, 10,
EvalMode::Legacy).unwrap();
+ let arr = result.as_primitive::<Decimal128Type>();
+ assert_eq!(arr.value(0), 3000000000); // 0.1 + 0.2
+ assert_eq!(arr.value(1), 10000000000); // 0.25 + 0.75
+ }
+
+ #[test]
+ fn test_subtract_same_scale() {
+ let batch = make_batch(
+ vec![Some(5000), Some(1000)],
+ 38,
+ 2,
+ vec![Some(3000), Some(2000)],
+ 38,
+ 2,
+ );
+ let result = eval_expr(&batch, WideDecimalOp::Subtract, 38, 2,
EvalMode::Legacy).unwrap();
+ let arr = result.as_primitive::<Decimal128Type>();
+ assert_eq!(arr.value(0), 2000); // 50.00 - 30.00
+ assert_eq!(arr.value(1), -1000); // 10.00 - 20.00
+ }
+
+ #[test]
+ fn test_add_different_scales() {
+ // Decimal128(10, 2) + Decimal128(10, 4) -> output scale 4
+ let batch = make_batch(
+ vec![Some(150)], // 1.50
+ 10,
+ 2,
+ vec![Some(2500)], // 0.2500
+ 10,
+ 4,
+ );
+ let result = eval_expr(&batch, WideDecimalOp::Add, 38, 4,
EvalMode::Legacy).unwrap();
+ let arr = result.as_primitive::<Decimal128Type>();
+ assert_eq!(arr.value(0), 17500); // 1.5000 + 0.2500 = 1.7500
+ }
+
+ #[test]
+ fn test_multiply_with_scale_reduction() {
+ // Decimal128(20, 5) * Decimal128(20, 5) -> natural scale 10, output
scale 6
+ // 1.00000 * 2.00000 = 2.000000
+ let batch = make_batch(
+ vec![Some(100000)], // 1.00000
+ 20,
+ 5,
+ vec![Some(200000)], // 2.00000
+ 20,
+ 5,
+ );
+ let result = eval_expr(&batch, WideDecimalOp::Multiply, 38, 6,
EvalMode::Legacy).unwrap();
+ let arr = result.as_primitive::<Decimal128Type>();
+ assert_eq!(arr.value(0), 2000000); // 2.000000
+ }
+
+ #[test]
+ fn test_multiply_half_up_rounding() {
+ // Test HALF_UP rounding: 1.5 * 1.5 = 2.25, but if output scale=1,
should round to 2.3
+ // Input: scale 1, values 15 (1.5) * 15 (1.5) = natural scale 2, raw =
225
+ // Output scale 1: 225 / 10 = 22 remainder 5 -> HALF_UP rounds to 23
+ let batch = make_batch(
+ vec![Some(15)], // 1.5
+ 10,
+ 1,
+ vec![Some(15)], // 1.5
+ 10,
+ 1,
+ );
+ let result = eval_expr(&batch, WideDecimalOp::Multiply, 38, 1,
EvalMode::Legacy).unwrap();
+ let arr = result.as_primitive::<Decimal128Type>();
+ assert_eq!(arr.value(0), 23); // 2.3
+ }
+
+ #[test]
+ fn test_multiply_half_up_rounding_negative() {
+ // -1.5 * 1.5 = -2.25, output scale 1: -225/10 => -22 rem -5 ->
HALF_UP rounds to -23
+ let batch = make_batch(
+ vec![Some(-15)], // -1.5
+ 10,
+ 1,
+ vec![Some(15)], // 1.5
+ 10,
+ 1,
+ );
+ let result = eval_expr(&batch, WideDecimalOp::Multiply, 38, 1,
EvalMode::Legacy).unwrap();
+ let arr = result.as_primitive::<Decimal128Type>();
+ assert_eq!(arr.value(0), -23); // -2.3
+ }
+
+ #[test]
+ fn test_overflow_legacy_mode_returns_null() {
+ // Use precision 1 (max value 9), so 5 + 5 = 10 overflows
+ let batch = make_batch(vec![Some(5)], 38, 0, vec![Some(5)], 38, 0);
+ let result = eval_expr(&batch, WideDecimalOp::Add, 1, 0,
EvalMode::Legacy).unwrap();
+ let arr = result.as_primitive::<Decimal128Type>();
+ assert!(arr.is_null(0));
+ }
+
+ #[test]
+ fn test_overflow_ansi_mode_returns_error() {
+ let batch = make_batch(vec![Some(5)], 38, 0, vec![Some(5)], 38, 0);
+ let result = eval_expr(&batch, WideDecimalOp::Add, 1, 0,
EvalMode::Ansi);
+ assert!(result.is_err());
+ }
+
+ #[test]
+ fn test_null_propagation() {
+ let batch = make_batch(vec![Some(100), None], 10, 2, vec![None,
Some(200)], 10, 2);
+ let result = eval_expr(&batch, WideDecimalOp::Add, 38, 2,
EvalMode::Legacy).unwrap();
+ let arr = result.as_primitive::<Decimal128Type>();
+ assert!(arr.is_null(0));
+ assert!(arr.is_null(1));
+ }
+
+ #[test]
+ fn test_zeros() {
+ let batch = make_batch(vec![Some(0)], 38, 10, vec![Some(0)], 38, 10);
+ let result = eval_expr(&batch, WideDecimalOp::Multiply, 38, 10,
EvalMode::Legacy).unwrap();
+ let arr = result.as_primitive::<Decimal128Type>();
+ assert_eq!(arr.value(0), 0);
+ }
+
+ #[test]
+ fn test_max_precision_values() {
+ // Max Decimal128(38,0) value: 10^38 - 1
+ let max_val = 10i128.pow(38) - 1;
+ let batch = make_batch(vec![Some(max_val)], 38, 0, vec![Some(0)], 38,
0);
+ let result = eval_expr(&batch, WideDecimalOp::Add, 38, 0,
EvalMode::Legacy).unwrap();
+ let arr = result.as_primitive::<Decimal128Type>();
+ assert_eq!(arr.value(0), max_val);
+ }
+
+ #[test]
+ fn test_add_scale_up_to_output() {
+ // When s_out > max(s1, s2), result must be scaled UP
+ // Decimal128(10, 2) + Decimal128(10, 2) with output scale 4
+ // 1.50 + 0.25 = 1.75, at scale 4 = 17500
+ let batch = make_batch(
+ vec![Some(150)], // 1.50
+ 10,
+ 2,
+ vec![Some(25)], // 0.25
+ 10,
+ 2,
+ );
+ let result = eval_expr(&batch, WideDecimalOp::Add, 38, 4,
EvalMode::Legacy).unwrap();
+ let arr = result.as_primitive::<Decimal128Type>();
+ assert_eq!(arr.value(0), 17500); // 1.7500
+ }
+
+ #[test]
+ fn test_subtract_scale_up_to_output() {
+ // s_out (4) > max(s1, s2) (2) — verify scale-up path for subtract
+ let batch = make_batch(
+ vec![Some(300)], // 3.00
+ 10,
+ 2,
+ vec![Some(100)], // 1.00
+ 10,
+ 2,
+ );
+ let result = eval_expr(&batch, WideDecimalOp::Subtract, 38, 4,
EvalMode::Legacy).unwrap();
+ let arr = result.as_primitive::<Decimal128Type>();
+ assert_eq!(arr.value(0), 20000); // 2.0000
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]