This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f212c61  [SPARK-34868][SQL] Support divide an year-month interval by a 
numeric
f212c61 is described below

commit f212c61c435f74cf021e4e780ef9a20ff6ab8c90
Author: Max Gekk <max.g...@gmail.com>
AuthorDate: Fri Mar 26 05:56:56 2021 +0000

    [SPARK-34868][SQL] Support divide an year-month interval by a numeric
    
    ### What changes were proposed in this pull request?
    1. Add new expression `DivideYMInterval` which multiplies a 
`YearMonthIntervalType` expression by a `NumericType` expression including 
ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType.
    2. Extend binary arithmetic rules to support `year-month interval / 
numeric`.
    
    ### Why are the changes needed?
    To conform the ANSI SQL standard which requires such operation over 
year-month intervals:
    <img width="656" alt="Screenshot 2021-03-25 at 18 44 58" 
src="https://user-images.githubusercontent.com/1580697/112501559-68f07080-8d9a-11eb-8781-66e6631bb7ef.png";>
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    By running new tests:
    ```
    $ build/sbt "test:testOnly *IntervalExpressionsSuite"
    $ build/sbt "test:testOnly *ColumnExpressionSuite"
    ```
    
    Closes #31961 from MaxGekk/div-ym-interval-by-num.
    
    Authored-by: Max Gekk <max.g...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     |  1 +
 .../catalyst/expressions/intervalExpressions.scala | 54 +++++++++++++++++++++-
 .../expressions/IntervalExpressionsSuite.scala     | 33 +++++++++++++
 .../apache/spark/sql/ColumnExpressionSuite.scala   | 34 ++++++++++++++
 4 files changed, 120 insertions(+), 2 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 66546f8..fedf9ec 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -380,6 +380,7 @@ class Analyzer(override val catalogManager: CatalogManager)
         }
         case d @ Divide(l, r, f) if d.childrenResolved => (l.dataType, 
r.dataType) match {
           case (CalendarIntervalType, _) => DivideInterval(l, r, f)
+          case (YearMonthIntervalType, _) => DivideYMInterval(l, r)
           case _ => d
         }
       }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
index 8c64d23..78b3871 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
@@ -20,9 +20,9 @@ package org.apache.spark.sql.catalyst.expressions
 import java.math.RoundingMode
 import java.util.Locale
 
-import com.google.common.math.DoubleMath
+import com.google.common.math.{DoubleMath, IntMath, LongMath}
 
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, 
ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, 
CodeGenerator, ExprCode}
 import org.apache.spark.sql.catalyst.util.IntervalUtils
 import org.apache.spark.sql.catalyst.util.IntervalUtils._
 import org.apache.spark.sql.internal.SQLConf
@@ -341,3 +341,53 @@ case class MultiplyDTInterval(
 
   override def toString: String = s"($left * $right)"
 }
+
+// Divide an year-month interval by a numeric
+case class DivideYMInterval(
+    interval: Expression,
+    num: Expression)
+  extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant 
with Serializable {
+  override def left: Expression = interval
+  override def right: Expression = num
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(YearMonthIntervalType, 
NumericType)
+  override def dataType: DataType = YearMonthIntervalType
+
+  @transient
+  private lazy val evalFunc: (Int, Any) => Any = right.dataType match {
+    case LongType => (months: Int, num) =>
+      LongMath.divide(months, num.asInstanceOf[Long], 
RoundingMode.HALF_UP).toInt
+    case _: IntegralType => (months: Int, num) =>
+      IntMath.divide(months, num.asInstanceOf[Number].intValue(), 
RoundingMode.HALF_UP)
+    case _: DecimalType => (months: Int, num) =>
+      val decimalRes = ((new Decimal).set(months) / 
num.asInstanceOf[Decimal]).toJavaBigDecimal
+      decimalRes.setScale(0, java.math.RoundingMode.HALF_UP).intValueExact()
+    case _: FractionalType => (months: Int, num) =>
+      DoubleMath.roundToInt(months / num.asInstanceOf[Number].doubleValue(), 
RoundingMode.HALF_UP)
+  }
+
+  override def nullSafeEval(interval: Any, num: Any): Any = {
+    evalFunc(interval.asInstanceOf[Int], num)
+  }
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = 
right.dataType match {
+    case LongType =>
+      val math = classOf[LongMath].getName
+      val javaType = CodeGenerator.javaType(dataType)
+      defineCodeGen(ctx, ev, (m, n) =>
+        s"($javaType)($math.divide($m, $n, java.math.RoundingMode.HALF_UP))")
+    case _: IntegralType =>
+      val math = classOf[IntMath].getName
+      defineCodeGen(ctx, ev, (m, n) => s"$math.divide($m, $n, 
java.math.RoundingMode.HALF_UP)")
+    case _: DecimalType =>
+      defineCodeGen(ctx, ev, (m, n) =>
+        s"((new Decimal()).set($m).$$div($n)).toJavaBigDecimal()" +
+        ".setScale(0, java.math.RoundingMode.HALF_UP).intValueExact()")
+    case _: FractionalType =>
+      val math = classOf[DoubleMath].getName
+      defineCodeGen(ctx, ev, (m, n) =>
+        s"$math.roundToInt($m / (double)$n, java.math.RoundingMode.HALF_UP)")
+  }
+
+  override def toString: String = s"($left / $right)"
+}
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala
index bc9a50f..6971b08 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala
@@ -344,4 +344,37 @@ class IntervalExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
         DayTimeIntervalType, numType)
     }
   }
