Repository: spark
Updated Branches:
  refs/heads/master 3bc7055e2 -> 63a492b93


[SPARK-8828] [SQL] Revert SPARK-5680

JIRA: https://issues.apache.org/jira/browse/SPARK-8828

Author: Yijie Shen <[email protected]>

Closes #7667 from yjshen/revert_combinesum_2 and squashes the following commits:

c37ccb1 [Yijie Shen] add test case
8377214 [Yijie Shen] revert spark.sql.useAggregate2 to its default value
e2305ac [Yijie Shen] fix bug - avg on decimal column
7cb0e95 [Yijie Shen] [wip] resolving bugs
1fadb5a [Yijie Shen] remove occurance
17c6248 [Yijie Shen] revert SPARK-5680


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/63a492b9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/63a492b9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/63a492b9

Branch: refs/heads/master
Commit: 63a492b931765b1edd66624421d503f1927825ec
Parents: 3bc7055
Author: Yijie Shen <[email protected]>
Authored: Mon Jul 27 22:47:31 2015 -0700
Committer: Michael Armbrust <[email protected]>
Committed: Mon Jul 27 22:47:33 2015 -0700

----------------------------------------------------------------------
 .../sql/catalyst/expressions/aggregates.scala   | 70 ++------------------
 .../sql/execution/GeneratedAggregate.scala      | 41 +-----------
 .../spark/sql/execution/SparkStrategies.scala   |  2 +-
 .../org/apache/spark/sql/SQLQuerySuite.scala    | 31 +++++++++
 .../hive/execution/HiveCompatibilitySuite.scala |  1 -
 ...er_format-0-eff4ef3c207d14d5121368f294697964 |  0
 ...er_format-1-4a03c4328565c60ca99689239f07fb16 |  1 -
 7 files changed, 37 insertions(+), 109 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/63a492b9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 42343d4..5d4b349 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -404,7 +404,7 @@ case class Average(child: Expression) extends 
UnaryExpression with PartialAggreg
 
         // partialSum already increase the precision by 10
         val castedSum = Cast(Sum(partialSum.toAttribute), partialSum.dataType)
-        val castedCount = Sum(partialCount.toAttribute)
+        val castedCount = Cast(Sum(partialCount.toAttribute), 
partialSum.dataType)
         SplitEvaluation(
           Cast(Divide(castedSum, castedCount), dataType),
           partialCount :: partialSum :: Nil)
@@ -490,13 +490,13 @@ case class Sum(child: Expression) extends UnaryExpression 
with PartialAggregate1
       case DecimalType.Fixed(_, _) =>
         val partialSum = Alias(Sum(child), "PartialSum")()
         SplitEvaluation(
-          Cast(CombineSum(partialSum.toAttribute), dataType),
+          Cast(Sum(partialSum.toAttribute), dataType),
           partialSum :: Nil)
 
       case _ =>
         val partialSum = Alias(Sum(child), "PartialSum")()
         SplitEvaluation(
-          CombineSum(partialSum.toAttribute),
+          Sum(partialSum.toAttribute),
           partialSum :: Nil)
     }
   }
@@ -522,8 +522,7 @@ case class SumFunction(expr: Expression, base: 
AggregateExpression1) extends Agg
 
   private val sum = MutableLiteral(null, calcType)
 
-  private val addFunction =
-    Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, 
zero))
+  private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), 
Cast(expr, calcType)), sum))
 
   override def update(input: InternalRow): Unit = {
     sum.update(addFunction, input)
@@ -538,67 +537,6 @@ case class SumFunction(expr: Expression, base: 
AggregateExpression1) extends Agg
   }
 }
 
-/**
- * Sum should satisfy 3 cases:
- * 1) sum of all null values = zero
- * 2) sum for table column with no data = null
- * 3) sum of column with null and not null values = sum of not null values
- * Require separate CombineSum Expression and function as it has to 
distinguish "No data" case
- * versus "data equals null" case, while aggregating results and at each 
partial expression.i.e.,
- * Combining    PartitionLevel   InputData
- *                           <-- null
- * Zero     <-- Zero         <-- null
- *
- *          <-- null         <-- no data
- * null     <-- null         <-- no data
- */
-case class CombineSum(child: Expression) extends AggregateExpression1 {
-  def this() = this(null)
-
-  override def children: Seq[Expression] = child :: Nil
-  override def nullable: Boolean = true
-  override def dataType: DataType = child.dataType
-  override def toString: String = s"CombineSum($child)"
-  override def newInstance(): CombineSumFunction = new 
CombineSumFunction(child, this)
-}
-
-case class CombineSumFunction(expr: Expression, base: AggregateExpression1)
-  extends AggregateFunction1 {
-
-  def this() = this(null, null) // Required for serialization.
-
-  private val calcType =
-    expr.dataType match {
-      case DecimalType.Fixed(precision, scale) =>
-        DecimalType.bounded(precision + 10, scale)
-      case _ =>
-        expr.dataType
-    }
-
-  private val zero = Cast(Literal(0), calcType)
-
-  private val sum = MutableLiteral(null, calcType)
-
-  private val addFunction =
-    Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, 
zero))
-
-  override def update(input: InternalRow): Unit = {
-    val result = expr.eval(input)
-    // partial sum result can be null only when no input rows present
-    if(result != null) {
-      sum.update(addFunction, input)
-    }
-  }
-
-  override def eval(input: InternalRow): Any = {
-    expr.dataType match {
-      case DecimalType.Fixed(_, _) =>
-        Cast(sum, dataType).eval(null)
-      case _ => sum.eval(null)
-    }
-  }
-}
-
 case class SumDistinct(child: Expression) extends UnaryExpression with 
