This is an automated email from the ASF dual-hosted git repository.
parthc pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new c3dd3a4c2 fix: handle scalar decimal value overflow correctly in ANSI
mode (#3803)
c3dd3a4c2 is described below
commit c3dd3a4c291f2c89ef043283c84aaaf96a39e8e4
Author: Parth Chandra <[email protected]>
AuthorDate: Fri Mar 27 10:47:01 2026 -0700
fix: handle scalar decimal value overflow correctly in ANSI mode (#3803)
* fix: handle scalar decimal value overflow correctly.
---
.../src/math_funcs/internal/checkoverflow.rs | 183 +++++++++++++++++++--
.../org/apache/comet/CometExpressionSuite.scala | 29 ++++
2 files changed, 196 insertions(+), 16 deletions(-)
diff --git a/native/spark-expr/src/math_funcs/internal/checkoverflow.rs
b/native/spark-expr/src/math_funcs/internal/checkoverflow.rs
index a9e8f6748..f1fb9c2f0 100644
--- a/native/spark-expr/src/math_funcs/internal/checkoverflow.rs
+++ b/native/spark-expr/src/math_funcs/internal/checkoverflow.rs
@@ -199,22 +199,38 @@ impl PhysicalExpr for CheckOverflow {
Ok(ColumnarValue::Array(new_array))
}
ColumnarValue::Scalar(ScalarValue::Decimal128(v, precision,
scale)) => {
- // `fail_on_error` is only true when ANSI is enabled, which we
don't support yet
- // (Java side will simply fallback to Spark when it is enabled)
- assert!(
- !self.fail_on_error,
- "fail_on_error (ANSI mode) is not supported yet"
- );
-
- let new_v: Option<i128> = v.and_then(|v| {
- Decimal128Type::validate_decimal_precision(v, precision,
scale)
- .map(|_| v)
- .ok()
- });
-
- Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
- new_v, precision, scale,
- )))
+ if self.fail_on_error {
+ if let Some(val) = v {
+ Decimal128Type::validate_decimal_precision(val,
precision, scale).map_err(
+ |_| {
+ let spark_error =
+ crate::error::decimal_overflow_error(val,
precision, scale);
+ if let Some(ctx) = &self.query_context {
+ DataFusionError::External(Box::new(
+
crate::SparkErrorWithContext::with_context(
+ spark_error,
+ Arc::clone(ctx),
+ ),
+ ))
+ } else {
+
DataFusionError::External(Box::new(spark_error))
+ }
+ },
+ )?;
+ }
+ Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
+ v, precision, scale,
+ )))
+ } else {
+ let new_v: Option<i128> = v.and_then(|v| {
+ Decimal128Type::validate_decimal_precision(v,
precision, scale)
+ .map(|_| v)
+ .ok()
+ });
+ Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
+ new_v, precision, scale,
+ )))
+ }
}
v => Err(DataFusionError::Execution(format!(
"CheckOverflow's child expression should be decimal array, but
found {v:?}"
@@ -239,3 +255,138 @@ impl PhysicalExpr for CheckOverflow {
)))
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use arrow::datatypes::{Field, Schema};
+ use arrow::record_batch::RecordBatch;
+ use std::fmt::{Display, Formatter};
+
+ /// Helper that always returns a fixed Decimal128 scalar.
+ #[derive(Debug, Eq, PartialEq, Hash)]
+ struct ScalarChild(Option<i128>, u8, i8);
+
+ impl Display for ScalarChild {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ write!(f, "ScalarChild({:?})", self.0)
+ }
+ }
+
+ impl PhysicalExpr for ScalarChild {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+ fn data_type(&self, _: &Schema) ->
datafusion::common::Result<DataType> {
+ Ok(DataType::Decimal128(self.1, self.2))
+ }
+ fn nullable(&self, _: &Schema) -> datafusion::common::Result<bool> {
+ Ok(true)
+ }
+ fn evaluate(&self, _: &RecordBatch) ->
datafusion::common::Result<ColumnarValue> {
+ Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
+ self.0, self.1, self.2,
+ )))
+ }
+ fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
+ vec![]
+ }
+ fn with_new_children(
+ self: Arc<Self>,
+ _: Vec<Arc<dyn PhysicalExpr>>,
+ ) -> datafusion::common::Result<Arc<dyn PhysicalExpr>> {
+ Ok(self)
+ }
+ fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ Display::fmt(self, f)
+ }
+ }
+
+ fn empty_batch() -> RecordBatch {
+ let schema = Schema::new(vec![Field::new("x", DataType::Decimal128(38,
0), true)]);
+ RecordBatch::new_empty(Arc::new(schema))
+ }
+
+ fn make_check_overflow(
+ value: Option<i128>,
+ precision: u8,
+ scale: i8,
+ fail_on_error: bool,
+ ) -> CheckOverflow {
+ CheckOverflow::new(
+ Arc::new(ScalarChild(value, precision, scale)),
+ DataType::Decimal128(precision, scale),
+ fail_on_error,
+ None,
+ None,
+ )
+ }
+
+ // --- scalar, fail_on_error = false (legacy mode) ---
+
+ #[test]
+ fn test_scalar_no_overflow_legacy() {
+ // 999 fits in precision 3, scale 0 → returned as-is
+ let expr = make_check_overflow(Some(999), 3, 0, false);
+ let result = expr.evaluate(&empty_batch()).unwrap();
+ match result {
+ ColumnarValue::Scalar(ScalarValue::Decimal128(v, 3, 0)) =>
assert_eq!(v, Some(999)),
+ other => panic!("unexpected: {other:?}"),
+ }
+ }
+
+ #[test]
+ fn test_scalar_overflow_returns_null_in_legacy_mode() {
+ // 1000 does not fit in precision 3 → null, no error
+ let expr = make_check_overflow(Some(1000), 3, 0, false);
+ let result = expr.evaluate(&empty_batch()).unwrap();
+ match result {
+ ColumnarValue::Scalar(ScalarValue::Decimal128(v, 3, 0)) =>
assert_eq!(v, None),
+ other => panic!("unexpected: {other:?}"),
+ }
+ }
+
+ #[test]
+ fn test_scalar_null_passthrough_legacy() {
+ let expr = make_check_overflow(None, 3, 0, false);
+ let result = expr.evaluate(&empty_batch()).unwrap();
+ match result {
+ ColumnarValue::Scalar(ScalarValue::Decimal128(v, 3, 0)) =>
assert_eq!(v, None),
+ other => panic!("unexpected: {other:?}"),
+ }
+ }
+
+ // --- scalar, fail_on_error = true (ANSI mode) ---
+
+ #[test]
+ fn test_scalar_no_overflow_ansi() {
+ // 999 fits in precision 3 → returned as-is, no error
+ let expr = make_check_overflow(Some(999), 3, 0, true);
+ let result = expr.evaluate(&empty_batch()).unwrap();
+ match result {
+ ColumnarValue::Scalar(ScalarValue::Decimal128(v, 3, 0)) =>
assert_eq!(v, Some(999)),
+ other => panic!("unexpected: {other:?}"),
+ }
+ }
+
+ #[test]
+ fn test_scalar_overflow_returns_error_in_ansi_mode() {
+ // 1000 does not fit in precision 3 → error, not Ok(None)
+ // This is the case that previously panicked with "fail_on_error (ANSI
mode) is not
+ // supported yet".
+ let expr = make_check_overflow(Some(1000), 3, 0, true);
+ let result = expr.evaluate(&empty_batch());
+ assert!(result.is_err(), "expected error on overflow in ANSI mode");
+ }
+
+ #[test]
+ fn test_scalar_null_passthrough_ansi() {
+ // None input → None output even in ANSI mode (no value to overflow)
+ let expr = make_check_overflow(None, 3, 0, true);
+ let result = expr.evaluate(&empty_batch()).unwrap();
+ match result {
+ ColumnarValue::Scalar(ScalarValue::Decimal128(v, 3, 0)) =>
assert_eq!(v, None),
+ other => panic!("unexpected: {other:?}"),
+ }
+ }
+}
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index 68c1a82f1..9fdd5a677 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -1271,6 +1271,35 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
+ test("scalar decimal overflow - legacy mode produces null") {
+ // 1.1e19 * 1.1e19 = 1.21e38 fits in i128 (max ~1.7e38) but exceeds
DECIMAL(38,0)'s
+ // max of 10^38-1, so CheckOverflow nulls the result in legacy (non-ANSI)
mode.
+ withSQLConf(CometConf.COMET_ENABLED.key -> "true",
SQLConf.ANSI_ENABLED.key -> "false") {
+ withParquetTable(Seq((BigDecimal("11000000000000000000"), 0)), "tbl") {
+ checkSparkAnswerAndOperator("SELECT _1 * _1 FROM tbl")
+ }
+ }
+ }
+
+ test("scalar decimal overflow - ANSI mode throws ArithmeticException") {
+ // 1.1e19 * 1.1e19 = 1.21e38 overflows DECIMAL(38,0). With ANSI mode on,
both Spark and
+ // Comet must throw — Comet must not panic or silently return null. Spark
reports
+ // NUMERIC_VALUE_OUT_OF_RANGE; Comet's WideDecimalBinaryExpr catches the
overflow first
+ // and surfaces it as an arithmetic overflow error.
+ withSQLConf(CometConf.COMET_ENABLED.key -> "true",
SQLConf.ANSI_ENABLED.key -> "true") {
+ withParquetTable(Seq((BigDecimal("11000000000000000000"), 0)), "tbl") {
+ val res = sql("SELECT _1 * _1 FROM tbl")
+ checkSparkAnswerMaybeThrows(res) match {
+ case (Some(sparkExc), Some(cometExc)) =>
+ assert(sparkExc.getMessage.contains("NUMERIC_VALUE_OUT_OF_RANGE"))
+ assert(cometExc.getMessage.toLowerCase.contains("overflow"))
+ case _ =>
+ fail("Expected exception for decimal overflow in ANSI mode")
+ }
+ }
+ }
+ }
+
test("cast decimals to int") {
Seq(16, 1024).foreach { batchSize =>
withSQLConf(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]