This is an automated email from the ASF dual-hosted git repository. richox pushed a commit to branch dev-fix-declar-agg in repository https://gitbox.apache.org/repos/asf/auron.git
commit dc995f1dbb0d4b0ea0b6f332361009771bb8f725 Author: zhangli20 <[email protected]> AuthorDate: Sat Sep 27 23:35:05 2025 +0800 fix UDAF fallback bug when handling DeclarativeAggregator --- .../spark/sql/auron/SparkUDAFWrapperContext.scala | 50 +++++++++++----------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/auron/SparkUDAFWrapperContext.scala b/spark-extension/src/main/scala/org/apache/spark/sql/auron/SparkUDAFWrapperContext.scala index 8537cb32..9ea52b08 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/auron/SparkUDAFWrapperContext.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/auron/SparkUDAFWrapperContext.scala @@ -41,12 +41,7 @@ import org.apache.spark.memory.MemoryMode import org.apache.spark.sql.auron.memory.OnHeapSpillManager import org.apache.spark.sql.auron.util.Using import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.catalyst.expressions.JoinedRow -import org.apache.spark.sql.catalyst.expressions.Nondeterministic -import org.apache.spark.sql.catalyst.expressions.UnsafeProjection -import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, JoinedRow, Nondeterministic, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate @@ -68,10 +63,6 @@ case class SparkUDAFWrapperContext[B](serialized: ByteBuffer) extends Logging { bytes }) - val inputAttributes: Seq[Attribute] = javaParamsSchema.fields.map { field => - AttributeReference(field.name, field.dataType, field.nullable)() - } - private val outputSchema = { val schema = StructType(Seq(StructField("", expr.dataType, expr.nullable))) ArrowUtils.toArrowSchema(schema) @@ -93,7 +84,7 @@ case class SparkUDAFWrapperContext[B](serialized: ByteBuffer) extends Logging { override def initialValue: AggregateEvaluator[B, BufferRowsColumn[B]] = { val evaluator = expr match { case declarative: DeclarativeAggregate => - new DeclarativeEvaluator(declarative, inputAttributes) + new DeclarativeEvaluator(declarative) case imperative: TypedImperativeAggregate[B] => new TypedImperativeEvaluator(imperative) } @@ -243,7 +234,7 @@ trait AggregateEvaluator[B, R <: BufferRowsColumn[B]] extends Logging { } } -class DeclarativeEvaluator(val agg: DeclarativeAggregate, inputAttributes: Seq[Attribute]) +class DeclarativeEvaluator(val agg: DeclarativeAggregate) extends AggregateEvaluator[UnsafeRow, DeclarativeAggRowsColumn] { val initializedRow: UnsafeRow = { @@ -252,8 +243,20 @@ class DeclarativeEvaluator(val agg: DeclarativeAggregate, inputAttributes: Seq[A } val releasedRow: UnsafeRow = null - val updater: UnsafeProjection = - UnsafeProjection.create(agg.updateExpressions, agg.aggBufferAttributes ++ inputAttributes) + val updater: UnsafeProjection = { + val updateExpressions = agg.updateExpressions.map { expr => + expr.transform { + case BoundReference(odin, dt, nullable) => + BoundReference(odin + agg.aggBufferAttributes.length, dt, nullable) + case expr => expr + } + } + UnsafeProjection.create( + updateExpressions, + agg.aggBufferAttributes ++ agg.children.map(c => { + AttributeReference("", c.dataType, c.nullable)() + })) + } val merger: UnsafeProjection = UnsafeProjection.create( agg.mergeExpressions, @@ -322,7 +325,7 @@ case class DeclarativeAggRowsColumn( override def resize(len: Int): Unit = { rows.appendAll((rows.length until len).map(_ => { - val newRow = evaluator.initializedRow.copy() + val newRow = evaluator.initializedRow rowsMemUsed += newRow.getSizeInBytes newRow })) @@ -335,28 +338,25 @@ case class DeclarativeAggRowsColumn( override def updateRow(i: Int, inputRow: InternalRow): Unit = { if (i == rows.length) { - val newRow = evaluator.updater(evaluator.joiner(evaluator.initializedRow.copy(), inputRow)) - rowsMemUsed += newRow.getSizeInBytes - rows.append(newRow) + rows.append(evaluator.updater(evaluator.joiner(evaluator.initializedRow, inputRow)).copy()) } else { rowsMemUsed -= rows(i).getSizeInBytes - rows(i) = evaluator.updater(evaluator.joiner(rows(i), inputRow)) - rowsMemUsed += rows(i).getSizeInBytes + rows(i) = evaluator.updater(evaluator.joiner(rows(i), inputRow)).copy() } + rowsMemUsed += rows(i).getSizeInBytes } override def mergeRow(i: Int, mergeRows: BufferRowsColumn[UnsafeRow], mergeIdx: Int): Unit = { mergeRows match { case mergeRows: DeclarativeAggRowsColumn => + val mergeRow = mergeRows.rows(mergeIdx) if (i == rows.length) { - val newRow = mergeRows.rows(mergeIdx) - rowsMemUsed += newRow.getSizeInBytes - rows.append(newRow) + rows.append(mergeRow) } else { rowsMemUsed -= rows(i).getSizeInBytes - rows(i) = evaluator.merger(evaluator.joiner(rows(i), mergeRows.rows(mergeIdx))) - rowsMemUsed += rows(i).getSizeInBytes + rows(i) = evaluator.merger(evaluator.joiner(rows(i), mergeRow)).copy() } + rowsMemUsed += rows(i).getSizeInBytes mergeRows.rows(mergeIdx) = evaluator.releasedRow } }
