This is an automated email from the ASF dual-hosted git repository.

richox pushed a commit to branch dev-fix-declar-agg
in repository https://gitbox.apache.org/repos/asf/auron.git

commit dc995f1dbb0d4b0ea0b6f332361009771bb8f725
Author: zhangli20 <[email protected]>
AuthorDate: Sat Sep 27 23:35:05 2025 +0800

    fix UDAF fallback bug when handling DeclarativeAggregator
---
 .../spark/sql/auron/SparkUDAFWrapperContext.scala  | 50 +++++++++++-----------
 1 file changed, 25 insertions(+), 25 deletions(-)

diff --git 
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/SparkUDAFWrapperContext.scala
 
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/SparkUDAFWrapperContext.scala
index 8537cb32..9ea52b08 100644
--- 
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/SparkUDAFWrapperContext.scala
+++ 
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/SparkUDAFWrapperContext.scala
@@ -41,12 +41,7 @@ import org.apache.spark.memory.MemoryMode
 import org.apache.spark.sql.auron.memory.OnHeapSpillManager
 import org.apache.spark.sql.auron.util.Using
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.catalyst.expressions.AttributeReference
-import org.apache.spark.sql.catalyst.expressions.JoinedRow
-import org.apache.spark.sql.catalyst.expressions.Nondeterministic
-import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
-import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, 
AttributeReference, BoundReference, JoinedRow, Nondeterministic, 
UnsafeProjection, UnsafeRow}
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
 import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
 import 
org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
@@ -68,10 +63,6 @@ case class SparkUDAFWrapperContext[B](serialized: 
ByteBuffer) extends Logging {
       bytes
     })
 
-  val inputAttributes: Seq[Attribute] = javaParamsSchema.fields.map { field =>
-    AttributeReference(field.name, field.dataType, field.nullable)()
-  }
-
   private val outputSchema = {
     val schema = StructType(Seq(StructField("", expr.dataType, expr.nullable)))
     ArrowUtils.toArrowSchema(schema)
@@ -93,7 +84,7 @@ case class SparkUDAFWrapperContext[B](serialized: ByteBuffer) 
extends Logging {
     override def initialValue: AggregateEvaluator[B, BufferRowsColumn[B]] = {
       val evaluator = expr match {
         case declarative: DeclarativeAggregate =>
-          new DeclarativeEvaluator(declarative, inputAttributes)
+          new DeclarativeEvaluator(declarative)
         case imperative: TypedImperativeAggregate[B] =>
           new TypedImperativeEvaluator(imperative)
       }
@@ -243,7 +234,7 @@ trait AggregateEvaluator[B, R <: BufferRowsColumn[B]] 
extends Logging {
   }
 }
 
-class DeclarativeEvaluator(val agg: DeclarativeAggregate, inputAttributes: 
Seq[Attribute])
+class DeclarativeEvaluator(val agg: DeclarativeAggregate)
     extends AggregateEvaluator[UnsafeRow, DeclarativeAggRowsColumn] {
 
   val initializedRow: UnsafeRow = {
@@ -252,8 +243,20 @@ class DeclarativeEvaluator(val agg: DeclarativeAggregate, 
inputAttributes: Seq[A
   }
   val releasedRow: UnsafeRow = null
 
-  val updater: UnsafeProjection =
-    UnsafeProjection.create(agg.updateExpressions, agg.aggBufferAttributes ++ 
inputAttributes)
+  val updater: UnsafeProjection = {
+    val updateExpressions = agg.updateExpressions.map { expr =>
+      expr.transform {
+        case BoundReference(odin, dt, nullable) =>
+          BoundReference(odin + agg.aggBufferAttributes.length, dt, nullable)
+        case expr => expr
+      }
+    }
+    UnsafeProjection.create(
+      updateExpressions,
+      agg.aggBufferAttributes ++ agg.children.map(c => {
+        AttributeReference("", c.dataType, c.nullable)()
+      }))
+  }
 
   val merger: UnsafeProjection = UnsafeProjection.create(
     agg.mergeExpressions,
@@ -322,7 +325,7 @@ case class DeclarativeAggRowsColumn(
 
   override def resize(len: Int): Unit = {
     rows.appendAll((rows.length until len).map(_ => {
-      val newRow = evaluator.initializedRow.copy()
+      val newRow = evaluator.initializedRow
       rowsMemUsed += newRow.getSizeInBytes
       newRow
     }))
@@ -335,28 +338,25 @@ case class DeclarativeAggRowsColumn(
 
   override def updateRow(i: Int, inputRow: InternalRow): Unit = {
     if (i == rows.length) {
-      val newRow = 
evaluator.updater(evaluator.joiner(evaluator.initializedRow.copy(), inputRow))
-      rowsMemUsed += newRow.getSizeInBytes
-      rows.append(newRow)
+      rows.append(evaluator.updater(evaluator.joiner(evaluator.initializedRow, 
inputRow)).copy())
     } else {
       rowsMemUsed -= rows(i).getSizeInBytes
-      rows(i) = evaluator.updater(evaluator.joiner(rows(i), inputRow))
-      rowsMemUsed += rows(i).getSizeInBytes
+      rows(i) = evaluator.updater(evaluator.joiner(rows(i), inputRow)).copy()
     }
+    rowsMemUsed += rows(i).getSizeInBytes
   }
 
   override def mergeRow(i: Int, mergeRows: BufferRowsColumn[UnsafeRow], 
mergeIdx: Int): Unit = {
     mergeRows match {
       case mergeRows: DeclarativeAggRowsColumn =>
+        val mergeRow = mergeRows.rows(mergeIdx)
         if (i == rows.length) {
-          val newRow = mergeRows.rows(mergeIdx)
-          rowsMemUsed += newRow.getSizeInBytes
-          rows.append(newRow)
+          rows.append(mergeRow)
         } else {
           rowsMemUsed -= rows(i).getSizeInBytes
-          rows(i) = evaluator.merger(evaluator.joiner(rows(i), 
mergeRows.rows(mergeIdx)))
-          rowsMemUsed += rows(i).getSizeInBytes
+          rows(i) = evaluator.merger(evaluator.joiner(rows(i), 
mergeRow)).copy()
         }
+        rowsMemUsed += rows(i).getSizeInBytes
         mergeRows.rows(mergeIdx) = evaluator.releasedRow
     }
   }

Reply via email to