Repository: spark
Updated Branches:
  refs/heads/branch-1.6 285792b6c -> 7e90893e9


[SPARK-11795][SQL] combine grouping attributes into a single NamedExpression

we use `ExpressionEncoder.tuple` to build the result encoder, which assumes the 
input encoder should point to a struct type field if it’s non-flat.
However, our keyEncoder always point to a flat field/fields: 
`groupingAttributes`, we should combine them into a single `NamedExpression`.

Author: Wenchen Fan <[email protected]>

Closes #9792 from cloud-fan/agg.

(cherry picked from commit dbf428c87ab34b6f76c75946043bdf5f60c9b1b3)
Signed-off-by: Michael Armbrust <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7e90893e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7e90893e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7e90893e

Branch: refs/heads/branch-1.6
Commit: 7e90893e9960bbc767500b2c5ecacaf87bfd176b
Parents: 285792b
Author: Wenchen Fan <[email protected]>
Authored: Wed Nov 18 10:33:17 2015 -0800
Committer: Michael Armbrust <[email protected]>
Committed: Wed Nov 18 10:33:27 2015 -0800

----------------------------------------------------------------------
 .../main/scala/org/apache/spark/sql/GroupedDataset.scala    | 9 +++++++--
 .../src/test/scala/org/apache/spark/sql/DatasetSuite.scala  | 5 ++---
 2 files changed, 9 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7e90893e/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index c66162e..3f84e22 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -22,7 +22,7 @@ import scala.collection.JavaConverters._
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.function._
 import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, 
encoderFor}
-import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, 
Attribute}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.execution.QueryExecution
 
@@ -187,7 +187,12 @@ class GroupedDataset[K, T] private[sql](
     val namedColumns =
       columns.map(
         _.withInputType(resolvedTEncoder, dataAttributes).named)
-    val aggregate = Aggregate(groupingAttributes, groupingAttributes ++ 
namedColumns, logicalPlan)
+    val keyColumn = if (groupingAttributes.length > 1) {
+      Alias(CreateStruct(groupingAttributes), "key")()
+    } else {
+      groupingAttributes.head
+    }
+    val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, 
logicalPlan)
     val execution = new QueryExecution(sqlContext, aggregate)
 
     new Dataset(

http://git-wip-us.apache.org/repos/asf/spark/blob/7e90893e/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 198962b..b6db583 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -84,8 +84,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
       ("a", 2), ("b", 3), ("c", 4))
   }
 
-  ignore("Dataset should set the resolved encoders internally for maps") {
-    // TODO: Enable this once we fix SPARK-11793.
+  test("map and group by with class data") {
     // We inject a group by here to make sure this test case is future proof
     // when we implement better pipelining and local execution mode.
     val ds: Dataset[(ClassData, Long)] = Seq(ClassData("one", 1), 
ClassData("two", 2)).toDS()
@@ -94,7 +93,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
 
     checkAnswer(
       ds,
-      (ClassData("one", 1), 1L), (ClassData("two", 2), 1L))
+      (ClassData("one", 2), 1L), (ClassData("two", 3), 1L))
   }
 
   test("select") {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to