This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 1c408c31941b [SPARK-52023][SQL][3.5] Fix data corruption/segfault 
returning Option[Product] from udaf
1c408c31941b is described below

commit 1c408c31941baf005be6f5bc294128b2ac177815
Author: Emil Ejbyfeldt <emil.ejbyfe...@choreograph.com>
AuthorDate: Wed Jul 2 06:51:40 2025 -0700

    [SPARK-52023][SQL][3.5] Fix data corruption/segfault returning 
Option[Product] from udaf
    
    ### What changes were proposed in this pull request?
    
    This fixes so defining a udaf returning a `Option[Product]` produces 
correct results instead of the current behavior. Where it throws an exception, 
segfaults or produces incorrect results.
    
    ### Why are the changes needed?
    
    Fix correctness issue.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Fixes a correctness issue.
    
    ### How was this patch tested?
    
    Existing and new unittest.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #51347 from eejbyfeldt/3.5-SPARK-52023.
    
    Authored-by: Emil Ejbyfeldt <emil.ejbyfe...@choreograph.com>
    Signed-off-by: Dongjoon Hyun <dongj...@apache.org>
---
 .../spark/sql/execution/aggregate/udaf.scala       |  2 +-
 .../spark/sql/hive/execution/UDAQuerySuite.scala   | 28 ++++++++++++++++++++++
 2 files changed, 29 insertions(+), 1 deletion(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
index e517376bc5fc..fe6307b5bbe8 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -530,7 +530,7 @@ case class ScalaAggregator[IN, BUF, OUT](
 
   def eval(buffer: BUF): Any = {
     val row = outputSerializer(agg.finish(buffer))
-    if (outputEncoder.isSerializedAsStruct) row else row.get(0, dataType)
+    if (outputEncoder.isSerializedAsStructForTopLevel) row else row.get(0, 
dataType)
   }
 
   private[this] lazy val bufferRow = new 
UnsafeRow(bufferEncoder.namedExpressions.length)
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala
index 0bd6b1403d39..31d0452c7061 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala
@@ -60,6 +60,22 @@ object LongProductSumAgg extends Aggregator[(jlLong, 
jlLong), Long, jlLong] {
   def outputEncoder: Encoder[jlLong] = Encoders.LONG
 }
 
+final case class Reduce[T: Encoder](r: (T, T) => T)(implicit i: 
Encoder[Option[T]])
+  extends Aggregator[T, Option[T], T] {
+  def zero: Option[T] = None
+  def reduce(b: Option[T], a: T): Option[T] = Some(b.fold(a)(r(_, a)))
+  def merge(b1: Option[T], b2: Option[T]): Option[T] =
+    (b1, b2) match {
+      case (Some(a), Some(b)) => Some(r(a, b))
+      case (Some(a), None) => Some(a)
+      case (None, Some(b)) => Some(b)
+      case (None, None) => None
+    }
+  def finish(reduction: Option[T]): T = reduction.get
+  def bufferEncoder: Encoder[Option[T]] = implicitly
+  def outputEncoder: Encoder[T] = implicitly
+}
+
 @SQLUserDefinedType(udt = classOf[CountSerDeUDT])
 case class CountSerDeSQL(nSer: Int, nDeSer: Int, sum: Int)
 
@@ -180,6 +196,9 @@ abstract class UDAQuerySuite extends QueryTest with 
SQLTestUtils with TestHiveSi
     val data4 = Seq[Boolean](true, false, true).toDF("boolvalues")
     data4.write.saveAsTable("agg4")
 
+    val data5 = Seq[(Int, (Int, Int))]((1, (2, 3))).toDF("key", "value")
+    data5.write.saveAsTable("agg5")
+
     val emptyDF = spark.createDataFrame(
       sparkContext.emptyRDD[Row],
       StructType(StructField("key", StringType) :: StructField("value", 
IntegerType) :: Nil))
@@ -190,6 +209,9 @@ abstract class UDAQuerySuite extends QueryTest with 
SQLTestUtils with TestHiveSi
     spark.udf.register("mydoubleavg", udaf(MyDoubleAvgAgg))
     spark.udf.register("longProductSum", udaf(LongProductSumAgg))
     spark.udf.register("arraysum", udaf(ArrayDataAgg))
+    spark.udf.register("reduceOptionPair", udaf(Reduce[Option[(Int, Int)]](
+      (opt1, opt2) =>
+        opt1.zip(opt2).map { case ((a1, b1), (a2, b2)) => (a1 + a2, b1 + b2) 
}.headOption)))
   }
 
   override def afterAll(): Unit = {
@@ -371,6 +393,12 @@ abstract class UDAQuerySuite extends QueryTest with 
SQLTestUtils with TestHiveSi
       Row(Seq(12.0, 15.0, 18.0)) :: Nil)
   }
 
+  test("SPARK-52023: Returning Option[Product] from udaf") {
+    checkAnswer(
+      spark.sql("SELECT reduceOptionPair(value) FROM agg5 GROUP BY key"),
+      Row(Row(2, 3)) :: Nil)
+  }
+
   test("verify aggregator ser/de behavior") {
     val data = sparkContext.parallelize((1 to 100).toSeq, 3).toDF("value1")
     val agg = udaf(CountSerDeAgg)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to