PartialAggregate1 {
 
   def this() = this(null)

http://git-wip-us.apache.org/repos/asf/spark/blob/63a492b9/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index 5ad4691..1cd1420 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -108,7 +108,7 @@ case class GeneratedAggregate(
           Add(
             Coalesce(currentSum :: zero :: Nil),
             Cast(expr, calcType)
-          ) :: currentSum :: zero :: Nil)
+          ) :: currentSum :: Nil)
         val result =
           expr.dataType match {
             case DecimalType.Fixed(_, _) =>
@@ -118,45 +118,6 @@ case class GeneratedAggregate(
 
         AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, 
updateFunction :: Nil, result)
 
-      case cs @ CombineSum(expr) =>
-        val calcType =
-          expr.dataType match {
-            case DecimalType.Fixed(p, s) =>
-              DecimalType.bounded(p + 10, s)
-            case _ =>
-              expr.dataType
-          }
-
-        val currentSum = AttributeReference("currentSum", calcType, nullable = 
true)()
-        val initialValue = Literal.create(null, calcType)
-
-        // Coalesce avoids double calculation...
-        // but really, common sub expression elimination would be better....
-        val zero = Cast(Literal(0), calcType)
-        // If we're evaluating UnscaledValue(x), we can do Count on x 
directly, since its
-        // UnscaledValue will be null if and only if x is null; helps with 
Average on decimals
-        val actualExpr = expr match {
-          case UnscaledValue(e) => e
-          case _ => expr
-        }
-        // partial sum result can be null only when no input rows present
-        val updateFunction = If(
-          IsNotNull(actualExpr),
-          Coalesce(
-            Add(
-              Coalesce(currentSum :: zero :: Nil),
-              Cast(expr, calcType)) :: currentSum :: zero :: Nil),
-          currentSum)
-
-        val result =
-          expr.dataType match {
-            case DecimalType.Fixed(_, _) =>
-              Cast(currentSum, cs.dataType)
-            case _ => currentSum
-          }
-
-        AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, 
updateFunction :: Nil, result)
-
       case m @ Max(expr) =>
         val currentMax = AttributeReference("currentMax", expr.dataType, 
nullable = true)()
         val initialValue = Literal.create(null, expr.dataType)

http://git-wip-us.apache.org/repos/asf/spark/blob/63a492b9/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 306bbfe..d88a022 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -201,7 +201,7 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
     }
 
     def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = aggs.forall 
{
-      case _: CombineSum | _: Sum | _: Count | _: Max | _: Min |  _: 
CombineSetsAndCount => true
+      case _: Sum | _: Count | _: Max | _: Min |  _: CombineSetsAndCount => 
true
       // The generated set implementation is pretty limited ATM.
       case CollectHashSet(exprs) if exprs.size == 1  &&
            Seq(IntegerType, LongType).contains(exprs.head.dataType) => true

http://git-wip-us.apache.org/repos/asf/spark/blob/63a492b9/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 358e319..42724ed 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -227,6 +227,37 @@ class SQLQuerySuite extends QueryTest with 
BeforeAndAfterAll with SQLTestUtils {
       Seq(Row("1"), Row("2")))
   }
 
+  test("SPARK-8828 sum should return null if all input values are null") {
+    withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "true") {
+      withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") {
+        checkAnswer(
+          sql("select sum(a), avg(a) from allNulls"),
+          Seq(Row(null, null))
+        )
+      }
+      withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") {
+        checkAnswer(
+          sql("select sum(a), avg(a) from allNulls"),
+          Seq(Row(null, null))
+        )
+      }
+    }
+    withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") {
+      withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") {
+        checkAnswer(
+          sql("select sum(a), avg(a) from allNulls"),
+          Seq(Row(null, null))
+        )
+      }
+      withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") {
+        checkAnswer(
+          sql("select sum(a), avg(a) from allNulls"),
+          Seq(Row(null, null))
+        )
+      }
+    }
+  }
+
   test("aggregation with codegen") {
     val originalValue = sqlContext.conf.codegenEnabled
     sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true)

http://git-wip-us.apache.org/repos/asf/spark/blob/63a492b9/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
 
b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index b12b383..ec959cb 100644
--- 
a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ 
b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -822,7 +822,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with 
BeforeAndAfter {
     "udaf_covar_pop",
     "udaf_covar_samp",
     "udaf_histogram_numeric",
-    "udaf_number_format",
     "udf2",
     "udf5",
     "udf6",

http://git-wip-us.apache.org/repos/asf/spark/blob/63a492b9/sql/hive/src/test/resources/golden/udaf_number_format-0-eff4ef3c207d14d5121368f294697964
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/resources/golden/udaf_number_format-0-eff4ef3c207d14d5121368f294697964
 
b/sql/hive/src/test/resources/golden/udaf_number_format-0-eff4ef3c207d14d5121368f294697964
deleted file mode 100644
index e69de29..0000000

http://git-wip-us.apache.org/repos/asf/spark/blob/63a492b9/sql/hive/src/test/resources/golden/udaf_number_format-1-4a03c4328565c60ca99689239f07fb16
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/resources/golden/udaf_number_format-1-4a03c4328565c60ca99689239f07fb16
 
b/sql/hive/src/test/resources/golden/udaf_number_format-1-4a03c4328565c60ca99689239f07fb16
deleted file mode 100644
index c6f275a..0000000
--- 
a/sql/hive/src/test/resources/golden/udaf_number_format-1-4a03c4328565c60ca99689239f07fb16
+++ /dev/null
@@ -1 +0,0 @@
-0.0    NULL    NULL    NULL


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to