This is an automated email from the ASF dual-hosted git repository.
lihao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/auron.git
The following commit(s) were added to refs/heads/master by this push:
new 07109601 fix UDAF fallback bug when handling DeclarativeAggregator
(#1369)
07109601 is described below
commit 07109601f765c0ab2a4096fd22f0843fa6fbe296
Author: Zhang Li <[email protected]>
AuthorDate: Mon Sep 29 15:14:04 2025 +0800
fix UDAF fallback bug when handling DeclarativeAggregator (#1369)
enable UDAF fallback by default
---
.../spark/sql/auron/AuronFunctionSuite.scala | 11 +++++
.../java/org/apache/spark/sql/auron/AuronConf.java | 2 +-
.../spark/sql/auron/SparkUDAFWrapperContext.scala | 50 +++++++++++-----------
3 files changed, 37 insertions(+), 26 deletions(-)
diff --git
a/spark-extension-shims-spark3/src/test/scala/org/apache/spark/sql/auron/AuronFunctionSuite.scala
b/spark-extension-shims-spark3/src/test/scala/org/apache/spark/sql/auron/AuronFunctionSuite.scala
index de140d55..a0f300a6 100644
---
a/spark-extension-shims-spark3/src/test/scala/org/apache/spark/sql/auron/AuronFunctionSuite.scala
+++
b/spark-extension-shims-spark3/src/test/scala/org/apache/spark/sql/auron/AuronFunctionSuite.scala
@@ -100,4 +100,15 @@ class AuronFunctionSuite
checkAnswer(df, Seq(Row("11", "537061726B2053514C")))
}
}
+
+ test("stddev_samp function with UDAF fallback") {
+ withSQLConf("spark.auron.udafFallback.enable" -> "true") {
+ withTable("t1") {
+ sql("create table t1(c1 double) using parquet")
+ sql("insert into t1 values(10.0), (20.0), (30.0), (31.0), (null)")
+ val df = sql("select stddev_samp(c1) from t1")
+ checkAnswer(df, Seq(Row(9.844626283748239)))
+ }
+ }
+ }
}
diff --git
a/spark-extension/src/main/java/org/apache/spark/sql/auron/AuronConf.java
b/spark-extension/src/main/java/org/apache/spark/sql/auron/AuronConf.java
index d45ccefd..1dd4d3f6 100644
--- a/spark-extension/src/main/java/org/apache/spark/sql/auron/AuronConf.java
+++ b/spark-extension/src/main/java/org/apache/spark/sql/auron/AuronConf.java
@@ -40,7 +40,7 @@ public enum AuronConf {
INPUT_BATCH_STATISTICS_ENABLE("spark.auron.enableInputBatchStatistics",
true),
/// supports UDAF and other aggregate functions not implemented
- UDAF_FALLBACK_ENABLE("spark.auron.udafFallback.enable", false),
+ UDAF_FALLBACK_ENABLE("spark.auron.udafFallback.enable", true),
// TypedImperativeAggregate one row mem use size
SUGGESTED_UDAF_ROW_MEM_USAGE("spark.auron.suggested.udaf.memUsedSize", 64),
diff --git
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/SparkUDAFWrapperContext.scala
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/SparkUDAFWrapperContext.scala
index 8537cb32..a2be0ecc 100644
---
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/SparkUDAFWrapperContext.scala
+++
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/SparkUDAFWrapperContext.scala
@@ -41,12 +41,7 @@ import org.apache.spark.memory.MemoryMode
import org.apache.spark.sql.auron.memory.OnHeapSpillManager
import org.apache.spark.sql.auron.util.Using
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.catalyst.expressions.AttributeReference
-import org.apache.spark.sql.catalyst.expressions.JoinedRow
-import org.apache.spark.sql.catalyst.expressions.Nondeterministic
-import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
-import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference,
BoundReference, JoinedRow, Nondeterministic, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
import
org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
@@ -68,10 +63,6 @@ case class SparkUDAFWrapperContext[B](serialized:
ByteBuffer) extends Logging {
bytes
})
- val inputAttributes: Seq[Attribute] = javaParamsSchema.fields.map { field =>
- AttributeReference(field.name, field.dataType, field.nullable)()
- }
-
private val outputSchema = {
val schema = StructType(Seq(StructField("", expr.dataType, expr.nullable)))
ArrowUtils.toArrowSchema(schema)
@@ -93,7 +84,7 @@ case class SparkUDAFWrapperContext[B](serialized: ByteBuffer)
extends Logging {
override def initialValue: AggregateEvaluator[B, BufferRowsColumn[B]] = {
val evaluator = expr match {
case declarative: DeclarativeAggregate =>
- new DeclarativeEvaluator(declarative, inputAttributes)
+ new DeclarativeEvaluator(declarative)
case imperative: TypedImperativeAggregate[B] =>
new TypedImperativeEvaluator(imperative)
}
@@ -243,7 +234,7 @@ trait AggregateEvaluator[B, R <: BufferRowsColumn[B]]
extends Logging {
}
}
-class DeclarativeEvaluator(val agg: DeclarativeAggregate, inputAttributes:
Seq[Attribute])
+class DeclarativeEvaluator(val agg: DeclarativeAggregate)
extends AggregateEvaluator[UnsafeRow, DeclarativeAggRowsColumn] {
val initializedRow: UnsafeRow = {
@@ -252,8 +243,20 @@ class DeclarativeEvaluator(val agg: DeclarativeAggregate,
inputAttributes: Seq[A
}
val releasedRow: UnsafeRow = null
- val updater: UnsafeProjection =
- UnsafeProjection.create(agg.updateExpressions, agg.aggBufferAttributes ++
inputAttributes)
+ val updater: UnsafeProjection = {
+ val updateExpressions = agg.updateExpressions.map { expr =>
+ expr.transform {
+ case BoundReference(odin, dt, nullable) =>
+ BoundReference(odin + agg.aggBufferAttributes.length, dt, nullable)
+ case expr => expr
+ }
+ }
+ UnsafeProjection.create(
+ updateExpressions,
+ agg.aggBufferAttributes ++ agg.children.map(c => {
+ AttributeReference("", c.dataType, c.nullable)()
+ }))
+ }
val merger: UnsafeProjection = UnsafeProjection.create(
agg.mergeExpressions,
@@ -322,7 +325,7 @@ case class DeclarativeAggRowsColumn(
override def resize(len: Int): Unit = {
rows.appendAll((rows.length until len).map(_ => {
- val newRow = evaluator.initializedRow.copy()
+ val newRow = evaluator.initializedRow
rowsMemUsed += newRow.getSizeInBytes
newRow
}))
@@ -335,28 +338,25 @@ case class DeclarativeAggRowsColumn(
override def updateRow(i: Int, inputRow: InternalRow): Unit = {
if (i == rows.length) {
- val newRow =
evaluator.updater(evaluator.joiner(evaluator.initializedRow.copy(), inputRow))
- rowsMemUsed += newRow.getSizeInBytes
- rows.append(newRow)
+ rows.append(evaluator.updater(evaluator.joiner(evaluator.initializedRow,
inputRow)).copy())
} else {
rowsMemUsed -= rows(i).getSizeInBytes
- rows(i) = evaluator.updater(evaluator.joiner(rows(i), inputRow))
- rowsMemUsed += rows(i).getSizeInBytes
+ rows(i) = evaluator.updater(evaluator.joiner(rows(i), inputRow)).copy()
}
+ rowsMemUsed += rows(i).getSizeInBytes
}
override def mergeRow(i: Int, mergeRows: BufferRowsColumn[UnsafeRow],
mergeIdx: Int): Unit = {
mergeRows match {
case mergeRows: DeclarativeAggRowsColumn =>
+ val mergeRow = mergeRows.rows(mergeIdx)
if (i == rows.length) {
- val newRow = mergeRows.rows(mergeIdx)
- rowsMemUsed += newRow.getSizeInBytes
- rows.append(newRow)
+ rows.append(mergeRow)
} else {
rowsMemUsed -= rows(i).getSizeInBytes
- rows(i) = evaluator.merger(evaluator.joiner(rows(i),
mergeRows.rows(mergeIdx)))
- rowsMemUsed += rows(i).getSizeInBytes
+ rows(i) = evaluator.merger(evaluator.joiner(rows(i),
mergeRow)).copy()
}
+ rowsMemUsed += rows(i).getSizeInBytes
mergeRows.rows(mergeIdx) = evaluator.releasedRow
}
}