This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ebacb9163f26 [SPARK-48718][SQL] Handle and fix the case when 
deserializer in cogroup is resolved during application of DeduplicateRelation 
rule
ebacb9163f26 is described below

commit ebacb9163f268de83ad721a509d6298a6690f338
Author: Xinyi Yu <[email protected]>
AuthorDate: Wed Jun 26 09:43:27 2024 +0800

    [SPARK-48718][SQL] Handle and fix the case when deserializer in cogroup is 
resolved during application of DeduplicateRelation rule
    
    ### What changes were proposed in this pull request?
    A followup for https://github.com/apache/spark/pull/41554/files.
    Handle the case when the deserializer in cogroup is resolved when applying 
DeduplicateRelation rule. Otherwise, it will throw an uncastable error.
    See the added test case as an example.
    
    ### Why are the changes needed?
    Fix a bug introduced in a previous commit.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Add a new test case.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #47091 from anchovYu/fix-cogroup-dedup-rel.
    
    Lead-authored-by: Xinyi Yu <[email protected]>
    Co-authored-by: Wenchen Fan <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../sql/catalyst/analysis/DeduplicateRelations.scala | 18 ++++++++++++------
 .../scala/org/apache/spark/sql/DatasetSuite.scala    | 20 ++++++++++++++++++++
 2 files changed, 32 insertions(+), 6 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
index 3e4344f98bce..0fa11b9c4503 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
@@ -255,12 +255,18 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
                 val newRightGroup = rewriteAttrs(c.rightGroup, rightAttrMap)
                 val newLeftOrder = rewriteAttrs(c.leftOrder, leftAttrMap)
                 val newRightOrder = rewriteAttrs(c.rightOrder, rightAttrMap)
-                val newKeyDes = 
c.keyDeserializer.asInstanceOf[UnresolvedDeserializer]
-                  .copy(inputAttributes = newLeftGroup)
-                val newLeftDes = 
c.leftDeserializer.asInstanceOf[UnresolvedDeserializer]
-                  .copy(inputAttributes = newLeftAttr)
-                val newRightDes = 
c.rightDeserializer.asInstanceOf[UnresolvedDeserializer]
-                  .copy(inputAttributes = newRightAttr)
+                val newKeyDes = c.keyDeserializer match {
+                  case u: UnresolvedDeserializer => u.copy(inputAttributes = 
newLeftGroup)
+                  case e: Expression => 
e.withNewChildren(rewriteAttrs(e.children, leftAttrMap))
+                }
+                val newLeftDes = c.leftDeserializer match {
+                  case u: UnresolvedDeserializer => u.copy(inputAttributes = 
newLeftAttr)
+                  case e: Expression => 
e.withNewChildren(rewriteAttrs(e.children, leftAttrMap))
+                }
+                val newRightDes = c.rightDeserializer match {
+                  case u: UnresolvedDeserializer => u.copy(inputAttributes = 
newRightAttr)
+                  case e: Expression => 
e.withNewChildren(rewriteAttrs(e.children, rightAttrMap))
+                }
                 c.copy(keyDeserializer = newKeyDes, leftDeserializer = 
newLeftDes,
                   rightDeserializer = newRightDes, leftGroup = newLeftGroup,
                   rightGroup = newRightGroup, leftAttr = newLeftAttr, 
rightAttr = newRightAttr,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index b939ed40c7db..fdb2ec30fdd2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -21,6 +21,7 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput}
 import java.sql.{Date, Timestamp}
 
 import scala.collection.immutable.HashSet
+import scala.jdk.CollectionConverters._
 import scala.reflect.ClassTag
 import scala.util.Random
 
@@ -952,6 +953,25 @@ class DatasetSuite extends QueryTest
     assert(result2.length == 3)
   }
 
+  test("SPARK-48718: cogroup deserializer expr is resolved before dedup 
relation") {
+    val lhs = spark.createDataFrame(
+      List(Row(123L)).asJava,
+      StructType(Seq(StructField("GROUPING_KEY", LongType)))
+    )
+    val rhs = spark.createDataFrame(
+      List(Row(0L, 123L)).asJava,
+      StructType(Seq(StructField("ID", LongType), StructField("GROUPING_KEY", 
LongType)))
+    )
+
+    val lhsKV = lhs.groupByKey((r: Row) => r.getAs[Long]("GROUPING_KEY"))
+    val rhsKV = rhs.groupByKey((r: Row) => r.getAs[Long]("GROUPING_KEY"))
+    val cogrouped = lhsKV.cogroup(rhsKV)(
+      (a: Long, b: Iterator[Row], c: Iterator[Row]) => Iterator(0L)
+    )
+    val joined = rhs.join(cogrouped, col("ID") === col("value"), "left")
+    checkAnswer(joined, Row(0L, 123L, 0L) :: Nil)
+  }
+
   test("SPARK-34806: observation on datasets") {
     val namedObservation = Observation("named")
     val unnamedObservation = Observation()


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to