Repository: spark
Updated Branches:
  refs/heads/master 408a3ff2c -> 0f3fa2f28


[SPARK-24996][SQL] Use DSL in DeclarativeAggregate

## What changes were proposed in this pull request?

The PR refactors the aggregate expressions which were not using DSL in order to 
simplify them.

## How was this patch tested?

NA

Author: Marco Gaido <marcogaid...@gmail.com>

Closes #21970 from mgaido91/SPARK-24996.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0f3fa2f2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0f3fa2f2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0f3fa2f2

Branch: refs/heads/master
Commit: 0f3fa2f289f53a8ceea3b0a52fa6dc319001b10b
Parents: 408a3ff
Author: Marco Gaido <marcogaid...@gmail.com>
Authored: Mon Aug 6 19:46:51 2018 -0400
Committer: Xiao Li <gatorsm...@gmail.com>
Committed: Mon Aug 6 19:46:51 2018 -0400

----------------------------------------------------------------------
 .../apache/spark/sql/catalyst/dsl/package.scala |  2 +
 .../expressions/aggregate/Average.scala         |  2 +-
 .../aggregate/CentralMomentAgg.scala            | 40 +++++++++-----------
 .../catalyst/expressions/aggregate/Corr.scala   | 13 +++----
 .../expressions/aggregate/Covariance.scala      | 16 ++++----
 .../catalyst/expressions/aggregate/First.scala  |  7 ++--
 .../catalyst/expressions/aggregate/Last.scala   |  7 ++--
 .../catalyst/expressions/aggregate/Max.scala    |  5 ++-
 .../catalyst/expressions/aggregate/Min.scala    |  5 ++-
 .../catalyst/expressions/aggregate/Sum.scala    |  7 ++--
 .../expressions/windowExpressions.scala         | 30 +++++++--------
 11 files changed, 65 insertions(+), 69 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0f3fa2f2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 75387fa..2b582b5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -167,6 +167,8 @@ package object dsl {
     def upper(e: Expression): Expression = Upper(e)
     def lower(e: Expression): Expression = Lower(e)
     def coalesce(args: Expression*): Expression = Coalesce(args)
+    def greatest(args: Expression*): Expression = Greatest(args)
+    def least(args: Expression*): Expression = Least(args)
     def sqrt(e: Expression): Expression = Sqrt(e)
     def abs(e: Expression): Expression = Abs(e)
     def star(names: String*): Expression = names match {

http://git-wip-us.apache.org/repos/asf/spark/blob/0f3fa2f2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index f1fad77..5ecb77b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -68,7 +68,7 @@ abstract class AverageLike(child: Expression) extends 
DeclarativeAggregate {
     Add(
       sum,
       coalesce(child.cast(sumDataType), Literal(0).cast(sumDataType))),
-    /* count = */ If(IsNull(child), count, count + 1L)
+    /* count = */ If(child.isNull, count, count + 1L)
   )
 
   override lazy val updateExpressions = updateExpressionsDef

http://git-wip-us.apache.org/repos/asf/spark/blob/0f3fa2f2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
index 6bbb083..e2ff0ef 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
@@ -75,7 +75,7 @@ abstract class CentralMomentAgg(child: Expression)
     val n2 = n.right
     val newN = n1 + n2
     val delta = avg.right - avg.left
-    val deltaN = If(newN === Literal(0.0), Literal(0.0), delta / newN)
+    val deltaN = If(newN === 0.0, 0.0, delta / newN)
     val newAvg = avg.left + deltaN * n2
 
     // higher order moments computed according to:
@@ -102,7 +102,7 @@ abstract class CentralMomentAgg(child: Expression)
   }
 
   protected def updateExpressionsDef: Seq[Expression] = {
-    val newN = n + Literal(1.0)
+    val newN = n + 1.0
     val delta = child - avg
     val deltaN = delta / newN
     val newAvg = avg + deltaN
@@ -123,11 +123,11 @@ abstract class CentralMomentAgg(child: Expression)
     }
 
     trimHigherOrder(Seq(
-      If(IsNull(child), n, newN),
-      If(IsNull(child), avg, newAvg),
-      If(IsNull(child), m2, newM2),
-      If(IsNull(child), m3, newM3),
-      If(IsNull(child), m4, newM4)
+      If(child.isNull, n, newN),
+      If(child.isNull, avg, newAvg),
+      If(child.isNull, m2, newM2),
+      If(child.isNull, m3, newM3),
+      If(child.isNull, m4, newM4)
     ))
   }
 }
