This is an automated email from the ASF dual-hosted git repository.
cloud-fan pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.x by this push:
new a8a156fca1fa [SPARK-56913][SQL] Simplify BinaryArithmetic byte/short
codegen under ANSI mode
a8a156fca1fa is described below
commit a8a156fca1fa4d93855c7339cd0bf90c57ef7101
Author: Gengliang Wang <[email protected]>
AuthorDate: Thu May 21 21:32:43 2026 +0800
[SPARK-56913][SQL] Simplify BinaryArithmetic byte/short codegen under ANSI
mode
### What changes were proposed in this pull request?
In `BinaryArithmetic.doGenCode`, the `Byte`/`Short` ANSI overflow-check
branch previously emitted ~7 lines per call site (int tmpResult + overflow
check + cast back). After this PR it emits a single static call into the
existing `ByteExactNumeric` / `ShortExactNumeric` (in `numerics.scala`), which
already implements the same overflow check + `BINARY_ARITHMETIC_OVERFLOW` error
that the eval path uses.
The codegen rewrite uses the same
`getClass.getCanonicalName.stripSuffix("$")` pattern as the adjacent
`MathUtils` / `IntervalMathUtils` calls. The Scala compiler emits `public
static` forwarders on the companion class of top-level objects, so the
generated Java code calls e.g.
`org.apache.spark.sql.types.ByteExactNumeric.plus(a, b)` directly.
Primitive `int`/`long`/`float`/`double` branches are intentionally left
inline (single bytecode op; routing those through a static method would be a
runtime regression).
### Why are the changes needed?
Part of SPARK-56908 (umbrella). The Byte/Short ANSI branch is the largest
single inline body in `BinaryArithmetic.doGenCode`.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
```
build/sbt "catalyst/testOnly *ArithmeticExpressionSuite"
```
35/35 pass.
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Cursor 1.x
Closes #55938 from gengliangwang/SPARK-56913-arithmetic-byte-short.
Lead-authored-by: Gengliang Wang <[email protected]>
Co-authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
(cherry picked from commit c3096ee570572f385a409d07988e7a75c524ecd1)
Signed-off-by: Wenchen Fan <[email protected]>
---
.../sql/catalyst/expressions/arithmetic.scala | 39 ++++++++--------------
1 file changed, 14 insertions(+), 25 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 1c93a6586761..348b45472c57 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -301,32 +301,21 @@ abstract class BinaryArithmetic extends BinaryOperator
with SupportQueryContext
val mathUtils =
IntervalMathUtils.getClass.getCanonicalName.stripSuffix("$")
defineCodeGen(ctx, ev, (eval1, eval2) =>
s"$mathUtils.${exactMathMethod.get}($eval1, $eval2)")
// byte and short are casted into int when add, minus, times or divide
+ case ByteType | ShortType if failOnError =>
+ val methodName = symbol match {
+ case "+" => "plus"
+ case "-" => "minus"
+ case "*" => "times"
+ case _ =>
+ throw SparkException.internalError(
+ s"Unexpected symbol '$symbol' for Byte/Short BinaryArithmetic")
+ }
+ val numericObj = (if (dataType == ByteType) ByteExactNumeric else
ShortExactNumeric)
+ .getClass.getCanonicalName.stripSuffix("$")
+ defineCodeGen(ctx, ev, (eval1, eval2) =>
s"$numericObj.$methodName($eval1, $eval2)")
case ByteType | ShortType =>
- nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
- val tmpResult = ctx.freshName("tmpResult")
- val try_suggestion = symbol match {
- case "+" => "try_add"
- case "-" => "try_subtract"
- case "*" => "try_multiply"
- case _ => "unknown_function"
- }
- val overflowCheck = if (failOnError) {
- val javaType = CodeGenerator.boxedType(dataType)
- s"""
- |if ($tmpResult < $javaType.MIN_VALUE || $tmpResult >
$javaType.MAX_VALUE) {
- | throw QueryExecutionErrors.binaryArithmeticCauseOverflowError(
- | $eval1, "$symbol", $eval2, "$try_suggestion");
- |}
- """.stripMargin
- } else {
- ""
- }
- s"""
- |${CodeGenerator.JAVA_INT} $tmpResult = $eval1 $symbol $eval2;
- |$overflowCheck
- |${ev.value} = (${CodeGenerator.javaType(dataType)})($tmpResult);
- """.stripMargin
- })
+ defineCodeGen(ctx, ev, (eval1, eval2) =>
+ s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)")
case IntegerType | LongType if failOnError && exactMathMethod.isDefined =>
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
val errorContext = getContextOrNullCode(ctx)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]