This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new ebacb9163f26 [SPARK-48718][SQL] Handle and fix the case when
deserializer in cogroup is resolved during application of DeduplicateRelation
rule
ebacb9163f26 is described below
commit ebacb9163f268de83ad721a509d6298a6690f338
Author: Xinyi Yu <[email protected]>
AuthorDate: Wed Jun 26 09:43:27 2024 +0800
[SPARK-48718][SQL] Handle and fix the case when deserializer in cogroup is
resolved during application of DeduplicateRelation rule
### What changes were proposed in this pull request?
A followup for https://github.com/apache/spark/pull/41554/files.
Handle the case when the deserializer in cogroup is resolved when applying
DeduplicateRelation rule. Otherwise, it will throw an uncastable error.
See the added test case as an example.
### Why are the changes needed?
Fix a bug introduced in a previous commit.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Add a new test case.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #47091 from anchovYu/fix-cogroup-dedup-rel.
Lead-authored-by: Xinyi Yu <[email protected]>
Co-authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../sql/catalyst/analysis/DeduplicateRelations.scala | 18 ++++++++++++------
.../scala/org/apache/spark/sql/DatasetSuite.scala | 20 ++++++++++++++++++++
2 files changed, 32 insertions(+), 6 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
index 3e4344f98bce..0fa11b9c4503 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
@@ -255,12 +255,18 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
val newRightGroup = rewriteAttrs(c.rightGroup, rightAttrMap)
val newLeftOrder = rewriteAttrs(c.leftOrder, leftAttrMap)
val newRightOrder = rewriteAttrs(c.rightOrder, rightAttrMap)
- val newKeyDes =
c.keyDeserializer.asInstanceOf[UnresolvedDeserializer]
- .copy(inputAttributes = newLeftGroup)
- val newLeftDes =
c.leftDeserializer.asInstanceOf[UnresolvedDeserializer]
- .copy(inputAttributes = newLeftAttr)
- val newRightDes =
c.rightDeserializer.asInstanceOf[UnresolvedDeserializer]
- .copy(inputAttributes = newRightAttr)
+ val newKeyDes = c.keyDeserializer match {
+ case u: UnresolvedDeserializer => u.copy(inputAttributes =
newLeftGroup)
+ case e: Expression =>
e.withNewChildren(rewriteAttrs(e.children, leftAttrMap))
+ }
+ val newLeftDes = c.leftDeserializer match {
+ case u: UnresolvedDeserializer => u.copy(inputAttributes =
newLeftAttr)
+ case e: Expression =>
e.withNewChildren(rewriteAttrs(e.children, leftAttrMap))
+ }
+ val newRightDes = c.rightDeserializer match {
+ case u: UnresolvedDeserializer => u.copy(inputAttributes =
newRightAttr)
+ case e: Expression =>
e.withNewChildren(rewriteAttrs(e.children, rightAttrMap))
+ }
c.copy(keyDeserializer = newKeyDes, leftDeserializer =
newLeftDes,
rightDeserializer = newRightDes, leftGroup = newLeftGroup,
rightGroup = newRightGroup, leftAttr = newLeftAttr,
rightAttr = newRightAttr,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index b939ed40c7db..fdb2ec30fdd2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -21,6 +21,7 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.sql.{Date, Timestamp}
import scala.collection.immutable.HashSet
+import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag
import scala.util.Random
@@ -952,6 +953,25 @@ class DatasetSuite extends QueryTest
assert(result2.length == 3)
}
+ test("SPARK-48718: cogroup deserializer expr is resolved before dedup
relation") {
+ val lhs = spark.createDataFrame(
+ List(Row(123L)).asJava,
+ StructType(Seq(StructField("GROUPING_KEY", LongType)))
+ )
+ val rhs = spark.createDataFrame(
+ List(Row(0L, 123L)).asJava,
+ StructType(Seq(StructField("ID", LongType), StructField("GROUPING_KEY",
LongType)))
+ )
+
+ val lhsKV = lhs.groupByKey((r: Row) => r.getAs[Long]("GROUPING_KEY"))
+ val rhsKV = rhs.groupByKey((r: Row) => r.getAs[Long]("GROUPING_KEY"))
+ val cogrouped = lhsKV.cogroup(rhsKV)(
+ (a: Long, b: Iterator[Row], c: Iterator[Row]) => Iterator(0L)
+ )
+ val joined = rhs.join(cogrouped, col("ID") === col("value"), "left")
+ checkAnswer(joined, Row(0L, 123L, 0L) :: Nil)
+ }
+
test("SPARK-34806: observation on datasets") {
val namedObservation = Observation("named")
val unnamedObservation = Observation()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]