Repository: spark
Updated Branches:
  refs/heads/master f94f3624e -> 649888415


[SPARK-23898][SQL] Simplify add & subtract code generation

## What changes were proposed in this pull request?
Code generation for the `Add` and `Subtract` expressions was not done using the 
`BinaryArithmetic.doCodeGen` method because these expressions also support 
`CalendarInterval`. This leads to a bit of duplication.

This PR gets rid of that duplication by adding `calendarIntervalMethod` to 
`BinaryArithmetic` and doing the code generation for `CalendarInterval` in 
`BinaryArithmetic` instead.

## How was this patch tested?
Existing tests.

Author: Herman van Hovell <hvanhov...@databricks.com>

Closes #21005 from hvanhovell/SPARK-23898.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/64988841
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/64988841
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/64988841

Branch: refs/heads/master
Commit: 64988841540464e261b0cbaede43058e7bd36261
Parents: f94f362
Author: Herman van Hovell <hvanhov...@databricks.com>
Authored: Mon Apr 9 21:49:49 2018 -0700
Committer: gatorsmile <gatorsm...@gmail.com>
Committed: Mon Apr 9 21:49:49 2018 -0700

----------------------------------------------------------------------
 .../sql/catalyst/expressions/arithmetic.scala   | 50 ++++++++------------
 1 file changed, 20 insertions(+), 30 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/64988841/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 478ff3a..defd6f3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -43,7 +43,7 @@ case class UnaryMinus(child: Expression) extends 
UnaryExpression
   private lazy val numeric = TypeUtils.getNumeric(dataType)
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = 
dataType match {
-    case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()")
+    case _: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()")
     case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => {
       val originValue = ctx.freshName("origin")
       // codegen would fail to compile if we just write (-($c))
@@ -52,7 +52,7 @@ case class UnaryMinus(child: Expression) extends 
UnaryExpression
         ${CodeGenerator.javaType(dt)} $originValue = 
(${CodeGenerator.javaType(dt)})($eval);
         ${ev.value} = (${CodeGenerator.javaType(dt)})(-($originValue));
       """})
-    case dt: CalendarIntervalType => defineCodeGen(ctx, ev, c => 
s"$c.negate()")
+    case _: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()")
   }
 
   protected override def nullSafeEval(input: Any): Any = {
@@ -104,7 +104,7 @@ case class Abs(child: Expression)
   private lazy val numeric = TypeUtils.getNumeric(dataType)
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = 
dataType match {
-    case dt: DecimalType =>
+    case _: DecimalType =>
       defineCodeGen(ctx, ev, c => s"$c.abs()")
     case dt: NumericType =>
       defineCodeGen(ctx, ev, c => 
s"(${CodeGenerator.javaType(dt)})(java.lang.Math.abs($c))")
@@ -117,15 +117,21 @@ abstract class BinaryArithmetic extends BinaryOperator 
with NullIntolerant {
 
   override def dataType: DataType = left.dataType
 
-  override lazy val resolved = childrenResolved && 
checkInputDataTypes().isSuccess
+  override lazy val resolved: Boolean = childrenResolved && 
checkInputDataTypes().isSuccess
 
   /** Name of the function for this expression on a [[Decimal]] type. */
   def decimalMethod: String =
     sys.error("BinaryArithmetics must override either decimalMethod or 
genCode")
 
+  /** Name of the function for this expression on a [[CalendarInterval]] type. 
*/
+  def calendarIntervalMethod: String =
+    sys.error("BinaryArithmetics must override either calendarIntervalMethod 
or genCode")
+
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = 
dataType match {
-    case dt: DecimalType =>
+    case _: DecimalType =>
       defineCodeGen(ctx, ev, (eval1, eval2) => 
s"$eval1.$decimalMethod($eval2)")
+    case CalendarIntervalType =>
+      defineCodeGen(ctx, ev, (eval1, eval2) => 
s"$eval1.$calendarIntervalMethod($eval2)")
     // byte and short are casted into int when add, minus, times or divide
     case ByteType | ShortType =>
       defineCodeGen(ctx, ev,
@@ -152,6 +158,10 @@ case class Add(left: Expression, right: Expression) 
extends BinaryArithmetic {
 
   override def symbol: String = "+"
 
+  override def decimalMethod: String = "$plus"
+
+  override def calendarIntervalMethod: String = "add"
+
   private lazy val numeric = TypeUtils.getNumeric(dataType)
 
   protected override def nullSafeEval(input1: Any, input2: Any): Any = {
@@ -161,18 +171,6 @@ case class Add(left: Expression, right: Expression) 
extends BinaryArithmetic {
       numeric.plus(input1, input2)
     }
   }
-
-  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = 
dataType match {
-    case dt: DecimalType =>
-      defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$plus($eval2)")
-    case ByteType | ShortType =>
-      defineCodeGen(ctx, ev,
-        (eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 
$symbol $eval2)")
-    case CalendarIntervalType =>
-      defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.add($eval2)")
-    case _ =>
-      defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
-  }
 }
 
 @ExpressionDescription(
@@ -188,6 +186,10 @@ case class Subtract(left: Expression, right: Expression) 
extends BinaryArithmeti
 
   override def symbol: String = "-"
 
+  override def decimalMethod: String = "$minus"
+
+  override def calendarIntervalMethod: String = "subtract"
+
   private lazy val numeric = TypeUtils.getNumeric(dataType)
 
   protected override def nullSafeEval(input1: Any, input2: Any): Any = {
@@ -197,18 +199,6 @@ case class Subtract(left: Expression, right: Expression) 
extends BinaryArithmeti
       numeric.minus(input1, input2)
     }
   }
-
-  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = 
dataType match {
-    case dt: DecimalType =>
-      defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$minus($eval2)")
-    case ByteType | ShortType =>
-      defineCodeGen(ctx, ev,
-        (eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 
$symbol $eval2)")
-    case CalendarIntervalType =>
-      defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.subtract($eval2)")
-    case _ =>
-      defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
-  }
 }
 
 @ExpressionDescription(
@@ -416,7 +406,7 @@ case class Pmod(left: Expression, right: Expression) 
extends BinaryArithmetic {
 
   override def symbol: String = "pmod"
 
-  protected def checkTypesInternal(t: DataType) =
+  protected def checkTypesInternal(t: DataType): TypeCheckResult =
     TypeUtils.checkForNumericExpr(t, "pmod")
 
   override def inputType: AbstractDataType = NumericType


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to