This is an automated email from the ASF dual-hosted git repository.
mbutrovich pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new af21ae93a feat: add getSupportLevel for aggregates (#2777)
af21ae93a is described below
commit af21ae93ab43e9471f4fb28523564eb9fc09e657
Author: Andy Grove <[email protected]>
AuthorDate: Fri Nov 14 13:54:32 2025 -0700
feat: add getSupportLevel for aggregates (#2777)
---
.../serde/CometAggregateExpressionSerde.scala | 10 ++++
.../org/apache/comet/serde/QueryPlanSerde.scala | 30 +++++++++++-
.../scala/org/apache/comet/serde/aggregates.scala | 53 ++++++++++------------
3 files changed, 64 insertions(+), 29 deletions(-)
diff --git
a/spark/src/main/scala/org/apache/comet/serde/CometAggregateExpressionSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/CometAggregateExpressionSerde.scala
index c0c2b0728..0a5a2770b 100644
---
a/spark/src/main/scala/org/apache/comet/serde/CometAggregateExpressionSerde.scala
+++
b/spark/src/main/scala/org/apache/comet/serde/CometAggregateExpressionSerde.scala
@@ -39,6 +39,16 @@ trait CometAggregateExpressionSerde[T <: AggregateFunction] {
*/
def getExprConfigName(expr: T): String = expr.getClass.getSimpleName
+ /**
+ * Determine the support level of the expression based on its attributes.
+ *
+ * @param expr
+ * The Spark expression.
+ * @return
+ * Support level (Compatible, Incompatible, or Unsupported).
+ */
+ def getSupportLevel(expr: T): SupportLevel = Compatible(None)
+
/**
* Convert a Spark expression into a protocol buffer representation that can
be passed into
* native code.
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 2ef2a1ae2..3f62cd7f9 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -398,7 +398,35 @@ object QueryPlanSerde extends Logging with CometExprShim {
s"${CometConf.getExprEnabledConfigKey(exprConfName)}=true to
enable it.")
return None
}
- aggHandler.convert(aggExpr, fn, inputs, binding, conf)
+ aggHandler.getSupportLevel(fn) match {
+ case Unsupported(notes) =>
+ withInfo(fn, notes.getOrElse(""))
+ None
+ case Incompatible(notes) =>
+ val exprAllowIncompat = CometConf.isExprAllowIncompat(exprConfName)
+ if (exprAllowIncompat ||
CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.get()) {
+ if (notes.isDefined) {
+ logWarning(
+ s"Comet supports $fn when
${CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key}=true " +
+ s"but has notes: ${notes.get}")
+ }
+ aggHandler.convert(aggExpr, fn, inputs, binding, conf)
+ } else {
+ val optionalNotes = notes.map(str => s" ($str)").getOrElse("")
+ withInfo(
+ fn,
+ s"$fn is not fully compatible with Spark$optionalNotes. To
enable it anyway, " +
+ s"set
${CometConf.getExprAllowIncompatConfigKey(exprConfName)}=true, or set " +
+ s"${CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key}=true to
enable all " +
+ s"incompatible expressions. ${CometConf.COMPAT_GUIDE}.")
+ None
+ }
+ case Compatible(notes) =>
+ if (notes.isDefined) {
+ logWarning(s"Comet supports $fn but has notes: ${notes.get}")
+ }
+ aggHandler.convert(aggExpr, fn, inputs, binding, conf)
+ }
case _ =>
withInfo(
aggExpr,
diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
index 48344e061..d00bbf4df 100644
--- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
@@ -149,6 +149,18 @@ object CometCount extends
CometAggregateExpressionSerde[Count] {
}
object CometAverage extends CometAggregateExpressionSerde[Average] {
+
+ override def getSupportLevel(avg: Average): SupportLevel = {
+ avg.evalMode match {
+ case EvalMode.ANSI =>
+ Incompatible(Some("ANSI mode is not supported"))
+ case EvalMode.TRY =>
+ Incompatible(Some("TRY mode is not supported"))
+ case _ =>
+ Compatible()
+ }
+ }
+
override def convert(
aggExpr: AggregateExpression,
avg: Average,
@@ -161,20 +173,6 @@ object CometAverage extends
CometAggregateExpressionSerde[Average] {
return None
}
- avg.evalMode match {
- case EvalMode.ANSI if !CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.get() =>
- withInfo(
- aggExpr,
- "ANSI mode is not supported. Set " +
- s"${CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key}=true to allow it
anyway")
- return None
- case EvalMode.TRY =>
- withInfo(aggExpr, "TRY mode is not supported")
- return None
- case _ =>
- // supported
- }
-
val child = avg.child
val childExpr = exprToProto(child, inputs, binding)
val dataType = serializeDataType(avg.dataType)
@@ -211,7 +209,20 @@ object CometAverage extends
CometAggregateExpressionSerde[Average] {
}
}
}
+
object CometSum extends CometAggregateExpressionSerde[Sum] {
+
+ override def getSupportLevel(sum: Sum): SupportLevel = {
+ sum.evalMode match {
+ case EvalMode.ANSI =>
+ Incompatible(Some("ANSI mode is not supported"))
+ case EvalMode.TRY =>
+ Incompatible(Some("TRY mode is not supported"))
+ case _ =>
+ Compatible()
+ }
+ }
+
override def convert(
aggExpr: AggregateExpression,
sum: Sum,
@@ -224,20 +235,6 @@ object CometSum extends CometAggregateExpressionSerde[Sum]
{
return None
}
- sum.evalMode match {
- case EvalMode.ANSI if !CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.get() =>
- withInfo(
- aggExpr,
- "ANSI mode is not supported. Set " +
- s"${CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key}=true to allow it
anyway")
- return None
- case EvalMode.TRY =>
- withInfo(aggExpr, "TRY mode is not supported")
- return None
- case _ =>
- // supported
- }
-
val childExpr = exprToProto(sum.child, inputs, binding)
val dataType = serializeDataType(sum.dataType)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]