dongjoon-hyun closed pull request #23434: [SPARK-22951][SQL][BRANCH-2.2] fix 
aggregation after dropDuplicates on empty dataframes
URL: https://github.com/apache/spark/pull/23434
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index fe668217a6a5e..2541f95d24901 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1165,7 +1165,13 @@ object ReplaceDeduplicateWithAggregate extends 
Rule[LogicalPlan] {
           Alias(new First(attr).toAggregateExpression(), 
attr.name)(attr.exprId)
         }
       }
-      Aggregate(keys, aggCols, child)
+      // SPARK-22951: Physical aggregate operators distinguishes global 
aggregation and grouping
+      // aggregations by checking the number of grouping keys. The key 
difference here is that a
+      // global aggregation always returns at least one row even if there are 
no input rows. Here
+      // we append a literal when the grouping key list is empty so that the 
result aggregate
+      // operator is properly treated as a grouping aggregation.
+      val nonemptyKeys = if (keys.isEmpty) Literal(1) :: Nil else keys
+      Aggregate(nonemptyKeys, aggCols, child)
   }
 }
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
index e68423f85c92e..c4c9fc95e96d4 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
 
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.Alias
+import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
 import org.apache.spark.sql.catalyst.expressions.aggregate.First
 import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest}
 import org.apache.spark.sql.catalyst.plans.logical._
@@ -94,6 +94,14 @@ class ReplaceOperatorSuite extends PlanTest {
     comparePlans(optimized, correctAnswer)
   }
 
+  test("add one grouping key if necessary when replace Deduplicate with 
Aggregate") {
+    val input = LocalRelation()
+    val query = Deduplicate(Seq.empty, input, streaming = false) // 
dropDuplicates()
+    val optimized = Optimize.execute(query.analyze)
+    val correctAnswer = Aggregate(Seq(Literal(1)), input.output, input)
+    comparePlans(optimized, correctAnswer)
+  }
+
   test("don't replace streaming Deduplicate") {
     val input = LocalRelation('a.int, 'b.int)
     val attrA = input.output(0)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 87aabf7220246..d8d19dd2160ad 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.test.SQLTestData.DecimalData
-import org.apache.spark.sql.types.{Decimal, DecimalType}
+import org.apache.spark.sql.types.DecimalType
 
 case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: 
Double)
 
@@ -453,7 +453,6 @@ class DataFrameAggregateSuite extends QueryTest with 
SharedSQLContext {
 
   test("null moments") {
     val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
-
     checkAnswer(
       emptyTableData.agg(variance('a), var_samp('a), var_pop('a), 
skewness('a), kurtosis('a)),
       Row(null, null, null, null, null))
@@ -608,4 +607,23 @@ class DataFrameAggregateSuite extends QueryTest with 
SharedSQLContext {
       assert(exchangePlans.length == 1)
     }
   }
+
+  Seq(true, false).foreach { codegen =>
+    test("SPARK-22951: dropDuplicates on empty dataFrames should produce 
correct aggregate " +
+      s"results when codegen is enabled: $codegen") {
+      withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, codegen.toString)) {
+        // explicit global aggregations
+        val emptyAgg = Map.empty[String, String]
+        checkAnswer(spark.emptyDataFrame.agg(emptyAgg), Seq(Row()))
+        checkAnswer(spark.emptyDataFrame.groupBy().agg(emptyAgg), Seq(Row()))
+        checkAnswer(spark.emptyDataFrame.groupBy().agg(count("*")), 
Seq(Row(0)))
+        checkAnswer(spark.emptyDataFrame.dropDuplicates().agg(emptyAgg), 
Seq(Row()))
+        
checkAnswer(spark.emptyDataFrame.dropDuplicates().groupBy().agg(emptyAgg), 
Seq(Row()))
+        
checkAnswer(spark.emptyDataFrame.dropDuplicates().groupBy().agg(count("*")), 
Seq(Row(0)))
+
+        // global aggregation is converted to grouping aggregation:
+        assert(spark.emptyDataFrame.dropDuplicates().count() == 0)
+      }
+    }
+  }
 }


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to