This is an automated email from the ASF dual-hosted git repository. mbutrovich pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push: new 178ab5dd3 chore: Add type parameter to CometAggregateExpressionSerde (#2249) 178ab5dd3 is described below commit 178ab5dd3eb6d9587a6b71ae0076c49f135182da Author: Andy Grove <agr...@apache.org> AuthorDate: Thu Aug 28 11:59:15 2025 -0600 chore: Add type parameter to CometAggregateExpressionSerde (#2249) --- .../org/apache/comet/serde/QueryPlanSerde.scala | 10 +- .../scala/org/apache/comet/serde/aggregates.scala | 131 ++++++++------------- 2 files changed, 57 insertions(+), 84 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 85aeb6647..84e3445fd 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -174,7 +174,7 @@ object QueryPlanSerde extends Logging with CometExprShim { /** * Mapping of Spark aggregate expression class to Comet expression handler. */ - private val aggrSerdeMap: Map[Class[_], CometAggregateExpressionSerde] = Map( + private val aggrSerdeMap: Map[Class[_], CometAggregateExpressionSerde[_]] = Map( classOf[Sum] -> CometSum, classOf[Average] -> CometAverage, classOf[Count] -> CometCount, @@ -498,7 +498,9 @@ object QueryPlanSerde extends Logging with CometExprShim { val cometExpr = aggrSerdeMap.get(fn.getClass) cometExpr match { case Some(handler) => - handler.convert(aggExpr, fn, inputs, binding, conf) + handler + .asInstanceOf[CometAggregateExpressionSerde[AggregateFunction]] + .convert(aggExpr, fn, inputs, binding, conf) case _ => withInfo( aggExpr, @@ -2456,7 +2458,7 @@ trait CometExpressionSerde[T <: Expression] { /** * Trait for providing serialization logic for aggregate expressions. */ -trait CometAggregateExpressionSerde { +trait CometAggregateExpressionSerde[T <: AggregateFunction] { /** * Convert a Spark expression into a protocol buffer representation that can be passed into @@ -2479,7 +2481,7 @@ trait CometAggregateExpressionSerde { */ def convert( aggExpr: AggregateExpression, - expr: Expression, + expr: T, inputs: Seq[Attribute], binding: Boolean, conf: SQLConf): Option[ExprOuterClass.AggExpr] diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index 0f784b76f..51c895128 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -21,8 +21,8 @@ package org.apache.comet.serde import scala.collection.JavaConverters.asJavaIterableConverter -import org.apache.spark.sql.catalyst.expressions.{Attribute, EvalMode, Expression} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, Corr, Covariance, CovPopulation, CovSample, First, Last, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} +import org.apache.spark.sql.catalyst.expressions.{Attribute, EvalMode} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ByteType, DecimalType, IntegerType, LongType, ShortType, StringType} @@ -30,11 +30,11 @@ import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType} -object CometMin extends CometAggregateExpressionSerde { +object CometMin extends CometAggregateExpressionSerde[Min] { override def convert( aggExpr: AggregateExpression, - expr: Expression, + expr: Min, inputs: Seq[Attribute], binding: Boolean, conf: SQLConf): Option[ExprOuterClass.AggExpr] = { @@ -66,11 +66,11 @@ object CometMin extends CometAggregateExpressionSerde { } } -object CometMax extends CometAggregateExpressionSerde { +object CometMax extends CometAggregateExpressionSerde[Max] { override def convert( aggExpr: AggregateExpression, - expr: Expression, + expr: Max, inputs: Seq[Attribute], binding: Boolean, conf: SQLConf): Option[ExprOuterClass.AggExpr] = { @@ -102,10 +102,10 @@ object CometMax extends CometAggregateExpressionSerde { } } -object CometCount extends CometAggregateExpressionSerde { +object CometCount extends CometAggregateExpressionSerde[Count] { override def convert( aggExpr: AggregateExpression, - expr: Expression, + expr: Count, inputs: Seq[Attribute], binding: Boolean, conf: SQLConf): Option[ExprOuterClass.AggExpr] = { @@ -125,17 +125,16 @@ object CometCount extends CometAggregateExpressionSerde { } } -object CometAverage extends CometAggregateExpressionSerde { +object CometAverage extends CometAggregateExpressionSerde[Average] { override def convert( aggExpr: AggregateExpression, - expr: Expression, + avg: Average, inputs: Seq[Attribute], binding: Boolean, conf: SQLConf): Option[ExprOuterClass.AggExpr] = { - val avg = expr.asInstanceOf[Average] - if (!AggSerde.avgDataTypeSupported(expr.dataType)) { - withInfo(aggExpr, s"Unsupported data type: ${expr.dataType}") + if (!AggSerde.avgDataTypeSupported(avg.dataType)) { + withInfo(aggExpr, s"Unsupported data type: ${avg.dataType}") return None } @@ -155,7 +154,7 @@ object CometAverage extends CometAggregateExpressionSerde { val child = avg.child val childExpr = exprToProto(child, inputs, binding) - val dataType = serializeDataType(expr.dataType) + val dataType = serializeDataType(avg.dataType) val sumDataType = child.dataType match { case decimalType: DecimalType => @@ -181,7 +180,7 @@ object CometAverage extends CometAggregateExpressionSerde { .setAvg(builder) .build()) } else if (dataType.isEmpty) { - withInfo(aggExpr, s"datatype ${expr.dataType} is not supported", child) + withInfo(aggExpr, s"datatype ${avg.dataType} is not supported", child) None } else { withInfo(aggExpr, child) @@ -189,17 +188,16 @@ object CometAverage extends CometAggregateExpressionSerde { } } } -object CometSum extends CometAggregateExpressionSerde { +object CometSum extends CometAggregateExpressionSerde[Sum] { override def convert( aggExpr: AggregateExpression, - expr: Expression, + sum: Sum, inputs: Seq[Attribute], binding: Boolean, conf: SQLConf): Option[ExprOuterClass.AggExpr] = { - val sum = expr.asInstanceOf[Sum] if (!AggSerde.sumDataTypeSupported(sum.dataType)) { - withInfo(aggExpr, s"Unsupported data type: ${expr.dataType}") + withInfo(aggExpr, s"Unsupported data type: ${sum.dataType}") return None } @@ -242,14 +240,13 @@ object CometSum extends CometAggregateExpressionSerde { } } -object CometFirst extends CometAggregateExpressionSerde { +object CometFirst extends CometAggregateExpressionSerde[First] { override def convert( aggExpr: AggregateExpression, - expr: Expression, + first: First, inputs: Seq[Attribute], binding: Boolean, conf: SQLConf): Option[ExprOuterClass.AggExpr] = { - val first = expr.asInstanceOf[First] val child = first.children.head val childExpr = exprToProto(child, inputs, binding) val dataType = serializeDataType(first.dataType) @@ -275,14 +272,13 @@ object CometFirst extends CometAggregateExpressionSerde { } } -object CometLast extends CometAggregateExpressionSerde { +object CometLast extends CometAggregateExpressionSerde[Last] { override def convert( aggExpr: AggregateExpression, - expr: Expression, + last: Last, inputs: Seq[Attribute], binding: Boolean, conf: SQLConf): Option[ExprOuterClass.AggExpr] = { - val last = expr.asInstanceOf[Last] val child = last.children.head val childExpr = exprToProto(child, inputs, binding) val dataType = serializeDataType(last.dataType) @@ -308,16 +304,15 @@ object CometLast extends CometAggregateExpressionSerde { } } -object CometBitAndAgg extends CometAggregateExpressionSerde { +object CometBitAndAgg extends CometAggregateExpressionSerde[BitAndAgg] { override def convert( aggExpr: AggregateExpression, - expr: Expression, + bitAnd: BitAndAgg, inputs: Seq[Attribute], binding: Boolean, conf: SQLConf): Option[ExprOuterClass.AggExpr] = { - val bitAnd = expr.asInstanceOf[BitAndAgg] if (!AggSerde.bitwiseAggTypeSupported(bitAnd.dataType)) { - withInfo(aggExpr, s"Unsupported data type: ${expr.dataType}") + withInfo(aggExpr, s"Unsupported data type: ${bitAnd.dataType}") return None } val child = bitAnd.child @@ -343,16 +338,15 @@ object CometBitAndAgg extends CometAggregateExpressionSerde { } } -object CometBitOrAgg extends CometAggregateExpressionSerde { +object CometBitOrAgg extends CometAggregateExpressionSerde[BitOrAgg] { override def convert( aggExpr: AggregateExpression, - expr: Expression, + bitOr: BitOrAgg, inputs: Seq[Attribute], binding: Boolean, conf: SQLConf): Option[ExprOuterClass.AggExpr] = { - val bitOr = expr.asInstanceOf[BitOrAgg] if (!AggSerde.bitwiseAggTypeSupported(bitOr.dataType)) { - withInfo(aggExpr, s"Unsupported data type: ${expr.dataType}") + withInfo(aggExpr, s"Unsupported data type: ${bitOr.dataType}") return None } val child = bitOr.child @@ -378,16 +372,15 @@ object CometBitOrAgg extends CometAggregateExpressionSerde { } } -object CometBitXOrAgg extends CometAggregateExpressionSerde { +object CometBitXOrAgg extends CometAggregateExpressionSerde[BitXorAgg] { override def convert( aggExpr: AggregateExpression, - expr: Expression, + bitXor: BitXorAgg, inputs: Seq[Attribute], binding: Boolean, conf: SQLConf): Option[ExprOuterClass.AggExpr] = { - val bitXor = expr.asInstanceOf[BitXorAgg] if (!AggSerde.bitwiseAggTypeSupported(bitXor.dataType)) { - withInfo(aggExpr, s"Unsupported data type: ${expr.dataType}") + withInfo(aggExpr, s"Unsupported data type: ${bitXor.dataType}") return None } val child = bitXor.child @@ -413,7 +406,7 @@ object CometBitXOrAgg extends CometAggregateExpressionSerde { } } -trait CometCovBase extends CometAggregateExpressionSerde { +trait CometCovBase { def convertCov( aggExpr: AggregateExpression, cov: Covariance, @@ -446,14 +439,13 @@ trait CometCovBase extends CometAggregateExpressionSerde { } } -object CometCovSample extends CometCovBase { +object CometCovSample extends CometAggregateExpressionSerde[CovSample] with CometCovBase { override def convert( aggExpr: AggregateExpression, - expr: Expression, + covSample: CovSample, inputs: Seq[Attribute], binding: Boolean, conf: SQLConf): Option[ExprOuterClass.AggExpr] = { - val covSample = expr.asInstanceOf[CovSample] convertCov( aggExpr, covSample, @@ -465,14 +457,13 @@ object CometCovSample extends CometCovBase { } } -object CometCovPopulation extends CometCovBase { +object CometCovPopulation extends CometAggregateExpressionSerde[CovPopulation] with CometCovBase { override def convert( aggExpr: AggregateExpression, - expr: Expression, + covPopulation: CovPopulation, inputs: Seq[Attribute], binding: Boolean, conf: SQLConf): Option[ExprOuterClass.AggExpr] = { - val covPopulation = expr.asInstanceOf[CovPopulation] convertCov( aggExpr, covPopulation, @@ -484,7 +475,7 @@ object CometCovPopulation extends CometCovBase { } } -trait CometVariance extends CometAggregateExpressionSerde { +trait CometVariance { def convertVariance( aggExpr: AggregateExpression, expr: CentralMomentAgg, @@ -515,31 +506,29 @@ trait CometVariance extends CometAggregateExpressionSerde { } -object CometVarianceSamp extends CometVariance { +object CometVarianceSamp extends CometAggregateExpressionSerde[VarianceSamp] with CometVariance { override def convert( aggExpr: AggregateExpression, - expr: Expression, + variance: VarianceSamp, inputs: Seq[Attribute], binding: Boolean, conf: SQLConf): Option[ExprOuterClass.AggExpr] = { - val variance = expr.asInstanceOf[VarianceSamp] convertVariance(aggExpr, variance, variance.nullOnDivideByZero, 0, inputs, binding) } } -object CometVariancePop extends CometVariance { +object CometVariancePop extends CometAggregateExpressionSerde[VariancePop] with CometVariance { override def convert( aggExpr: AggregateExpression, - expr: Expression, + variance: VariancePop, inputs: Seq[Attribute], binding: Boolean, conf: SQLConf): Option[ExprOuterClass.AggExpr] = { - val variance = expr.asInstanceOf[VariancePop] convertVariance(aggExpr, variance, variance.nullOnDivideByZero, 1, inputs, binding) } } -trait CometStddev extends CometAggregateExpressionSerde { +trait CometStddev { def convertStddev( aggExpr: AggregateExpression, stddev: CentralMomentAgg, @@ -580,52 +569,35 @@ trait CometStddev extends CometAggregateExpressionSerde { } } -object CometStddevSamp extends CometStddev { +object CometStddevSamp extends CometAggregateExpressionSerde[StddevSamp] with CometStddev { override def convert( aggExpr: AggregateExpression, - expr: Expression, + stddev: StddevSamp, inputs: Seq[Attribute], binding: Boolean, conf: SQLConf): Option[ExprOuterClass.AggExpr] = { - val variance = expr.asInstanceOf[StddevSamp] - convertStddev( - aggExpr, - variance, - variance.nullOnDivideByZero, - 0, - inputs, - binding, - conf: SQLConf) + convertStddev(aggExpr, stddev, stddev.nullOnDivideByZero, 0, inputs, binding, conf: SQLConf) } } -object CometStddevPop extends CometStddev { +object CometStddevPop extends CometAggregateExpressionSerde[StddevPop] with CometStddev { override def convert( aggExpr: AggregateExpression, - expr: Expression, + stddev: StddevPop, inputs: Seq[Attribute], binding: Boolean, conf: SQLConf): Option[ExprOuterClass.AggExpr] = { - val variance = expr.asInstanceOf[StddevPop] - convertStddev( - aggExpr, - variance, - variance.nullOnDivideByZero, - 1, - inputs, - binding, - conf: SQLConf) + convertStddev(aggExpr, stddev, stddev.nullOnDivideByZero, 1, inputs, binding, conf: SQLConf) } } -object CometCorr extends CometAggregateExpressionSerde { +object CometCorr extends CometAggregateExpressionSerde[Corr] { override def convert( aggExpr: AggregateExpression, - expr: Expression, + corr: Corr, inputs: Seq[Attribute], binding: Boolean, conf: SQLConf): Option[ExprOuterClass.AggExpr] = { - val corr = expr.asInstanceOf[Corr] val child1Expr = exprToProto(corr.x, inputs, binding) val child2Expr = exprToProto(corr.y, inputs, binding) val dataType = serializeDataType(corr.dataType) @@ -649,17 +621,16 @@ object CometCorr extends CometAggregateExpressionSerde { } } -object CometBloomFilterAggregate extends CometAggregateExpressionSerde { +object CometBloomFilterAggregate extends CometAggregateExpressionSerde[BloomFilterAggregate] { override def convert( aggExpr: AggregateExpression, - expr: Expression, + bloomFilter: BloomFilterAggregate, inputs: Seq[Attribute], binding: Boolean, conf: SQLConf): Option[ExprOuterClass.AggExpr] = { // We ignore mutableAggBufferOffset and inputAggBufferOffset because they are // implementation details for Spark's ObjectHashAggregate. - val bloomFilter = expr.asInstanceOf[BloomFilterAggregate] val childExpr = exprToProto(bloomFilter.child, inputs, binding) val numItemsExpr = exprToProto(bloomFilter.estimatedNumItemsExpression, inputs, binding) val numBitsExpr = exprToProto(bloomFilter.numBitsExpression, inputs, binding) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org For additional commands, e-mail: commits-h...@datafusion.apache.org