This is an automated email from the ASF dual-hosted git repository.

lihao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/auron.git


The following commit(s) were added to refs/heads/master by this push:
     new 07109601 fix UDAF fallback bug when handling DeclarativeAggregator 
(#1369)
07109601 is described below

commit 07109601f765c0ab2a4096fd22f0843fa6fbe296
Author: Zhang Li <[email protected]>
AuthorDate: Mon Sep 29 15:14:04 2025 +0800

    fix UDAF fallback bug when handling DeclarativeAggregator (#1369)
    
    enable UDAF fallback by default
---
 .../spark/sql/auron/AuronFunctionSuite.scala       | 11 +++++
 .../java/org/apache/spark/sql/auron/AuronConf.java |  2 +-
 .../spark/sql/auron/SparkUDAFWrapperContext.scala  | 50 +++++++++++-----------
 3 files changed, 37 insertions(+), 26 deletions(-)

diff --git 
a/spark-extension-shims-spark3/src/test/scala/org/apache/spark/sql/auron/AuronFunctionSuite.scala
 
b/spark-extension-shims-spark3/src/test/scala/org/apache/spark/sql/auron/AuronFunctionSuite.scala
index de140d55..a0f300a6 100644
--- 
a/spark-extension-shims-spark3/src/test/scala/org/apache/spark/sql/auron/AuronFunctionSuite.scala
+++ 
b/spark-extension-shims-spark3/src/test/scala/org/apache/spark/sql/auron/AuronFunctionSuite.scala
@@ -100,4 +100,15 @@ class AuronFunctionSuite
       checkAnswer(df, Seq(Row("11", "537061726B2053514C")))
     }
   }
+
+  test("stddev_samp function with UDAF fallback") {
+    withSQLConf("spark.auron.udafFallback.enable" -> "true") {
+      withTable("t1") {
+        sql("create table t1(c1 double) using parquet")
+        sql("insert into t1 values(10.0), (20.0), (30.0), (31.0), (null)")
+        val df = sql("select stddev_samp(c1) from t1")
+        checkAnswer(df, Seq(Row(9.844626283748239)))
+      }
+    }
+  }
 }
diff --git 
a/spark-extension/src/main/java/org/apache/spark/sql/auron/AuronConf.java 
b/spark-extension/src/main/java/org/apache/spark/sql/auron/AuronConf.java
index d45ccefd..1dd4d3f6 100644
--- a/spark-extension/src/main/java/org/apache/spark/sql/auron/AuronConf.java
+++ b/spark-extension/src/main/java/org/apache/spark/sql/auron/AuronConf.java
@@ -40,7 +40,7 @@ public enum AuronConf {
     INPUT_BATCH_STATISTICS_ENABLE("spark.auron.enableInputBatchStatistics", 
true),
 
     /// supports UDAF and other aggregate functions not implemented
-    UDAF_FALLBACK_ENABLE("spark.auron.udafFallback.enable", false),
+    UDAF_FALLBACK_ENABLE("spark.auron.udafFallback.enable", true),
 
     // TypedImperativeAggregate one row mem use size
     SUGGESTED_UDAF_ROW_MEM_USAGE("spark.auron.suggested.udaf.memUsedSize", 64),
diff --git 
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/SparkUDAFWrapperContext.scala
 
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/SparkUDAFWrapperContext.scala
index 8537cb32..a2be0ecc 100644
--- 
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/SparkUDAFWrapperContext.scala
+++ 
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/SparkUDAFWrapperContext.scala
@@ -41,12 +41,7 @@ import org.apache.spark.memory.MemoryMode
 import org.apache.spark.sql.auron.memory.OnHeapSpillManager
 import org.apache.spark.sql.auron.util.Using
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.catalyst.expressions.AttributeReference
-import org.apache.spark.sql.catalyst.expressions.JoinedRow
-import org.apache.spark.sql.catalyst.expressions.Nondeterministic
-import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
-import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
BoundReference, JoinedRow, Nondeterministic, UnsafeProjection, UnsafeRow}
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
 import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
 import 
org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
@@ -68,10 +63,6 @@ case class SparkUDAFWrapperContext[B](serialized: 
ByteBuffer) extends Logging {
       bytes
     })
 
