Github user HyukjinKwon commented on a diff in the pull request:
https://github.com/apache/spark/pull/19501#discussion_r144739274
--- Diff:
sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala ---
@@ -636,4 +637,33 @@ class DataFrameAggregateSuite extends QueryTest with
SharedSQLContext {
spark.sql("SELECT 3 AS c, 4 AS d, SUM(b) FROM testData2 GROUP BY c,
d"),
Seq(Row(3, 4, 9)))
}
+
+ test("SPARK-22223: ObjectHashAggregate should not introduce unnecessary
shuffle") {
+ withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") {
+ val df = Seq(("1", "2", 1), ("1", "2", 2), ("2", "3", 3), ("2", "3",
4)).toDF("a", "b", "c")
+ .repartition(col("a"))
+
+ val objHashAggDF = df
+ .withColumn("d", expr("(a, b, c)"))
+ .groupBy("a", "b").agg(collect_list("d").as("e"))
+ .withColumn("f", expr("(b, e)"))
+ .groupBy("a").agg(collect_list("f").as("g"))
+ val aggPlan = objHashAggDF.queryExecution.executedPlan
+
+ val sortAggPlans = aggPlan.collect {
+ case sortAgg: SortAggregateExec => sortAgg
+ }
+ assert(sortAggPlans.isEmpty)
+
+ val objHashAggPlans = aggPlan.collect {
+ case objHashAgg: ObjectHashAggregateExec => objHashAgg
+ }
+ assert(objHashAggPlans.length > 0)
--- End diff --
Not a big deal at all but maybe `nonEmpty`?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]