Repository: spark
Updated Branches:
  refs/heads/branch-1.0 010040fd0 -> 8f3b9250c


[SQL] Improve SparkSQL Aggregates

* Add native min/max (was using hive before).
* Handle nulls correctly in Avg and Sum.

Author: Michael Armbrust <mich...@databricks.com>

Closes #683 from marmbrus/aggFixes and squashes the following commits:

64fe30b [Michael Armbrust] Improve SparkSQL Aggregates * Add native min/max 
(was using hive before). * Handle nulls correctly in Avg and Sum.

(cherry picked from commit 19c8fb02bc2c2f76c3c45bfff4b8d093be9d7c66)
Signed-off-by: Reynold Xin <r...@apache.org>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8f3b9250
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8f3b9250
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8f3b9250

Branch: refs/heads/branch-1.0
Commit: 8f3b9250c975debafd663b857ef66e6627eb0b5f
Parents: 010040f
Author: Michael Armbrust <mich...@databricks.com>
Authored: Thu May 8 01:08:43 2014 -0400
Committer: Reynold Xin <r...@apache.org>
Committed: Thu May 8 01:12:10 2014 -0400

----------------------------------------------------------------------
 .../apache/spark/sql/catalyst/SqlParser.scala   |  4 +
 .../sql/catalyst/expressions/aggregates.scala   | 85 +++++++++++++++++---
 .../org/apache/spark/sql/SQLQuerySuite.scala    |  7 ++
 .../scala/org/apache/spark/sql/TestData.scala   | 10 +++
 4 files changed, 96 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8f3b9250/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index 8c76a3a..b3a3a1e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -114,6 +114,8 @@ class SqlParser extends StandardTokenParsers with 
PackratParsers {
   protected val JOIN = Keyword("JOIN")
   protected val LEFT = Keyword("LEFT")
   protected val LIMIT = Keyword("LIMIT")
+  protected val MAX = Keyword("MAX")
+  protected val MIN = Keyword("MIN")
   protected val NOT = Keyword("NOT")
   protected val NULL = Keyword("NULL")
   protected val ON = Keyword("ON")
@@ -318,6 +320,8 @@ class SqlParser extends StandardTokenParsers with 
PackratParsers {
     COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => 
CountDistinct(exp :: Nil) } |
     FIRST ~> "(" ~> expression <~ ")" ^^ { case exp => First(exp) } |
     AVG ~> "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } |
+    MIN ~> "(" ~> expression <~ ")" ^^ { case exp => Min(exp) } |
+    MAX ~> "(" ~> expression <~ ")" ^^ { case exp => Max(exp) } |
     IF ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ {
       case c ~ "," ~ t ~ "," ~ f => If(c,t,f)
     } |

http://git-wip-us.apache.org/repos/asf/spark/blob/8f3b9250/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index b152f95..7777d37 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -86,6 +86,67 @@ abstract class AggregateFunction
   override def newInstance() = makeCopy(productIterator.map { case a: AnyRef 
=> a }.toArray)
 }
 
+case class Min(child: Expression) extends PartialAggregate with 
trees.UnaryNode[Expression] {
+  override def references = child.references
+  override def nullable = child.nullable
+  override def dataType = child.dataType
+  override def toString = s"MIN($child)"
+
+  override def asPartial: SplitEvaluation = {
+    val partialMin = Alias(Min(child), "PartialMin")()
+    SplitEvaluation(Min(partialMin.toAttribute), partialMin :: Nil)
+  }
+
+  override def newInstance() = new MinFunction(child, this)
+}
+
+case class MinFunction(expr: Expression, base: AggregateExpression) extends 
AggregateFunction {
+  def this() = this(null, null) // Required for serialization.
+
+  var currentMin: Any = _
+
+  override def update(input: Row): Unit = {
+    if (currentMin == null) {
+      currentMin = expr.eval(input)
+    } else if(GreaterThan(Literal(currentMin, expr.dataType), 
expr).eval(input) == true) {
+      currentMin = expr.eval(input)
+    }
+  }
+
+  override def eval(input: Row): Any = currentMin
+}
+
+case class Max(child: Expression) extends PartialAggregate with 
trees.UnaryNode[Expression] {
+  override def references = child.references
+  override def nullable = child.nullable
+  override def dataType = child.dataType
+  override def toString = s"MAX($child)"
+
+  override def asPartial: SplitEvaluation = {
+    val partialMax = Alias(Max(child), "PartialMax")()
+    SplitEvaluation(Max(partialMax.toAttribute), partialMax :: Nil)
+  }
+
+  override def newInstance() = new MaxFunction(child, this)
+}
+
+case class MaxFunction(expr: Expression, base: AggregateExpression) extends 
AggregateFunction {
+  def this() = this(null, null) // Required for serialization.
+
+  var currentMax: Any = _
+
+  override def update(input: Row): Unit = {
+    if (currentMax == null) {
+      currentMax = expr.eval(input)
+    } else if(LessThan(Literal(currentMax, expr.dataType), expr).eval(input) 
== true) {
+      currentMax = expr.eval(input)
+    }
+  }
+
+  override def eval(input: Row): Any = currentMax
+}
+
+
 case class Count(child: Expression) extends PartialAggregate with 