+
+  test("SPARK-34868: divide year-month interval by numeric") {
+    Seq(
+      (Period.ofYears(-123), Literal(null, DecimalType.USER_DEFAULT)) -> null,
+      (Period.ofMonths(0), 10) -> Period.ofMonths(0),
+      (Period.ofMonths(200), Double.PositiveInfinity) -> Period.ofMonths(0),
+      (Period.ofMonths(-200), Float.NegativeInfinity) -> Period.ofMonths(0),
+      (Period.ofYears(100), -1.toByte) -> Period.ofYears(-100),
+      (Period.ofYears(1), 2.toShort) -> Period.ofMonths(6),
+      (Period.ofYears(-1), -3) -> Period.ofMonths(4),
+      (Period.ofMonths(-1000), 0.5f) -> Period.ofMonths(-2000),
+      (Period.ofYears(1000), 100d) -> Period.ofYears(10),
+      (Period.ofMonths(2), BigDecimal(0.1)) -> Period.ofMonths(20)
+    ).foreach { case ((period, num), expected) =>
+      checkEvaluation(DivideYMInterval(Literal(period), Literal(num)), 
expected)
+    }
+
+    Seq(
+      (Period.ofMonths(1), 0) -> "/ by zero",
+      (Period.ofMonths(Int.MinValue), 0d) -> "input is infinite or NaN",
+      (Period.ofMonths(-100), Float.NaN) -> "input is infinite or NaN"
+    ).foreach { case ((period, num), expectedErrMsg) =>
+      checkExceptionInExpression[ArithmeticException](
+        DivideYMInterval(Literal(period), Literal(num)),
+        expectedErrMsg)
+    }
+
+    numericTypes.foreach { numType =>
+      checkConsistencyBetweenInterpretedAndCodegenAllowingException(
+        (interval: Expression, num: Expression) => DivideYMInterval(interval, 
num),
+        YearMonthIntervalType, numType)
+    }
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 60044ad..8c57b7f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -2647,4 +2647,38 @@ class ColumnExpressionSuite extends QueryTest with 
SharedSparkSession {
     assert(e.isInstanceOf[ArithmeticException])
     assert(e.getMessage.contains("overflow"))
   }
+
+  test("SPARK-34868: divide year-month interval by numeric") {
+    checkAnswer(
+      Seq((Period.ofYears(0), 10.toByte)).toDF("i", "n").select($"i" / $"n"),
+      Row(Period.ofYears(0)))
+    checkAnswer(
+      Seq((Period.ofYears(10), 3.toShort)).toDF("i", "n").select($"i" / $"n"),
+      Row(Period.ofYears(3).plusMonths(4)))
+    checkAnswer(
+      Seq((Period.ofYears(1000), "2")).toDF("i", "n").select($"i" / $"n"),
+      Row(Period.ofYears(500)))
+    checkAnswer(
+      Seq((Period.ofMonths(1).multipliedBy(Int.MaxValue), Int.MaxValue))
+        .toDF("i", "n").select($"i" / $"n"),
+      Row(Period.ofMonths(1)))
+    checkAnswer(
+      Seq((Period.ofYears(-1), 12L)).toDF("i", "n").select($"i" / $"n"),
+      Row(Period.ofMonths(-1)))
+    checkAnswer(
+      Seq((Period.ofMonths(-1), 0.499f)).toDF("i", "n").select($"i" / $"n"),
+      Row(Period.ofMonths(-2)))
+    checkAnswer(
+      Seq((Period.ofMonths(10000000), 10000000d)).toDF("i", "n").select($"i" / 
$"n"),
+      Row(Period.ofMonths(1)))
+    checkAnswer(
+      Seq((Period.ofMonths(-1), BigDecimal(0.5))).toDF("i", "n").select($"i" / 
$"n"),
+      Row(Period.ofMonths(-2)))
+
+    val e = intercept[SparkException] {
+      Seq((Period.ofYears(9999), 0)).toDF("i", "n").select($"i" / 
$"n").collect()
+    }.getCause
+    assert(e.isInstanceOf[ArithmeticException])
+    assert(e.getMessage.contains("/ by zero"))
+  }
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to