This is an automated email from the ASF dual-hosted git repository.
gengliangwang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 251df783c65e [SPARK-56914][SQL] Simplify decimal arithmetic codegen
under ANSI mode
251df783c65e is described below
commit 251df783c65ed22e3388b2c04e337ce1bdc069ea
Author: Gengliang Wang <[email protected]>
AuthorDate: Fri May 29 09:06:29 2026 -0700
[SPARK-56914][SQL] Simplify decimal arithmetic codegen under ANSI mode
### What changes were proposed in this pull request?
Use `CastUtils.changePrecisionExact` / `changePrecisionOrNull` (added in
SPARK-56911) from the `DecimalType.Fixed` codegen branches of:
* `BinaryArithmetic.doGenCode` (covers `Add` / `Subtract` / `Multiply` on
`Decimal`).
* `BinaryDivModLike.doGenCode` (covers `Divide` / `IntegralDivide` /
`Remainder` / `Pmod` on `Decimal`).
Each codegen call site goes from `eval1.$op(eval2).toPrecision(p, s,
ROUND_HALF_UP, !failOnError, ctx)` + a 4-line null check to a single
`CastUtils.changePrecision{Exact,OrNull}` call.
The eval path (`BinaryArithmetic.checkDecimalOverflow`) is left as the
original one-line `value.toPrecision(p, s, ROUND_HALF_UP, !failOnError,
getContextOrNull())`. Per the review on #55938 — routing a one-line eval call
through a new helper would just be a different route to the same logic without
a real win.
### Why are the changes needed?
Part of SPARK-56908 (umbrella). Decimal arithmetic is widespread in TPC-DS
plans, and the `BinaryArithmetic` Decimal branch was one of the longer ANSI
codegen bodies still emitted inline.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
```
build/sbt "catalyst/testOnly *ArithmeticExpressionSuite *DecimalSuite"
```
60/60 pass.
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Cursor 1.x
Closes #55939 from gengliangwang/SPARK-56914-decimal-arithmetic.
Authored-by: Gengliang Wang <[email protected]>
Signed-off-by: Gengliang Wang <[email protected]>
---
.../sql/catalyst/expressions/arithmetic.scala | 52 +++++++++++++---------
1 file changed, 32 insertions(+), 20 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 348b45472c57..23fffb162a8f 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -278,19 +278,21 @@ abstract class BinaryArithmetic extends BinaryOperator
with SupportQueryContext
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
dataType match {
case DecimalType.Fixed(precision, scale) =>
- val errorContextCode = getContextOrNullCode(ctx, failOnError)
- val updateIsNull = if (failOnError) {
- ""
+ val castUtils = classOf[CastUtils].getName
+ if (failOnError) {
+ val errorContextCode = getContextOrNullCode(ctx)
+ defineCodeGen(ctx, ev, (eval1, eval2) =>
+ s"$castUtils.changePrecisionExact(" +
+ s"$eval1.$decimalMethod($eval2), $precision, $scale,
$errorContextCode)")
} else {
- s"${ev.isNull} = ${ev.value} == null;"
+ nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
+ s"""
+ |${ev.value} = $castUtils.changePrecisionOrNull(
+ | $eval1.$decimalMethod($eval2), $precision, $scale);
+ |${ev.isNull} = ${ev.value} == null;
+ """.stripMargin
+ })
}
- nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
- s"""
- |${ev.value} = $eval1.$decimalMethod($eval2).toPrecision(
- | $precision, $scale, Decimal.ROUND_HALF_UP(), ${!failOnError},
$errorContextCode);
- |$updateIsNull
- """.stripMargin
- })
case CalendarIntervalType =>
val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
defineCodeGen(ctx, ev, (eval1, eval2) =>
s"$iu.$calendarIntervalMethod($eval1, $eval2)")
@@ -706,16 +708,26 @@ trait DivModLike extends BinaryArithmetic {
val errorContextCode = getContextOrNullCode(ctx, failOnError)
val operation = super.dataType match {
case DecimalType.Fixed(precision, scale) =>
+ val castUtils = classOf[CastUtils].getName
val decimalValue = ctx.freshName("decimalValue")
- s"""
- |Decimal $decimalValue =
${eval1.value}.$decimalMethod(${eval2.value}).toPrecision(
- | $precision, $scale, Decimal.ROUND_HALF_UP(), ${!failOnError},
$errorContextCode);
- |if ($decimalValue != null) {
- | ${ev.value} = ${decimalToDataTypeCodeGen(s"$decimalValue")};
- |} else {
- | ${ev.isNull} = true;
- |}
- |""".stripMargin
+ if (failOnError) {
+ s"""
+ |Decimal $decimalValue = $castUtils.changePrecisionExact(
+ | ${eval1.value}.$decimalMethod(${eval2.value}), $precision,
$scale,
+ | $errorContextCode);
+ |${ev.value} = ${decimalToDataTypeCodeGen(s"$decimalValue")};
+ |""".stripMargin
+ } else {
+ s"""
+ |Decimal $decimalValue = $castUtils.changePrecisionOrNull(
+ | ${eval1.value}.$decimalMethod(${eval2.value}), $precision,
$scale);
+ |if ($decimalValue != null) {
+ | ${ev.value} = ${decimalToDataTypeCodeGen(s"$decimalValue")};
+ |} else {
+ | ${ev.isNull} = true;
+ |}
+ |""".stripMargin
+ }
case _ => s"${ev.value} = ($javaType)(${eval1.value} $symbol
${eval2.value});"
}
val checkIntegralDivideOverflow = if (checkDivideOverflow) {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]