trees.UnaryNode[Expression] {
   override def references = child.references
   override def nullable = false
@@ -97,7 +158,7 @@ case class Count(child: Expression) extends PartialAggregate 
with trees.UnaryNod
     SplitEvaluation(Sum(partialCount.toAttribute), partialCount :: Nil)
   }
 
-  override def newInstance()= new CountFunction(child, this)
+  override def newInstance() = new CountFunction(child, this)
 }
 
 case class CountDistinct(expressions: Seq[Expression]) extends 
AggregateExpression {
@@ -106,7 +167,7 @@ case class CountDistinct(expressions: Seq[Expression]) 
extends AggregateExpressi
   override def nullable = false
   override def dataType = IntegerType
   override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")}})"
-  override def newInstance()= new CountDistinctFunction(expressions, this)
+  override def newInstance() = new CountDistinctFunction(expressions, this)
 }
 
 case class Average(child: Expression) extends PartialAggregate with 
trees.UnaryNode[Expression] {
@@ -126,7 +187,7 @@ case class Average(child: Expression) extends 
PartialAggregate with trees.UnaryN
       partialCount :: partialSum :: Nil)
   }
 
-  override def newInstance()= new AverageFunction(child, this)
+  override def newInstance() = new AverageFunction(child, this)
 }
 
 case class Sum(child: Expression) extends PartialAggregate with 
trees.UnaryNode[Expression] {
@@ -142,7 +203,7 @@ case class Sum(child: Expression) extends PartialAggregate 
with trees.UnaryNode[
       partialSum :: Nil)
   }
 
-  override def newInstance()= new SumFunction(child, this)
+  override def newInstance() = new SumFunction(child, this)
 }
 
 case class SumDistinct(child: Expression)
@@ -153,7 +214,7 @@ case class SumDistinct(child: Expression)
   override def dataType = child.dataType
   override def toString = s"SUM(DISTINCT $child)"
 
-  override def newInstance()= new SumDistinctFunction(child, this)
+  override def newInstance() = new SumDistinctFunction(child, this)
 }
 
 case class First(child: Expression) extends PartialAggregate with 
trees.UnaryNode[Expression] {
@@ -168,7 +229,7 @@ case class First(child: Expression) extends 
PartialAggregate with trees.UnaryNod
       First(partialFirst.toAttribute),
       partialFirst :: Nil)
   }
-  override def newInstance()= new FirstFunction(child, this)
+  override def newInstance() = new FirstFunction(child, this)
 }
 
 case class AverageFunction(expr: Expression, base: AggregateExpression)
@@ -176,11 +237,13 @@ case class AverageFunction(expr: Expression, base: 
AggregateExpression)
 
   def this() = this(null, null) // Required for serialization.
 
+  private val zero = Cast(Literal(0), expr.dataType)
+
   private var count: Long = _
-  private val sum = MutableLiteral(Cast(Literal(0), 
expr.dataType).eval(EmptyRow))
+  private val sum = MutableLiteral(zero.eval(EmptyRow))
   private val sumAsDouble = Cast(sum, DoubleType)
 
-  private val addFunction = Add(sum, expr)
+  private val addFunction = Add(sum, Coalesce(Seq(expr, zero)))
 
   override def eval(input: Row): Any =
     sumAsDouble.eval(EmptyRow).asInstanceOf[Double] / count.toDouble
@@ -209,9 +272,11 @@ case class CountFunction(expr: Expression, base: 
AggregateExpression) extends Ag
 case class SumFunction(expr: Expression, base: AggregateExpression) extends 
AggregateFunction {
   def this() = this(null, null) // Required for serialization.
 
-  private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).eval(null))
+  private val zero = Cast(Literal(0), expr.dataType)
+
+  private val sum = MutableLiteral(zero.eval(null))
 
-  private val addFunction = Add(sum, expr)
+  private val addFunction = Add(sum, Coalesce(Seq(expr, zero)))
 
   override def update(input: Row): Unit = {
     sum.update(addFunction, input)

http://git-wip-us.apache.org/repos/asf/spark/blob/8f3b9250/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index dde957d..e966d89 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -50,6 +50,13 @@ class SQLQuerySuite extends QueryTest {
       Seq((1,3),(2,3),(3,3)))
   }
 
+  test("aggregates with nulls") {
+    checkAnswer(
+      sql("SELECT MIN(a), MAX(a), AVG(a), SUM(a), COUNT(a) FROM nullInts"),
+      (1, 3, 2, 6, 3) :: Nil
+    )
+  }
+
   test("select *") {
     checkAnswer(
       sql("SELECT * FROM testData"),

http://git-wip-us.apache.org/repos/asf/spark/blob/8f3b9250/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index b5973c0..aa71e27 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -84,4 +84,14 @@ object TestData {
       List.fill(2)(StringData(null)) ++
       List.fill(2)(StringData("test")))
   nullableRepeatedData.registerAsTable("nullableRepeatedData")
+
+  case class NullInts(a: Integer)
+  val nullInts =
+    TestSQLContext.sparkContext.parallelize(
+      NullInts(1) ::
+      NullInts(2) ::
+      NullInts(3) ::
+      NullInts(null) :: Nil
+    )
+  nullInts.registerAsTable("nullInts")
 }

Reply via email to