This is an automated email from the ASF dual-hosted git repository.
gengliangwang pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.x by this push:
new 1c6ed48edf4d [SPARK-57181][SQL] Simplify Pmod codegen by sharing a
MathUtils.pmod helper with eval
1c6ed48edf4d is described below
commit 1c6ed48edf4d310fcbad2d9b0295086d83611a43
Author: Gengliang Wang <[email protected]>
AuthorDate: Tue Jun 2 13:44:13 2026 -0700
[SPARK-57181][SQL] Simplify Pmod codegen by sharing a MathUtils.pmod helper
with eval
### What changes were proposed in this pull request?
`Pmod.doGenCode` emitted the positive-modulo `remainder`/adjust block
inline, once for byte/short and once for the int/long/float/double case (~6-8
lines each), duplicating the algorithm already implemented by `Pmod`'s private
`pmod` eval methods. This adds `MathUtils.pmod` overloads (Int, Long, Byte,
Short, Float, Double) -- the exact bodies moved out of `Pmod` -- and routes
both the eval dispatch (`pmodFunc`) and codegen through them. The primitive
codegen cases collapse to a single [...]
### Why are the changes needed?
Part of SPARK-56908 (umbrella). `Pmod` over IntegerType is emitted by every
`HashPartitioning` (`Pmod(Murmur3Hash(...), numPartitions)`), so collapsing the
inline block shrinks the generated Java on a very common path, and the eval and
codegen paths now share one implementation instead of duplicating the algorithm
(helping with the JVM 64KB method / constant-pool limits, Janino compile time,
and JIT work).
### Does this PR introduce _any_ user-facing change?
No. The compiled behavior is identical; only the emitted Java source text
changes.
### How was this patch tested?
Existing `ArithmeticExpressionSuite."pmod"` covers all numeric types,
negative operands / divisors, mod-by-zero (ANSI on/off), and
`checkConsistencyBetweenInterpretedAndCodegenAllowingException` across all
numeric types (which verifies eval and codegen agree -- exactly the invariant
this refactor must preserve).
```
build/sbt "catalyst/testOnly *ArithmeticExpressionSuite"
```
35/35 pass.
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Claude Code (Opus 4.8)
Closes #56232 from gengliangwang/spark-pmod-codegen.
Authored-by: Gengliang Wang <[email protected]>
Signed-off-by: Gengliang Wang <[email protected]>
(cherry picked from commit 576070aeb9ddab611609116e7753c2fda9b0de7c)
Signed-off-by: Gengliang Wang <[email protected]>
---
.../apache/spark/sql/catalyst/util/MathUtils.scala | 34 +++++++++++
.../sql/catalyst/expressions/arithmetic.scala | 66 ++++------------------
2 files changed, 46 insertions(+), 54 deletions(-)
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala
index b113bccc74df..a4c6c75358f4 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala
@@ -89,6 +89,40 @@ object MathUtils {
def floorMod(a: Long, b: Long): Long = withOverflow(Math.floorMod(a, b))
+ // Positive modulo (`pmod`): the remainder `a % n` adjusted to share the
sign of `n`.
+ // Unlike `floorMod`, this matches the `pmod` SQL function /
`HashPartitioning` semantics.
+ // Shared by `Pmod`'s eval and codegen paths so the two never diverge.
+
+ def pmod(a: Int, n: Int): Int = {
+ val r = a % n
+ if (r < 0) (r + n) % n else r
+ }
+
+ def pmod(a: Long, n: Long): Long = {
+ val r = a % n
+ if (r < 0) (r + n) % n else r
+ }
+
+ def pmod(a: Byte, n: Byte): Byte = {
+ val r = a % n
+ if (r < 0) ((r + n) % n).toByte else r.toByte
+ }
+
+ def pmod(a: Short, n: Short): Short = {
+ val r = a % n
+ if (r < 0) ((r + n) % n).toShort else r.toShort
+ }
+
+ def pmod(a: Float, n: Float): Float = {
+ val r = a % n
+ if (r < 0) (r + n) % n else r
+ }
+
+ def pmod(a: Double, n: Double): Double = {
+ val r = a % n
+ if (r < 0) (r + n) % n else r
+ }
+
def withOverflow[A](f: => A, hint: String = "", context: QueryContext =
null): A = {
try {
f
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 23fffb162a8f..cb7ba16aeb81 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -1080,12 +1080,12 @@ case class Pmod(
}
private lazy val pmodFunc: (Any, Any) => Any = dataType match {
- case _: IntegerType => (l, r) => pmod(l.asInstanceOf[Int],
r.asInstanceOf[Int])
- case _: LongType => (l, r) => pmod(l.asInstanceOf[Long],
r.asInstanceOf[Long])
- case _: ShortType => (l, r) => pmod(l.asInstanceOf[Short],
r.asInstanceOf[Short])
- case _: ByteType => (l, r) => pmod(l.asInstanceOf[Byte],
r.asInstanceOf[Byte])
- case _: FloatType => (l, r) => pmod(l.asInstanceOf[Float],
r.asInstanceOf[Float])
- case _: DoubleType => (l, r) => pmod(l.asInstanceOf[Double],
r.asInstanceOf[Double])
+ case _: IntegerType => (l, r) => MathUtils.pmod(l.asInstanceOf[Int],
r.asInstanceOf[Int])
+ case _: LongType => (l, r) => MathUtils.pmod(l.asInstanceOf[Long],
r.asInstanceOf[Long])
+ case _: ShortType => (l, r) => MathUtils.pmod(l.asInstanceOf[Short],
r.asInstanceOf[Short])
+ case _: ByteType => (l, r) => MathUtils.pmod(l.asInstanceOf[Byte],
r.asInstanceOf[Byte])
+ case _: FloatType => (l, r) => MathUtils.pmod(l.asInstanceOf[Float],
r.asInstanceOf[Float])
+ case _: DoubleType => (l, r) => MathUtils.pmod(l.asInstanceOf[Double],
r.asInstanceOf[Double])
case DecimalType.Fixed(precision, scale) => (l, r) => checkDecimalOverflow(
pmod(l.asInstanceOf[Decimal], r.asInstanceOf[Decimal]), precision, scale)
}
@@ -1120,6 +1120,7 @@ case class Pmod(
val remainder = ctx.freshName("remainder")
val javaType = CodeGenerator.javaType(dataType)
val errorContext = getContextOrNullCode(ctx)
+ val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$")
val result = dataType match {
case DecimalType.Fixed(precision, scale) =>
val decimalAdd = "$plus"
@@ -1135,25 +1136,12 @@ case class Pmod(
|${ev.isNull} = ${ev.value} == null;
|""".stripMargin
- // byte and short are casted into int when add, minus, times or divide
- case ByteType | ShortType =>
- s"""
- $javaType $remainder = ($javaType)(${eval1.value} % ${eval2.value});
- if ($remainder < 0) {
- ${ev.value}=($javaType)(($remainder + ${eval2.value}) %
${eval2.value});
- } else {
- ${ev.value}=$remainder;
- }
- """
+ // The positive-modulo arithmetic is the same fixed algorithm for every
primitive numeric
+ // type, so delegate to the shared MathUtils.pmod helper (also used by
the eval path) instead
+ // of emitting the remainder/adjust block inline. byte/short are widened
to int by `%`, and
+ // the matching MathUtils.pmod overload narrows the result back.
case _ =>
- s"""
- $javaType $remainder = ${eval1.value} % ${eval2.value};
- if ($remainder < 0) {
- ${ev.value}=($remainder + ${eval2.value}) % ${eval2.value};
- } else {
- ${ev.value}=$remainder;
- }
- """
+ s"${ev.value} = $mathUtils.pmod(${eval1.value}, ${eval2.value});"
}
// evaluate right first as we have a chance to skip left if right is 0
@@ -1198,36 +1186,6 @@ case class Pmod(
}
}
- private def pmod(a: Int, n: Int): Int = {
- val r = a % n
- if (r < 0) {(r + n) % n} else r
- }
-
- private def pmod(a: Long, n: Long): Long = {
- val r = a % n
- if (r < 0) {(r + n) % n} else r
- }
-
- private def pmod(a: Byte, n: Byte): Byte = {
- val r = a % n
- if (r < 0) {((r + n) % n).toByte} else r.toByte
- }
-
- private def pmod(a: Double, n: Double): Double = {
- val r = a % n
- if (r < 0) {(r + n) % n} else r
- }
-
- private def pmod(a: Short, n: Short): Short = {
- val r = a % n
- if (r < 0) {((r + n) % n).toShort} else r.toShort
- }
-
- private def pmod(a: Float, n: Float): Float = {
- val r = a % n
- if (r < 0) {(r + n) % n} else r
- }
-
private def pmod(a: Decimal, n: Decimal): Decimal = {
val r = a % n
if (r != null && r.compare(Decimal.ZERO) < 0) {(r + n) % n} else r
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]