This is an automated email from the ASF dual-hosted git repository.

richox pushed a commit to branch dev-columnar-agg-buf
in repository https://gitbox.apache.org/repos/asf/auron.git

commit e18a17b610c256746056dcdcea270972fbede16a
Author: zhangli20 <[email protected]>
AuthorDate: Wed Jan 21 20:05:08 2026 +0800

    Implement columnar aggregate buffers
---
 .../sql/execution/auron/plan/NativeAggExec.scala   | 13 -------------
 .../sql/execution/auron/plan/NativeAggBase.scala   | 22 ++++++++++++----------
 2 files changed, 12 insertions(+), 23 deletions(-)

diff --git 
a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeAggExec.scala
 
b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeAggExec.scala
index 7e4b9d6f..99f9247b 100644
--- 
a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeAggExec.scala
+++ 
b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeAggExec.scala
@@ -17,16 +17,12 @@
 package org.apache.spark.sql.execution.auron.plan
 
 import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.catalyst.expressions.AttributeReference
 import org.apache.spark.sql.catalyst.expressions.Expression
-import org.apache.spark.sql.catalyst.expressions.ExprId
 import org.apache.spark.sql.catalyst.expressions.NamedExpression
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
-import org.apache.spark.sql.catalyst.expressions.aggregate.Final
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
 import org.apache.spark.sql.execution.auron.plan.NativeAggBase.AggExecMode
-import org.apache.spark.sql.types.BinaryType
 
 import org.apache.auron.sparkver
 
@@ -55,15 +51,6 @@ case class NativeAggExec(
   @sparkver("3.3 / 3.4 / 3.5")
   override val initialInputBufferOffset: Int = theInitialInputBufferOffset
 
-  override def output: Seq[Attribute] =
-    if (aggregateExpressions.map(_.mode).contains(Final)) {
-      groupingExpressions.map(_.toAttribute) ++ aggregateAttributes
-    } else {
-      groupingExpressions.map(_.toAttribute) :+
-        AttributeReference(NativeAggBase.AGG_BUF_COLUMN_NAME, BinaryType, 
nullable = false)(
-          ExprId.apply(NativeAggBase.AGG_BUF_COLUMN_EXPR_ID))
-    }
-
   @sparkver("3.2 / 3.3 / 3.4 / 3.5")
   override def isStreaming: Boolean = false
 
diff --git 
a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeAggBase.scala
 
b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeAggBase.scala
index 292368b9..f595dacc 100644
--- 
a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeAggBase.scala
+++ 
b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeAggBase.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.auron.plan
 import scala.annotation.tailrec
 import scala.collection.JavaConverters._
 import scala.collection.immutable.SortedMap
+
 import org.apache.spark.OneToOneDependency
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.auron.NativeConverters
@@ -52,6 +53,7 @@ import 
org.apache.spark.sql.execution.exchange.ReusedExchangeExec
 import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.sql.types.{BinaryType, BooleanType, DataType, LongType}
+
 import org.apache.auron.{protobuf => pb}
 import org.apache.auron.jni.AuronAdaptor
 import org.apache.auron.metric.SparkMetricNode
@@ -146,7 +148,7 @@ abstract class NativeAggBase(
       val aggBufferAttrs = nativeAggrInfos
         .flatMap(_.aggBufferDataTypes)
         .map(AttributeReference("", _, nullable = false)(ExprId.apply(0)))
-      groupingExpressions.map(_.toAttribute) :+ aggBufferAttrs
+      groupingExpressions.map(_.toAttribute) ++ aggBufferAttrs
     }
 
   override def outputPartitioning: Partitioning =
@@ -207,10 +209,6 @@ abstract class NativeAggBase(
 }
 
 object NativeAggBase extends Logging {
-
-  val AGG_BUF_COLUMN_EXPR_ID = 9223372036854775807L
-  val AGG_BUF_COLUMN_NAME = s"#$AGG_BUF_COLUMN_EXPR_ID"
-
   trait AggExecMode;
   case object HashAgg extends AggExecMode
   case object SortAgg extends AggExecMode
@@ -222,7 +220,7 @@ object NativeAggBase extends Logging {
       outputAttr: Attribute)
 
   def getNativeAggrInfo(aggr: AggregateExpression, aggrAttr: Attribute): 
NativeAggrInfo = {
-    val aggBufferDataTypes = computeNativeAggBufferDataTypes(aggr)
+    val aggBufferDataTypes = 
computeNativeAggBufferDataTypes(aggr.aggregateFunction)
     val reducedAggr = AggregateExpression(
       aggr.aggregateFunction
         .mapChildren(e => createPlaceholder(e))
@@ -235,7 +233,11 @@ object NativeAggBase extends Logging {
 
     aggr.mode match {
       case Partial =>
-        NativeAggrInfo(aggr.mode, NativeConverters.convertAggregateExpr(aggr) 
:: Nil, aggBufferDataTypes, outputAttr)
+        NativeAggrInfo(
+          aggr.mode,
+          NativeConverters.convertAggregateExpr(aggr) :: Nil,
+          aggBufferDataTypes,
+          outputAttr)
 
       case PartialMerge | Final =>
         NativeAggrInfo(
@@ -297,12 +299,12 @@ object NativeAggBase extends Logging {
   def computeNativeAggBufferDataTypes(aggr: AggregateFunction): Seq[DataType] 
= {
     aggr match {
       case _: Count => Seq(LongType)
-      case f: Max  => Seq(f.dataType)
+      case f: Max => Seq(f.dataType)
       case f: Min => Seq(f.dataType)
       case f: Sum => Seq(f.dataType)
       case f: Average => Seq(f.dataType, LongType)
-      case f@First(_, true) => Seq(f.dataType)
-      case f@First(_, false) => Seq(f.dataType, BooleanType)
+      case f @ First(_, true) => Seq(f.dataType)
+      case f @ First(_, false) => Seq(f.dataType, BooleanType)
       case _ => Seq(BinaryType)
     }
   }

Reply via email to