This is an automated email from the ASF dual-hosted git repository.

gengliangwang pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.x by this push:
     new c67622af222d [SPARK-56909][SQL] Simplify Cast to int/long codegen 
under ANSI mode
c67622af222d is described below

commit c67622af222d9f795b9f34e5366fdbd63b875fdc
Author: Gengliang Wang <[email protected]>
AuthorDate: Fri May 22 15:10:24 2026 -0700

    [SPARK-56909][SQL] Simplify Cast to int/long codegen under ANSI mode
    
    ### What changes were proposed in this pull request?
    
    In `Cast.scala`, the ANSI codegen for narrowing casts to `int` / `long` 
previously emitted a 5-line inline body per call site (bounds check + cast + 
throw). After this PR it emits a single static call into the existing 
`LongExactNumeric` / `FloatExactNumeric` / `DoubleExactNumeric` objects in 
`numerics.scala`, which already implement the same overflow check + 
`castingCauseOverflowError` throw that this codegen needs.
    
    The rewrite uses the same `getClass.getCanonicalName.stripSuffix("$")` 
pattern as the adjacent `MathUtils` / `IntervalMathUtils` calls. The Scala 
compiler emits `public static` forwarders on the companion class of top-level 
objects, so generated Java code can call e.g. 
`org.apache.spark.sql.types.LongExactNumeric.toInt(v)` directly.
    
    Touched `Cast.scala` helpers:
    * `castIntegralTypeToIntegralTypeExactCode`: the `int` target branch now 
emits `LongExactNumeric.toInt($c)` (byte/short narrowing stays inline; 
refactored in SPARK-56910).
    * `castFractionToIntegralTypeCode`: the `int` / `long` target branches now 
emit `FloatExactNumeric` / `DoubleExactNumeric` `toInt` / `toLong` (byte/short 
narrowing stays inline; refactored in SPARK-56910).
    
    Primitive widening branches and the non-ANSI paths are untouched.
    
    ### Why are the changes needed?
    
    Part of SPARK-56908 (umbrella). The narrow-cast ANSI branches in 
