This is an automated email from the ASF dual-hosted git repository.
gengliangwang pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.x by this push:
new 19c1b95b9519 [SPARK-56910][SQL] Simplify Cast to byte/short codegen
under ANSI mode
19c1b95b9519 is described below
commit 19c1b95b951905f811401e1e6112cce41dad6a23
Author: Gengliang Wang <[email protected]>
AuthorDate: Wed May 27 11:30:45 2026 -0700
[SPARK-56910][SQL] Simplify Cast to byte/short codegen under ANSI mode
### What changes were proposed in this pull request?
Introduce `CastUtils.java` with nine static helpers for ANSI
overflow-checked narrowing to `byte` / `short`, and use them from `Cast.scala`
(both codegen and eval paths).
Helpers added:
* `shortToByteExact(short)`, `intToByteExact(int)`, `longToByteExact(long)`
* `intToShortExact(int)`, `longToShortExact(long)`
* `floatToByteExact(float)`, `doubleToByteExact(double)`
* `floatToShortExact(float)`, `doubleToShortExact(double)`
`ByteExactNumeric` / `ShortExactNumeric` only expose same-type identity
narrowing (their `toByte(byte)` / `toShort(short)` are trivial), so unlike the
`int` / `long` targets refactored in #55934 — which delegate to
`LongExactNumeric.toInt` / `FloatExactNumeric.toInt` /
`DoubleExactNumeric.toInt` / `toLong` — there is no existing Scala object to
route the byte/short narrowing through. The Java helper is the cleanest fit.
`Cast.scala` changes:
* `castIntegralTypeToIntegralTypeExactCode`: the `byte` / `short` branch
(previously an inline 5-line if/throw block) emits a single
`CastUtils.${integralPrefix(from)}To${target.capitalize}Exact($c)` call. The
`int` branch (added in #55934) is unchanged.
* `castFractionToIntegralTypeCode`: the `byte` / `short` branch (previously
an inline 5-line floor/ceil block plus `lowerAndUpperBound`) emits a single
`CastUtils.${fractionalPrefix(from)}To${target.capitalize}Exact($c)` call. The
`int` / `long` branch (added in #55934) is unchanged. The now-unused
`lowerAndUpperBound` Scala helper is removed.
* Eval paths for `castToByte` and `castToShort` add ANSI cases for
`ShortType` / `IntegerType` / `LongType` / `FloatType` / `DoubleType` source
types that delegate to the new helpers, replacing the existing multi-line
`exactNumeric.toInt(b) + bounds-check` body.
* Two small `integralPrefix(from: DataType)` / `fractionalPrefix(from:
DataType)` Scala helpers handle the method-name dispatch.
### Why are the changes needed?
Part of SPARK-56908 (umbrella). The byte/short narrowing ANSI bodies were 5
lines each across 8 codegen call sites; this PR collapses them to one line per
call site, matching the int/long target work merged in #55934.
### Does this PR introduce _any_ user-facing change?
No. The compiled behavior is identical; only the emitted Java source text
changes.
### How was this patch tested?
```
build/sbt "catalyst/testOnly *CastSuite *CastWithAnsiOnSuite
*CastWithAnsiOffSuite *AnsiCastSuite *TryCastSuite"
```
307/307 pass.
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Cursor 1.x
Closes #55935 from gengliangwang/SPARK-56910-cast-byte-short.
Authored-by: Gengliang Wang <[email protected]>
Signed-off-by: Gengliang Wang <[email protected]>
(cherry picked from commit 6165bb0c99f219bc49586548abe61e683857f546)
Signed-off-by: Gengliang Wang <[email protected]>
---
.../spark/sql/catalyst/expressions/CastUtils.java | 98 ++++++++++++++++++++++
.../spark/sql/catalyst/expressions/Cast.scala | 82 +++++++++---------
2 files changed, 142 insertions(+), 38 deletions(-)
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java
new file mode 100644
index 000000000000..700f7e41d233
--- /dev/null
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/CastUtils.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions;
+
+import org.apache.spark.sql.errors.QueryExecutionErrors;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+
+/**
+ * Static helpers used by {@code Cast.doGenCode} (and corresponding eval
+ * paths) for ANSI overflow-checked narrowing to {@code byte} / {@code short}.
+ *
+ * <p>Narrowing to {@code int} / {@code long} is handled by calling the
existing
+ * {@code LongExactNumeric} / {@code FloatExactNumeric} / {@code
DoubleExactNumeric}
+ * Scala objects directly from codegen (see SPARK-56909). The helpers below
+ * cover {@code byte} / {@code short} only, since {@code ByteExactNumeric} /
+ * {@code ShortExactNumeric} don't expose a cross-type narrowing API.
+ *
+ * <p>The source and target {@link DataType} objects referenced by the overflow
+ * error message are held in {@code private static final} fields so the happy
+ * path performs no per-row {@code references[]} lookups.
+ */
+public final class CastUtils {
+
+ private CastUtils() {}
+
+ private static final DataType SHORT = DataTypes.ShortType;
+ private static final DataType INT = DataTypes.IntegerType;
+ private static final DataType LONG = DataTypes.LongType;
+ private static final DataType BYTE = DataTypes.ByteType;
+ private static final DataType FLOAT = DataTypes.FloatType;
+ private static final DataType DOUBLE = DataTypes.DoubleType;
+
+ // ----- integral narrowing (ANSI: throw on overflow) -----
+
+ public static byte shortToByteExact(short v) {
+ if (v == (byte) v) return (byte) v;
+ throw QueryExecutionErrors.castingCauseOverflowError(v, SHORT, BYTE);
+ }
+
+ public static byte intToByteExact(int v) {
+ if (v == (byte) v) return (byte) v;
+ throw QueryExecutionErrors.castingCauseOverflowError(v, INT, BYTE);
+ }
+
+ public static byte longToByteExact(long v) {
+ if (v == (byte) v) return (byte) v;
+ throw QueryExecutionErrors.castingCauseOverflowError(v, LONG, BYTE);
+ }
+
+ public static short intToShortExact(int v) {
+ if (v == (short) v) return (short) v;
+ throw QueryExecutionErrors.castingCauseOverflowError(v, INT, SHORT);
+ }
+
+ public static short longToShortExact(long v) {
+ if (v == (short) v) return (short) v;
+ throw QueryExecutionErrors.castingCauseOverflowError(v, LONG, SHORT);
+ }
+
+ // ----- fractional -> integral (ANSI: throw on overflow) -----
+ // Mirrors castFractionToIntegralTypeCode: floor(v) <= MAX && ceil(v) >= MIN.
+
+ public static byte floatToByteExact(float v) {
+ if (Math.floor(v) <= Byte.MAX_VALUE && Math.ceil(v) >= Byte.MIN_VALUE)
return (byte) v;
+ throw QueryExecutionErrors.castingCauseOverflowError(v, FLOAT, BYTE);
+ }
+
+ public static byte doubleToByteExact(double v) {
+ if (Math.floor(v) <= Byte.MAX_VALUE && Math.ceil(v) >= Byte.MIN_VALUE)
return (byte) v;
+ throw QueryExecutionErrors.castingCauseOverflowError(v, DOUBLE, BYTE);
+ }
+
+ public static short floatToShortExact(float v) {
+ if (Math.floor(v) <= Short.MAX_VALUE && Math.ceil(v) >= Short.MIN_VALUE)
return (short) v;
+ throw QueryExecutionErrors.castingCauseOverflowError(v, FLOAT, SHORT);
+ }
+
+ public static short doubleToShortExact(double v) {
+ if (Math.floor(v) <= Short.MAX_VALUE && Math.ceil(v) >= Short.MIN_VALUE)
return (short) v;
+ throw QueryExecutionErrors.castingCauseOverflowError(v, DOUBLE, SHORT);
+ }
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 419ca3f32d88..0611c3e9bfb3 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -984,6 +984,14 @@ case class Cast(
errorOrNull(t, from, ShortType)
}
})
+ case IntegerType if ansiEnabled =>
+ b => CastUtils.intToShortExact(b.asInstanceOf[Int])
+ case LongType if ansiEnabled =>
+ b => CastUtils.longToShortExact(b.asInstanceOf[Long])
+ case FloatType if ansiEnabled =>
+ b => CastUtils.floatToShortExact(b.asInstanceOf[Float])
+ case DoubleType if ansiEnabled =>
+ b => CastUtils.doubleToShortExact(b.asInstanceOf[Double])
case x: NumericType if ansiEnabled =>
val exactNumeric = PhysicalNumericType.exactNumeric(x)
b =>
@@ -1040,6 +1048,16 @@ case class Cast(
errorOrNull(t, from, ByteType)
}
})
+ case ShortType if ansiEnabled =>
+ b => CastUtils.shortToByteExact(b.asInstanceOf[Short])
+ case IntegerType if ansiEnabled =>
+ b => CastUtils.intToByteExact(b.asInstanceOf[Int])
+ case LongType if ansiEnabled =>
+ b => CastUtils.longToByteExact(b.asInstanceOf[Long])
+ case FloatType if ansiEnabled =>
+ b => CastUtils.floatToByteExact(b.asInstanceOf[Float])
+ case DoubleType if ansiEnabled =>
+ b => CastUtils.doubleToByteExact(b.asInstanceOf[Double])
case x: NumericType if ansiEnabled =>
val exactNumeric = PhysicalNumericType.exactNumeric(x)
b =>
@@ -1999,28 +2017,13 @@ case class Cast(
}).getClass.getCanonicalName.stripSuffix("$")
(c, evPrim, _) => code"$evPrim = $numericObj.toInt($c);"
} else {
- val fromDt = ctx.addReferenceObj("from", from, from.getClass.getName)
- val toDt = ctx.addReferenceObj("to", to, to.getClass.getName)
- (c, evPrim, _) =>
- code"""
- if ($c == ($integralType) $c) {
- $evPrim = ($integralType) $c;
- } else {
- throw QueryExecutionErrors.castingCauseOverflowError($c, $fromDt,
$toDt);
- }
- """
- }
- }
-
-
- private[this] def lowerAndUpperBound(integralType: String): (String, String)
= {
- val (min, max, typeIndicator) = integralType.toLowerCase(Locale.ROOT)
match {
- case "long" => (Long.MinValue, Long.MaxValue, "L")
- case "int" => (Int.MinValue, Int.MaxValue, "")
- case "short" => (Short.MinValue, Short.MaxValue, "")
- case "byte" => (Byte.MinValue, Byte.MaxValue, "")
+ // Byte / short narrowing: call the matching CastUtils helper. Existing
*ExactNumeric
+ // objects don't expose cross-type narrowing to byte / short (their
toByte / toShort are
+ // same-type identities), so a Java helper is the cleanest fit.
+ val castUtils = classOf[CastUtils].getName
+ val method = s"${integralPrefix(from)}To${integralType.capitalize}Exact"
+ (c, evPrim, _) => code"$evPrim = $castUtils.$method($c);"
}
- (min.toString + typeIndicator, max.toString + typeIndicator)
}
private[this] def castFractionToIntegralTypeCode(
@@ -2042,26 +2045,29 @@ case class Cast(
val method = s"to${integralType.capitalize}"
(c, evPrim, _) => code"$evPrim = $numericObj.$method($c);"
} else {
- val (min, max) = lowerAndUpperBound(integralType)
- val mathClass = classOf[Math].getName
- val fromDt = ctx.addReferenceObj("from", from, from.getClass.getName)
- val toDt = ctx.addReferenceObj("to", to, to.getClass.getName)
- // When casting floating values to integral types, Spark uses the method
`Numeric.toInt`
- // Or `Numeric.toLong` directly. For positive floating values, it is
equivalent to
- // `Math.floor`; for negative floating values, it is equivalent to
`Math.ceil`.
- // So, we can use the condition `Math.floor(x) <= upperBound &&
Math.ceil(x) >= lowerBound`
- // to check if the floating value x is in the range of an integral type
after rounding.
- (c, evPrim, _) =>
- code"""
- if ($mathClass.floor($c) <= $max && $mathClass.ceil($c) >= $min) {
- $evPrim = ($integralType) $c;
- } else {
- throw QueryExecutionErrors.castingCauseOverflowError($c, $fromDt,
$toDt);
- }
- """
+ // Float / double -> byte / short: same rationale as the integral byte /
short branch
+ // above -- no equivalent *ExactNumeric API, so route through CastUtils.
+ val castUtils = classOf[CastUtils].getName
+ val method =
s"${fractionalPrefix(from)}To${integralType.capitalize}Exact"
+ (c, evPrim, _) => code"$evPrim = $castUtils.$method($c);"
}
}
+ private[this] def integralPrefix(from: DataType): String = from match {
+ case ShortType => "short"
+ case IntegerType => "int"
+ case LongType => "long"
+ case _ => throw SparkException.internalError(
+ s"Unexpected source type $from for
castIntegralTypeToIntegralTypeExactCode")
+ }
+
+ private[this] def fractionalPrefix(from: DataType): String = from match {
+ case FloatType => "float"
+ case DoubleType => "double"
+ case _ => throw SparkException.internalError(
+ s"Unexpected source type $from for castFractionToIntegralTypeCode")
+ }
+
private[this] def castToByteCode(from: DataType, ctx: CodegenContext):
CastFunction = from match {
case _: StringType if ansiEnabled =>
val stringUtils =
UTF8StringUtils.getClass.getCanonicalName.stripSuffix("$")
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]