-  val inputAttributes: Seq[Attribute] = javaParamsSchema.fields.map { field =>
-    AttributeReference(field.name, field.dataType, field.nullable)()
-  }
-
   private val outputSchema = {
     val schema = StructType(Seq(StructField("", expr.dataType, expr.nullable)))
     ArrowUtils.toArrowSchema(schema)
@@ -93,7 +84,7 @@ case class SparkUDAFWrapperContext[B](serialized: ByteBuffer) 
extends Logging {
     override def initialValue: AggregateEvaluator[B, BufferRowsColumn[B]] = {
       val evaluator = expr match {
         case declarative: DeclarativeAggregate =>
-          new DeclarativeEvaluator(declarative, inputAttributes)
+          new DeclarativeEvaluator(declarative)
         case imperative: TypedImperativeAggregate[B] =>
           new TypedImperativeEvaluator(imperative)
       }
@@ -243,7 +234,7 @@ trait AggregateEvaluator[B, R <: BufferRowsColumn[B]] 
extends Logging {
   }
 }
 
-class DeclarativeEvaluator(val agg: DeclarativeAggregate, inputAttributes: 
Seq[Attribute])
+class DeclarativeEvaluator(val agg: DeclarativeAggregate)
     extends AggregateEvaluator[UnsafeRow, DeclarativeAggRowsColumn] {
 
   val initializedRow: UnsafeRow = {
@@ -252,8 +243,20 @@ class DeclarativeEvaluator(val agg: DeclarativeAggregate, 
inputAttributes: Seq[A
   }
   val releasedRow: UnsafeRow = null
 
-  val updater: UnsafeProjection =
-    UnsafeProjection.create(agg.updateExpressions, agg.aggBufferAttributes ++ 
inputAttributes)
+  val updater: UnsafeProjection = {
+    val updateExpressions = agg.updateExpressions.map { expr =>
+      expr.transform {
+        case BoundReference(odin, dt, nullable) =>
+          BoundReference(odin + agg.aggBufferAttributes.length, dt, nullable)
+        case expr => expr
+      }
+    }
+    UnsafeProjection.create(
+      updateExpressions,
+      agg.aggBufferAttributes ++ agg.children.map(c => {
+        AttributeReference("", c.dataType, c.nullable)()
+      }))
+  }
 
   val merger: UnsafeProjection = UnsafeProjection.create(
     agg.mergeExpressions,
@@ -322,7 +325,7 @@ case class DeclarativeAggRowsColumn(
 
   override def resize(len: Int): Unit = {
     rows.appendAll((rows.length until len).map(_ => {
-      val newRow = evaluator.initializedRow.copy()
+      val newRow = evaluator.initializedRow
       rowsMemUsed += newRow.getSizeInBytes
       newRow
     }))
@@ -335,28 +338,25 @@ case class DeclarativeAggRowsColumn(
 
   override def updateRow(i: Int, inputRow: InternalRow): Unit = {
     if (i == rows.length) {
-      val newRow = 
evaluator.updater(evaluator.joiner(evaluator.initializedRow.copy(), inputRow))
-      rowsMemUsed += newRow.getSizeInBytes
-      rows.append(newRow)
+      rows.append(evaluator.updater(evaluator.joiner(evaluator.initializedRow, 
inputRow)).copy())
     } else {
       rowsMemUsed -= rows(i).getSizeInBytes
-      rows(i) = evaluator.updater(evaluator.joiner(rows(i), inputRow))
-      rowsMemUsed += rows(i).getSizeInBytes
+      rows(i) = evaluator.updater(evaluator.joiner(rows(i), inputRow)).copy()
     }
+    rowsMemUsed += rows(i).getSizeInBytes
   }
 
   override def mergeRow(i: Int, mergeRows: BufferRowsColumn[UnsafeRow], 
mergeIdx: Int): Unit = {
     mergeRows match {
       case mergeRows: DeclarativeAggRowsColumn =>
+        val mergeRow = mergeRows.rows(mergeIdx)
         if (i == rows.length) {
-          val newRow = mergeRows.rows(mergeIdx)
-          rowsMemUsed += newRow.getSizeInBytes
-          rows.append(newRow)
+          rows.append(mergeRow)
         } else {
           rowsMemUsed -= rows(i).getSizeInBytes
-          rows(i) = evaluator.merger(evaluator.joiner(rows(i), 
mergeRows.rows(mergeIdx)))
-          rowsMemUsed += rows(i).getSizeInBytes
+          rows(i) = evaluator.merger(evaluator.joiner(rows(i), 
mergeRow)).copy()
         }
+        rowsMemUsed += rows(i).getSizeInBytes
         mergeRows.rows(mergeIdx) = evaluator.releasedRow
     }
   }

Reply via email to