This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new a4c6c12ca chore: Refactor aggregate expression serde (#1380)
a4c6c12ca is described below

commit a4c6c12cafd22b01f3eed88081f1bada4eeb5cdf
Author: Andy Grove <[email protected]>
AuthorDate: Fri Feb 14 06:10:59 2025 -0700

    chore: Refactor aggregate expression serde (#1380)
---
 .../org/apache/comet/serde/QueryPlanSerde.scala    | 519 ++-------------
 .../scala/org/apache/comet/serde/aggregates.scala  | 734 +++++++++++++++++++++
 2 files changed, 792 insertions(+), 461 deletions(-)

diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 8f0a4273b..aa1aba11d 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -174,35 +174,6 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
     Some(dataType)
   }
 
-  private def sumDataTypeSupported(dt: DataType): Boolean = {
-    dt match {
-      case _: NumericType => true
-      case _ => false
-    }
-  }
-
-  private def avgDataTypeSupported(dt: DataType): Boolean = {
-    dt match {
-      case _: NumericType => true
-      // TODO: implement support for interval types
-      case _ => false
-    }
-  }
-
-  private def minMaxDataTypeSupported(dt: DataType): Boolean = {
-    dt match {
-      case _: NumericType | DateType | TimestampType | BooleanType => true
-      case _ => false
-    }
-  }
-
-  private def bitwiseAggTypeSupported(dt: DataType): Boolean = {
-    dt match {
-      case _: IntegerType | LongType | ShortType | ByteType => true
-      case _ => false
-    }
-  }
-
   def windowExprToProto(
       windowExpr: WindowExpression,
       output: Seq[Attribute],
@@ -215,21 +186,22 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
             case _: Count =>
               Some(agg)
             case min: Min =>
-              if (minMaxDataTypeSupported(min.dataType)) {
+              if (AggSerde.minMaxDataTypeSupported(min.dataType)) {
                 Some(agg)
               } else {
                 withInfo(windowExpr, s"datatype ${min.dataType} is not 
supported", expr)
                 None
               }
             case max: Max =>
-              if (minMaxDataTypeSupported(max.dataType)) {
+              if (AggSerde.minMaxDataTypeSupported(max.dataType)) {
                 Some(agg)
               } else {
                 withInfo(windowExpr, s"datatype ${max.dataType} is not 
supported", expr)
                 None
               }
             case s: Sum =>
-              if (sumDataTypeSupported(s.dataType) && 
!s.dataType.isInstanceOf[DecimalType]) {
+              if (AggSerde.sumDataTypeSupported(s.dataType) && !s.dataType
+                  .isInstanceOf[DecimalType]) {
                 Some(agg)
               } else {
                 withInfo(windowExpr, s"datatype ${s.dataType} is not 
supported", expr)
@@ -379,440 +351,33 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
       return None
     }
 
-    aggExpr.aggregateFunction match {
-      case s @ Sum(child, _) if sumDataTypeSupported(s.dataType) && 
isLegacyMode(s) =>
-        val childExpr = exprToProto(child, inputs, binding)
-        val dataType = serializeDataType(s.dataType)
-
-        if (childExpr.isDefined && dataType.isDefined) {
-          val sumBuilder = ExprOuterClass.Sum.newBuilder()
-          sumBuilder.setChild(childExpr.get)
-          sumBuilder.setDatatype(dataType.get)
-          sumBuilder.setFailOnError(getFailOnError(s))
-
-          Some(
-            ExprOuterClass.AggExpr
-              .newBuilder()
-              .setSum(sumBuilder)
-              .build())
-        } else {
-          if (dataType.isEmpty) {
-            withInfo(aggExpr, s"datatype ${s.dataType} is not supported", 
child)
-          } else {
-            withInfo(aggExpr, child)
-          }
-          None
-        }
-      case s @ Average(child, _) if avgDataTypeSupported(s.dataType) && 
isLegacyMode(s) =>
-        val childExpr = exprToProto(child, inputs, binding)
-        val dataType = serializeDataType(s.dataType)
-
-        val sumDataType = if (child.dataType.isInstanceOf[DecimalType]) {
-
-          // This is input precision + 10 to be consistent with Spark
-          val precision = Math.min(
-            DecimalType.MAX_PRECISION,
-            child.dataType.asInstanceOf[DecimalType].precision + 10)
-          val newType =
-            DecimalType.apply(precision, 
child.dataType.asInstanceOf[DecimalType].scale)
-          serializeDataType(newType)
-        } else {
-          serializeDataType(child.dataType)
-        }
-
-        if (childExpr.isDefined && dataType.isDefined) {
-          val builder = ExprOuterClass.Avg.newBuilder()
-          builder.setChild(childExpr.get)
-          builder.setDatatype(dataType.get)
-          builder.setFailOnError(getFailOnError(s))
-          builder.setSumDatatype(sumDataType.get)
-
-          Some(
-            ExprOuterClass.AggExpr
-              .newBuilder()
-              .setAvg(builder)
-              .build())
-        } else if (dataType.isEmpty) {
-          withInfo(aggExpr, s"datatype ${s.dataType} is not supported", child)
-          None
-        } else {
-          withInfo(aggExpr, child)
-          None
-        }
-      case Count(children) =>
-        val exprChildren = children.map(exprToProto(_, inputs, binding))
-
-        if (exprChildren.forall(_.isDefined)) {
-          val countBuilder = ExprOuterClass.Count.newBuilder()
-          countBuilder.addAllChildren(exprChildren.map(_.get).asJava)
-
-          Some(
-            ExprOuterClass.AggExpr
-              .newBuilder()
-              .setCount(countBuilder)
-              .build())
-        } else {
-          withInfo(aggExpr, children: _*)
-          None
-        }
-      case min @ Min(child) if minMaxDataTypeSupported(min.dataType) =>
-        val childExpr = exprToProto(child, inputs, binding)
-        val dataType = serializeDataType(min.dataType)
-
-        if (childExpr.isDefined && dataType.isDefined) {
-          val minBuilder = ExprOuterClass.Min.newBuilder()
-          minBuilder.setChild(childExpr.get)
-          minBuilder.setDatatype(dataType.get)
-
-          Some(
-            ExprOuterClass.AggExpr
-              .newBuilder()
-              .setMin(minBuilder)
-              .build())
-        } else if (dataType.isEmpty) {
-          withInfo(aggExpr, s"datatype ${min.dataType} is not supported", 
child)
-          None
-        } else {
-          withInfo(aggExpr, child)
-          None
-        }
-      case max @ Max(child) if minMaxDataTypeSupported(max.dataType) =>
-        val childExpr = exprToProto(child, inputs, binding)
-        val dataType = serializeDataType(max.dataType)
-
-        if (childExpr.isDefined && dataType.isDefined) {
-          val maxBuilder = ExprOuterClass.Max.newBuilder()
-          maxBuilder.setChild(childExpr.get)
-          maxBuilder.setDatatype(dataType.get)
-
-          Some(
-            ExprOuterClass.AggExpr
-              .newBuilder()
-              .setMax(maxBuilder)
-              .build())
-        } else if (dataType.isEmpty) {
-          withInfo(aggExpr, s"datatype ${max.dataType} is not supported", 
child)
-          None
-        } else {
-          withInfo(aggExpr, child)
-          None
-        }
-      case first @ First(child, ignoreNulls)
-          if !ignoreNulls => // DataFusion doesn't support ignoreNulls true
-        val childExpr = exprToProto(child, inputs, binding)
-        val dataType = serializeDataType(first.dataType)
-
-        if (childExpr.isDefined && dataType.isDefined) {
-          val firstBuilder = ExprOuterClass.First.newBuilder()
-          firstBuilder.setChild(childExpr.get)
-          firstBuilder.setDatatype(dataType.get)
-
-          Some(
-            ExprOuterClass.AggExpr
-              .newBuilder()
-              .setFirst(firstBuilder)
-              .build())
-        } else if (dataType.isEmpty) {
-          withInfo(aggExpr, s"datatype ${first.dataType} is not supported", 
child)
-          None
-        } else {
-          withInfo(aggExpr, child)
-          None
-        }
-      case last @ Last(child, ignoreNulls)
-          if !ignoreNulls => // DataFusion doesn't support ignoreNulls true
-        val childExpr = exprToProto(child, inputs, binding)
-        val dataType = serializeDataType(last.dataType)
-
-        if (childExpr.isDefined && dataType.isDefined) {
-          val lastBuilder = ExprOuterClass.Last.newBuilder()
-          lastBuilder.setChild(childExpr.get)
-          lastBuilder.setDatatype(dataType.get)
-
-          Some(
-            ExprOuterClass.AggExpr
-              .newBuilder()
-              .setLast(lastBuilder)
-              .build())
-        } else if (dataType.isEmpty) {
-          withInfo(aggExpr, s"datatype ${last.dataType} is not supported", 
child)
-          None
-        } else {
-          withInfo(aggExpr, child)
-          None
-        }
-      case bitAnd @ BitAndAgg(child) if 
bitwiseAggTypeSupported(bitAnd.dataType) =>
-        val childExpr = exprToProto(child, inputs, binding)
-        val dataType = serializeDataType(bitAnd.dataType)
-
-        if (childExpr.isDefined && dataType.isDefined) {
-          val bitAndBuilder = ExprOuterClass.BitAndAgg.newBuilder()
-          bitAndBuilder.setChild(childExpr.get)
-          bitAndBuilder.setDatatype(dataType.get)
-
-          Some(
-            ExprOuterClass.AggExpr
-              .newBuilder()
-              .setBitAndAgg(bitAndBuilder)
-              .build())
-        } else if (dataType.isEmpty) {
-          withInfo(aggExpr, s"datatype ${bitAnd.dataType} is not supported", 
child)
-          None
-        } else {
-          withInfo(aggExpr, child)
-          None
-        }
-      case bitOr @ BitOrAgg(child) if bitwiseAggTypeSupported(bitOr.dataType) 
=>
-        val childExpr = exprToProto(child, inputs, binding)
-        val dataType = serializeDataType(bitOr.dataType)
-
-        if (childExpr.isDefined && dataType.isDefined) {
-          val bitOrBuilder = ExprOuterClass.BitOrAgg.newBuilder()
-          bitOrBuilder.setChild(childExpr.get)
-          bitOrBuilder.setDatatype(dataType.get)
-
-          Some(
-            ExprOuterClass.AggExpr
-              .newBuilder()
-              .setBitOrAgg(bitOrBuilder)
-              .build())
-        } else if (dataType.isEmpty) {
-          withInfo(aggExpr, s"datatype ${bitOr.dataType} is not supported", 
child)
-          None
-        } else {
-          withInfo(aggExpr, child)
-          None
-        }
-      case bitXor @ BitXorAgg(child) if 
bitwiseAggTypeSupported(bitXor.dataType) =>
-        val childExpr = exprToProto(child, inputs, binding)
-        val dataType = serializeDataType(bitXor.dataType)
-
-        if (childExpr.isDefined && dataType.isDefined) {
-          val bitXorBuilder = ExprOuterClass.BitXorAgg.newBuilder()
-          bitXorBuilder.setChild(childExpr.get)
-          bitXorBuilder.setDatatype(dataType.get)
-
-          Some(
-            ExprOuterClass.AggExpr
-              .newBuilder()
-              .setBitXorAgg(bitXorBuilder)
-              .build())
-        } else if (dataType.isEmpty) {
-          withInfo(aggExpr, s"datatype ${bitXor.dataType} is not supported", 
child)
-          None
-        } else {
-          withInfo(aggExpr, child)
-          None
-        }
-      case cov @ CovSample(child1, child2, nullOnDivideByZero) =>
-        val child1Expr = exprToProto(child1, inputs, binding)
-        val child2Expr = exprToProto(child2, inputs, binding)
-        val dataType = serializeDataType(cov.dataType)
-
-        if (child1Expr.isDefined && child2Expr.isDefined && 
dataType.isDefined) {
-          val covBuilder = ExprOuterClass.Covariance.newBuilder()
-          covBuilder.setChild1(child1Expr.get)
-          covBuilder.setChild2(child2Expr.get)
-          covBuilder.setNullOnDivideByZero(nullOnDivideByZero)
-          covBuilder.setDatatype(dataType.get)
-          covBuilder.setStatsTypeValue(0)
-
-          Some(
-            ExprOuterClass.AggExpr
-              .newBuilder()
-              .setCovariance(covBuilder)
-              .build())
-        } else {
-          None
-        }
-      case cov @ CovPopulation(child1, child2, nullOnDivideByZero) =>
-        val child1Expr = exprToProto(child1, inputs, binding)
-        val child2Expr = exprToProto(child2, inputs, binding)
-        val dataType = serializeDataType(cov.dataType)
-
-        if (child1Expr.isDefined && child2Expr.isDefined && 
dataType.isDefined) {
-          val covBuilder = ExprOuterClass.Covariance.newBuilder()
-          covBuilder.setChild1(child1Expr.get)
-          covBuilder.setChild2(child2Expr.get)
-          covBuilder.setNullOnDivideByZero(nullOnDivideByZero)
-          covBuilder.setDatatype(dataType.get)
-          covBuilder.setStatsTypeValue(1)
-
-          Some(
-            ExprOuterClass.AggExpr
-              .newBuilder()
-              .setCovariance(covBuilder)
-              .build())
-        } else {
-          None
-        }
-      case variance @ VarianceSamp(child, nullOnDivideByZero) =>
-        val childExpr = exprToProto(child, inputs, binding)
-        val dataType = serializeDataType(variance.dataType)
-
-        if (childExpr.isDefined && dataType.isDefined) {
-          val varBuilder = ExprOuterClass.Variance.newBuilder()
-          varBuilder.setChild(childExpr.get)
-          varBuilder.setNullOnDivideByZero(nullOnDivideByZero)
-          varBuilder.setDatatype(dataType.get)
-          varBuilder.setStatsTypeValue(0)
-
-          Some(
-            ExprOuterClass.AggExpr
-              .newBuilder()
-              .setVariance(varBuilder)
-              .build())
-        } else {
-          withInfo(aggExpr, child)
-          None
-        }
-      case variancePop @ VariancePop(child, nullOnDivideByZero) =>
-        val childExpr = exprToProto(child, inputs, binding)
-        val dataType = serializeDataType(variancePop.dataType)
-
-        if (childExpr.isDefined && dataType.isDefined) {
-          val varBuilder = ExprOuterClass.Variance.newBuilder()
-          varBuilder.setChild(childExpr.get)
-          varBuilder.setNullOnDivideByZero(nullOnDivideByZero)
-          varBuilder.setDatatype(dataType.get)
-          varBuilder.setStatsTypeValue(1)
-
-          Some(
-            ExprOuterClass.AggExpr
-              .newBuilder()
-              .setVariance(varBuilder)
-              .build())
-        } else {
-          withInfo(aggExpr, child)
-          None
-        }
-
-      case std @ StddevSamp(child, nullOnDivideByZero) =>
-        if (CometConf.COMET_EXPR_STDDEV_ENABLED.get(conf)) {
-          val childExpr = exprToProto(child, inputs, binding)
-          val dataType = serializeDataType(std.dataType)
-
-          if (childExpr.isDefined && dataType.isDefined) {
-            val stdBuilder = ExprOuterClass.Stddev.newBuilder()
-            stdBuilder.setChild(childExpr.get)
-            stdBuilder.setNullOnDivideByZero(nullOnDivideByZero)
-            stdBuilder.setDatatype(dataType.get)
-            stdBuilder.setStatsTypeValue(0)
-
-            Some(
-              ExprOuterClass.AggExpr
-                .newBuilder()
-                .setStddev(stdBuilder)
-                .build())
-          } else {
-            withInfo(aggExpr, child)
-            None
-          }
-        } else {
-          withInfo(
-            aggExpr,
-            "stddev disabled by default because it can be slower than Spark. " 
+
-              s"Set ${CometConf.COMET_EXPR_STDDEV_ENABLED}=true to enable it.",
-            child)
-          None
-        }
-
-      case std @ StddevPop(child, nullOnDivideByZero) =>
-        if (CometConf.COMET_EXPR_STDDEV_ENABLED.get(conf)) {
-          val childExpr = exprToProto(child, inputs, binding)
-          val dataType = serializeDataType(std.dataType)
-
-          if (childExpr.isDefined && dataType.isDefined) {
-            val stdBuilder = ExprOuterClass.Stddev.newBuilder()
-            stdBuilder.setChild(childExpr.get)
-            stdBuilder.setNullOnDivideByZero(nullOnDivideByZero)
-            stdBuilder.setDatatype(dataType.get)
-            stdBuilder.setStatsTypeValue(1)
-
-            Some(
-              ExprOuterClass.AggExpr
-                .newBuilder()
-                .setStddev(stdBuilder)
-                .build())
-          } else {
-            withInfo(aggExpr, child)
-            None
-          }
-        } else {
-          withInfo(
-            aggExpr,
-            "stddev disabled by default because it can be slower than Spark. " 
+
-              s"Set ${CometConf.COMET_EXPR_STDDEV_ENABLED}=true to enable it.",
-            child)
-          None
-        }
-
-      case corr @ Corr(child1, child2, nullOnDivideByZero) =>
-        val child1Expr = exprToProto(child1, inputs, binding)
-        val child2Expr = exprToProto(child2, inputs, binding)
-        val dataType = serializeDataType(corr.dataType)
-
-        if (child1Expr.isDefined && child2Expr.isDefined && 
dataType.isDefined) {
-          val corrBuilder = ExprOuterClass.Correlation.newBuilder()
-          corrBuilder.setChild1(child1Expr.get)
-          corrBuilder.setChild2(child2Expr.get)
-          corrBuilder.setNullOnDivideByZero(nullOnDivideByZero)
-          corrBuilder.setDatatype(dataType.get)
-
-          Some(
-            ExprOuterClass.AggExpr
-              .newBuilder()
-              .setCorrelation(corrBuilder)
-              .build())
-        } else {
-          withInfo(aggExpr, child1, child2)
-          None
-        }
-
-      case bloom_filter @ BloomFilterAggregate(child, numItems, numBits, _, _) 
=>
-        // We ignore mutableAggBufferOffset and inputAggBufferOffset because 
they are
-        // implementation details for Spark's ObjectHashAggregate.
-        val childExpr = exprToProto(child, inputs, binding)
-        val numItemsExpr = exprToProto(numItems, inputs, binding)
-        val numBitsExpr = exprToProto(numBits, inputs, binding)
-        val dataType = serializeDataType(bloom_filter.dataType)
-
-        if (childExpr.isDefined &&
-          (child.dataType
-            .isInstanceOf[ByteType] ||
-            child.dataType
-              .isInstanceOf[ShortType] ||
-            child.dataType
-              .isInstanceOf[IntegerType] ||
-            child.dataType
-              .isInstanceOf[LongType] ||
-            child.dataType
-              .isInstanceOf[StringType]) &&
-          numItemsExpr.isDefined &&
-          numBitsExpr.isDefined &&
-          dataType.isDefined) {
-          val bloomFilterAggBuilder = 
ExprOuterClass.BloomFilterAgg.newBuilder()
-          bloomFilterAggBuilder.setChild(childExpr.get)
-          bloomFilterAggBuilder.setNumItems(numItemsExpr.get)
-          bloomFilterAggBuilder.setNumBits(numBitsExpr.get)
-          bloomFilterAggBuilder.setDatatype(dataType.get)
-
-          Some(
-            ExprOuterClass.AggExpr
-              .newBuilder()
-              .setBloomFilterAgg(bloomFilterAggBuilder)
-              .build())
-        } else {
-          withInfo(aggExpr, child, numItems, numBits)
-          None
-        }
-
+    val cometExpr: CometAggregateExpressionSerde = aggExpr.aggregateFunction 
match {
+      case _: Sum => CometSum
+      case _: Average => CometAverage
+      case _: Count => CometCount
+      case _: Min => CometMin
+      case _: Max => CometMax
+      case _: First => CometFirst
+      case _: Last => CometLast
+      case _: BitAndAgg => CometBitAndAgg
+      case _: BitOrAgg => CometBitOrAgg
+      case _: BitXorAgg => CometBitXOrAgg
+      case _: CovSample => CometCovSample
+      case _: CovPopulation => CometCovPopulation
+      case _: VarianceSamp => CometVarianceSamp
+      case _: VariancePop => CometVariancePop
+      case _: StddevSamp => CometStddevSamp
+      case _: StddevPop => CometStddevPop
+      case _: Corr => CometCorr
+      case _: BloomFilterAggregate => CometBloomFilterAggregate
       case fn =>
         val msg = s"unsupported Spark aggregate function: ${fn.prettyName}"
         emitWarning(msg)
         withInfo(aggExpr, msg, fn.children: _*)
-        None
+        return None
+
     }
+    cometExpr.convert(aggExpr, aggExpr.aggregateFunction, inputs, binding, 
conf)
   }
 
   def evalModeToProto(evalMode: CometEvalMode.Value): ExprOuterClass.EvalMode 
= {
@@ -3441,5 +3006,37 @@ trait CometExpressionSerde {
       binding: Boolean): Option[ExprOuterClass.Expr]
 }
 
+/**
+ * Trait for providing serialization logic for aggregate expressions.
+ */
+trait CometAggregateExpressionSerde {
+
+  /**
+   * Convert a Spark expression into a protocol buffer representation that can 
be passed into
+   * native code.
+   *
+   * @param expr
+   *   The aggregate expression.
+   * @param expr
+   *   The aggregate function.
+   * @param inputs
+   *   The input attributes.
+   * @param binding
+   *   Whether the attributes are bound (this is only relevant in aggregate 
expressions).
+   * @param conf
+   *   SQLConf
+   * @return
+   *   Protocol buffer representation, or None if the expression could not be 
converted. In this
+   *   case it is expected that the input expression will have been tagged 
with reasons why it
+   *   could not be converted.
+   */
+  def convert(
+      aggExpr: AggregateExpression,
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean,
+      conf: SQLConf): Option[ExprOuterClass.AggExpr]
+}
+
 /** Marker trait for an expression that is not guaranteed to be 100% 
compatible with Spark */
 trait IncompatExpr {}
diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala 
b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
new file mode 100644
index 000000000..da5e9ff53
--- /dev/null
+++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
@@ -0,0 +1,734 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.serde
+
+import scala.collection.JavaConverters.asJavaIterableConverter
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, 
CentralMomentAgg, Corr, Covariance, CovPopulation, CovSample, First, Last, 
StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{ByteType, DecimalType, IntegerType, 
LongType, ShortType, StringType}
+
+import org.apache.comet.CometConf
+import org.apache.comet.CometSparkSessionExtensions.withInfo
+import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType}
+import org.apache.comet.shims.ShimQueryPlanSerde
+
+object CometMin extends CometAggregateExpressionSerde {
+
+  override def convert(
+      aggExpr: AggregateExpression,
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean,
+      conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+    if (!AggSerde.minMaxDataTypeSupported(expr.dataType)) {
+      withInfo(aggExpr, s"Unsupported data type: ${expr.dataType}")
+      return None
+    }
+    val child = expr.children.head
+    val childExpr = exprToProto(child, inputs, binding)
+    val dataType = serializeDataType(expr.dataType)
+
+    if (childExpr.isDefined && dataType.isDefined) {
+      val builder = ExprOuterClass.Min.newBuilder()
+      builder.setChild(childExpr.get)
+      builder.setDatatype(dataType.get)
+
+      Some(
+        ExprOuterClass.AggExpr
+          .newBuilder()
+          .setMin(builder)
+          .build())
+    } else if (dataType.isEmpty) {
+      withInfo(aggExpr, s"datatype ${expr.dataType} is not supported", child)
+      None
+    } else {
+      withInfo(aggExpr, child)
+      None
+    }
+  }
+}
+
+object CometMax extends CometAggregateExpressionSerde {
+
+  override def convert(
+      aggExpr: AggregateExpression,
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean,
+      conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+    if (!AggSerde.minMaxDataTypeSupported(expr.dataType)) {
+      withInfo(aggExpr, s"Unsupported data type: ${expr.dataType}")
+      return None
+    }
+    val child = expr.children.head
+    val childExpr = exprToProto(child, inputs, binding)
+    val dataType = serializeDataType(expr.dataType)
+
+    if (childExpr.isDefined && dataType.isDefined) {
+      val builder = ExprOuterClass.Max.newBuilder()
+      builder.setChild(childExpr.get)
+      builder.setDatatype(dataType.get)
+
+      Some(
+        ExprOuterClass.AggExpr
+          .newBuilder()
+          .setMax(builder)
+          .build())
+    } else if (dataType.isEmpty) {
+      withInfo(aggExpr, s"datatype ${expr.dataType} is not supported", child)
+      None
+    } else {
+      withInfo(aggExpr, child)
+      None
+    }
+  }
+}
+
+object CometCount extends CometAggregateExpressionSerde {
+  override def convert(
+      aggExpr: AggregateExpression,
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean,
+      conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+    val exprChildren = expr.children.map(exprToProto(_, inputs, binding))
+    if (exprChildren.forall(_.isDefined)) {
+      val builder = ExprOuterClass.Count.newBuilder()
+      builder.addAllChildren(exprChildren.map(_.get).asJava)
+      Some(
+        ExprOuterClass.AggExpr
+          .newBuilder()
+          .setCount(builder)
+          .build())
+    } else {
+      withInfo(aggExpr, expr.children: _*)
+      None
+    }
+  }
+}
+
+object CometAverage extends CometAggregateExpressionSerde with 
ShimQueryPlanSerde {
+  override def convert(
+      aggExpr: AggregateExpression,
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean,
+      conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+    val avg = expr.asInstanceOf[Average]
+
+    if (!AggSerde.avgDataTypeSupported(expr.dataType)) {
+      withInfo(aggExpr, s"Unsupported data type: ${expr.dataType}")
+      return None
+    }
+
+    if (!isLegacyMode(avg)) {
+      withInfo(aggExpr, "Average is only supported in legacy mode")
+      return None
+    }
+
+    val child = avg.child
+    val childExpr = exprToProto(child, inputs, binding)
+    val dataType = serializeDataType(expr.dataType)
+
+    val sumDataType = child.dataType match {
+      case decimalType: DecimalType =>
+        // This is input precision + 10 to be consistent with Spark
+        val precision = Math.min(DecimalType.MAX_PRECISION, 
decimalType.precision + 10)
+        val newType =
+          DecimalType.apply(precision, decimalType.scale)
+        serializeDataType(newType)
+      case _ =>
+        serializeDataType(child.dataType)
+    }
+
+    if (childExpr.isDefined && dataType.isDefined) {
+      val builder = ExprOuterClass.Avg.newBuilder()
+      builder.setChild(childExpr.get)
+      builder.setDatatype(dataType.get)
+      builder.setFailOnError(getFailOnError(avg))
+      builder.setSumDatatype(sumDataType.get)
+
+      Some(
+        ExprOuterClass.AggExpr
+          .newBuilder()
+          .setAvg(builder)
+          .build())
+    } else if (dataType.isEmpty) {
+      withInfo(aggExpr, s"datatype ${expr.dataType} is not supported", child)
+      None
+    } else {
+      withInfo(aggExpr, child)
+      None
+    }
+  }
+}
+object CometSum extends CometAggregateExpressionSerde with ShimQueryPlanSerde {
+  override def convert(
+      aggExpr: AggregateExpression,
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean,
+      conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+    val sum = expr.asInstanceOf[Sum]
+
+    if (!AggSerde.sumDataTypeSupported(sum.dataType)) {
+      withInfo(aggExpr, s"Unsupported data type: ${expr.dataType}")
+      return None
+    }
+
+    if (!isLegacyMode(sum)) {
+      withInfo(aggExpr, "Sum is only supported in legacy mode")
+      return None
+    }
+
+    val childExpr = exprToProto(sum.child, inputs, binding)
+    val dataType = serializeDataType(sum.dataType)
+
+    if (childExpr.isDefined && dataType.isDefined) {
+      val builder = ExprOuterClass.Sum.newBuilder()
+      builder.setChild(childExpr.get)
+      builder.setDatatype(dataType.get)
+      builder.setFailOnError(getFailOnError(sum))
+
+      Some(
+        ExprOuterClass.AggExpr
+          .newBuilder()
+          .setSum(builder)
+          .build())
+    } else {
+      if (dataType.isEmpty) {
+        withInfo(aggExpr, s"datatype ${sum.dataType} is not supported", 
sum.child)
+      } else {
+        withInfo(aggExpr, sum.child)
+      }
+      None
+    }
+  }
+}
+
+object CometFirst extends CometAggregateExpressionSerde {
+  override def convert(
+      aggExpr: AggregateExpression,
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean,
+      conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+    val first = expr.asInstanceOf[First]
+    if (first.ignoreNulls) {
+      // DataFusion doesn't support ignoreNulls true
+      withInfo(aggExpr, "First not supported when ignoreNulls=true")
+      return None
+    }
+    val child = first.children.head
+    val childExpr = exprToProto(child, inputs, binding)
+    val dataType = serializeDataType(first.dataType)
+
+    if (childExpr.isDefined && dataType.isDefined) {
+      val builder = ExprOuterClass.First.newBuilder()
+      builder.setChild(childExpr.get)
+      builder.setDatatype(dataType.get)
+
+      Some(
+        ExprOuterClass.AggExpr
+          .newBuilder()
+          .setFirst(builder)
+          .build())
+    } else if (dataType.isEmpty) {
+      withInfo(aggExpr, s"datatype ${first.dataType} is not supported", child)
+      None
+    } else {
+      withInfo(aggExpr, child)
+      None
+    }
+  }
+}
+
+object CometLast extends CometAggregateExpressionSerde {
+  override def convert(
+      aggExpr: AggregateExpression,
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean,
+      conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+    val last = expr.asInstanceOf[Last]
+    if (last.ignoreNulls) {
+      // DataFusion doesn't support ignoreNulls true
+      withInfo(aggExpr, "Last not supported when ignoreNulls=true")
+      return None
+    }
+    val child = last.children.head
+    val childExpr = exprToProto(child, inputs, binding)
+    val dataType = serializeDataType(last.dataType)
+
+    if (childExpr.isDefined && dataType.isDefined) {
+      val builder = ExprOuterClass.Last.newBuilder()
+      builder.setChild(childExpr.get)
+      builder.setDatatype(dataType.get)
+
+      Some(
+        ExprOuterClass.AggExpr
+          .newBuilder()
+          .setLast(builder)
+          .build())
+    } else if (dataType.isEmpty) {
+      withInfo(aggExpr, s"datatype ${last.dataType} is not supported", child)
+      None
+    } else {
+      withInfo(aggExpr, child)
+      None
+    }
+  }
+}
+
+object CometBitAndAgg extends CometAggregateExpressionSerde {
+  override def convert(
+      aggExpr: AggregateExpression,
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean,
+      conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+    val bitAnd = expr.asInstanceOf[BitAndAgg]
+    if (!AggSerde.bitwiseAggTypeSupported(bitAnd.dataType)) {
+      withInfo(aggExpr, s"Unsupported data type: ${expr.dataType}")
+      return None
+    }
+    val child = bitAnd.child
+    val childExpr = exprToProto(child, inputs, binding)
+    val dataType = serializeDataType(bitAnd.dataType)
+
+    if (childExpr.isDefined && dataType.isDefined) {
+      val builder = ExprOuterClass.BitAndAgg.newBuilder()
+      builder.setChild(childExpr.get)
+      builder.setDatatype(dataType.get)
+      Some(
+        ExprOuterClass.AggExpr
+          .newBuilder()
+          .setBitAndAgg(builder)
+          .build())
+    } else if (dataType.isEmpty) {
+      withInfo(aggExpr, s"datatype ${bitAnd.dataType} is not supported", child)
+      None
+    } else {
+      withInfo(aggExpr, child)
+      None
+    }
+  }
+}
+
+object CometBitOrAgg extends CometAggregateExpressionSerde {
+  override def convert(
+      aggExpr: AggregateExpression,
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean,
+      conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+    val bitOr = expr.asInstanceOf[BitOrAgg]
+    if (!AggSerde.bitwiseAggTypeSupported(bitOr.dataType)) {
+      withInfo(aggExpr, s"Unsupported data type: ${expr.dataType}")
+      return None
+    }
+    val child = bitOr.child
+    val childExpr = exprToProto(child, inputs, binding)
+    val dataType = serializeDataType(bitOr.dataType)
+
+    if (childExpr.isDefined && dataType.isDefined) {
+      val builder = ExprOuterClass.BitOrAgg.newBuilder()
+      builder.setChild(childExpr.get)
+      builder.setDatatype(dataType.get)
+      Some(
+        ExprOuterClass.AggExpr
+          .newBuilder()
+          .setBitOrAgg(builder)
+          .build())
+    } else if (dataType.isEmpty) {
+      withInfo(aggExpr, s"datatype ${bitOr.dataType} is not supported", child)
+      None
+    } else {
+      withInfo(aggExpr, child)
+      None
+    }
+  }
+}
+
+object CometBitXOrAgg extends CometAggregateExpressionSerde {
+  override def convert(
+      aggExpr: AggregateExpression,
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean,
+      conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+    val bitXor = expr.asInstanceOf[BitXorAgg]
+    if (!AggSerde.bitwiseAggTypeSupported(bitXor.dataType)) {
+      withInfo(aggExpr, s"Unsupported data type: ${expr.dataType}")
+      return None
+    }
+    val child = bitXor.child
+    val childExpr = exprToProto(child, inputs, binding)
+    val dataType = serializeDataType(bitXor.dataType)
+
+    if (childExpr.isDefined && dataType.isDefined) {
+      val builder = ExprOuterClass.BitXorAgg.newBuilder()
+      builder.setChild(childExpr.get)
+      builder.setDatatype(dataType.get)
+      Some(
+        ExprOuterClass.AggExpr
+          .newBuilder()
+          .setBitXorAgg(builder)
+          .build())
+    } else if (dataType.isEmpty) {
+      withInfo(aggExpr, s"datatype ${bitXor.dataType} is not supported", child)
+      None
+    } else {
+      withInfo(aggExpr, child)
+      None
+    }
+  }
+}
+
+trait CometCovBase extends CometAggregateExpressionSerde {
+  def convertCov(
+      aggExpr: AggregateExpression,
+      cov: Covariance,
+      nullOnDivideByZero: Boolean,
+      statsType: Int,
+      inputs: Seq[Attribute],
+      binding: Boolean,
+      conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+    val child1Expr = exprToProto(cov.left, inputs, binding)
+    val child2Expr = exprToProto(cov.right, inputs, binding)
+    val dataType = serializeDataType(cov.dataType)
+
+    if (child1Expr.isDefined && child2Expr.isDefined && dataType.isDefined) {
+      val builder = ExprOuterClass.Covariance.newBuilder()
+      builder.setChild1(child1Expr.get)
+      builder.setChild2(child2Expr.get)
+      builder.setNullOnDivideByZero(nullOnDivideByZero)
+      builder.setDatatype(dataType.get)
+      builder.setStatsTypeValue(statsType)
+
+      Some(
+        ExprOuterClass.AggExpr
+          .newBuilder()
+          .setCovariance(builder)
+          .build())
+    } else {
+      withInfo(aggExpr, "Child expression or data type not supported")
+      None
+    }
+  }
+}
+
+object CometCovSample extends CometCovBase {
+  override def convert(
+      aggExpr: AggregateExpression,
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean,
+      conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+    val covSample = expr.asInstanceOf[CovSample]
+    convertCov(
+      aggExpr,
+      covSample,
+      covSample.nullOnDivideByZero,
+      0,
+      inputs,
+      binding,
+      conf: SQLConf)
+  }
+}
+
+object CometCovPopulation extends CometCovBase {
+  override def convert(
+      aggExpr: AggregateExpression,
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean,
+      conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+    val covPopulation = expr.asInstanceOf[CovPopulation]
+    convertCov(
+      aggExpr,
+      covPopulation,
+      covPopulation.nullOnDivideByZero,
+      1,
+      inputs,
+      binding,
+      conf: SQLConf)
+  }
+}
+
+trait CometVariance extends CometAggregateExpressionSerde {
+  def convertVariance(
+      aggExpr: AggregateExpression,
+      expr: CentralMomentAgg,
+      nullOnDivideByZero: Boolean,
+      statsType: Int,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[ExprOuterClass.AggExpr] = {
+    val childExpr = exprToProto(expr.child, inputs, binding)
+    val dataType = serializeDataType(expr.dataType)
+
+    if (childExpr.isDefined && dataType.isDefined) {
+      val builder = ExprOuterClass.Variance.newBuilder()
+      builder.setChild(childExpr.get)
+      builder.setNullOnDivideByZero(nullOnDivideByZero)
+      builder.setDatatype(dataType.get)
+      builder.setStatsTypeValue(statsType)
+
+      Some(
+        ExprOuterClass.AggExpr
+          .newBuilder()
+          .setVariance(builder)
+          .build())
+    } else {
+      withInfo(aggExpr, expr.child)
+      None
+    }
+  }
+
+}
+
+object CometVarianceSamp extends CometVariance {
+  override def convert(
+      aggExpr: AggregateExpression,
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean,
+      conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+    val variance = expr.asInstanceOf[VarianceSamp]
+    convertVariance(aggExpr, variance, variance.nullOnDivideByZero, 0, inputs, 
binding)
+  }
+}
+
+object CometVariancePop extends CometVariance {
+  override def convert(
+      aggExpr: AggregateExpression,
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean,
+      conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+    val variance = expr.asInstanceOf[VariancePop]
+    convertVariance(aggExpr, variance, variance.nullOnDivideByZero, 1, inputs, 
binding)
+  }
+}
+
+trait CometStddev extends CometAggregateExpressionSerde {
+  def convertStddev(
+      aggExpr: AggregateExpression,
+      stddev: CentralMomentAgg,
+      nullOnDivideByZero: Boolean,
+      statsType: Int,
+      inputs: Seq[Attribute],
+      binding: Boolean,
+      conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+    val child = stddev.child
+    if (CometConf.COMET_EXPR_STDDEV_ENABLED.get(conf)) {
+      val childExpr = exprToProto(child, inputs, binding)
+      val dataType = serializeDataType(stddev.dataType)
+
+      if (childExpr.isDefined && dataType.isDefined) {
+        val builder = ExprOuterClass.Stddev.newBuilder()
+        builder.setChild(childExpr.get)
+        builder.setNullOnDivideByZero(nullOnDivideByZero)
+        builder.setDatatype(dataType.get)
+        builder.setStatsTypeValue(statsType)
+
+        Some(
+          ExprOuterClass.AggExpr
+            .newBuilder()
+            .setStddev(builder)
+            .build())
+      } else {
+        withInfo(aggExpr, child)
+        None
+      }
+    } else {
+      withInfo(
+        aggExpr,
+        "stddev disabled by default because it can be slower than Spark. " +
+          s"Set ${CometConf.COMET_EXPR_STDDEV_ENABLED}=true to enable it.",
+        child)
+      None
+    }
+  }
+}
+
+object CometStddevSamp extends CometStddev {
+  override def convert(
+      aggExpr: AggregateExpression,
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean,
+      conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+    val variance = expr.asInstanceOf[StddevSamp]
+    convertStddev(
+      aggExpr,
+      variance,
+      variance.nullOnDivideByZero,
+      0,
+      inputs,
+      binding,
+      conf: SQLConf)
+  }
+}
+
+object CometStddevPop extends CometStddev {
+  override def convert(
+      aggExpr: AggregateExpression,
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean,
+      conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+    val variance = expr.asInstanceOf[StddevPop]
+    convertStddev(
+      aggExpr,
+      variance,
+      variance.nullOnDivideByZero,
+      1,
+      inputs,
+      binding,
+      conf: SQLConf)
+  }
+}
+
+object CometCorr extends CometAggregateExpressionSerde {
+  override def convert(
+      aggExpr: AggregateExpression,
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean,
+      conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+    val corr = expr.asInstanceOf[Corr]
+    val child1Expr = exprToProto(corr.x, inputs, binding)
+    val child2Expr = exprToProto(corr.y, inputs, binding)
+    val dataType = serializeDataType(corr.dataType)
+
+    if (child1Expr.isDefined && child2Expr.isDefined && dataType.isDefined) {
+      val builder = ExprOuterClass.Correlation.newBuilder()
+      builder.setChild1(child1Expr.get)
+      builder.setChild2(child2Expr.get)
+      builder.setNullOnDivideByZero(corr.nullOnDivideByZero)
+      builder.setDatatype(dataType.get)
+
+      Some(
+        ExprOuterClass.AggExpr
+          .newBuilder()
+          .setCorrelation(builder)
+          .build())
+    } else {
+      withInfo(aggExpr, corr.x, corr.y)
+      None
+    }
+  }
+}
+
+object CometBloomFilterAggregate extends CometAggregateExpressionSerde {
+
+  override def convert(
+      aggExpr: AggregateExpression,
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean,
+      conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+    // We ignore mutableAggBufferOffset and inputAggBufferOffset because they 
are
+    // implementation details for Spark's ObjectHashAggregate.
+    val bloomFilter = expr.asInstanceOf[BloomFilterAggregate]
+    val childExpr = exprToProto(bloomFilter.child, inputs, binding)
+    val numItemsExpr = exprToProto(bloomFilter.estimatedNumItemsExpression, 
inputs, binding)
+    val numBitsExpr = exprToProto(bloomFilter.numBitsExpression, inputs, 
binding)
+    val dataType = serializeDataType(bloomFilter.dataType)
+
+    if (childExpr.isDefined &&
+      (bloomFilter.child.dataType
+        .isInstanceOf[ByteType] ||
+        bloomFilter.child.dataType
+          .isInstanceOf[ShortType] ||
+        bloomFilter.child.dataType
+          .isInstanceOf[IntegerType] ||
+        bloomFilter.child.dataType
+          .isInstanceOf[LongType] ||
+        bloomFilter.child.dataType
+          .isInstanceOf[StringType]) &&
+      numItemsExpr.isDefined &&
+      numBitsExpr.isDefined &&
+      dataType.isDefined) {
+      val builder = ExprOuterClass.BloomFilterAgg.newBuilder()
+      builder.setChild(childExpr.get)
+      builder.setNumItems(numItemsExpr.get)
+      builder.setNumBits(numBitsExpr.get)
+      builder.setDatatype(dataType.get)
+
+      Some(
+        ExprOuterClass.AggExpr
+          .newBuilder()
+          .setBloomFilterAgg(builder)
+          .build())
+    } else {
+      withInfo(
+        aggExpr,
+        bloomFilter.child,
+        bloomFilter.estimatedNumItemsExpression,
+        bloomFilter.numBitsExpression)
+      None
+    }
+  }
+}
+
+object AggSerde {
+  import org.apache.spark.sql.types._
+
+  def minMaxDataTypeSupported(dt: DataType): Boolean = {
+    dt match {
+      case BooleanType => true
+      case ByteType | ShortType | IntegerType | LongType => true
+      case FloatType | DoubleType => true
+      case _: DecimalType => true
+      case DateType | TimestampType => true
+      case _ => false
+    }
+  }
+
+  def avgDataTypeSupported(dt: DataType): Boolean = {
+    dt match {
+      case ByteType | ShortType | IntegerType | LongType => true
+      case FloatType | DoubleType => true
+      case _: DecimalType => true
+      case _ => false
+    }
+  }
+
+  def sumDataTypeSupported(dt: DataType): Boolean = {
+    dt match {
+      case ByteType | ShortType | IntegerType | LongType => true
+      case FloatType | DoubleType => true
+      case _: DecimalType => true
+      case _ => false
+    }
+  }
+
+  def bitwiseAggTypeSupported(dt: DataType): Boolean = {
+    dt match {
+      case ByteType | ShortType | IntegerType | LongType => true
+      case _ => false
+    }
+  }
+
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to