@@ -142,8 +142,7 @@ case class StddevPop(child: Expression) extends 
CentralMomentAgg(child) {
   override protected def momentOrder = 2
 
   override val evaluateExpression: Expression = {
-    If(n === Literal(0.0), Literal.create(null, DoubleType),
-      Sqrt(m2 / n))
+    If(n === 0.0, Literal.create(null, DoubleType), sqrt(m2 / n))
   }
 
   override def prettyName: String = "stddev_pop"
@@ -159,9 +158,8 @@ case class StddevSamp(child: Expression) extends 
CentralMomentAgg(child) {
   override protected def momentOrder = 2
 
   override val evaluateExpression: Expression = {
-    If(n === Literal(0.0), Literal.create(null, DoubleType),
-      If(n === Literal(1.0), Literal(Double.NaN),
-        Sqrt(m2 / (n - Literal(1.0)))))
+    If(n === 0.0, Literal.create(null, DoubleType),
+      If(n === 1.0, Double.NaN, sqrt(m2 / (n - 1.0))))
   }
 
   override def prettyName: String = "stddev_samp"
@@ -175,8 +173,7 @@ case class VariancePop(child: Expression) extends 
CentralMomentAgg(child) {
   override protected def momentOrder = 2
 
   override val evaluateExpression: Expression = {
-    If(n === Literal(0.0), Literal.create(null, DoubleType),
-      m2 / n)
+    If(n === 0.0, Literal.create(null, DoubleType), m2 / n)
   }
 
   override def prettyName: String = "var_pop"
@@ -190,9 +187,8 @@ case class VarianceSamp(child: Expression) extends 
CentralMomentAgg(child) {
   override protected def momentOrder = 2
 
   override val evaluateExpression: Expression = {
-    If(n === Literal(0.0), Literal.create(null, DoubleType),
-      If(n === Literal(1.0), Literal(Double.NaN),
-        m2 / (n - Literal(1.0))))
+    If(n === 0.0, Literal.create(null, DoubleType),
+      If(n === 1.0, Double.NaN, m2 / (n - 1.0)))
   }
 
   override def prettyName: String = "var_samp"
@@ -207,9 +203,8 @@ case class Skewness(child: Expression) extends 
CentralMomentAgg(child) {
   override protected def momentOrder = 3
 
   override val evaluateExpression: Expression = {
-    If(n === Literal(0.0), Literal.create(null, DoubleType),
-      If(m2 === Literal(0.0), Literal(Double.NaN),
-        Sqrt(n) * m3 / Sqrt(m2 * m2 * m2)))
+    If(n === 0.0, Literal.create(null, DoubleType),
+      If(m2 === 0.0, Double.NaN, sqrt(n) * m3 / sqrt(m2 * m2 * m2)))
   }
 }
 
@@ -220,9 +215,8 @@ case class Kurtosis(child: Expression) extends 
CentralMomentAgg(child) {
   override protected def momentOrder = 4
 
   override val evaluateExpression: Expression = {
-    If(n === Literal(0.0), Literal.create(null, DoubleType),
-      If(m2 === Literal(0.0), Literal(Double.NaN),
-        n * m4 / (m2 * m2) - Literal(3.0)))
+    If(n === 0.0, Literal.create(null, DoubleType),
+      If(m2 === 0.0, Double.NaN, n * m4 / (m2 * m2) - 3.0))
   }
 
   override def prettyName: String = "kurtosis"

http://git-wip-us.apache.org/repos/asf/spark/blob/0f3fa2f2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
index 3cdef72..e14cc71 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
@@ -54,9 +54,9 @@ abstract class PearsonCorrelation(x: Expression, y: 
Expression)
     val n2 = n.right
     val newN = n1 + n2
     val dx = xAvg.right - xAvg.left
-    val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN)
+    val dxN = If(newN === 0.0, 0.0, dx / newN)
     val dy = yAvg.right - yAvg.left
-    val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN)
+    val dyN = If(newN === 0.0, 0.0, dy / newN)
     val newXAvg = xAvg.left + dxN * n2
     val newYAvg = yAvg.left + dyN * n2
     val newCk = ck.left + ck.right + dx * dyN * n1 * n2
@@ -67,7 +67,7 @@ abstract class PearsonCorrelation(x: Expression, y: 
Expression)
   }
 
   protected def updateExpressionsDef: Seq[Expression] = {
-    val newN = n + Literal(1.0)
+    val newN = n + 1.0
     val dx = x - xAvg
     val dxN = dx / newN
     val dy = y - yAvg
@@ -78,7 +78,7 @@ abstract class PearsonCorrelation(x: Expression, y: 
Expression)
     val newXMk = xMk + dx * (x - newXAvg)
     val newYMk = yMk + dy * (y - newYAvg)
 
-    val isNull = IsNull(x) || IsNull(y)
+    val isNull = x.isNull || y.isNull
     Seq(
       If(isNull, n, newN),
       If(isNull, xAvg, newXAvg),
@@ -99,9 +99,8 @@ case class Corr(x: Expression, y: Expression)
   extends PearsonCorrelation(x, y) {
 
   override val evaluateExpression: Expression = {
-    If(n === Literal(0.0), Literal.create(null, DoubleType),
-      If(n === Literal(1.0), Literal(Double.NaN),
-        ck / Sqrt(xMk * yMk)))
+    If(n === 0.0, Literal.create(null, DoubleType),
+      If(n === 1.0, Double.NaN, ck / sqrt(xMk * yMk)))
   }
 
   override def prettyName: String = "corr"

http://git-wip-us.apache.org/repos/asf/spark/blob/0f3fa2f2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
index 72a7c62..ee28eb5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
@@ -50,9 +50,9 @@ abstract class Covariance(x: Expression, y: Expression)
     val n2 = n.right
     val newN = n1 + n2
     val dx = xAvg.right - xAvg.left
-    val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN)
+    val dxN = If(newN === 0.0, 0.0, dx / newN)
     val dy = yAvg.right - yAvg.left
-    val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN)
+    val dyN = If(newN === 0.0, 0.0, dy / newN)
     val newXAvg = xAvg.left + dxN * n2
     val newYAvg = yAvg.left + dyN * n2
     val newCk = ck.left + ck.right + dx * dyN * n1 * n2
@@ -61,7 +61,7 @@ abstract class Covariance(x: Expression, y: Expression)
   }
 
   protected def updateExpressionsDef: Seq[Expression] = {
-    val newN = n + Literal(1.0)
+    val newN = n + 1.0
     val dx = x - xAvg
     val dy = y - yAvg
     val dyN = dy / newN
@@ -69,7 +69,7 @@ abstract class Covariance(x: Expression, y: Expression)
     val newYAvg = yAvg + dyN
     val newCk = ck + dx * (y - newYAvg)
 
-    val isNull = IsNull(x) || IsNull(y)
+    val isNull = x.isNull || y.isNull
     Seq(
       If(isNull, n, newN),
       If(isNull, xAvg, newXAvg),
@@ -83,8 +83,7 @@ abstract class Covariance(x: Expression, y: Expression)
   usage = "_FUNC_(expr1, expr2) - Returns the population covariance of a set 
of number pairs.")
 case class CovPopulation(left: Expression, right: Expression) extends 
Covariance(left, right) {
   override val evaluateExpression: Expression = {
-    If(n === Literal(0.0), Literal.create(null, DoubleType),
-      ck / n)
+    If(n === 0.0, Literal.create(null, DoubleType), ck / n)
   }
   override def prettyName: String = "covar_pop"
 }
@@ -94,9 +93,8 @@ case class CovPopulation(left: Expression, right: Expression) 
extends Covariance
   usage = "_FUNC_(expr1, expr2) - Returns the sample covariance of a set of 
number pairs.")
 case class CovSample(left: Expression, right: Expression) extends 
Covariance(left, right) {
   override val evaluateExpression: Expression = {
-    If(n === Literal(0.0), Literal.create(null, DoubleType),
-      If(n === Literal(1.0), Literal(Double.NaN),
-        ck / (n - Literal(1.0))))
+    If(n === 0.0, Literal.create(null, DoubleType),
+      If(n === 1.0, Double.NaN, ck / (n - 1.0)))
   }
   override def prettyName: String = "covar_samp"
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/0f3fa2f2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala
index 4e671e1..f51bfd5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
 
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import 
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, 
TypeCheckSuccess}
+import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.types._
 
@@ -80,8 +81,8 @@ case class First(child: Expression, ignoreNullsExpr: 
Expression)
   override lazy val updateExpressions: Seq[Expression] = {
     if (ignoreNulls) {
       Seq(
-        /* first = */ If(Or(valueSet, IsNull(child)), first, child),
-        /* valueSet = */ Or(valueSet, IsNotNull(child))
+        /* first = */ If(valueSet || child.isNull, first, child),
+        /* valueSet = */ valueSet || child.isNotNull
       )
     } else {
       Seq(
@@ -97,7 +98,7 @@ case class First(child: Expression, ignoreNullsExpr: 
Expression)
     // false, we are safe to do so because first.right will be null in this 
case).
     Seq(
       /* first = */ If(valueSet.left, first.left, first.right),
-      /* valueSet = */ Or(valueSet.left, valueSet.right)
+      /* valueSet = */ valueSet.left || valueSet.right
     )
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0f3fa2f2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
index 0ccabb9..2650d7b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
 
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import 
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, 
TypeCheckSuccess}
+import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.types._
 
@@ -80,8 +81,8 @@ case class Last(child: Expression, ignoreNullsExpr: 
Expression)
   override lazy val updateExpressions: Seq[Expression] = {
     if (ignoreNulls) {
       Seq(
-        /* last = */ If(IsNull(child), last, child),
-        /* valueSet = */ Or(valueSet, IsNotNull(child))
+        /* last = */ If(child.isNull, last, child),
+        /* valueSet = */ valueSet || child.isNotNull
       )
     } else {
       Seq(
@@ -95,7 +96,7 @@ case class Last(child: Expression, ignoreNullsExpr: 
Expression)
     // Prefer the right hand expression if it has been set.
     Seq(
       /* last = */ If(valueSet.right, last.right, last.left),
-      /* valueSet = */ Or(valueSet.right, valueSet.left)
+      /* valueSet = */ valueSet.right || valueSet.left
     )
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0f3fa2f2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala
index 58fd1d8..71099eb 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.util.TypeUtils
 import org.apache.spark.sql.types._
@@ -45,12 +46,12 @@ case class Max(child: Expression) extends 
DeclarativeAggregate {
   )
 
   override lazy val updateExpressions: Seq[Expression] = Seq(
-    /* max = */ Greatest(Seq(max, child))
+    /* max = */ greatest(max, child)
   )
 
   override lazy val mergeExpressions: Seq[Expression] = {
     Seq(
-      /* max = */ Greatest(Seq(max.left, max.right))
+      /* max = */ greatest(max.left, max.right)
     )
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0f3fa2f2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala
index b2724ee..8c4ba93 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.util.TypeUtils
 import org.apache.spark.sql.types._
@@ -45,12 +46,12 @@ case class Min(child: Expression) extends 
DeclarativeAggregate {
   )
 
   override lazy val updateExpressions: Seq[Expression] = Seq(
-    /* min = */ Least(Seq(min, child))
+    /* min = */ least(min, child)
   )
 
   override lazy val mergeExpressions: Seq[Expression] = {
     Seq(
-      /* min = */ Least(Seq(min.left, min.right))
+      /* min = */ least(min.left, min.right)
     )
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0f3fa2f2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
index 86e40a9..761dba1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.util.TypeUtils
 import org.apache.spark.sql.types._
@@ -61,12 +62,12 @@ case class Sum(child: Expression) extends 
DeclarativeAggregate with ImplicitCast
     if (child.nullable) {
       Seq(
         /* sum = */
-        Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), 
sum))
+        coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum)
       )
     } else {
       Seq(
         /* sum = */
-        Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType))
+        coalesce(sum, zero) + child.cast(sumDataType)
       )
     }
   }
@@ -74,7 +75,7 @@ case class Sum(child: Expression) extends 
DeclarativeAggregate with ImplicitCast
   override lazy val mergeExpressions: Seq[Expression] = {
     Seq(
       /* sum = */
-      Coalesce(Seq(Add(Coalesce(Seq(sum.left, zero)), sum.right), sum.left))
+      coalesce(coalesce(sum.left, zero) + sum.right, sum.left)
     )
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0f3fa2f2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
index 53c6f01..707f312 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
@@ -21,6 +21,7 @@ import java.util.Locale
 
 import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, 
UnresolvedException}
 import 
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, 
TypeCheckSuccess}
+import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, 
DeclarativeAggregate, NoOp}
 import org.apache.spark.sql.types._
 
@@ -476,7 +477,7 @@ abstract class RowNumberLike extends 
AggregateWindowFunction {
   protected val rowNumber = AttributeReference("rowNumber", IntegerType, 
nullable = false)()
   override val aggBufferAttributes: Seq[AttributeReference] = rowNumber :: Nil
   override val initialValues: Seq[Expression] = zero :: Nil
-  override val updateExpressions: Seq[Expression] = Add(rowNumber, one) :: Nil
+  override val updateExpressions: Seq[Expression] = rowNumber + one :: Nil
 }
 
 /**
@@ -527,7 +528,7 @@ case class CumeDist() extends RowNumberLike with 
SizeBasedWindowFunction {
   // The frame for CUME_DIST is Range based instead of Row based, because 
CUME_DIST must
   // return the same value for equal values in the partition.
   override val frame = SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, 
CurrentRow)
-  override val evaluateExpression = Divide(Cast(rowNumber, DoubleType), 
Cast(n, DoubleType))
+  override val evaluateExpression = rowNumber.cast(DoubleType) / 
n.cast(DoubleType)
   override def prettyName: String = "cume_dist"
 }
 
@@ -587,8 +588,7 @@ case class NTile(buckets: Expression) extends RowNumberLike 
with SizeBasedWindow
   private val bucketSize = AttributeReference("bucketSize", IntegerType, 
nullable = false)()
   private val bucketsWithPadding =
     AttributeReference("bucketsWithPadding", IntegerType, nullable = false)()
-  private def bucketOverflow(e: Expression) =
-    If(GreaterThanOrEqual(rowNumber, bucketThreshold), e, zero)
+  private def bucketOverflow(e: Expression) = If(rowNumber >= bucketThreshold, 
e, zero)
 
   override val aggBufferAttributes = Seq(
     rowNumber,
@@ -602,15 +602,14 @@ case class NTile(buckets: Expression) extends 
RowNumberLike with SizeBasedWindow
     zero,
     zero,
     zero,
-    Cast(Divide(n, buckets), IntegerType),
-    Cast(Remainder(n, buckets), IntegerType)
+    (n / buckets).cast(IntegerType),
+    (n % buckets).cast(IntegerType)
   )
 
   override val updateExpressions = Seq(
-    Add(rowNumber, one),
-    Add(bucket, bucketOverflow(one)),
-    Add(bucketThreshold, bucketOverflow(
-      Add(bucketSize, If(LessThan(bucket, bucketsWithPadding), one, zero)))),
+    rowNumber + one,
+    bucket + bucketOverflow(one),
+    bucketThreshold + bucketOverflow(bucketSize + If(bucket < 
bucketsWithPadding, one, zero)),
     NoOp,
     NoOp
   )
@@ -644,7 +643,7 @@ abstract class RankLike extends AggregateWindowFunction {
   protected val rowNumber = AttributeReference("rowNumber", IntegerType, 
nullable = false)()
   protected val zero = Literal(0)
   protected val one = Literal(1)
-  protected val increaseRowNumber = Add(rowNumber, one)
+  protected val increaseRowNumber = rowNumber + one
 
   /**
    * Different RankLike implementations use different source expressions to 
update their rank value.
@@ -653,7 +652,7 @@ abstract class RankLike extends AggregateWindowFunction {
   protected def rankSource: Expression = rowNumber
 
   /** Increase the rank when the current rank == 0 or when the one of order 
attributes changes. */
-  protected val increaseRank = If(And(orderEquals, Not(EqualTo(rank, zero))), 
rank, rankSource)
+  protected val increaseRank = If(orderEquals && rank =!= zero, rank, 
rankSource)
 
   override val aggBufferAttributes: Seq[AttributeReference] = rank +: 
rowNumber +: orderAttrs
   override val initialValues = zero +: one +: orderInit
@@ -707,7 +706,7 @@ case class Rank(children: Seq[Expression]) extends RankLike 
{
 case class DenseRank(children: Seq[Expression]) extends RankLike {
   def this() = this(Nil)
   override def withOrder(order: Seq[Expression]): DenseRank = DenseRank(order)
-  override protected def rankSource = Add(rank, one)
+  override protected def rankSource = rank + one
   override val updateExpressions = increaseRank +: children
   override val aggBufferAttributes = rank +: orderAttrs
   override val initialValues = zero +: orderInit
@@ -736,8 +735,7 @@ case class PercentRank(children: Seq[Expression]) extends 
RankLike with SizeBase
   def this() = this(Nil)
   override def withOrder(order: Seq[Expression]): PercentRank = 
PercentRank(order)
   override def dataType: DataType = DoubleType
-  override val evaluateExpression = If(GreaterThan(n, one),
-      Divide(Cast(Subtract(rank, one), DoubleType), Cast(Subtract(n, one), 
DoubleType)),
-      Literal(0.0d))
+  override val evaluateExpression =
+    If(n > one, (rank - one).cast(DoubleType) / (n - one).cast(DoubleType), 
0.0d)
   override def prettyName: String = "percent_rank"
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to