`Cast.doGenCode` are some of the longer inline bodies still emitted per call 
site. Multiplied across the many cast paths in a TPC-DS plan, they contribute 
meaningfully to the generated source size and Janino compile time, and push 
whole-stage methods closer to the 64KB JVM method limit.
    
    Compared to v1 of this PR (which added a new `CastUtils.java` with 
`longToIntExact` / `floatToIntExact` / etc.), this version calls the existing 
`LongExactNumeric.toInt` / `FloatExactNumeric.toInt` / `toLong` / 
`DoubleExactNumeric.toInt` / `toLong` directly. Those are public static 
forwarders on top-level Scala objects that already implement the same 
`castingCauseOverflowError(v, FROM, TO)` throw — no new helper class needed. 
(Applying the same lesson cloud-fan called out on #55938.)
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    ```
    build/sbt "catalyst/testOnly *CastSuite *CastWithAnsiOnSuite \
      *CastWithAnsiOffSuite *AnsiCastSuite *TryCastSuite 
*ExpressionClassIdentitySuite"
    ```
    
    307/307 pass.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Generated-by: Cursor 1.x
    
    Closes #55934 from gengliangwang/SPARK-56909-cast-int-long.
    
    Authored-by: Gengliang Wang <[email protected]>
    Signed-off-by: Gengliang Wang <[email protected]>
    (cherry picked from commit 9e1344250a4e284c6e865105377c7cde4678afb5)
    Signed-off-by: Gengliang Wang <[email protected]>
---
 .../spark/sql/catalyst/expressions/Cast.scala      | 82 ++++++++++++++--------
 1 file changed, 54 insertions(+), 28 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index c51d3508d04a..419ca3f32d88 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -21,7 +21,7 @@ import java.time.{ZoneId, ZoneOffset}
 import java.util.Locale
 import java.util.concurrent.TimeUnit._
 
-import org.apache.spark.{QueryContext, SparkArithmeticException, 
SparkIllegalArgumentException}
+import org.apache.spark.{QueryContext, SparkArithmeticException, 
SparkException, SparkIllegalArgumentException}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
@@ -1988,16 +1988,28 @@ case class Cast(
       from: DataType,
       to: DataType): CastFunction = {
     assert(ansiEnabled)
-    val fromDt = ctx.addReferenceObj("from", from, from.getClass.getName)
-    val toDt = ctx.addReferenceObj("to", to, to.getClass.getName)
-    (c, evPrim, _) =>
-      code"""
-        if ($c == ($integralType) $c) {
-          $evPrim = ($integralType) $c;
-        } else {
-          throw QueryExecutionErrors.castingCauseOverflowError($c, $fromDt, 
$toDt);
-        }
-      """
+    if (integralType == "int") {
+      // Integral -> Int: call the existing *ExactNumeric.toInt directly. It 
already does the
+      // bounds check and throws castingCauseOverflowError -- same as the 
inline body.
+      // Only LongType reaches this branch today (`castToIntCode` gates on 
`case LongType`).
+      val numericObj = (from match {
+        case LongType => LongExactNumeric
+        case _ => throw SparkException.internalError(
+          s"Unexpected source type $from for 
castIntegralTypeToIntegralTypeExactCode int branch")
+      }).getClass.getCanonicalName.stripSuffix("$")
+      (c, evPrim, _) => code"$evPrim = $numericObj.toInt($c);"
+    } else {
+      val fromDt = ctx.addReferenceObj("from", from, from.getClass.getName)
+      val toDt = ctx.addReferenceObj("to", to, to.getClass.getName)
+      (c, evPrim, _) =>
+        code"""
+          if ($c == ($integralType) $c) {
+            $evPrim = ($integralType) $c;
+          } else {
+            throw QueryExecutionErrors.castingCauseOverflowError($c, $fromDt, 
$toDt);
+          }
+        """
+    }
   }
 
 
@@ -2017,23 +2029,37 @@ case class Cast(
       from: DataType,
       to: DataType): CastFunction = {
     assert(ansiEnabled)
-    val (min, max) = lowerAndUpperBound(integralType)
-    val mathClass = classOf[Math].getName
-    val fromDt = ctx.addReferenceObj("from", from, from.getClass.getName)
-    val toDt = ctx.addReferenceObj("to", to, to.getClass.getName)
-    // When casting floating values to integral types, Spark uses the method 
`Numeric.toInt`
-    // Or `Numeric.toLong` directly. For positive floating values, it is 
equivalent to `Math.floor`;
-    // for negative floating values, it is equivalent to `Math.ceil`.
-    // So, we can use the condition `Math.floor(x) <= upperBound && 
Math.ceil(x) >= lowerBound`
-    // to check if the floating value x is in the range of an integral type 
after rounding.
-    (c, evPrim, _) =>
-      code"""
-        if ($mathClass.floor($c) <= $max && $mathClass.ceil($c) >= $min) {
-          $evPrim = ($integralType) $c;
-        } else {
-          throw QueryExecutionErrors.castingCauseOverflowError($c, $fromDt, 
$toDt);
-        }
-      """
+    if (integralType == "int" || integralType == "long") {
+      // Float/Double -> Int/Long: call 
FloatExactNumeric/DoubleExactNumeric.toInt/toLong
+      // directly. Each already does the floor/ceil bounds check and throws
+      // castingCauseOverflowError -- same as the inline body.
+      val numericObj = (from match {
+        case FloatType => FloatExactNumeric
+        case DoubleType => DoubleExactNumeric
+        case _ => throw SparkException.internalError(
+          s"Unexpected source type $from for castFractionToIntegralTypeCode")
+      }).getClass.getCanonicalName.stripSuffix("$")
+      val method = s"to${integralType.capitalize}"
+      (c, evPrim, _) => code"$evPrim = $numericObj.$method($c);"
+    } else {
+      val (min, max) = lowerAndUpperBound(integralType)
+      val mathClass = classOf[Math].getName
+      val fromDt = ctx.addReferenceObj("from", from, from.getClass.getName)
+      val toDt = ctx.addReferenceObj("to", to, to.getClass.getName)
+      // When casting floating values to integral types, Spark uses the method 
`Numeric.toInt`
+      // Or `Numeric.toLong` directly. For positive floating values, it is 
equivalent to
+      // `Math.floor`; for negative floating values, it is equivalent to 
`Math.ceil`.
+      // So, we can use the condition `Math.floor(x) <= upperBound && 
Math.ceil(x) >= lowerBound`
+      // to check if the floating value x is in the range of an integral type 
after rounding.
+      (c, evPrim, _) =>
+        code"""
+          if ($mathClass.floor($c) <= $max && $mathClass.ceil($c) >= $min) {
+            $evPrim = ($integralType) $c;
+          } else {
+            throw QueryExecutionErrors.castingCauseOverflowError($c, $fromDt, 
$toDt);
+          }
+        """
+    }
   }
 
   private[this] def castToByteCode(from: DataType, ctx: CodegenContext): 
CastFunction = from match {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to