This is an automated email from the ASF dual-hosted git repository.

gengliangwang pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.x by this push:
     new 1c6ed48edf4d [SPARK-57181][SQL] Simplify Pmod codegen by sharing a 
MathUtils.pmod helper with eval
1c6ed48edf4d is described below

commit 1c6ed48edf4d310fcbad2d9b0295086d83611a43
Author: Gengliang Wang <[email protected]>
AuthorDate: Tue Jun 2 13:44:13 2026 -0700

    [SPARK-57181][SQL] Simplify Pmod codegen by sharing a MathUtils.pmod helper 
with eval
    
    ### What changes were proposed in this pull request?
    
    `Pmod.doGenCode` emitted the positive-modulo `remainder`/adjust block 
inline, once for byte/short and once for the int/long/float/double case (~6-8 
lines each), duplicating the algorithm already implemented by `Pmod`'s private 
`pmod` eval methods. This adds `MathUtils.pmod` overloads (Int, Long, Byte, 
Short, Float, Double) -- the exact bodies moved out of `Pmod` -- and routes 
both the eval dispatch (`pmodFunc`) and codegen through them. The primitive 
codegen cases collapse to a single [...]
    
    ### Why are the changes needed?
    
    Part of SPARK-56908 (umbrella). `Pmod` over IntegerType is emitted by every 
`HashPartitioning` (`Pmod(Murmur3Hash(...), numPartitions)`), so collapsing the 
inline block shrinks the generated Java on a very common path, and the eval and 
codegen paths now share one implementation instead of duplicating the algorithm 
(helping with the JVM 64KB method / constant-pool limits, Janino compile time, 
and JIT work).
    
    ### Does this PR introduce _any_ user-facing change?
    
    No. The compiled behavior is identical; only the emitted Java source text 
changes.
    
    ### How was this patch tested?
    
    Existing `ArithmeticExpressionSuite."pmod"` covers all numeric types, 
negative operands / divisors, mod-by-zero (ANSI on/off), and 
`checkConsistencyBetweenInterpretedAndCodegenAllowingException` across all 
numeric types (which verifies eval and codegen agree -- exactly the invariant 
this refactor must preserve).
    
    ```
    build/sbt "catalyst/testOnly *ArithmeticExpressionSuite"
    ```
    
    35/35 pass.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Generated-by: Claude Code (Opus 4.8)
    
    Closes #56232 from gengliangwang/spark-pmod-codegen.
    
    Authored-by: Gengliang Wang <[email protected]>
    Signed-off-by: Gengliang Wang <[email protected]>
    (cherry picked from commit 576070aeb9ddab611609116e7753c2fda9b0de7c)
    Signed-off-by: Gengliang Wang <[email protected]>
---
 .../apache/spark/sql/catalyst/util/MathUtils.scala | 34 +++++++++++
 .../sql/catalyst/expressions/arithmetic.scala      | 66 ++++------------------
 2 files changed, 46 insertions(+), 54 deletions(-)

diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala
index b113bccc74df..a4c6c75358f4 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala
@@ -89,6 +89,40 @@ object MathUtils {
 
   def floorMod(a: Long, b: Long): Long = withOverflow(Math.floorMod(a, b))
 
+  // Positive modulo (`pmod`): the remainder `a % n` adjusted to share the 
sign of `n`.
+  // Unlike `floorMod`, this matches the `pmod` SQL function / 
`HashPartitioning` semantics.
+  // Shared by `Pmod`'s eval and codegen paths so the two never diverge.
+
+  def pmod(a: Int, n: Int): Int = {
+    val r = a % n
+    if (r < 0) (r + n) % n else r
+  }
+
+  def pmod(a: Long, n: Long): Long = {
+    val r = a % n
+    if (r < 0) (r + n) % n else r
+  }
+
+  def pmod(a: Byte, n: Byte): Byte = {
+    val r = a % n
+    if (r < 0) ((r + n) % n).toByte else r.toByte
+  }
+
+  def pmod(a: Short, n: Short): Short = {
+    val r = a % n
+    if (r < 0) ((r + n) % n).toShort else r.toShort
+  }
+
+  def pmod(a: Float, n: Float): Float = {
+    val r = a % n
+    if (r < 0) (r + n) % n else r
+  }
+
+  def pmod(a: Double, n: Double): Double = {
+    val r = a % n
+    if (r < 0) (r + n) % n else r
+  }
+
   def withOverflow[A](f: => A, hint: String = "", context: QueryContext = 
null): A = {
     try {
       f
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 23fffb162a8f..cb7ba16aeb81 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -1080,12 +1080,12 @@ case class Pmod(
   }
 
   private lazy val pmodFunc: (Any, Any) => Any = dataType match {
-    case _: IntegerType => (l, r) => pmod(l.asInstanceOf[Int], 
r.asInstanceOf[Int])
-    case _: LongType => (l, r) => pmod(l.asInstanceOf[Long], 
r.asInstanceOf[Long])
-    case _: ShortType => (l, r) => pmod(l.asInstanceOf[Short], 
r.asInstanceOf[Short])
-    case _: ByteType => (l, r) => pmod(l.asInstanceOf[Byte], 
r.asInstanceOf[Byte])
-    case _: FloatType => (l, r) => pmod(l.asInstanceOf[Float], 
r.asInstanceOf[Float])
-    case _: DoubleType => (l, r) => pmod(l.asInstanceOf[Double], 
r.asInstanceOf[Double])
+    case _: IntegerType => (l, r) => MathUtils.pmod(l.asInstanceOf[Int], 
r.asInstanceOf[Int])
+    case _: LongType => (l, r) => MathUtils.pmod(l.asInstanceOf[Long], 
r.asInstanceOf[Long])
+    case _: ShortType => (l, r) => MathUtils.pmod(l.asInstanceOf[Short], 
r.asInstanceOf[Short])
+    case _: ByteType => (l, r) => MathUtils.pmod(l.asInstanceOf[Byte], 
r.asInstanceOf[Byte])
+    case _: FloatType => (l, r) => MathUtils.pmod(l.asInstanceOf[Float], 
r.asInstanceOf[Float])
+    case _: DoubleType => (l, r) => MathUtils.pmod(l.asInstanceOf[Double], 
r.asInstanceOf[Double])
     case DecimalType.Fixed(precision, scale) => (l, r) => checkDecimalOverflow(
       pmod(l.asInstanceOf[Decimal], r.asInstanceOf[Decimal]), precision, scale)
   }
@@ -1120,6 +1120,7 @@ case class Pmod(
     val remainder = ctx.freshName("remainder")
     val javaType = CodeGenerator.javaType(dataType)
     val errorContext = getContextOrNullCode(ctx)
+    val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$")
     val result = dataType match {
       case DecimalType.Fixed(precision, scale) =>
         val decimalAdd = "$plus"
@@ -1135,25 +1136,12 @@ case class Pmod(
            |${ev.isNull} = ${ev.value} == null;
            |""".stripMargin
 
-      // byte and short are casted into int when add, minus, times or divide
-      case ByteType | ShortType =>
-        s"""
-          $javaType $remainder = ($javaType)(${eval1.value} % ${eval2.value});
-          if ($remainder < 0) {
-            ${ev.value}=($javaType)(($remainder + ${eval2.value}) % 
${eval2.value});
-          } else {
-            ${ev.value}=$remainder;
-          }
-        """
+      // The positive-modulo arithmetic is the same fixed algorithm for every 
primitive numeric
+      // type, so delegate to the shared MathUtils.pmod helper (also used by 
the eval path) instead
+      // of emitting the remainder/adjust block inline. byte/short are widened 
to int by `%`, and
+      // the matching MathUtils.pmod overload narrows the result back.
       case _ =>
-        s"""
-          $javaType $remainder = ${eval1.value} % ${eval2.value};
-          if ($remainder < 0) {
-            ${ev.value}=($remainder + ${eval2.value}) % ${eval2.value};
-          } else {
-            ${ev.value}=$remainder;
-          }
-        """
+        s"${ev.value} = $mathUtils.pmod(${eval1.value}, ${eval2.value});"
     }
 
     // evaluate right first as we have a chance to skip left if right is 0
@@ -1198,36 +1186,6 @@ case class Pmod(
     }
   }
 
-  private def pmod(a: Int, n: Int): Int = {
-    val r = a % n
-    if (r < 0) {(r + n) % n} else r
-  }
-
-  private def pmod(a: Long, n: Long): Long = {
-    val r = a % n
-    if (r < 0) {(r + n) % n} else r
-  }
-
-  private def pmod(a: Byte, n: Byte): Byte = {
-    val r = a % n
-    if (r < 0) {((r + n) % n).toByte} else r.toByte
-  }
-
-  private def pmod(a: Double, n: Double): Double = {
-    val r = a % n
-    if (r < 0) {(r + n) % n} else r
-  }
-
-  private def pmod(a: Short, n: Short): Short = {
-    val r = a % n
-    if (r < 0) {((r + n) % n).toShort} else r.toShort
-  }
-
-  private def pmod(a: Float, n: Float): Float = {
-    val r = a % n
-    if (r < 0) {(r + n) % n} else r
-  }
-
   private def pmod(a: Decimal, n: Decimal): Decimal = {
     val r = a % n
     if (r != null && r.compare(Decimal.ZERO) < 0) {(r + n) % n} else r


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to