This is an automated email from the ASF dual-hosted git repository.
gengliangwang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 9e1344250a4e [SPARK-56909][SQL] Simplify Cast to int/long codegen
under ANSI mode
9e1344250a4e is described below
commit 9e1344250a4e284c6e865105377c7cde4678afb5
Author: Gengliang Wang <[email protected]>
AuthorDate: Fri May 22 15:10:24 2026 -0700
[SPARK-56909][SQL] Simplify Cast to int/long codegen under ANSI mode
### What changes were proposed in this pull request?
In `Cast.scala`, the ANSI codegen for narrowing casts to `int` / `long`
previously emitted a 5-line inline body per call site (bounds check + cast +
throw). After this PR it emits a single static call into the existing
`LongExactNumeric` / `FloatExactNumeric` / `DoubleExactNumeric` objects in
`numerics.scala`, which already implement the same overflow check +
`castingCauseOverflowError` throw that this codegen needs.
The rewrite uses the same `getClass.getCanonicalName.stripSuffix("$")`
pattern as the adjacent `MathUtils` / `IntervalMathUtils` calls. The Scala
compiler emits `public static` forwarders on the companion class of top-level
objects, so generated Java code can call e.g.
`org.apache.spark.sql.types.LongExactNumeric.toInt(v)` directly.
Touched `Cast.scala` helpers:
* `castIntegralTypeToIntegralTypeExactCode`: the `int` target branch now
emits `LongExactNumeric.toInt($c)` (byte/short narrowing stays inline;
refactored in SPARK-56910).
* `castFractionToIntegralTypeCode`: the `int` / `long` target branches now
emit `FloatExactNumeric` / `DoubleExactNumeric` `toInt` / `toLong` (byte/short
narrowing stays inline; refactored in SPARK-56910).
Primitive widening branches and the non-ANSI paths are untouched.
### Why are the changes needed?
Part of SPARK-56908 (umbrella). The narrow-cast ANSI branches in
`Cast.doGenCode` are some of the longer inline bodies still emitted per call
site. Multiplied across the many cast paths in a TPC-DS plan, they contribute
meaningfully to the generated source size and Janino compile time, and push
whole-stage methods closer to the 64KB JVM method limit.
Compared to v1 of this PR (which added a new `CastUtils.java` with
`longToIntExact` / `floatToIntExact` / etc.), this version calls the existing
`LongExactNumeric.toInt` / `FloatExactNumeric.toInt` / `toLong` /
`DoubleExactNumeric.toInt` / `toLong` directly. Those are public static
forwarders on top-level Scala objects that already implement the same
`castingCauseOverflowError(v, FROM, TO)` throw — no new helper class needed.
(Applying the same lesson cloud-fan called out on #55938.)
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
```
build/sbt "catalyst/testOnly *CastSuite *CastWithAnsiOnSuite \
*CastWithAnsiOffSuite *AnsiCastSuite *TryCastSuite
*ExpressionClassIdentitySuite"
```
307/307 pass.
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Cursor 1.x
Closes #55934 from gengliangwang/SPARK-56909-cast-int-long.
Authored-by: Gengliang Wang <[email protected]>
Signed-off-by: Gengliang Wang <[email protected]>
---
.../spark/sql/catalyst/expressions/Cast.scala | 82 ++++++++++++++--------
1 file changed, 54 insertions(+), 28 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index c51d3508d04a..419ca3f32d88 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -21,7 +21,7 @@ import java.time.{ZoneId, ZoneOffset}
import java.util.Locale
import java.util.concurrent.TimeUnit._
-import org.apache.spark.{QueryContext, SparkArithmeticException,
SparkIllegalArgumentException}
+import org.apache.spark.{QueryContext, SparkArithmeticException,
SparkException, SparkIllegalArgumentException}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
@@ -1988,16 +1988,28 @@ case class Cast(
from: DataType,
to: DataType): CastFunction = {
assert(ansiEnabled)
- val fromDt = ctx.addReferenceObj("from", from, from.getClass.getName)
- val toDt = ctx.addReferenceObj("to", to, to.getClass.getName)
- (c, evPrim, _) =>
- code"""
- if ($c == ($integralType) $c) {
- $evPrim = ($integralType) $c;
- } else {
- throw QueryExecutionErrors.castingCauseOverflowError($c, $fromDt,
$toDt);
- }
- """
+ if (integralType == "int") {
+ // Integral -> Int: call the existing *ExactNumeric.toInt directly. It
already does the
+ // bounds check and throws castingCauseOverflowError -- same as the
inline body.
+ // Only LongType reaches this branch today (`castToIntCode` gates on
`case LongType`).
+ val numericObj = (from match {
+ case LongType => LongExactNumeric
+ case _ => throw SparkException.internalError(
+ s"Unexpected source type $from for
castIntegralTypeToIntegralTypeExactCode int branch")
+ }).getClass.getCanonicalName.stripSuffix("$")
+ (c, evPrim, _) => code"$evPrim = $numericObj.toInt($c);"
+ } else {
+ val fromDt = ctx.addReferenceObj("from", from, from.getClass.getName)
+ val toDt = ctx.addReferenceObj("to", to, to.getClass.getName)
+ (c, evPrim, _) =>
+ code"""
+ if ($c == ($integralType) $c) {
+ $evPrim = ($integralType) $c;
+ } else {
+ throw QueryExecutionErrors.castingCauseOverflowError($c, $fromDt,
$toDt);
+ }
+ """
+ }
}
@@ -2017,23 +2029,37 @@ case class Cast(
from: DataType,
to: DataType): CastFunction = {
assert(ansiEnabled)
- val (min, max) = lowerAndUpperBound(integralType)
- val mathClass = classOf[Math].getName
- val fromDt = ctx.addReferenceObj("from", from, from.getClass.getName)
- val toDt = ctx.addReferenceObj("to", to, to.getClass.getName)
- // When casting floating values to integral types, Spark uses the method
`Numeric.toInt`
- // Or `Numeric.toLong` directly. For positive floating values, it is
equivalent to `Math.floor`;
- // for negative floating values, it is equivalent to `Math.ceil`.
- // So, we can use the condition `Math.floor(x) <= upperBound &&
Math.ceil(x) >= lowerBound`
- // to check if the floating value x is in the range of an integral type
after rounding.
- (c, evPrim, _) =>
- code"""
- if ($mathClass.floor($c) <= $max && $mathClass.ceil($c) >= $min) {
- $evPrim = ($integralType) $c;
- } else {
- throw QueryExecutionErrors.castingCauseOverflowError($c, $fromDt,
$toDt);
- }
- """
+ if (integralType == "int" || integralType == "long") {
+ // Float/Double -> Int/Long: call
FloatExactNumeric/DoubleExactNumeric.toInt/toLong
+ // directly. Each already does the floor/ceil bounds check and throws
+ // castingCauseOverflowError -- same as the inline body.
+ val numericObj = (from match {
+ case FloatType => FloatExactNumeric
+ case DoubleType => DoubleExactNumeric
+ case _ => throw SparkException.internalError(
+ s"Unexpected source type $from for castFractionToIntegralTypeCode")
+ }).getClass.getCanonicalName.stripSuffix("$")
+ val method = s"to${integralType.capitalize}"
+ (c, evPrim, _) => code"$evPrim = $numericObj.$method($c);"
+ } else {
+ val (min, max) = lowerAndUpperBound(integralType)
+ val mathClass = classOf[Math].getName
+ val fromDt = ctx.addReferenceObj("from", from, from.getClass.getName)
+ val toDt = ctx.addReferenceObj("to", to, to.getClass.getName)
+ // When casting floating values to integral types, Spark uses the method
`Numeric.toInt`
+ // Or `Numeric.toLong` directly. For positive floating values, it is
equivalent to
+ // `Math.floor`; for negative floating values, it is equivalent to
`Math.ceil`.
+ // So, we can use the condition `Math.floor(x) <= upperBound &&
Math.ceil(x) >= lowerBound`
+ // to check if the floating value x is in the range of an integral type
after rounding.
+ (c, evPrim, _) =>
+ code"""
+ if ($mathClass.floor($c) <= $max && $mathClass.ceil($c) >= $min) {
+ $evPrim = ($integralType) $c;
+ } else {
+ throw QueryExecutionErrors.castingCauseOverflowError($c, $fromDt,
$toDt);
+ }
+ """
+ }
}
private[this] def castToByteCode(from: DataType, ctx: CodegenContext):
CastFunction = from match {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]