parthchandra commented on code in PR #3619:
URL: https://github.com/apache/datafusion-comet/pull/3619#discussion_r2897212225


##########
native/spark-expr/src/math_funcs/wide_decimal_binary_expr.rs:
##########
@@ -0,0 +1,502 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Fused wide-decimal binary expression for Decimal128 add/sub/mul that may 
overflow.
+//!
+//! Instead of building a 4-node expression tree (Cast→BinaryExpr→Cast→Cast), 
this performs
+//! i256 intermediate arithmetic in a single expression, producing only one 
output array.
+
+use crate::math_funcs::utils::get_precision_scale;
+use crate::EvalMode;
+use arrow::array::{Array, ArrayRef, AsArray, Decimal128Array};
+use arrow::datatypes::{i256, DataType, Decimal128Type, Schema};
+use arrow::error::ArrowError;
+use arrow::record_batch::RecordBatch;
+use datafusion::common::Result;
+use datafusion::logical_expr::ColumnarValue;
+use datafusion::physical_expr::PhysicalExpr;
+use std::fmt::{Display, Formatter};
+use std::hash::Hash;
+use std::{any::Any, sync::Arc};
+
+/// The arithmetic operation to perform.
+#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
+pub enum WideDecimalOp {
+    Add,
+    Subtract,
+    Multiply,
+}
+
+impl Display for WideDecimalOp {
+    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+        match self {
+            WideDecimalOp::Add => write!(f, "+"),
+            WideDecimalOp::Subtract => write!(f, "-"),
+            WideDecimalOp::Multiply => write!(f, "*"),
+        }
+    }
+}
+
+/// A fused expression that evaluates Decimal128 add/sub/mul using i256 
intermediate arithmetic,
+/// applies scale adjustment with HALF_UP rounding, checks precision bounds, 
and outputs
+/// a single Decimal128 array.
+#[derive(Debug, Eq)]
+pub struct WideDecimalBinaryExpr {
+    left: Arc<dyn PhysicalExpr>,
+    right: Arc<dyn PhysicalExpr>,
+    op: WideDecimalOp,
+    output_precision: u8,
+    output_scale: i8,
+    eval_mode: EvalMode,
+}
+
+impl Hash for WideDecimalBinaryExpr {
+    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+        self.left.hash(state);
+        self.right.hash(state);
+        self.op.hash(state);
+        self.output_precision.hash(state);
+        self.output_scale.hash(state);
+        self.eval_mode.hash(state);
+    }
+}
+
+impl PartialEq for WideDecimalBinaryExpr {
+    fn eq(&self, other: &Self) -> bool {
+        self.left.eq(&other.left)
+            && self.right.eq(&other.right)
+            && self.op == other.op
+            && self.output_precision == other.output_precision
+            && self.output_scale == other.output_scale
+            && self.eval_mode == other.eval_mode
+    }
+}
+
+impl WideDecimalBinaryExpr {
+    pub fn new(
+        left: Arc<dyn PhysicalExpr>,
+        right: Arc<dyn PhysicalExpr>,
+        op: WideDecimalOp,
+        output_precision: u8,
+        output_scale: i8,
+        eval_mode: EvalMode,
+    ) -> Self {
+        Self {
+            left,
+            right,
+            op,
+            output_precision,
+            output_scale,
+            eval_mode,
+        }
+    }
+}
+
+impl Display for WideDecimalBinaryExpr {
+    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+        write!(
+            f,
+            "WideDecimalBinaryExpr [{} {} {}, output: Decimal128({}, {})]",
+            self.left, self.op, self.right, self.output_precision, 
self.output_scale
+        )
+    }
+}
+
+/// Compute `value / divisor` with HALF_UP rounding.
+#[inline]
+fn div_round_half_up(value: i256, divisor: i256) -> i256 {
+    let (quot, rem) = (value / divisor, value % divisor);
+    // HALF_UP: if |remainder| * 2 >= |divisor|, round away from zero
+    let abs_rem_x2 = if rem < i256::ZERO {
+        rem.wrapping_neg()
+    } else {
+        rem
+    }
+    .wrapping_mul(i256::from_i128(2));
+    let abs_divisor = if divisor < i256::ZERO {
+        divisor.wrapping_neg()
+    } else {
+        divisor
+    };
+    if abs_rem_x2 >= abs_divisor {
+        if (value < i256::ZERO) != (divisor < i256::ZERO) {
+            quot.wrapping_sub(i256::ONE)
+        } else {
+            quot.wrapping_add(i256::ONE)
+        }
+    } else {
+        quot
+    }
+}
+
+/// i256 constant for 10.
+const I256_TEN: i256 = i256::from_i128(10);
+
+/// Compute 10^exp as i256.
+#[inline]
+fn i256_pow10(exp: u32) -> i256 {
+    let mut result = i256::ONE;
+    for _ in 0..exp {
+        result = result.wrapping_mul(I256_TEN);
+    }
+    result
+}
+
+/// Maximum i128 value for a given decimal precision (1-indexed).
+/// precision p allows values in [-10^p + 1, 10^p - 1].
+#[inline]
+fn max_for_precision(precision: u8) -> i256 {
+    i256_pow10(precision as u32).wrapping_sub(i256::ONE)
+}
+
+impl PhysicalExpr for WideDecimalBinaryExpr {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
+        Ok(DataType::Decimal128(
+            self.output_precision,
+            self.output_scale,
+        ))
+    }
+
+    fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
+        Ok(true)
+    }
+
+    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
+        let left_val = self.left.evaluate(batch)?;
+        let right_val = self.right.evaluate(batch)?;
+
+        let (left_arr, right_arr): (ArrayRef, ArrayRef) = match (&left_val, 
&right_val) {
+            (ColumnarValue::Array(l), ColumnarValue::Array(r)) => 
(Arc::clone(l), Arc::clone(r)),
+            (ColumnarValue::Scalar(l), ColumnarValue::Array(r)) => {
+                (l.to_array_of_size(r.len())?, Arc::clone(r))
+            }
+            (ColumnarValue::Array(l), ColumnarValue::Scalar(r)) => {
+                (Arc::clone(l), r.to_array_of_size(l.len())?)
+            }
+            (ColumnarValue::Scalar(l), ColumnarValue::Scalar(r)) => 
(l.to_array()?, r.to_array()?),
+        };
+
+        let left = left_arr.as_primitive::<Decimal128Type>();
+        let right = right_arr.as_primitive::<Decimal128Type>();
+        let (_p1, s1) = get_precision_scale(left.data_type());
+        let (_p2, s2) = get_precision_scale(right.data_type());
+
+        let p_out = self.output_precision;
+        let s_out = self.output_scale;
+        let op = self.op;
+        let eval_mode = self.eval_mode;
+
+        let bound = max_for_precision(p_out);
+        let neg_bound = i256::ZERO.wrapping_sub(bound);
+
+        let result: Decimal128Array = match op {
+            WideDecimalOp::Add | WideDecimalOp::Subtract => {
+                let max_scale = std::cmp::max(s1, s2);
+                let l_scale_up = i256_pow10((max_scale - s1) as u32);
+                let r_scale_up = i256_pow10((max_scale - s2) as u32);
+                let need_rescale = s_out < max_scale;
+                let rescale_divisor = if need_rescale {
+                    i256_pow10((max_scale - s_out) as u32)
+                } else {
+                    i256::ONE

Review Comment:
   @andygrove did you miss this comment? 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to