dongjoon-hyun closed pull request #23434: [SPARK-22951][SQL][BRANCH-2.2] fix
aggregation after dropDuplicates on empty dataframes
URL: https://github.com/apache/spark/pull/23434
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index fe668217a6a5e..2541f95d24901 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1165,7 +1165,13 @@ object ReplaceDeduplicateWithAggregate extends
Rule[LogicalPlan] {
Alias(new First(attr).toAggregateExpression(),
attr.name)(attr.exprId)
}
}
- Aggregate(keys, aggCols, child)
+ // SPARK-22951: Physical aggregate operators distinguishes global
aggregation and grouping
+ // aggregations by checking the number of grouping keys. The key
difference here is that a
+ // global aggregation always returns at least one row even if there are
no input rows. Here
+ // we append a literal when the grouping key list is empty so that the
result aggregate
+ // operator is properly treated as a grouping aggregation.
+ val nonemptyKeys = if (keys.isEmpty) Literal(1) :: Nil else keys
+ Aggregate(nonemptyKeys, aggCols, child)
}
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
index e68423f85c92e..c4c9fc95e96d4 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.Alias
+import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.First
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical._
@@ -94,6 +94,14 @@ class ReplaceOperatorSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+ test("add one grouping key if necessary when replace Deduplicate with
Aggregate") {
+ val input = LocalRelation()
+ val query = Deduplicate(Seq.empty, input, streaming = false) //
dropDuplicates()
+ val optimized = Optimize.execute(query.analyze)
+ val correctAnswer = Aggregate(Seq(Literal(1)), input.output, input)
+ comparePlans(optimized, correctAnswer)
+ }
+
test("don't replace streaming Deduplicate") {
val input = LocalRelation('a.int, 'b.int)
val attrA = input.output(0)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 87aabf7220246..d8d19dd2160ad 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.test.SQLTestData.DecimalData
-import org.apache.spark.sql.types.{Decimal, DecimalType}
+import org.apache.spark.sql.types.DecimalType
case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp:
Double)
@@ -453,7 +453,6 @@ class DataFrameAggregateSuite extends QueryTest with
SharedSQLContext {
test("null moments") {
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
-
checkAnswer(
emptyTableData.agg(variance('a), var_samp('a), var_pop('a),
skewness('a), kurtosis('a)),
Row(null, null, null, null, null))
@@ -608,4 +607,23 @@ class DataFrameAggregateSuite extends QueryTest with
SharedSQLContext {
assert(exchangePlans.length == 1)
}
}
+
+ Seq(true, false).foreach { codegen =>
+ test("SPARK-22951: dropDuplicates on empty dataFrames should produce
correct aggregate " +
+ s"results when codegen is enabled: $codegen") {
+ withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, codegen.toString)) {
+ // explicit global aggregations
+ val emptyAgg = Map.empty[String, String]
+ checkAnswer(spark.emptyDataFrame.agg(emptyAgg), Seq(Row()))
+ checkAnswer(spark.emptyDataFrame.groupBy().agg(emptyAgg), Seq(Row()))
+ checkAnswer(spark.emptyDataFrame.groupBy().agg(count("*")),
Seq(Row(0)))
+ checkAnswer(spark.emptyDataFrame.dropDuplicates().agg(emptyAgg),
Seq(Row()))
+
checkAnswer(spark.emptyDataFrame.dropDuplicates().groupBy().agg(emptyAgg),
Seq(Row()))
+
checkAnswer(spark.emptyDataFrame.dropDuplicates().groupBy().agg(count("*")),
Seq(Row(0)))
+
+ // global aggregation is converted to grouping aggregation:
+ assert(spark.emptyDataFrame.dropDuplicates().count() == 0)
+ }
+ }
+ }
}
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]