This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new 1c408c31941b [SPARK-52023][SQL][3.5] Fix data corruption/segfault returning Option[Product] from udaf 1c408c31941b is described below commit 1c408c31941baf005be6f5bc294128b2ac177815 Author: Emil Ejbyfeldt <emil.ejbyfe...@choreograph.com> AuthorDate: Wed Jul 2 06:51:40 2025 -0700 [SPARK-52023][SQL][3.5] Fix data corruption/segfault returning Option[Product] from udaf ### What changes were proposed in this pull request? This fixes so defining a udaf returning a `Option[Product]` produces correct results instead of the current behavior. Where it throws an exception, segfaults or produces incorrect results. ### Why are the changes needed? Fix correctness issue. ### Does this PR introduce _any_ user-facing change? Fixes a correctness issue. ### How was this patch tested? Existing and new unittest. ### Was this patch authored or co-authored using generative AI tooling? No Closes #51347 from eejbyfeldt/3.5-SPARK-52023. Authored-by: Emil Ejbyfeldt <emil.ejbyfe...@choreograph.com> Signed-off-by: Dongjoon Hyun <dongj...@apache.org> --- .../spark/sql/execution/aggregate/udaf.scala | 2 +- .../spark/sql/hive/execution/UDAQuerySuite.scala | 28 ++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index e517376bc5fc..fe6307b5bbe8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -530,7 +530,7 @@ case class ScalaAggregator[IN, BUF, OUT]( def eval(buffer: BUF): Any = { val row = outputSerializer(agg.finish(buffer)) - if (outputEncoder.isSerializedAsStruct) row else row.get(0, dataType) + if (outputEncoder.isSerializedAsStructForTopLevel) row else row.get(0, dataType) } private[this] lazy val bufferRow = new UnsafeRow(bufferEncoder.namedExpressions.length) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala index 0bd6b1403d39..31d0452c7061 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala @@ -60,6 +60,22 @@ object LongProductSumAgg extends Aggregator[(jlLong, jlLong), Long, jlLong] { def outputEncoder: Encoder[jlLong] = Encoders.LONG } +final case class Reduce[T: Encoder](r: (T, T) => T)(implicit i: Encoder[Option[T]]) + extends Aggregator[T, Option[T], T] { + def zero: Option[T] = None + def reduce(b: Option[T], a: T): Option[T] = Some(b.fold(a)(r(_, a))) + def merge(b1: Option[T], b2: Option[T]): Option[T] = + (b1, b2) match { + case (Some(a), Some(b)) => Some(r(a, b)) + case (Some(a), None) => Some(a) + case (None, Some(b)) => Some(b) + case (None, None) => None + } + def finish(reduction: Option[T]): T = reduction.get + def bufferEncoder: Encoder[Option[T]] = implicitly + def outputEncoder: Encoder[T] = implicitly +} + @SQLUserDefinedType(udt = classOf[CountSerDeUDT]) case class CountSerDeSQL(nSer: Int, nDeSer: Int, sum: Int) @@ -180,6 +196,9 @@ abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSi val data4 = Seq[Boolean](true, false, true).toDF("boolvalues") data4.write.saveAsTable("agg4") + val data5 = Seq[(Int, (Int, Int))]((1, (2, 3))).toDF("key", "value") + data5.write.saveAsTable("agg5") + val emptyDF = spark.createDataFrame( sparkContext.emptyRDD[Row], StructType(StructField("key", StringType) :: StructField("value", IntegerType) :: Nil)) @@ -190,6 +209,9 @@ abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSi spark.udf.register("mydoubleavg", udaf(MyDoubleAvgAgg)) spark.udf.register("longProductSum", udaf(LongProductSumAgg)) spark.udf.register("arraysum", udaf(ArrayDataAgg)) + spark.udf.register("reduceOptionPair", udaf(Reduce[Option[(Int, Int)]]( + (opt1, opt2) => + opt1.zip(opt2).map { case ((a1, b1), (a2, b2)) => (a1 + a2, b1 + b2) }.headOption))) } override def afterAll(): Unit = { @@ -371,6 +393,12 @@ abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSi Row(Seq(12.0, 15.0, 18.0)) :: Nil) } + test("SPARK-52023: Returning Option[Product] from udaf") { + checkAnswer( + spark.sql("SELECT reduceOptionPair(value) FROM agg5 GROUP BY key"), + Row(Row(2, 3)) :: Nil) + } + test("verify aggregator ser/de behavior") { val data = sparkContext.parallelize((1 to 100).toSeq, 3).toDF("value1") val agg = udaf(CountSerDeAgg) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org