Repository: spark
Updated Branches:
  refs/heads/branch-1.6 895128505 -> 3cb1b6d39


[SPARK-11942][SQL] fix encoder life cycle for CoGroup

we should pass in resolved encodera to logical `CoGroup` and bind them in 
physical `CoGroup`

Author: Wenchen Fan <wenc...@databricks.com>

Closes #9928 from cloud-fan/cogroup.

(cherry picked from commit e5aaae6e1145b8c25c4872b2992ab425da9c6f9b)
Signed-off-by: Michael Armbrust <mich...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3cb1b6d3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3cb1b6d3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3cb1b6d3

Branch: refs/heads/branch-1.6
Commit: 3cb1b6d398a8066a6c9d800a16b8cd21f3fde50f
Parents: 8951285
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Tue Nov 24 09:28:39 2015 -0800
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Tue Nov 24 09:28:51 2015 -0800

----------------------------------------------------------------------
 .../catalyst/plans/logical/basicOperators.scala | 27 +++++++++++---------
 .../org/apache/spark/sql/GroupedDataset.scala   |  4 ++-
 .../spark/sql/execution/basicOperators.scala    | 20 ++++++++-------
 .../org/apache/spark/sql/DatasetSuite.scala     | 12 +++++++++
 4 files changed, 41 insertions(+), 22 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3cb1b6d3/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 737e62f..5665fd7 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -553,19 +553,22 @@ case class MapGroups[K, T, U](
 
 /** Factory for constructing new `CoGroup` nodes. */
 object CoGroup {
-  def apply[K : Encoder, Left : Encoder, Right : Encoder, R : Encoder](
-      func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R],
+  def apply[Key, Left, Right, Result : Encoder](
+      func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result],
+      keyEnc: ExpressionEncoder[Key],
+      leftEnc: ExpressionEncoder[Left],
+      rightEnc: ExpressionEncoder[Right],
       leftGroup: Seq[Attribute],
       rightGroup: Seq[Attribute],
       left: LogicalPlan,
-      right: LogicalPlan): CoGroup[K, Left, Right, R] = {
+      right: LogicalPlan): CoGroup[Key, Left, Right, Result] = {
     CoGroup(
       func,
-      encoderFor[K],
-      encoderFor[Left],
-      encoderFor[Right],
-      encoderFor[R],
-      encoderFor[R].schema.toAttributes,
+      keyEnc,
+      leftEnc,
+      rightEnc,
+      encoderFor[Result],
+      encoderFor[Result].schema.toAttributes,
       leftGroup,
       rightGroup,
       left,
@@ -577,12 +580,12 @@ object CoGroup {
  * A relation produced by applying `func` to each grouping key and associated 
values from left and
  * right children.
  */
-case class CoGroup[K, Left, Right, R](
-    func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R],
-    kEncoder: ExpressionEncoder[K],
+case class CoGroup[Key, Left, Right, Result](
+    func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result],
+    keyEnc: ExpressionEncoder[Key],
     leftEnc: ExpressionEncoder[Left],
     rightEnc: ExpressionEncoder[Right],
-    rEncoder: ExpressionEncoder[R],
+    resultEnc: ExpressionEncoder[Result],
     output: Seq[Attribute],
     leftGroup: Seq[Attribute],
     rightGroup: Seq[Attribute],

http://git-wip-us.apache.org/repos/asf/spark/blob/3cb1b6d3/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index 793a86b..a10a893 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -304,11 +304,13 @@ class GroupedDataset[K, V] private[sql](
   def cogroup[U, R : Encoder](
       other: GroupedDataset[K, U])(
       f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
-    implicit def uEnc: Encoder[U] = other.unresolvedVEncoder
     new Dataset[R](
       sqlContext,
       CoGroup(
         f,
+        this.resolvedKEncoder,
+        this.resolvedVEncoder,
+        other.resolvedVEncoder,
         this.groupingAttributes,
         other.groupingAttributes,
         this.logicalPlan,

http://git-wip-us.apache.org/repos/asf/spark/blob/3cb1b6d3/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index d57b8e7..a42aea0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -375,12 +375,12 @@ case class MapGroups[K, T, U](
  * iterators containing all elements in the group from left and right side.
  * The result of this function is encoded and flattened before being output.
  */
-case class CoGroup[K, Left, Right, R](
-    func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R],
-    kEncoder: ExpressionEncoder[K],
+case class CoGroup[Key, Left, Right, Result](
+    func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result],
+    keyEnc: ExpressionEncoder[Key],
     leftEnc: ExpressionEncoder[Left],
     rightEnc: ExpressionEncoder[Right],
-    rEncoder: ExpressionEncoder[R],
+    resultEnc: ExpressionEncoder[Result],
     output: Seq[Attribute],
     leftGroup: Seq[Attribute],
     rightGroup: Seq[Attribute],
@@ -397,15 +397,17 @@ case class CoGroup[K, Left, Right, R](
     left.execute().zipPartitions(right.execute()) { (leftData, rightData) =>
       val leftGrouped = GroupedIterator(leftData, leftGroup, left.output)
       val rightGrouped = GroupedIterator(rightData, rightGroup, right.output)
-      val groupKeyEncoder = kEncoder.bind(leftGroup)
+      val boundKeyEnc = keyEnc.bind(leftGroup)
+      val boundLeftEnc = leftEnc.bind(left.output)
+      val boundRightEnc = rightEnc.bind(right.output)
 
       new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap {
         case (key, leftResult, rightResult) =>
           val result = func(
-            groupKeyEncoder.fromRow(key),
-            leftResult.map(leftEnc.fromRow),
-            rightResult.map(rightEnc.fromRow))
-          result.map(rEncoder.toRow)
+            boundKeyEnc.fromRow(key),
+            leftResult.map(boundLeftEnc.fromRow),
+            rightResult.map(boundRightEnc.fromRow))
+          result.map(resultEnc.toRow)
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/3cb1b6d3/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index dbdd7ba..13eede1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -340,6 +340,18 @@ class DatasetSuite extends QueryTest with SharedSQLContext 
{
       1 -> "a#", 2 -> "#q", 3 -> "abcfoo#w", 5 -> "hello#er")
   }
 
+  test("cogroup with complex data") {
+    val ds1 = Seq(1 -> ClassData("a", 1), 2 -> ClassData("b", 2)).toDS()
+    val ds2 = Seq(2 -> ClassData("c", 3), 3 -> ClassData("d", 4)).toDS()
+    val cogrouped = ds1.groupBy(_._1).cogroup(ds2.groupBy(_._1)) { case (key, 
data1, data2) =>
+      Iterator(key -> (data1.map(_._2.a).mkString + 
data2.map(_._2.a).mkString))
+    }
+
+    checkAnswer(
+      cogrouped,
+      1 -> "a", 2 -> "bc", 3 -> "d")
+  }
+
   test("SPARK-11436: we should rebind right encoder when join 2 datasets") {
     val ds1 = Seq("1", "2").toDS().as("a")
     val ds2 = Seq(2, 3).toDS().as("b")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to