Repository: spark
Updated Branches:
  refs/heads/master fa4ec3606 -> a9385271a


[SPARK-8221][SQL]Add pmod function

https://issues.apache.org/jira/browse/SPARK-8221

One concern is the result would be negative if the divisor is not positive( i.e 
pmod(7, -3) ), but the behavior is the same as hive.

Author: zhichao.li <[email protected]>

Closes #6783 from zhichao-li/pmod2 and squashes the following commits:

7083eb9 [zhichao.li] update to the latest type checking
d26dba7 [zhichao.li] add pmod


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a9385271
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a9385271
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a9385271

Branch: refs/heads/master
Commit: a9385271a9f6b97ec6aa619cf56ee556ba2fb0de
Parents: fa4ec36
Author: zhichao.li <[email protected]>
Authored: Wed Jul 15 10:43:38 2015 -0700
Committer: Reynold Xin <[email protected]>
Committed: Wed Jul 15 10:43:38 2015 -0700

----------------------------------------------------------------------
 .../catalyst/analysis/FunctionRegistry.scala    |  1 +
 .../catalyst/analysis/HiveTypeCoercion.scala    |  6 ++
 .../sql/catalyst/expressions/arithmetic.scala   | 94 ++++++++++++++++++++
 .../expressions/ArithmeticExpressionSuite.scala | 16 +++-
 .../scala/org/apache/spark/sql/functions.scala  | 17 ++++
 .../spark/sql/DataFrameFunctionsSuite.scala     | 37 ++++++++
 6 files changed, 170 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a9385271/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index ec75f51..d2678ce 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -115,6 +115,7 @@ object FunctionRegistry {
     expression[Log2]("log2"),
     expression[Pow]("pow"),
     expression[Pow]("power"),
+    expression[Pmod]("pmod"),
     expression[UnaryPositive]("positive"),
     expression[Rint]("rint"),
     expression[Round]("round"),

http://git-wip-us.apache.org/repos/asf/spark/blob/a9385271/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 15da5ee..2508791 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -426,6 +426,12 @@ object HiveTypeCoercion {
             DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
           )
 
+        case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ 
DecimalType.Expression(p2, s2)) =>
+          Cast(
+            Pmod(Cast(e1, DecimalType.Unlimited), Cast(e2, 
DecimalType.Unlimited)),
+            DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
+          )
+
         // When we compare 2 decimal types with different precisions, cast 
them to the smallest
         // common precision.
         case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1),

http://git-wip-us.apache.org/repos/asf/spark/blob/a9385271/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 1a55a08..394ef55 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -377,3 +377,97 @@ case class MinOf(left: Expression, right: Expression) 
extends BinaryArithmetic {
   override def symbol: String = "min"
   override def prettyName: String = symbol
 }
