This is an automated email from the ASF dual-hosted git repository.

mbutrovich pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 178ab5dd3 chore: Add type parameter to CometAggregateExpressionSerde 
(#2249)
178ab5dd3 is described below

commit 178ab5dd3eb6d9587a6b71ae0076c49f135182da
Author: Andy Grove <agr...@apache.org>
AuthorDate: Thu Aug 28 11:59:15 2025 -0600

    chore: Add type parameter to CometAggregateExpressionSerde (#2249)
---
 .../org/apache/comet/serde/QueryPlanSerde.scala    |  10 +-
 .../scala/org/apache/comet/serde/aggregates.scala  | 131 ++++++++-------------
 2 files changed, 57 insertions(+), 84 deletions(-)

diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 85aeb6647..84e3445fd 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -174,7 +174,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
   /**
    * Mapping of Spark aggregate expression class to Comet expression handler.
    */
-  private val aggrSerdeMap: Map[Class[_], CometAggregateExpressionSerde] = Map(
+  private val aggrSerdeMap: Map[Class[_], CometAggregateExpressionSerde[_]] = 
Map(
     classOf[Sum] -> CometSum,
     classOf[Average] -> CometAverage,
     classOf[Count] -> CometCount,
@@ -498,7 +498,9 @@ object QueryPlanSerde extends Logging with CometExprShim {
     val cometExpr = aggrSerdeMap.get(fn.getClass)
     cometExpr match {
       case Some(handler) =>
-        handler.convert(aggExpr, fn, inputs, binding, conf)
+        handler
+          .asInstanceOf[CometAggregateExpressionSerde[AggregateFunction]]
+          .convert(aggExpr, fn, inputs, binding, conf)
       case _ =>
         withInfo(
           aggExpr,
@@ -2456,7 +2458,7 @@ trait CometExpressionSerde[T <: Expression] {
 /**
  * Trait for providing serialization logic for aggregate expressions.
  */
-trait CometAggregateExpressionSerde {
+trait CometAggregateExpressionSerde[T <: AggregateFunction] {
 
   /**
    * Convert a Spark expression into a protocol buffer representation that can 
be passed into
@@ -2479,7 +2481,7 @@ trait CometAggregateExpressionSerde {
    */
   def convert(
       aggExpr: AggregateExpression,
-      expr: Expression,
+      expr: T,
       inputs: Seq[Attribute],
       binding: Boolean,
       conf: SQLConf): Option[ExprOuterClass.AggExpr]
diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala 
b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
index 0f784b76f..51c895128 100644
--- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
@@ -21,8 +21,8 @@ package org.apache.comet.serde
 
 import scala.collection.JavaConverters.asJavaIterableConverter
 
-import org.apache.spark.sql.catalyst.expressions.{Attribute, EvalMode, 
Expression}
-import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, 
CentralMomentAgg, Corr, Covariance, CovPopulation, CovSample, First, Last, 
StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, EvalMode}
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, 
CentralMomentAgg, Corr, Count, Covariance, CovPopulation, CovSample, First, 
Last, Max, Min, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{ByteType, DecimalType, IntegerType, 
LongType, ShortType, StringType}
 
@@ -30,11 +30,11 @@ import org.apache.comet.CometConf
 import org.apache.comet.CometSparkSessionExtensions.withInfo
 import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType}
 
-object CometMin extends CometAggregateExpressionSerde {
+object CometMin extends CometAggregateExpressionSerde[Min] {
 
   override def convert(
       aggExpr: AggregateExpression,
-      expr: Expression,
+      expr: Min,
       inputs: Seq[Attribute],
       binding: Boolean,
       conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
@@ -66,11 +66,11 @@ object CometMin extends CometAggregateExpressionSerde {
   }
 }
 
-object CometMax extends CometAggregateExpressionSerde {
+object CometMax extends CometAggregateExpressionSerde[Max] {
 
   override def convert(
       aggExpr: AggregateExpression,
-      expr: Expression,
+      expr: Max,
       inputs: Seq[Attribute],
       binding: Boolean,
       conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
@@ -102,10 +102,10 @@ object CometMax extends CometAggregateExpressionSerde {
   }
 }
 
-object CometCount extends CometAggregateExpressionSerde {
+object CometCount extends CometAggregateExpressionSerde[Count] {
   override def convert(
       aggExpr: AggregateExpression,
-      expr: Expression,
+      expr: Count,
       inputs: Seq[Attribute],
       binding: Boolean,
       conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
@@ -125,17 +125,16 @@ object CometCount extends CometAggregateExpressionSerde {
   }
 }
 
-object CometAverage extends CometAggregateExpressionSerde {
+object CometAverage extends CometAggregateExpressionSerde[Average] {
   override def convert(
       aggExpr: AggregateExpression,
-      expr: Expression,
+      avg: Average,
       inputs: Seq[Attribute],
       binding: Boolean,
       conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
-    val avg = expr.asInstanceOf[Average]
 
-    if (!AggSerde.avgDataTypeSupported(expr.dataType)) {
-      withInfo(aggExpr, s"Unsupported data type: ${expr.dataType}")
+    if (!AggSerde.avgDataTypeSupported(avg.dataType)) {
+      withInfo(aggExpr, s"Unsupported data type: ${avg.dataType}")
       return None
     }
 
@@ -155,7 +154,7 @@ object CometAverage extends CometAggregateExpressionSerde {
 
     val child = avg.child
     val childExpr = exprToProto(child, inputs, binding)
-    val dataType = serializeDataType(expr.dataType)
+    val dataType = serializeDataType(avg.dataType)
 
     val sumDataType = child.dataType match {
       case decimalType: DecimalType =>
@@ -181,7 +180,7 @@ object CometAverage extends CometAggregateExpressionSerde {
           .setAvg(builder)
           .build())
     } else if (dataType.isEmpty) {
-      withInfo(aggExpr, s"datatype ${expr.dataType} is not supported", child)
+      withInfo(aggExpr, s"datatype ${avg.dataType} is not supported", child)
       None
     } else {
       withInfo(aggExpr, child)
@@ -189,17 +188,16 @@ object CometAverage extends CometAggregateExpressionSerde 
{
     }
   }
 }
-object CometSum extends CometAggregateExpressionSerde {
+object CometSum extends CometAggregateExpressionSerde[Sum] {
   override def convert(
       aggExpr: AggregateExpression,
-      expr: Expression,
+      sum: Sum,
       inputs: Seq[Attribute],
       binding: Boolean,
       conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
-    val sum = expr.asInstanceOf[Sum]
 
     if (!AggSerde.sumDataTypeSupported(sum.dataType)) {
-      withInfo(aggExpr, s"Unsupported data type: ${expr.dataType}")
+      withInfo(aggExpr, s"Unsupported data type: ${sum.dataType}")
       return None
     }
 
@@ -242,14 +240,13 @@ object CometSum extends CometAggregateExpressionSerde {
   }
 }
 
-object CometFirst extends CometAggregateExpressionSerde {
+object CometFirst extends CometAggregateExpressionSerde[First] {
   override def convert(
       aggExpr: AggregateExpression,
-      expr: Expression,
+      first: First,
       inputs: Seq[Attribute],
       binding: Boolean,
       conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
-    val first = expr.asInstanceOf[First]
     val child = first.children.head
     val childExpr = exprToProto(child, inputs, binding)
     val dataType = serializeDataType(first.dataType)
@@ -275,14 +272,13 @@ object CometFirst extends CometAggregateExpressionSerde {
   }
 }
 
-object CometLast extends CometAggregateExpressionSerde {
+object CometLast extends CometAggregateExpressionSerde[Last] {
   override def convert(
       aggExpr: AggregateExpression,
-      expr: Expression,
+      last: Last,
       inputs: Seq[Attribute],
       binding: Boolean,
       conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
-    val last = expr.asInstanceOf[Last]
     val child = last.children.head
     val childExpr = exprToProto(child, inputs, binding)
     val dataType = serializeDataType(last.dataType)
@@ -308,16 +304,15 @@ object CometLast extends CometAggregateExpressionSerde {
   }
 }
 
-object CometBitAndAgg extends CometAggregateExpressionSerde {
+object CometBitAndAgg extends CometAggregateExpressionSerde[BitAndAgg] {
   override def convert(
       aggExpr: AggregateExpression,
-      expr: Expression,
+      bitAnd: BitAndAgg,
       inputs: Seq[Attribute],
       binding: Boolean,
       conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
-    val bitAnd = expr.asInstanceOf[BitAndAgg]
     if (!AggSerde.bitwiseAggTypeSupported(bitAnd.dataType)) {
-      withInfo(aggExpr, s"Unsupported data type: ${expr.dataType}")
+      withInfo(aggExpr, s"Unsupported data type: ${bitAnd.dataType}")
       return None
     }
     val child = bitAnd.child
@@ -343,16 +338,15 @@ object CometBitAndAgg extends 
CometAggregateExpressionSerde {
   }
 }
 
-object CometBitOrAgg extends CometAggregateExpressionSerde {
+object CometBitOrAgg extends CometAggregateExpressionSerde[BitOrAgg] {
   override def convert(
       aggExpr: AggregateExpression,
-      expr: Expression,
+      bitOr: BitOrAgg,
       inputs: Seq[Attribute],
       binding: Boolean,
       conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
-    val bitOr = expr.asInstanceOf[BitOrAgg]
     if (!AggSerde.bitwiseAggTypeSupported(bitOr.dataType)) {
-      withInfo(aggExpr, s"Unsupported data type: ${expr.dataType}")
+      withInfo(aggExpr, s"Unsupported data type: ${bitOr.dataType}")
       return None
     }
     val child = bitOr.child
@@ -378,16 +372,15 @@ object CometBitOrAgg extends 
CometAggregateExpressionSerde {
   }
 }
 
-object CometBitXOrAgg extends CometAggregateExpressionSerde {
+object CometBitXOrAgg extends CometAggregateExpressionSerde[BitXorAgg] {
   override def convert(
       aggExpr: AggregateExpression,
-      expr: Expression,
+      bitXor: BitXorAgg,
       inputs: Seq[Attribute],
       binding: Boolean,
       conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
-    val bitXor = expr.asInstanceOf[BitXorAgg]
     if (!AggSerde.bitwiseAggTypeSupported(bitXor.dataType)) {
-      withInfo(aggExpr, s"Unsupported data type: ${expr.dataType}")
+      withInfo(aggExpr, s"Unsupported data type: ${bitXor.dataType}")
       return None
     }
     val child = bitXor.child
@@ -413,7 +406,7 @@ object CometBitXOrAgg extends CometAggregateExpressionSerde 
{
   }
 }
 
-trait CometCovBase extends CometAggregateExpressionSerde {
+trait CometCovBase {
   def convertCov(
       aggExpr: AggregateExpression,
       cov: Covariance,
@@ -446,14 +439,13 @@ trait CometCovBase extends CometAggregateExpressionSerde {
   }
 }
 
-object CometCovSample extends CometCovBase {
+object CometCovSample extends CometAggregateExpressionSerde[CovSample] with 
CometCovBase {
   override def convert(
       aggExpr: AggregateExpression,
-      expr: Expression,
+      covSample: CovSample,
       inputs: Seq[Attribute],
       binding: Boolean,
       conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
-    val covSample = expr.asInstanceOf[CovSample]
     convertCov(
       aggExpr,
       covSample,
@@ -465,14 +457,13 @@ object CometCovSample extends CometCovBase {
   }
 }
 
-object CometCovPopulation extends CometCovBase {
+object CometCovPopulation extends CometAggregateExpressionSerde[CovPopulation] 
with CometCovBase {
   override def convert(
       aggExpr: AggregateExpression,
-      expr: Expression,
+      covPopulation: CovPopulation,
       inputs: Seq[Attribute],
       binding: Boolean,
       conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
-    val covPopulation = expr.asInstanceOf[CovPopulation]
     convertCov(
       aggExpr,
       covPopulation,
@@ -484,7 +475,7 @@ object CometCovPopulation extends CometCovBase {
   }
 }
 
-trait CometVariance extends CometAggregateExpressionSerde {
+trait CometVariance {
   def convertVariance(
       aggExpr: AggregateExpression,
       expr: CentralMomentAgg,
@@ -515,31 +506,29 @@ trait CometVariance extends CometAggregateExpressionSerde 
{
 
 }
 
-object CometVarianceSamp extends CometVariance {
+object CometVarianceSamp extends CometAggregateExpressionSerde[VarianceSamp] 
with CometVariance {
   override def convert(
       aggExpr: AggregateExpression,
-      expr: Expression,
+      variance: VarianceSamp,
       inputs: Seq[Attribute],
       binding: Boolean,
       conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
-    val variance = expr.asInstanceOf[VarianceSamp]
     convertVariance(aggExpr, variance, variance.nullOnDivideByZero, 0, inputs, 
binding)
   }
 }
 
-object CometVariancePop extends CometVariance {
+object CometVariancePop extends CometAggregateExpressionSerde[VariancePop] 
with CometVariance {
   override def convert(
       aggExpr: AggregateExpression,
-      expr: Expression,
+      variance: VariancePop,
       inputs: Seq[Attribute],
       binding: Boolean,
       conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
-    val variance = expr.asInstanceOf[VariancePop]
     convertVariance(aggExpr, variance, variance.nullOnDivideByZero, 1, inputs, 
binding)
   }
 }
 
-trait CometStddev extends CometAggregateExpressionSerde {
+trait CometStddev {
   def convertStddev(
       aggExpr: AggregateExpression,
       stddev: CentralMomentAgg,
@@ -580,52 +569,35 @@ trait CometStddev extends CometAggregateExpressionSerde {
   }
 }
 
-object CometStddevSamp extends CometStddev {
+object CometStddevSamp extends CometAggregateExpressionSerde[StddevSamp] with 
CometStddev {
   override def convert(
       aggExpr: AggregateExpression,
-      expr: Expression,
+      stddev: StddevSamp,
       inputs: Seq[Attribute],
       binding: Boolean,
       conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
-    val variance = expr.asInstanceOf[StddevSamp]
-    convertStddev(
-      aggExpr,
-      variance,
-      variance.nullOnDivideByZero,
-      0,
-      inputs,
-      binding,
-      conf: SQLConf)
+    convertStddev(aggExpr, stddev, stddev.nullOnDivideByZero, 0, inputs, 
binding, conf: SQLConf)
   }
 }
 
-object CometStddevPop extends CometStddev {
+object CometStddevPop extends CometAggregateExpressionSerde[StddevPop] with 
CometStddev {
   override def convert(
       aggExpr: AggregateExpression,
-      expr: Expression,
+      stddev: StddevPop,
       inputs: Seq[Attribute],
       binding: Boolean,
       conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
-    val variance = expr.asInstanceOf[StddevPop]
-    convertStddev(
-      aggExpr,
-      variance,
-      variance.nullOnDivideByZero,
-      1,
-      inputs,
-      binding,
-      conf: SQLConf)
+    convertStddev(aggExpr, stddev, stddev.nullOnDivideByZero, 1, inputs, 
binding, conf: SQLConf)
   }
 }
 
-object CometCorr extends CometAggregateExpressionSerde {
+object CometCorr extends CometAggregateExpressionSerde[Corr] {
   override def convert(
       aggExpr: AggregateExpression,
-      expr: Expression,
+      corr: Corr,
       inputs: Seq[Attribute],
       binding: Boolean,
       conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
-    val corr = expr.asInstanceOf[Corr]
     val child1Expr = exprToProto(corr.x, inputs, binding)
     val child2Expr = exprToProto(corr.y, inputs, binding)
     val dataType = serializeDataType(corr.dataType)
@@ -649,17 +621,16 @@ object CometCorr extends CometAggregateExpressionSerde {
   }
 }
 
-object CometBloomFilterAggregate extends CometAggregateExpressionSerde {
+object CometBloomFilterAggregate extends 
CometAggregateExpressionSerde[BloomFilterAggregate] {
 
   override def convert(
       aggExpr: AggregateExpression,
-      expr: Expression,
+      bloomFilter: BloomFilterAggregate,
       inputs: Seq[Attribute],
       binding: Boolean,
       conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
     // We ignore mutableAggBufferOffset and inputAggBufferOffset because they 
are
     // implementation details for Spark's ObjectHashAggregate.
-    val bloomFilter = expr.asInstanceOf[BloomFilterAggregate]
     val childExpr = exprToProto(bloomFilter.child, inputs, binding)
     val numItemsExpr = exprToProto(bloomFilter.estimatedNumItemsExpression, 
inputs, binding)
     val numBitsExpr = exprToProto(bloomFilter.numBitsExpression, inputs, 
binding)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to