This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new 8e9ff72be0ee [SPARK-52614][SQL][4.0] Support RowEncoder inside Product 
Encoder
8e9ff72be0ee is described below

commit 8e9ff72be0eeeab5cfff0081f57cb79a0c9bee04
Author: Emil Ejbyfeldt <emil.ejbyfe...@choreograph.com>
AuthorDate: Mon Oct 6 14:10:34 2025 -0400

    [SPARK-52614][SQL][4.0] Support RowEncoder inside Product Encoder
    
    This is backport of SPARK-52614 #51319 to branch-4.0
    
    ### What changes were proposed in this pull request?
    This fixes support for using a RowEncoder inside a ProductEncoder.
    
    ### Why are the changes needed?
    The current does a dataType check on a path when contructing the RowEncoder 
deserializer. But this is not safe and if the RowEncoder is used inside a 
ProductEncoder, it will throw because the path Expression is unresolved.
    
    The check was introduced in https://github.com/apache/spark/pull/49785
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, it makes it possible to use RowEncoder in more cases.
    
    ### How was this patch tested?
    Existing and new unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #52503 from eejbyfeldt/SPARK-52614-4.0.
    
    Authored-by: Emil Ejbyfeldt <emil.ejbyfe...@choreograph.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 .../sql/catalyst/DeserializerBuildHelper.scala     | 28 +++++++++-------------
 .../catalyst/encoders/ExpressionEncoderSuite.scala | 16 +++++++++++++
 2 files changed, 27 insertions(+), 17 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
index 5d1bbb024074..9dcaba8c2bc4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
@@ -241,19 +241,12 @@ object DeserializerBuildHelper {
     val walkedTypePath = 
WalkedTypePath().recordRoot(enc.clsTag.runtimeClass.getName)
     // Assumes we are deserializing the first column of a row.
     val input = GetColumnByOrdinal(0, enc.dataType)
-    enc match {
-      case AgnosticEncoders.RowEncoder(fields) =>
-        val children = fields.zipWithIndex.map { case (f, i) =>
-          createDeserializer(f.enc, GetStructField(input, i), walkedTypePath)
-        }
-        CreateExternalRow(children, enc.schema)
-      case _ =>
-        val deserializer = createDeserializer(
-          enc,
-          upCastToExpectedType(input, enc.dataType, walkedTypePath),
-          walkedTypePath)
-        expressionWithNullSafety(deserializer, enc.nullable, walkedTypePath)
-    }
+    val deserializer = createDeserializer(
+      enc,
+      upCastToExpectedType(input, enc.dataType, walkedTypePath),
+      walkedTypePath,
+      isTopLevel = true)
+    expressionWithNullSafety(deserializer, enc.nullable, walkedTypePath)
   }
 
   /**
@@ -265,11 +258,13 @@ object DeserializerBuildHelper {
    *            external representation.
    * @param path The expression which can be used to extract serialized value.
    * @param walkedTypePath The paths from top to bottom to access current 
field when deserializing.
+   * @param isTopLevel true if we are creating a deserializer for the top 
level value.
    */
   private def createDeserializer(
       enc: AgnosticEncoder[_],
       path: Expression,
-      walkedTypePath: WalkedTypePath): Expression = enc match {
+      walkedTypePath: WalkedTypePath,
+      isTopLevel: Boolean = false): Expression = enc match {
     case ae: AgnosticExpressionPathEncoder[_] =>
       ae.fromCatalyst(path)
     case _ if isNativeEncoder(enc) =>
@@ -408,13 +403,12 @@ object DeserializerBuildHelper {
         NewInstance(cls, arguments, Nil, propagateNull = false, dt, 
outerPointerGetter))
 
     case AgnosticEncoders.RowEncoder(fields) =>
-      val isExternalRow = !path.dataType.isInstanceOf[StructType]
       val convertedFields = fields.zipWithIndex.map { case (f, i) =>
         val newTypePath = walkedTypePath.recordField(
           f.enc.clsTag.runtimeClass.getName,
           f.name)
         val deserializer = createDeserializer(f.enc, GetStructField(path, i), 
newTypePath)
-        if (isExternalRow) {
+        if (!isTopLevel) {
           exprs.If(
             Invoke(path, "isNullAt", BooleanType, exprs.Literal(i) :: Nil),
             exprs.Literal.create(null, externalDataTypeFor(f.enc)),
@@ -460,7 +454,7 @@ object DeserializerBuildHelper {
         Literal.create(provider(), ObjectType(classOf[Codec[_, _]])),
         "decode",
         dataTypeForClass(tag.runtimeClass),
-        createDeserializer(encoder, path, walkedTypePath) :: Nil)
+        createDeserializer(encoder, path, walkedTypePath, isTopLevel) :: Nil)
   }
 
   private def deserializeArray(
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 1b5f1b109c45..3d738fe985dd 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -659,6 +659,22 @@ class ExpressionEncoderSuite extends 
CodegenInterpretedPlanTest with AnalysisTes
     assert(fromRow(toRow(new Wrapper(Row(9L, "x")))) == new Wrapper(Row(9L, 
"x")))
   }
 
+  test("SPARK-52614: transforming encoder row encoder in product encoder") {
+    val schema = new StructType().add("a", LongType).add("b", StringType)
+    val wrapperEncoder = TransformingEncoder(
+      classTag[Wrapper[Row]],
+      RowEncoder.encoderFor(schema),
+      new WrapperCodecProvider[Row])
+    val encoder = ExpressionEncoder(ProductEncoder(
+      classTag[V[Wrapper[Row]]],
+      Seq(EncoderField("v", wrapperEncoder, nullable = false, Metadata.empty)),
+      None))
+      .resolveAndBind()
+    val toRow = encoder.createSerializer()
+    val fromRow = encoder.createDeserializer()
+    assert(fromRow(toRow(V(new Wrapper(Row(9L, "x"))))) == V(new 
Wrapper(Row(9L, "x"))))
+  }
+
   // below tests are related to SPARK-49960 and TransformingEncoder usage
   test("""Encoder with OptionEncoder of transformation""".stripMargin) {
     type T = Option[V[V[Int]]]


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to