+
+case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
+
+  override def toString: String = s"pmod($left, $right)"
+
+  override def symbol: String = "pmod"
+
+  protected def checkTypesInternal(t: DataType) =
+    TypeUtils.checkForNumericExpr(t, "pmod")
+
+  override def inputType: AbstractDataType = NumericType
+
+  protected override def nullSafeEval(left: Any, right: Any) =
+    dataType match {
+      case IntegerType => pmod(left.asInstanceOf[Int], right.asInstanceOf[Int])
+      case LongType => pmod(left.asInstanceOf[Long], right.asInstanceOf[Long])
+      case ShortType => pmod(left.asInstanceOf[Short], 
right.asInstanceOf[Short])
+      case ByteType => pmod(left.asInstanceOf[Byte], right.asInstanceOf[Byte])
+      case FloatType => pmod(left.asInstanceOf[Float], 
right.asInstanceOf[Float])
+      case DoubleType => pmod(left.asInstanceOf[Double], 
right.asInstanceOf[Double])
+      case _: DecimalType => pmod(left.asInstanceOf[Decimal], 
right.asInstanceOf[Decimal])
+    }
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
+    nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
+      dataType match {
+        case dt: DecimalType =>
+          val decimalAdd = "$plus"
+          s"""
+            ${ctx.javaType(dataType)} r = $eval1.remainder($eval2);
+            if (r.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 
0) {
+              ${ev.primitive} = (r.$decimalAdd($eval2)).remainder($eval2);
+            } else {
+              ${ev.primitive} = r;
+            }
+          """
+        // byte and short are casted into int when add, minus, times or divide
+        case ByteType | ShortType =>
+          s"""
+            ${ctx.javaType(dataType)} r = (${ctx.javaType(dataType)})($eval1 % 
$eval2);
+            if (r < 0) {
+              ${ev.primitive} = (${ctx.javaType(dataType)})((r + $eval2) % 
$eval2);
+            } else {
+              ${ev.primitive} = r;
+            }
+          """
+        case _ =>
+          s"""
+            ${ctx.javaType(dataType)} r = $eval1 % $eval2;
+            if (r < 0) {
+              ${ev.primitive} = (r + $eval2) % $eval2;
+            } else {
+              ${ev.primitive} = r;
+            }
+          """
+      }
+    })
+  }
+
+  private def pmod(a: Int, n: Int): Int = {
+    val r = a % n
+    if (r < 0) {(r + n) % n} else r
+  }
+
+  private def pmod(a: Long, n: Long): Long = {
+    val r = a % n
+    if (r < 0) {(r + n) % n} else r
+  }
+
+  private def pmod(a: Byte, n: Byte): Byte = {
+    val r = a % n
+    if (r < 0) {((r + n) % n).toByte} else r.toByte
+  }
+
+  private def pmod(a: Double, n: Double): Double = {
+    val r = a % n
+    if (r < 0) {(r + n) % n} else r
+  }
+
+  private def pmod(a: Short, n: Short): Short = {
+    val r = a % n
+    if (r < 0) {((r + n) % n).toShort} else r.toShort
+  }
+
+  private def pmod(a: Float, n: Float): Float = {
+    val r = a % n
+    if (r < 0) {(r + n) % n} else r
+  }
+
+  private def pmod(a: Decimal, n: Decimal): Decimal = {
+    val r = a % n
+    if (r.compare(Decimal(0)) < 0) {(r + n) % n} else r
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a9385271/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
index 6c93698..e7e5231 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
@@ -21,7 +21,6 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.types.Decimal
 
-
 class ArithmeticExpressionSuite extends SparkFunSuite with 
ExpressionEvalHelper {
 
   /**
@@ -158,4 +157,19 @@ class ArithmeticExpressionSuite extends SparkFunSuite with 
ExpressionEvalHelper
     checkEvaluation(MinOf(Array(1.toByte, 2.toByte), Array(1.toByte, 
3.toByte)),
       Array(1.toByte, 2.toByte))
   }
+
+  test("pmod") {
+    testNumericDataTypes { convert =>
+      val left = Literal(convert(7))
+      val right = Literal(convert(3))
+      checkEvaluation(Pmod(left, right), convert(1))
+      checkEvaluation(Pmod(Literal.create(null, left.dataType), right), null)
+      checkEvaluation(Pmod(left, Literal.create(null, right.dataType)), null)
+      checkEvaluation(Remainder(left, Literal(convert(0))), null)  // mod by 0
+    }
+    checkEvaluation(Pmod(-7, 3), 2)
+    checkEvaluation(Pmod(7.2D, 4.1D), 3.1000000000000005)
+    checkEvaluation(Pmod(Decimal(0.7), Decimal(0.2)), Decimal(0.1))
+    checkEvaluation(Pmod(2L, Long.MaxValue), 2)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a9385271/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 5119ee3..c7deaca 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1372,6 +1372,23 @@ object functions {
   def pow(l: Double, rightName: String): Column = pow(l, Column(rightName))
 
   /**
+   * Returns the positive value of dividend mod divisor.
+   *
+   * @group math_funcs
+   * @since 1.5.0
+   */
+  def pmod(dividend: Column, divisor: Column): Column = Pmod(dividend.expr, 
divisor.expr)
+
+  /**
+   * Returns the positive value of dividend mod divisor.
+   *
+   * @group math_funcs
+   * @since 1.5.0
+   */
+  def pmod(dividendColName: String, divisorColName: String): Column =
+    pmod(Column(dividendColName), Column(divisorColName))
+
+  /**
    * Returns the double value that is closest in value to the argument and
    * is equal to a mathematical integer.
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/a9385271/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 6cebec9..70bd787 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -403,4 +403,41 @@ class DataFrameFunctionsSuite extends QueryTest {
       Seq(Row(2), Row(2), Row(2), Row(2), Row(3), Row(3))
     )
   }
+
+  test("pmod") {
+    val intData = Seq((7, 3), (-7, 3)).toDF("a", "b")
+    checkAnswer(
+      intData.select(pmod('a, 'b)),
+      Seq(Row(1), Row(2))
+    )
+    checkAnswer(
+      intData.select(pmod('a, lit(3))),
+      Seq(Row(1), Row(2))
+    )
+    checkAnswer(
+      intData.select(pmod(lit(-7), 'b)),
+      Seq(Row(2), Row(2))
+    )
+    checkAnswer(
+      intData.selectExpr("pmod(a, b)"),
+      Seq(Row(1), Row(2))
+    )
+    checkAnswer(
+      intData.selectExpr("pmod(a, 3)"),
+      Seq(Row(1), Row(2))
+    )
+    checkAnswer(
+      intData.selectExpr("pmod(-7, b)"),
+      Seq(Row(2), Row(2))
+    )
+    val doubleData = Seq((7.2, 4.1)).toDF("a", "b")
+    checkAnswer(
+      doubleData.select(pmod('a, 'b)),
+      Seq(Row(3.1000000000000005))  // same as hive
+    )
+    checkAnswer(
+      doubleData.select(pmod(lit(2), lit(Int.MaxValue))),
+      Seq(Row(2))
+    )
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to