This is an automated email from the ASF dual-hosted git repository.

cloud-fan pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.x by this push:
     new a8a156fca1fa [SPARK-56913][SQL] Simplify BinaryArithmetic byte/short 
codegen under ANSI mode
a8a156fca1fa is described below

commit a8a156fca1fa4d93855c7339cd0bf90c57ef7101
Author: Gengliang Wang <[email protected]>
AuthorDate: Thu May 21 21:32:43 2026 +0800

    [SPARK-56913][SQL] Simplify BinaryArithmetic byte/short codegen under ANSI 
mode
    
    ### What changes were proposed in this pull request?
    
    In `BinaryArithmetic.doGenCode`, the `Byte`/`Short` ANSI overflow-check 
branch previously emitted ~7 lines per call site (int tmpResult + overflow 
check + cast back). After this PR it emits a single static call into the 
existing `ByteExactNumeric` / `ShortExactNumeric` (in `numerics.scala`), which 
already implements the same overflow check + `BINARY_ARITHMETIC_OVERFLOW` error 
that the eval path uses.
    
    The codegen rewrite uses the same 
`getClass.getCanonicalName.stripSuffix("$")` pattern as the adjacent 
`MathUtils` / `IntervalMathUtils` calls. The Scala compiler emits `public 
static` forwarders on the companion class of top-level objects, so the 
generated Java code calls e.g. 
`org.apache.spark.sql.types.ByteExactNumeric.plus(a, b)` directly.
    
    Primitive `int`/`long`/`float`/`double` branches are intentionally left 
inline (single bytecode op; routing those through a static method would be a 
runtime regression).
    
    ### Why are the changes needed?
    
    Part of SPARK-56908 (umbrella). The Byte/Short ANSI branch is the largest 
single inline body in `BinaryArithmetic.doGenCode`.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    ```
    build/sbt "catalyst/testOnly *ArithmeticExpressionSuite"
    ```
    
    35/35 pass.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Generated-by: Cursor 1.x
    
    Closes #55938 from gengliangwang/SPARK-56913-arithmetic-byte-short.
    
    Lead-authored-by: Gengliang Wang <[email protected]>
    Co-authored-by: Wenchen Fan <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
    (cherry picked from commit c3096ee570572f385a409d07988e7a75c524ecd1)
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../sql/catalyst/expressions/arithmetic.scala      | 39 ++++++++--------------
 1 file changed, 14 insertions(+), 25 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 1c93a6586761..348b45472c57 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -301,32 +301,21 @@ abstract class BinaryArithmetic extends BinaryOperator 
with SupportQueryContext
       val mathUtils = 
IntervalMathUtils.getClass.getCanonicalName.stripSuffix("$")
       defineCodeGen(ctx, ev, (eval1, eval2) => 
s"$mathUtils.${exactMathMethod.get}($eval1, $eval2)")
     // byte and short are casted into int when add, minus, times or divide
+    case ByteType | ShortType if failOnError =>
+      val methodName = symbol match {
+        case "+" => "plus"
+        case "-" => "minus"
+        case "*" => "times"
+        case _ =>
+          throw SparkException.internalError(
+            s"Unexpected symbol '$symbol' for Byte/Short BinaryArithmetic")
+      }
+      val numericObj = (if (dataType == ByteType) ByteExactNumeric else 
ShortExactNumeric)
+        .getClass.getCanonicalName.stripSuffix("$")
+      defineCodeGen(ctx, ev, (eval1, eval2) => 
s"$numericObj.$methodName($eval1, $eval2)")
     case ByteType | ShortType =>
-      nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
-        val tmpResult = ctx.freshName("tmpResult")
-        val try_suggestion = symbol match {
-          case "+" => "try_add"
-          case "-" => "try_subtract"
-          case "*" => "try_multiply"
-          case _ => "unknown_function"
-        }
-        val overflowCheck = if (failOnError) {
-          val javaType = CodeGenerator.boxedType(dataType)
-          s"""
-             |if ($tmpResult < $javaType.MIN_VALUE || $tmpResult > 
$javaType.MAX_VALUE) {
-             |  throw QueryExecutionErrors.binaryArithmeticCauseOverflowError(
-             |  $eval1, "$symbol", $eval2, "$try_suggestion");
-             |}
-           """.stripMargin
-        } else {
-          ""
-        }
-        s"""
-           |${CodeGenerator.JAVA_INT} $tmpResult = $eval1 $symbol $eval2;
-           |$overflowCheck
-           |${ev.value} = (${CodeGenerator.javaType(dataType)})($tmpResult);
-         """.stripMargin
-      })
+      defineCodeGen(ctx, ev, (eval1, eval2) =>
+        s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)")
     case IntegerType | LongType if failOnError && exactMathMethod.isDefined =>
       nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
         val errorContext = getContextOrNullCode(ctx)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to