This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new a4c6c12ca chore: Refactor aggregate expression serde (#1380)
a4c6c12ca is described below
commit a4c6c12cafd22b01f3eed88081f1bada4eeb5cdf
Author: Andy Grove <[email protected]>
AuthorDate: Fri Feb 14 06:10:59 2025 -0700
chore: Refactor aggregate expression serde (#1380)
---
.../org/apache/comet/serde/QueryPlanSerde.scala | 519 ++-------------
.../scala/org/apache/comet/serde/aggregates.scala | 734 +++++++++++++++++++++
2 files changed, 792 insertions(+), 461 deletions(-)
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 8f0a4273b..aa1aba11d 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -174,35 +174,6 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
Some(dataType)
}
- private def sumDataTypeSupported(dt: DataType): Boolean = {
- dt match {
- case _: NumericType => true
- case _ => false
- }
- }
-
- private def avgDataTypeSupported(dt: DataType): Boolean = {
- dt match {
- case _: NumericType => true
- // TODO: implement support for interval types
- case _ => false
- }
- }
-
- private def minMaxDataTypeSupported(dt: DataType): Boolean = {
- dt match {
- case _: NumericType | DateType | TimestampType | BooleanType => true
- case _ => false
- }
- }
-
- private def bitwiseAggTypeSupported(dt: DataType): Boolean = {
- dt match {
- case _: IntegerType | LongType | ShortType | ByteType => true
- case _ => false
- }
- }
-
def windowExprToProto(
windowExpr: WindowExpression,
output: Seq[Attribute],
@@ -215,21 +186,22 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
case _: Count =>
Some(agg)
case min: Min =>
- if (minMaxDataTypeSupported(min.dataType)) {
+ if (AggSerde.minMaxDataTypeSupported(min.dataType)) {
Some(agg)
} else {
withInfo(windowExpr, s"datatype ${min.dataType} is not
supported", expr)
None
}
case max: Max =>
- if (minMaxDataTypeSupported(max.dataType)) {
+ if (AggSerde.minMaxDataTypeSupported(max.dataType)) {
Some(agg)
} else {
withInfo(windowExpr, s"datatype ${max.dataType} is not
supported", expr)
None
}
case s: Sum =>
- if (sumDataTypeSupported(s.dataType) &&
!s.dataType.isInstanceOf[DecimalType]) {
+ if (AggSerde.sumDataTypeSupported(s.dataType) && !s.dataType
+ .isInstanceOf[DecimalType]) {
Some(agg)
} else {
withInfo(windowExpr, s"datatype ${s.dataType} is not
supported", expr)
@@ -379,440 +351,33 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
return None
}
- aggExpr.aggregateFunction match {
- case s @ Sum(child, _) if sumDataTypeSupported(s.dataType) &&
isLegacyMode(s) =>
- val childExpr = exprToProto(child, inputs, binding)
- val dataType = serializeDataType(s.dataType)
-
- if (childExpr.isDefined && dataType.isDefined) {
- val sumBuilder = ExprOuterClass.Sum.newBuilder()
- sumBuilder.setChild(childExpr.get)
- sumBuilder.setDatatype(dataType.get)
- sumBuilder.setFailOnError(getFailOnError(s))
-
- Some(
- ExprOuterClass.AggExpr
- .newBuilder()
- .setSum(sumBuilder)
- .build())
- } else {
- if (dataType.isEmpty) {
- withInfo(aggExpr, s"datatype ${s.dataType} is not supported",
child)
- } else {
- withInfo(aggExpr, child)
- }
- None
- }
- case s @ Average(child, _) if avgDataTypeSupported(s.dataType) &&
isLegacyMode(s) =>
- val childExpr = exprToProto(child, inputs, binding)
- val dataType = serializeDataType(s.dataType)
-
- val sumDataType = if (child.dataType.isInstanceOf[DecimalType]) {
-
- // This is input precision + 10 to be consistent with Spark
- val precision = Math.min(
- DecimalType.MAX_PRECISION,
- child.dataType.asInstanceOf[DecimalType].precision + 10)
- val newType =
- DecimalType.apply(precision,
child.dataType.asInstanceOf[DecimalType].scale)
- serializeDataType(newType)
- } else {
- serializeDataType(child.dataType)
- }
-
- if (childExpr.isDefined && dataType.isDefined) {
- val builder = ExprOuterClass.Avg.newBuilder()
- builder.setChild(childExpr.get)
- builder.setDatatype(dataType.get)
- builder.setFailOnError(getFailOnError(s))
- builder.setSumDatatype(sumDataType.get)
-
- Some(
- ExprOuterClass.AggExpr
- .newBuilder()
- .setAvg(builder)
- .build())
- } else if (dataType.isEmpty) {
- withInfo(aggExpr, s"datatype ${s.dataType} is not supported", child)
- None
- } else {
- withInfo(aggExpr, child)
- None
- }
- case Count(children) =>
- val exprChildren = children.map(exprToProto(_, inputs, binding))
-
- if (exprChildren.forall(_.isDefined)) {
- val countBuilder = ExprOuterClass.Count.newBuilder()
- countBuilder.addAllChildren(exprChildren.map(_.get).asJava)
-
- Some(
- ExprOuterClass.AggExpr
- .newBuilder()
- .setCount(countBuilder)
- .build())
- } else {
- withInfo(aggExpr, children: _*)
- None
- }
- case min @ Min(child) if minMaxDataTypeSupported(min.dataType) =>
- val childExpr = exprToProto(child, inputs, binding)
- val dataType = serializeDataType(min.dataType)
-
- if (childExpr.isDefined && dataType.isDefined) {
- val minBuilder = ExprOuterClass.Min.newBuilder()
- minBuilder.setChild(childExpr.get)
- minBuilder.setDatatype(dataType.get)
-
- Some(
- ExprOuterClass.AggExpr
- .newBuilder()
- .setMin(minBuilder)
- .build())
- } else if (dataType.isEmpty) {
- withInfo(aggExpr, s"datatype ${min.dataType} is not supported",
child)
- None
- } else {
- withInfo(aggExpr, child)
- None
- }
- case max @ Max(child) if minMaxDataTypeSupported(max.dataType) =>
- val childExpr = exprToProto(child, inputs, binding)
- val dataType = serializeDataType(max.dataType)
-
- if (childExpr.isDefined && dataType.isDefined) {
- val maxBuilder = ExprOuterClass.Max.newBuilder()
- maxBuilder.setChild(childExpr.get)
- maxBuilder.setDatatype(dataType.get)
-
- Some(
- ExprOuterClass.AggExpr
- .newBuilder()
- .setMax(maxBuilder)
- .build())
- } else if (dataType.isEmpty) {
- withInfo(aggExpr, s"datatype ${max.dataType} is not supported",
child)
- None
- } else {
- withInfo(aggExpr, child)
- None
- }
- case first @ First(child, ignoreNulls)
- if !ignoreNulls => // DataFusion doesn't support ignoreNulls true
- val childExpr = exprToProto(child, inputs, binding)
- val dataType = serializeDataType(first.dataType)
-
- if (childExpr.isDefined && dataType.isDefined) {
- val firstBuilder = ExprOuterClass.First.newBuilder()
- firstBuilder.setChild(childExpr.get)
- firstBuilder.setDatatype(dataType.get)
-
- Some(
- ExprOuterClass.AggExpr
- .newBuilder()
- .setFirst(firstBuilder)
- .build())
- } else if (dataType.isEmpty) {
- withInfo(aggExpr, s"datatype ${first.dataType} is not supported",
child)
- None
- } else {
- withInfo(aggExpr, child)
- None
- }
- case last @ Last(child, ignoreNulls)
- if !ignoreNulls => // DataFusion doesn't support ignoreNulls true
- val childExpr = exprToProto(child, inputs, binding)
- val dataType = serializeDataType(last.dataType)
-
- if (childExpr.isDefined && dataType.isDefined) {
- val lastBuilder = ExprOuterClass.Last.newBuilder()
- lastBuilder.setChild(childExpr.get)
- lastBuilder.setDatatype(dataType.get)
-
- Some(
- ExprOuterClass.AggExpr
- .newBuilder()
- .setLast(lastBuilder)
- .build())
- } else if (dataType.isEmpty) {
- withInfo(aggExpr, s"datatype ${last.dataType} is not supported",
child)
- None
- } else {
- withInfo(aggExpr, child)
- None
- }
- case bitAnd @ BitAndAgg(child) if
bitwiseAggTypeSupported(bitAnd.dataType) =>
- val childExpr = exprToProto(child, inputs, binding)
- val dataType = serializeDataType(bitAnd.dataType)
-
- if (childExpr.isDefined && dataType.isDefined) {
- val bitAndBuilder = ExprOuterClass.BitAndAgg.newBuilder()
- bitAndBuilder.setChild(childExpr.get)
- bitAndBuilder.setDatatype(dataType.get)
-
- Some(
- ExprOuterClass.AggExpr
- .newBuilder()
- .setBitAndAgg(bitAndBuilder)
- .build())
- } else if (dataType.isEmpty) {
- withInfo(aggExpr, s"datatype ${bitAnd.dataType} is not supported",
child)
- None
- } else {
- withInfo(aggExpr, child)
- None
- }
- case bitOr @ BitOrAgg(child) if bitwiseAggTypeSupported(bitOr.dataType)
=>
- val childExpr = exprToProto(child, inputs, binding)
- val dataType = serializeDataType(bitOr.dataType)
-
- if (childExpr.isDefined && dataType.isDefined) {
- val bitOrBuilder = ExprOuterClass.BitOrAgg.newBuilder()
- bitOrBuilder.setChild(childExpr.get)
- bitOrBuilder.setDatatype(dataType.get)
-
- Some(
- ExprOuterClass.AggExpr
- .newBuilder()
- .setBitOrAgg(bitOrBuilder)
- .build())
- } else if (dataType.isEmpty) {
- withInfo(aggExpr, s"datatype ${bitOr.dataType} is not supported",
child)
- None
- } else {
- withInfo(aggExpr, child)
- None
- }
- case bitXor @ BitXorAgg(child) if
bitwiseAggTypeSupported(bitXor.dataType) =>
- val childExpr = exprToProto(child, inputs, binding)
- val dataType = serializeDataType(bitXor.dataType)
-
- if (childExpr.isDefined && dataType.isDefined) {
- val bitXorBuilder = ExprOuterClass.BitXorAgg.newBuilder()
- bitXorBuilder.setChild(childExpr.get)
- bitXorBuilder.setDatatype(dataType.get)
-
- Some(
- ExprOuterClass.AggExpr
- .newBuilder()
- .setBitXorAgg(bitXorBuilder)
- .build())
- } else if (dataType.isEmpty) {
- withInfo(aggExpr, s"datatype ${bitXor.dataType} is not supported",
child)
- None
- } else {
- withInfo(aggExpr, child)
- None
- }
- case cov @ CovSample(child1, child2, nullOnDivideByZero) =>
- val child1Expr = exprToProto(child1, inputs, binding)
- val child2Expr = exprToProto(child2, inputs, binding)
- val dataType = serializeDataType(cov.dataType)
-
- if (child1Expr.isDefined && child2Expr.isDefined &&
dataType.isDefined) {
- val covBuilder = ExprOuterClass.Covariance.newBuilder()
- covBuilder.setChild1(child1Expr.get)
- covBuilder.setChild2(child2Expr.get)
- covBuilder.setNullOnDivideByZero(nullOnDivideByZero)
- covBuilder.setDatatype(dataType.get)
- covBuilder.setStatsTypeValue(0)
-
- Some(
- ExprOuterClass.AggExpr
- .newBuilder()
- .setCovariance(covBuilder)
- .build())
- } else {
- None
- }
- case cov @ CovPopulation(child1, child2, nullOnDivideByZero) =>
- val child1Expr = exprToProto(child1, inputs, binding)
- val child2Expr = exprToProto(child2, inputs, binding)
- val dataType = serializeDataType(cov.dataType)
-
- if (child1Expr.isDefined && child2Expr.isDefined &&
dataType.isDefined) {
- val covBuilder = ExprOuterClass.Covariance.newBuilder()
- covBuilder.setChild1(child1Expr.get)
- covBuilder.setChild2(child2Expr.get)
- covBuilder.setNullOnDivideByZero(nullOnDivideByZero)
- covBuilder.setDatatype(dataType.get)
- covBuilder.setStatsTypeValue(1)
-
- Some(
- ExprOuterClass.AggExpr
- .newBuilder()
- .setCovariance(covBuilder)
- .build())
- } else {
- None
- }
- case variance @ VarianceSamp(child, nullOnDivideByZero) =>
- val childExpr = exprToProto(child, inputs, binding)
- val dataType = serializeDataType(variance.dataType)
-
- if (childExpr.isDefined && dataType.isDefined) {
- val varBuilder = ExprOuterClass.Variance.newBuilder()
- varBuilder.setChild(childExpr.get)
- varBuilder.setNullOnDivideByZero(nullOnDivideByZero)
- varBuilder.setDatatype(dataType.get)
- varBuilder.setStatsTypeValue(0)
-
- Some(
- ExprOuterClass.AggExpr
- .newBuilder()
- .setVariance(varBuilder)
- .build())
- } else {
- withInfo(aggExpr, child)
- None
- }
- case variancePop @ VariancePop(child, nullOnDivideByZero) =>
- val childExpr = exprToProto(child, inputs, binding)
- val dataType = serializeDataType(variancePop.dataType)
-
- if (childExpr.isDefined && dataType.isDefined) {
- val varBuilder = ExprOuterClass.Variance.newBuilder()
- varBuilder.setChild(childExpr.get)
- varBuilder.setNullOnDivideByZero(nullOnDivideByZero)
- varBuilder.setDatatype(dataType.get)
- varBuilder.setStatsTypeValue(1)
-
- Some(
- ExprOuterClass.AggExpr
- .newBuilder()
- .setVariance(varBuilder)
- .build())
- } else {
- withInfo(aggExpr, child)
- None
- }
-
- case std @ StddevSamp(child, nullOnDivideByZero) =>
- if (CometConf.COMET_EXPR_STDDEV_ENABLED.get(conf)) {
- val childExpr = exprToProto(child, inputs, binding)
- val dataType = serializeDataType(std.dataType)
-
- if (childExpr.isDefined && dataType.isDefined) {
- val stdBuilder = ExprOuterClass.Stddev.newBuilder()
- stdBuilder.setChild(childExpr.get)
- stdBuilder.setNullOnDivideByZero(nullOnDivideByZero)
- stdBuilder.setDatatype(dataType.get)
- stdBuilder.setStatsTypeValue(0)
-
- Some(
- ExprOuterClass.AggExpr
- .newBuilder()
- .setStddev(stdBuilder)
- .build())
- } else {
- withInfo(aggExpr, child)
- None
- }
- } else {
- withInfo(
- aggExpr,
- "stddev disabled by default because it can be slower than Spark. "
+
- s"Set ${CometConf.COMET_EXPR_STDDEV_ENABLED}=true to enable it.",
- child)
- None
- }
-
- case std @ StddevPop(child, nullOnDivideByZero) =>
- if (CometConf.COMET_EXPR_STDDEV_ENABLED.get(conf)) {
- val childExpr = exprToProto(child, inputs, binding)
- val dataType = serializeDataType(std.dataType)
-
- if (childExpr.isDefined && dataType.isDefined) {
- val stdBuilder = ExprOuterClass.Stddev.newBuilder()
- stdBuilder.setChild(childExpr.get)
- stdBuilder.setNullOnDivideByZero(nullOnDivideByZero)
- stdBuilder.setDatatype(dataType.get)
- stdBuilder.setStatsTypeValue(1)
-
- Some(
- ExprOuterClass.AggExpr
- .newBuilder()
- .setStddev(stdBuilder)
- .build())
- } else {
- withInfo(aggExpr, child)
- None
- }
- } else {
- withInfo(
- aggExpr,
- "stddev disabled by default because it can be slower than Spark. "
+
- s"Set ${CometConf.COMET_EXPR_STDDEV_ENABLED}=true to enable it.",
- child)
- None
- }
-
- case corr @ Corr(child1, child2, nullOnDivideByZero) =>
- val child1Expr = exprToProto(child1, inputs, binding)
- val child2Expr = exprToProto(child2, inputs, binding)
- val dataType = serializeDataType(corr.dataType)
-
- if (child1Expr.isDefined && child2Expr.isDefined &&
dataType.isDefined) {
- val corrBuilder = ExprOuterClass.Correlation.newBuilder()
- corrBuilder.setChild1(child1Expr.get)
- corrBuilder.setChild2(child2Expr.get)
- corrBuilder.setNullOnDivideByZero(nullOnDivideByZero)
- corrBuilder.setDatatype(dataType.get)
-
- Some(
- ExprOuterClass.AggExpr
- .newBuilder()
- .setCorrelation(corrBuilder)
- .build())
- } else {
- withInfo(aggExpr, child1, child2)
- None
- }
-
- case bloom_filter @ BloomFilterAggregate(child, numItems, numBits, _, _)
=>
- // We ignore mutableAggBufferOffset and inputAggBufferOffset because
they are
- // implementation details for Spark's ObjectHashAggregate.
- val childExpr = exprToProto(child, inputs, binding)
- val numItemsExpr = exprToProto(numItems, inputs, binding)
- val numBitsExpr = exprToProto(numBits, inputs, binding)
- val dataType = serializeDataType(bloom_filter.dataType)
-
- if (childExpr.isDefined &&
- (child.dataType
- .isInstanceOf[ByteType] ||
- child.dataType
- .isInstanceOf[ShortType] ||
- child.dataType
- .isInstanceOf[IntegerType] ||
- child.dataType
- .isInstanceOf[LongType] ||
- child.dataType
- .isInstanceOf[StringType]) &&
- numItemsExpr.isDefined &&
- numBitsExpr.isDefined &&
- dataType.isDefined) {
- val bloomFilterAggBuilder =
ExprOuterClass.BloomFilterAgg.newBuilder()
- bloomFilterAggBuilder.setChild(childExpr.get)
- bloomFilterAggBuilder.setNumItems(numItemsExpr.get)
- bloomFilterAggBuilder.setNumBits(numBitsExpr.get)
- bloomFilterAggBuilder.setDatatype(dataType.get)
-
- Some(
- ExprOuterClass.AggExpr
- .newBuilder()
- .setBloomFilterAgg(bloomFilterAggBuilder)
- .build())
- } else {
- withInfo(aggExpr, child, numItems, numBits)
- None
- }
-
+ val cometExpr: CometAggregateExpressionSerde = aggExpr.aggregateFunction
match {
+ case _: Sum => CometSum
+ case _: Average => CometAverage
+ case _: Count => CometCount
+ case _: Min => CometMin
+ case _: Max => CometMax
+ case _: First => CometFirst
+ case _: Last => CometLast
+ case _: BitAndAgg => CometBitAndAgg
+ case _: BitOrAgg => CometBitOrAgg
+ case _: BitXorAgg => CometBitXOrAgg
+ case _: CovSample => CometCovSample
+ case _: CovPopulation => CometCovPopulation
+ case _: VarianceSamp => CometVarianceSamp
+ case _: VariancePop => CometVariancePop
+ case _: StddevSamp => CometStddevSamp
+ case _: StddevPop => CometStddevPop
+ case _: Corr => CometCorr
+ case _: BloomFilterAggregate => CometBloomFilterAggregate
case fn =>
val msg = s"unsupported Spark aggregate function: ${fn.prettyName}"
emitWarning(msg)
withInfo(aggExpr, msg, fn.children: _*)
- None
+ return None
+
}
+ cometExpr.convert(aggExpr, aggExpr.aggregateFunction, inputs, binding,
conf)
}
def evalModeToProto(evalMode: CometEvalMode.Value): ExprOuterClass.EvalMode
= {
@@ -3441,5 +3006,37 @@ trait CometExpressionSerde {
binding: Boolean): Option[ExprOuterClass.Expr]
}
+/**
+ * Trait for providing serialization logic for aggregate expressions.
+ */
+trait CometAggregateExpressionSerde {
+
+ /**
+ * Convert a Spark expression into a protocol buffer representation that can
be passed into
+ * native code.
+ *
+ * @param expr
+ * The aggregate expression.
+ * @param expr
+ * The aggregate function.
+ * @param inputs
+ * The input attributes.
+ * @param binding
+ * Whether the attributes are bound (this is only relevant in aggregate
expressions).
+ * @param conf
+ * SQLConf
+ * @return
+ * Protocol buffer representation, or None if the expression could not be
converted. In this
+ * case it is expected that the input expression will have been tagged
with reasons why it
+ * could not be converted.
+ */
+ def convert(
+ aggExpr: AggregateExpression,
+ expr: Expression,
+ inputs: Seq[Attribute],
+ binding: Boolean,
+ conf: SQLConf): Option[ExprOuterClass.AggExpr]
+}
+
/** Marker trait for an expression that is not guaranteed to be 100%
compatible with Spark */
trait IncompatExpr {}
diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
new file mode 100644
index 000000000..da5e9ff53
--- /dev/null
+++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
@@ -0,0 +1,734 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.serde
+
+import scala.collection.JavaConverters.asJavaIterableConverter
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
+import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate,
CentralMomentAgg, Corr, Covariance, CovPopulation, CovSample, First, Last,
StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{ByteType, DecimalType, IntegerType,
LongType, ShortType, StringType}
+
+import org.apache.comet.CometConf
+import org.apache.comet.CometSparkSessionExtensions.withInfo
+import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType}
+import org.apache.comet.shims.ShimQueryPlanSerde
+
+object CometMin extends CometAggregateExpressionSerde {
+
+ override def convert(
+ aggExpr: AggregateExpression,
+ expr: Expression,
+ inputs: Seq[Attribute],
+ binding: Boolean,
+ conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+ if (!AggSerde.minMaxDataTypeSupported(expr.dataType)) {
+ withInfo(aggExpr, s"Unsupported data type: ${expr.dataType}")
+ return None
+ }
+ val child = expr.children.head
+ val childExpr = exprToProto(child, inputs, binding)
+ val dataType = serializeDataType(expr.dataType)
+
+ if (childExpr.isDefined && dataType.isDefined) {
+ val builder = ExprOuterClass.Min.newBuilder()
+ builder.setChild(childExpr.get)
+ builder.setDatatype(dataType.get)
+
+ Some(
+ ExprOuterClass.AggExpr
+ .newBuilder()
+ .setMin(builder)
+ .build())
+ } else if (dataType.isEmpty) {
+ withInfo(aggExpr, s"datatype ${expr.dataType} is not supported", child)
+ None
+ } else {
+ withInfo(aggExpr, child)
+ None
+ }
+ }
+}
+
+object CometMax extends CometAggregateExpressionSerde {
+
+ override def convert(
+ aggExpr: AggregateExpression,
+ expr: Expression,
+ inputs: Seq[Attribute],
+ binding: Boolean,
+ conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+ if (!AggSerde.minMaxDataTypeSupported(expr.dataType)) {
+ withInfo(aggExpr, s"Unsupported data type: ${expr.dataType}")
+ return None
+ }
+ val child = expr.children.head
+ val childExpr = exprToProto(child, inputs, binding)
+ val dataType = serializeDataType(expr.dataType)
+
+ if (childExpr.isDefined && dataType.isDefined) {
+ val builder = ExprOuterClass.Max.newBuilder()
+ builder.setChild(childExpr.get)
+ builder.setDatatype(dataType.get)
+
+ Some(
+ ExprOuterClass.AggExpr
+ .newBuilder()
+ .setMax(builder)
+ .build())
+ } else if (dataType.isEmpty) {
+ withInfo(aggExpr, s"datatype ${expr.dataType} is not supported", child)
+ None
+ } else {
+ withInfo(aggExpr, child)
+ None
+ }
+ }
+}
+
+object CometCount extends CometAggregateExpressionSerde {
+ override def convert(
+ aggExpr: AggregateExpression,
+ expr: Expression,
+ inputs: Seq[Attribute],
+ binding: Boolean,
+ conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+ val exprChildren = expr.children.map(exprToProto(_, inputs, binding))
+ if (exprChildren.forall(_.isDefined)) {
+ val builder = ExprOuterClass.Count.newBuilder()
+ builder.addAllChildren(exprChildren.map(_.get).asJava)
+ Some(
+ ExprOuterClass.AggExpr
+ .newBuilder()
+ .setCount(builder)
+ .build())
+ } else {
+ withInfo(aggExpr, expr.children: _*)
+ None
+ }
+ }
+}
+
+object CometAverage extends CometAggregateExpressionSerde with
ShimQueryPlanSerde {
+ override def convert(
+ aggExpr: AggregateExpression,
+ expr: Expression,
+ inputs: Seq[Attribute],
+ binding: Boolean,
+ conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+ val avg = expr.asInstanceOf[Average]
+
+ if (!AggSerde.avgDataTypeSupported(expr.dataType)) {
+ withInfo(aggExpr, s"Unsupported data type: ${expr.dataType}")
+ return None
+ }
+
+ if (!isLegacyMode(avg)) {
+ withInfo(aggExpr, "Average is only supported in legacy mode")
+ return None
+ }
+
+ val child = avg.child
+ val childExpr = exprToProto(child, inputs, binding)
+ val dataType = serializeDataType(expr.dataType)
+
+ val sumDataType = child.dataType match {
+ case decimalType: DecimalType =>
+ // This is input precision + 10 to be consistent with Spark
+ val precision = Math.min(DecimalType.MAX_PRECISION,
decimalType.precision + 10)
+ val newType =
+ DecimalType.apply(precision, decimalType.scale)
+ serializeDataType(newType)
+ case _ =>
+ serializeDataType(child.dataType)
+ }
+
+ if (childExpr.isDefined && dataType.isDefined) {
+ val builder = ExprOuterClass.Avg.newBuilder()
+ builder.setChild(childExpr.get)
+ builder.setDatatype(dataType.get)
+ builder.setFailOnError(getFailOnError(avg))
+ builder.setSumDatatype(sumDataType.get)
+
+ Some(
+ ExprOuterClass.AggExpr
+ .newBuilder()
+ .setAvg(builder)
+ .build())
+ } else if (dataType.isEmpty) {
+ withInfo(aggExpr, s"datatype ${expr.dataType} is not supported", child)
+ None
+ } else {
+ withInfo(aggExpr, child)
+ None
+ }
+ }
+}
+object CometSum extends CometAggregateExpressionSerde with ShimQueryPlanSerde {
+ override def convert(
+ aggExpr: AggregateExpression,
+ expr: Expression,
+ inputs: Seq[Attribute],
+ binding: Boolean,
+ conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+ val sum = expr.asInstanceOf[Sum]
+
+ if (!AggSerde.sumDataTypeSupported(sum.dataType)) {
+ withInfo(aggExpr, s"Unsupported data type: ${expr.dataType}")
+ return None
+ }
+
+ if (!isLegacyMode(sum)) {
+ withInfo(aggExpr, "Sum is only supported in legacy mode")
+ return None
+ }
+
+ val childExpr = exprToProto(sum.child, inputs, binding)
+ val dataType = serializeDataType(sum.dataType)
+
+ if (childExpr.isDefined && dataType.isDefined) {
+ val builder = ExprOuterClass.Sum.newBuilder()
+ builder.setChild(childExpr.get)
+ builder.setDatatype(dataType.get)
+ builder.setFailOnError(getFailOnError(sum))
+
+ Some(
+ ExprOuterClass.AggExpr
+ .newBuilder()
+ .setSum(builder)
+ .build())
+ } else {
+ if (dataType.isEmpty) {
+ withInfo(aggExpr, s"datatype ${sum.dataType} is not supported",
sum.child)
+ } else {
+ withInfo(aggExpr, sum.child)
+ }
+ None
+ }
+ }
+}
+
+object CometFirst extends CometAggregateExpressionSerde {
+ override def convert(
+ aggExpr: AggregateExpression,
+ expr: Expression,
+ inputs: Seq[Attribute],
+ binding: Boolean,
+ conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+ val first = expr.asInstanceOf[First]
+ if (first.ignoreNulls) {
+ // DataFusion doesn't support ignoreNulls true
+ withInfo(aggExpr, "First not supported when ignoreNulls=true")
+ return None
+ }
+ val child = first.children.head
+ val childExpr = exprToProto(child, inputs, binding)
+ val dataType = serializeDataType(first.dataType)
+
+ if (childExpr.isDefined && dataType.isDefined) {
+ val builder = ExprOuterClass.First.newBuilder()
+ builder.setChild(childExpr.get)
+ builder.setDatatype(dataType.get)
+
+ Some(
+ ExprOuterClass.AggExpr
+ .newBuilder()
+ .setFirst(builder)
+ .build())
+ } else if (dataType.isEmpty) {
+ withInfo(aggExpr, s"datatype ${first.dataType} is not supported", child)
+ None
+ } else {
+ withInfo(aggExpr, child)
+ None
+ }
+ }
+}
+
+object CometLast extends CometAggregateExpressionSerde {
+ override def convert(
+ aggExpr: AggregateExpression,
+ expr: Expression,
+ inputs: Seq[Attribute],
+ binding: Boolean,
+ conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+ val last = expr.asInstanceOf[Last]
+ if (last.ignoreNulls) {
+ // DataFusion doesn't support ignoreNulls true
+ withInfo(aggExpr, "Last not supported when ignoreNulls=true")
+ return None
+ }
+ val child = last.children.head
+ val childExpr = exprToProto(child, inputs, binding)
+ val dataType = serializeDataType(last.dataType)
+
+ if (childExpr.isDefined && dataType.isDefined) {
+ val builder = ExprOuterClass.Last.newBuilder()
+ builder.setChild(childExpr.get)
+ builder.setDatatype(dataType.get)
+
+ Some(
+ ExprOuterClass.AggExpr
+ .newBuilder()
+ .setLast(builder)
+ .build())
+ } else if (dataType.isEmpty) {
+ withInfo(aggExpr, s"datatype ${last.dataType} is not supported", child)
+ None
+ } else {
+ withInfo(aggExpr, child)
+ None
+ }
+ }
+}
+
+object CometBitAndAgg extends CometAggregateExpressionSerde {
+ override def convert(
+ aggExpr: AggregateExpression,
+ expr: Expression,
+ inputs: Seq[Attribute],
+ binding: Boolean,
+ conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+ val bitAnd = expr.asInstanceOf[BitAndAgg]
+ if (!AggSerde.bitwiseAggTypeSupported(bitAnd.dataType)) {
+ withInfo(aggExpr, s"Unsupported data type: ${expr.dataType}")
+ return None
+ }
+ val child = bitAnd.child
+ val childExpr = exprToProto(child, inputs, binding)
+ val dataType = serializeDataType(bitAnd.dataType)
+
+ if (childExpr.isDefined && dataType.isDefined) {
+ val builder = ExprOuterClass.BitAndAgg.newBuilder()
+ builder.setChild(childExpr.get)
+ builder.setDatatype(dataType.get)
+ Some(
+ ExprOuterClass.AggExpr
+ .newBuilder()
+ .setBitAndAgg(builder)
+ .build())
+ } else if (dataType.isEmpty) {
+ withInfo(aggExpr, s"datatype ${bitAnd.dataType} is not supported", child)
+ None
+ } else {
+ withInfo(aggExpr, child)
+ None
+ }
+ }
+}
+
+object CometBitOrAgg extends CometAggregateExpressionSerde {
+ override def convert(
+ aggExpr: AggregateExpression,
+ expr: Expression,
+ inputs: Seq[Attribute],
+ binding: Boolean,
+ conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+ val bitOr = expr.asInstanceOf[BitOrAgg]
+ if (!AggSerde.bitwiseAggTypeSupported(bitOr.dataType)) {
+ withInfo(aggExpr, s"Unsupported data type: ${expr.dataType}")
+ return None
+ }
+ val child = bitOr.child
+ val childExpr = exprToProto(child, inputs, binding)
+ val dataType = serializeDataType(bitOr.dataType)
+
+ if (childExpr.isDefined && dataType.isDefined) {
+ val builder = ExprOuterClass.BitOrAgg.newBuilder()
+ builder.setChild(childExpr.get)
+ builder.setDatatype(dataType.get)
+ Some(
+ ExprOuterClass.AggExpr
+ .newBuilder()
+ .setBitOrAgg(builder)
+ .build())
+ } else if (dataType.isEmpty) {
+ withInfo(aggExpr, s"datatype ${bitOr.dataType} is not supported", child)
+ None
+ } else {
+ withInfo(aggExpr, child)
+ None
+ }
+ }
+}
+
+object CometBitXOrAgg extends CometAggregateExpressionSerde {
+ override def convert(
+ aggExpr: AggregateExpression,
+ expr: Expression,
+ inputs: Seq[Attribute],
+ binding: Boolean,
+ conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+ val bitXor = expr.asInstanceOf[BitXorAgg]
+ if (!AggSerde.bitwiseAggTypeSupported(bitXor.dataType)) {
+ withInfo(aggExpr, s"Unsupported data type: ${expr.dataType}")
+ return None
+ }
+ val child = bitXor.child
+ val childExpr = exprToProto(child, inputs, binding)
+ val dataType = serializeDataType(bitXor.dataType)
+
+ if (childExpr.isDefined && dataType.isDefined) {
+ val builder = ExprOuterClass.BitXorAgg.newBuilder()
+ builder.setChild(childExpr.get)
+ builder.setDatatype(dataType.get)
+ Some(
+ ExprOuterClass.AggExpr
+ .newBuilder()
+ .setBitXorAgg(builder)
+ .build())
+ } else if (dataType.isEmpty) {
+ withInfo(aggExpr, s"datatype ${bitXor.dataType} is not supported", child)
+ None
+ } else {
+ withInfo(aggExpr, child)
+ None
+ }
+ }
+}
+
+trait CometCovBase extends CometAggregateExpressionSerde {
+ def convertCov(
+ aggExpr: AggregateExpression,
+ cov: Covariance,
+ nullOnDivideByZero: Boolean,
+ statsType: Int,
+ inputs: Seq[Attribute],
+ binding: Boolean,
+ conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+ val child1Expr = exprToProto(cov.left, inputs, binding)
+ val child2Expr = exprToProto(cov.right, inputs, binding)
+ val dataType = serializeDataType(cov.dataType)
+
+ if (child1Expr.isDefined && child2Expr.isDefined && dataType.isDefined) {
+ val builder = ExprOuterClass.Covariance.newBuilder()
+ builder.setChild1(child1Expr.get)
+ builder.setChild2(child2Expr.get)
+ builder.setNullOnDivideByZero(nullOnDivideByZero)
+ builder.setDatatype(dataType.get)
+ builder.setStatsTypeValue(statsType)
+
+ Some(
+ ExprOuterClass.AggExpr
+ .newBuilder()
+ .setCovariance(builder)
+ .build())
+ } else {
+ withInfo(aggExpr, "Child expression or data type not supported")
+ None
+ }
+ }
+}
+
+object CometCovSample extends CometCovBase {
+ override def convert(
+ aggExpr: AggregateExpression,
+ expr: Expression,
+ inputs: Seq[Attribute],
+ binding: Boolean,
+ conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+ val covSample = expr.asInstanceOf[CovSample]
+ convertCov(
+ aggExpr,
+ covSample,
+ covSample.nullOnDivideByZero,
+ 0,
+ inputs,
+ binding,
+ conf: SQLConf)
+ }
+}
+
+object CometCovPopulation extends CometCovBase {
+ override def convert(
+ aggExpr: AggregateExpression,
+ expr: Expression,
+ inputs: Seq[Attribute],
+ binding: Boolean,
+ conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+ val covPopulation = expr.asInstanceOf[CovPopulation]
+ convertCov(
+ aggExpr,
+ covPopulation,
+ covPopulation.nullOnDivideByZero,
+ 1,
+ inputs,
+ binding,
+ conf: SQLConf)
+ }
+}
+
+trait CometVariance extends CometAggregateExpressionSerde {
+ def convertVariance(
+ aggExpr: AggregateExpression,
+ expr: CentralMomentAgg,
+ nullOnDivideByZero: Boolean,
+ statsType: Int,
+ inputs: Seq[Attribute],
+ binding: Boolean): Option[ExprOuterClass.AggExpr] = {
+ val childExpr = exprToProto(expr.child, inputs, binding)
+ val dataType = serializeDataType(expr.dataType)
+
+ if (childExpr.isDefined && dataType.isDefined) {
+ val builder = ExprOuterClass.Variance.newBuilder()
+ builder.setChild(childExpr.get)
+ builder.setNullOnDivideByZero(nullOnDivideByZero)
+ builder.setDatatype(dataType.get)
+ builder.setStatsTypeValue(statsType)
+
+ Some(
+ ExprOuterClass.AggExpr
+ .newBuilder()
+ .setVariance(builder)
+ .build())
+ } else {
+ withInfo(aggExpr, expr.child)
+ None
+ }
+ }
+
+}
+
+object CometVarianceSamp extends CometVariance {
+ override def convert(
+ aggExpr: AggregateExpression,
+ expr: Expression,
+ inputs: Seq[Attribute],
+ binding: Boolean,
+ conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+ val variance = expr.asInstanceOf[VarianceSamp]
+ convertVariance(aggExpr, variance, variance.nullOnDivideByZero, 0, inputs,
binding)
+ }
+}
+
+object CometVariancePop extends CometVariance {
+ override def convert(
+ aggExpr: AggregateExpression,
+ expr: Expression,
+ inputs: Seq[Attribute],
+ binding: Boolean,
+ conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+ val variance = expr.asInstanceOf[VariancePop]
+ convertVariance(aggExpr, variance, variance.nullOnDivideByZero, 1, inputs,
binding)
+ }
+}
+
+trait CometStddev extends CometAggregateExpressionSerde {
+ def convertStddev(
+ aggExpr: AggregateExpression,
+ stddev: CentralMomentAgg,
+ nullOnDivideByZero: Boolean,
+ statsType: Int,
+ inputs: Seq[Attribute],
+ binding: Boolean,
+ conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+ val child = stddev.child
+ if (CometConf.COMET_EXPR_STDDEV_ENABLED.get(conf)) {
+ val childExpr = exprToProto(child, inputs, binding)
+ val dataType = serializeDataType(stddev.dataType)
+
+ if (childExpr.isDefined && dataType.isDefined) {
+ val builder = ExprOuterClass.Stddev.newBuilder()
+ builder.setChild(childExpr.get)
+ builder.setNullOnDivideByZero(nullOnDivideByZero)
+ builder.setDatatype(dataType.get)
+ builder.setStatsTypeValue(statsType)
+
+ Some(
+ ExprOuterClass.AggExpr
+ .newBuilder()
+ .setStddev(builder)
+ .build())
+ } else {
+ withInfo(aggExpr, child)
+ None
+ }
+ } else {
+ withInfo(
+ aggExpr,
+ "stddev disabled by default because it can be slower than Spark. " +
+ s"Set ${CometConf.COMET_EXPR_STDDEV_ENABLED}=true to enable it.",
+ child)
+ None
+ }
+ }
+}
+
+object CometStddevSamp extends CometStddev {
+ override def convert(
+ aggExpr: AggregateExpression,
+ expr: Expression,
+ inputs: Seq[Attribute],
+ binding: Boolean,
+ conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+ val variance = expr.asInstanceOf[StddevSamp]
+ convertStddev(
+ aggExpr,
+ variance,
+ variance.nullOnDivideByZero,
+ 0,
+ inputs,
+ binding,
+ conf: SQLConf)
+ }
+}
+
+object CometStddevPop extends CometStddev {
+ override def convert(
+ aggExpr: AggregateExpression,
+ expr: Expression,
+ inputs: Seq[Attribute],
+ binding: Boolean,
+ conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+ val variance = expr.asInstanceOf[StddevPop]
+ convertStddev(
+ aggExpr,
+ variance,
+ variance.nullOnDivideByZero,
+ 1,
+ inputs,
+ binding,
+ conf: SQLConf)
+ }
+}
+
+object CometCorr extends CometAggregateExpressionSerde {
+ override def convert(
+ aggExpr: AggregateExpression,
+ expr: Expression,
+ inputs: Seq[Attribute],
+ binding: Boolean,
+ conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+ val corr = expr.asInstanceOf[Corr]
+ val child1Expr = exprToProto(corr.x, inputs, binding)
+ val child2Expr = exprToProto(corr.y, inputs, binding)
+ val dataType = serializeDataType(corr.dataType)
+
+ if (child1Expr.isDefined && child2Expr.isDefined && dataType.isDefined) {
+ val builder = ExprOuterClass.Correlation.newBuilder()
+ builder.setChild1(child1Expr.get)
+ builder.setChild2(child2Expr.get)
+ builder.setNullOnDivideByZero(corr.nullOnDivideByZero)
+ builder.setDatatype(dataType.get)
+
+ Some(
+ ExprOuterClass.AggExpr
+ .newBuilder()
+ .setCorrelation(builder)
+ .build())
+ } else {
+ withInfo(aggExpr, corr.x, corr.y)
+ None
+ }
+ }
+}
+
+object CometBloomFilterAggregate extends CometAggregateExpressionSerde {
+
+ override def convert(
+ aggExpr: AggregateExpression,
+ expr: Expression,
+ inputs: Seq[Attribute],
+ binding: Boolean,
+ conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
+ // We ignore mutableAggBufferOffset and inputAggBufferOffset because they
are
+ // implementation details for Spark's ObjectHashAggregate.
+ val bloomFilter = expr.asInstanceOf[BloomFilterAggregate]
+ val childExpr = exprToProto(bloomFilter.child, inputs, binding)
+ val numItemsExpr = exprToProto(bloomFilter.estimatedNumItemsExpression,
inputs, binding)
+ val numBitsExpr = exprToProto(bloomFilter.numBitsExpression, inputs,
binding)
+ val dataType = serializeDataType(bloomFilter.dataType)
+
+ if (childExpr.isDefined &&
+ (bloomFilter.child.dataType
+ .isInstanceOf[ByteType] ||
+ bloomFilter.child.dataType
+ .isInstanceOf[ShortType] ||
+ bloomFilter.child.dataType
+ .isInstanceOf[IntegerType] ||
+ bloomFilter.child.dataType
+ .isInstanceOf[LongType] ||
+ bloomFilter.child.dataType
+ .isInstanceOf[StringType]) &&
+ numItemsExpr.isDefined &&
+ numBitsExpr.isDefined &&
+ dataType.isDefined) {
+ val builder = ExprOuterClass.BloomFilterAgg.newBuilder()
+ builder.setChild(childExpr.get)
+ builder.setNumItems(numItemsExpr.get)
+ builder.setNumBits(numBitsExpr.get)
+ builder.setDatatype(dataType.get)
+
+ Some(
+ ExprOuterClass.AggExpr
+ .newBuilder()
+ .setBloomFilterAgg(builder)
+ .build())
+ } else {
+ withInfo(
+ aggExpr,
+ bloomFilter.child,
+ bloomFilter.estimatedNumItemsExpression,
+ bloomFilter.numBitsExpression)
+ None
+ }
+ }
+}
+
+object AggSerde {
+ import org.apache.spark.sql.types._
+
+ def minMaxDataTypeSupported(dt: DataType): Boolean = {
+ dt match {
+ case BooleanType => true
+ case ByteType | ShortType | IntegerType | LongType => true
+ case FloatType | DoubleType => true
+ case _: DecimalType => true
+ case DateType | TimestampType => true
+ case _ => false
+ }
+ }
+
+ def avgDataTypeSupported(dt: DataType): Boolean = {
+ dt match {
+ case ByteType | ShortType | IntegerType | LongType => true
+ case FloatType | DoubleType => true
+ case _: DecimalType => true
+ case _ => false
+ }
+ }
+
+ def sumDataTypeSupported(dt: DataType): Boolean = {
+ dt match {
+ case ByteType | ShortType | IntegerType | LongType => true
+ case FloatType | DoubleType => true
+ case _: DecimalType => true
+ case _ => false
+ }
+ }
+
+ def bitwiseAggTypeSupported(dt: DataType): Boolean = {
+ dt match {
+ case ByteType | ShortType | IntegerType | LongType => true
+ case _ => false
+ }
+ }
+
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]