This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 612cfd6be Refactor (#2780)
612cfd6be is described below
commit 612cfd6beaf9cadb803d2e0ce56d65ec10b701de
Author: Andy Grove <[email protected]>
AuthorDate: Fri Nov 14 11:33:45 2025 -0700
Refactor (#2780)
---
.../org/apache/comet/serde/QueryPlanSerde.scala | 172 --------------------
.../apache/spark/sql/comet/CometWindowExec.scala | 181 ++++++++++++++++++++-
2 files changed, 178 insertions(+), 175 deletions(-)
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 98d9cf9e3..2ef2a1ae2 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -367,178 +367,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
Some(dataType)
}
- def windowExprToProto(
- windowExpr: WindowExpression,
- output: Seq[Attribute],
- conf: SQLConf): Option[OperatorOuterClass.WindowExpr] = {
-
- val aggregateExpressions: Array[AggregateExpression] = windowExpr.flatMap
{ expr =>
- expr match {
- case agg: AggregateExpression =>
- agg.aggregateFunction match {
- case _: Count =>
- Some(agg)
- case min: Min =>
- if (AggSerde.minMaxDataTypeSupported(min.dataType)) {
- Some(agg)
- } else {
- withInfo(windowExpr, s"datatype ${min.dataType} is not
supported", expr)
- None
- }
- case max: Max =>
- if (AggSerde.minMaxDataTypeSupported(max.dataType)) {
- Some(agg)
- } else {
- withInfo(windowExpr, s"datatype ${max.dataType} is not
supported", expr)
- None
- }
- case s: Sum =>
- if (AggSerde.sumDataTypeSupported(s.dataType) && !s.dataType
- .isInstanceOf[DecimalType]) {
- Some(agg)
- } else {
- withInfo(windowExpr, s"datatype ${s.dataType} is not
supported", expr)
- None
- }
- case _ =>
- withInfo(
- windowExpr,
- s"aggregate ${agg.aggregateFunction}" +
- " is not supported for window function",
- expr)
- None
- }
- case _ =>
- None
- }
- }.toArray
-
- val (aggExpr, builtinFunc) = if (aggregateExpressions.nonEmpty) {
- val modes = aggregateExpressions.map(_.mode).distinct
- assert(modes.size == 1 && modes.head == Complete)
- (aggExprToProto(aggregateExpressions.head, output, true, conf), None)
- } else {
- (None, exprToProto(windowExpr.windowFunction, output))
- }
-
- if (aggExpr.isEmpty && builtinFunc.isEmpty) {
- return None
- }
-
- val f = windowExpr.windowSpec.frameSpecification
-
- val (frameType, lowerBound, upperBound) = f match {
- case SpecifiedWindowFrame(frameType, lBound, uBound) =>
- val frameProto = frameType match {
- case RowFrame => OperatorOuterClass.WindowFrameType.Rows
- case RangeFrame => OperatorOuterClass.WindowFrameType.Range
- }
-
- val lBoundProto = lBound match {
- case UnboundedPreceding =>
- OperatorOuterClass.LowerWindowFrameBound
- .newBuilder()
-
.setUnboundedPreceding(OperatorOuterClass.UnboundedPreceding.newBuilder().build())
- .build()
- case CurrentRow =>
- OperatorOuterClass.LowerWindowFrameBound
- .newBuilder()
-
.setCurrentRow(OperatorOuterClass.CurrentRow.newBuilder().build())
- .build()
- case e if frameType == RowFrame =>
- val offset = e.eval() match {
- case i: Integer => i.toLong
- case l: Long => l
- case _ => return None
- }
- OperatorOuterClass.LowerWindowFrameBound
- .newBuilder()
- .setPreceding(
- OperatorOuterClass.Preceding
- .newBuilder()
- .setOffset(offset)
- .build())
- .build()
- case _ =>
- // TODO add support for numeric and temporal RANGE BETWEEN
expressions
- // see https://github.com/apache/datafusion-comet/issues/1246
- return None
- }
-
- val uBoundProto = uBound match {
- case UnboundedFollowing =>
- OperatorOuterClass.UpperWindowFrameBound
- .newBuilder()
-
.setUnboundedFollowing(OperatorOuterClass.UnboundedFollowing.newBuilder().build())
- .build()
- case CurrentRow =>
- OperatorOuterClass.UpperWindowFrameBound
- .newBuilder()
-
.setCurrentRow(OperatorOuterClass.CurrentRow.newBuilder().build())
- .build()
- case e if frameType == RowFrame =>
- val offset = e.eval() match {
- case i: Integer => i.toLong
- case l: Long => l
- case _ => return None
- }
- OperatorOuterClass.UpperWindowFrameBound
- .newBuilder()
- .setFollowing(
- OperatorOuterClass.Following
- .newBuilder()
- .setOffset(offset)
- .build())
- .build()
- case _ =>
- // TODO add support for numeric and temporal RANGE BETWEEN
expressions
- // see https://github.com/apache/datafusion-comet/issues/1246
- return None
- }
-
- (frameProto, lBoundProto, uBoundProto)
- case _ =>
- (
- OperatorOuterClass.WindowFrameType.Rows,
- OperatorOuterClass.LowerWindowFrameBound
- .newBuilder()
-
.setUnboundedPreceding(OperatorOuterClass.UnboundedPreceding.newBuilder().build())
- .build(),
- OperatorOuterClass.UpperWindowFrameBound
- .newBuilder()
-
.setUnboundedFollowing(OperatorOuterClass.UnboundedFollowing.newBuilder().build())
- .build())
- }
-
- val frame = OperatorOuterClass.WindowFrame
- .newBuilder()
- .setFrameType(frameType)
- .setLowerBound(lowerBound)
- .setUpperBound(upperBound)
- .build()
-
- val spec =
-
OperatorOuterClass.WindowSpecDefinition.newBuilder().setFrameSpecification(frame).build()
-
- if (builtinFunc.isDefined) {
- Some(
- OperatorOuterClass.WindowExpr
- .newBuilder()
- .setBuiltInWindowFunction(builtinFunc.get)
- .setSpec(spec)
- .build())
- } else if (aggExpr.isDefined) {
- Some(
- OperatorOuterClass.WindowExpr
- .newBuilder()
- .setAggFunc(aggExpr.get)
- .setSpec(spec)
- .build())
- } else {
- None
- }
- }
-
def aggExprToProto(
aggExpr: AggregateExpression,
inputs: Seq[Attribute],
diff --git
a/spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala
index 159e4df6b..0a783b922 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala
@@ -21,19 +21,22 @@ package org.apache.spark.sql.comet
import scala.jdk.CollectionConverters._
-import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute,
AttributeReference, Expression, NamedExpression, SortOrder, WindowExpression}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute,
AttributeReference, CurrentRow, Expression, NamedExpression, RangeFrame,
RowFrame, SortOrder, SpecifiedWindowFrame, UnboundedFollowing,
UnboundedPreceding, WindowExpression}
+import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
Complete, Count, Max, Min, Sum}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.execution.window.WindowExec
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.DecimalType
import com.google.common.base.Objects
import org.apache.comet.{CometConf, ConfigEntry}
import org.apache.comet.CometSparkSessionExtensions.withInfo
-import org.apache.comet.serde.{CometOperatorSerde, Incompatible,
OperatorOuterClass, SupportLevel}
+import org.apache.comet.serde.{AggSerde, CometOperatorSerde, Incompatible,
OperatorOuterClass, SupportLevel}
import org.apache.comet.serde.OperatorOuterClass.Operator
-import org.apache.comet.serde.QueryPlanSerde.{exprToProto, windowExprToProto}
+import org.apache.comet.serde.QueryPlanSerde.{aggExprToProto, exprToProto}
object CometWindowExec extends CometOperatorSerde[WindowExec] {
@@ -92,6 +95,178 @@ object CometWindowExec extends
CometOperatorSerde[WindowExec] {
}
+ private def windowExprToProto(
+ windowExpr: WindowExpression,
+ output: Seq[Attribute],
+ conf: SQLConf): Option[OperatorOuterClass.WindowExpr] = {
+
+ val aggregateExpressions: Array[AggregateExpression] = windowExpr.flatMap
{ expr =>
+ expr match {
+ case agg: AggregateExpression =>
+ agg.aggregateFunction match {
+ case _: Count =>
+ Some(agg)
+ case min: Min =>
+ if (AggSerde.minMaxDataTypeSupported(min.dataType)) {
+ Some(agg)
+ } else {
+ withInfo(windowExpr, s"datatype ${min.dataType} is not
supported", expr)
+ None
+ }
+ case max: Max =>
+ if (AggSerde.minMaxDataTypeSupported(max.dataType)) {
+ Some(agg)
+ } else {
+ withInfo(windowExpr, s"datatype ${max.dataType} is not
supported", expr)
+ None
+ }
+ case s: Sum =>
+ if (AggSerde.sumDataTypeSupported(s.dataType) && !s.dataType
+ .isInstanceOf[DecimalType]) {
+ Some(agg)
+ } else {
+ withInfo(windowExpr, s"datatype ${s.dataType} is not
supported", expr)
+ None
+ }
+ case _ =>
+ withInfo(
+ windowExpr,
+ s"aggregate ${agg.aggregateFunction}" +
+ " is not supported for window function",
+ expr)
+ None
+ }
+ case _ =>
+ None
+ }
+ }.toArray
+
+ val (aggExpr, builtinFunc) = if (aggregateExpressions.nonEmpty) {
+ val modes = aggregateExpressions.map(_.mode).distinct
+ assert(modes.size == 1 && modes.head == Complete)
+ (aggExprToProto(aggregateExpressions.head, output, true, conf), None)
+ } else {
+ (None, exprToProto(windowExpr.windowFunction, output))
+ }
+
+ if (aggExpr.isEmpty && builtinFunc.isEmpty) {
+ return None
+ }
+
+ val f = windowExpr.windowSpec.frameSpecification
+
+ val (frameType, lowerBound, upperBound) = f match {
+ case SpecifiedWindowFrame(frameType, lBound, uBound) =>
+ val frameProto = frameType match {
+ case RowFrame => OperatorOuterClass.WindowFrameType.Rows
+ case RangeFrame => OperatorOuterClass.WindowFrameType.Range
+ }
+
+ val lBoundProto = lBound match {
+ case UnboundedPreceding =>
+ OperatorOuterClass.LowerWindowFrameBound
+ .newBuilder()
+
.setUnboundedPreceding(OperatorOuterClass.UnboundedPreceding.newBuilder().build())
+ .build()
+ case CurrentRow =>
+ OperatorOuterClass.LowerWindowFrameBound
+ .newBuilder()
+
.setCurrentRow(OperatorOuterClass.CurrentRow.newBuilder().build())
+ .build()
+ case e if frameType == RowFrame =>
+ val offset = e.eval() match {
+ case i: Integer => i.toLong
+ case l: Long => l
+ case _ => return None
+ }
+ OperatorOuterClass.LowerWindowFrameBound
+ .newBuilder()
+ .setPreceding(
+ OperatorOuterClass.Preceding
+ .newBuilder()
+ .setOffset(offset)
+ .build())
+ .build()
+ case _ =>
+ // TODO add support for numeric and temporal RANGE BETWEEN
expressions
+ // see https://github.com/apache/datafusion-comet/issues/1246
+ return None
+ }
+
+ val uBoundProto = uBound match {
+ case UnboundedFollowing =>
+ OperatorOuterClass.UpperWindowFrameBound
+ .newBuilder()
+
.setUnboundedFollowing(OperatorOuterClass.UnboundedFollowing.newBuilder().build())
+ .build()
+ case CurrentRow =>
+ OperatorOuterClass.UpperWindowFrameBound
+ .newBuilder()
+
.setCurrentRow(OperatorOuterClass.CurrentRow.newBuilder().build())
+ .build()
+ case e if frameType == RowFrame =>
+ val offset = e.eval() match {
+ case i: Integer => i.toLong
+ case l: Long => l
+ case _ => return None
+ }
+ OperatorOuterClass.UpperWindowFrameBound
+ .newBuilder()
+ .setFollowing(
+ OperatorOuterClass.Following
+ .newBuilder()
+ .setOffset(offset)
+ .build())
+ .build()
+ case _ =>
+ // TODO add support for numeric and temporal RANGE BETWEEN
expressions
+ // see https://github.com/apache/datafusion-comet/issues/1246
+ return None
+ }
+
+ (frameProto, lBoundProto, uBoundProto)
+ case _ =>
+ (
+ OperatorOuterClass.WindowFrameType.Rows,
+ OperatorOuterClass.LowerWindowFrameBound
+ .newBuilder()
+
.setUnboundedPreceding(OperatorOuterClass.UnboundedPreceding.newBuilder().build())
+ .build(),
+ OperatorOuterClass.UpperWindowFrameBound
+ .newBuilder()
+
.setUnboundedFollowing(OperatorOuterClass.UnboundedFollowing.newBuilder().build())
+ .build())
+ }
+
+ val frame = OperatorOuterClass.WindowFrame
+ .newBuilder()
+ .setFrameType(frameType)
+ .setLowerBound(lowerBound)
+ .setUpperBound(upperBound)
+ .build()
+
+ val spec =
+
OperatorOuterClass.WindowSpecDefinition.newBuilder().setFrameSpecification(frame).build()
+
+ if (builtinFunc.isDefined) {
+ Some(
+ OperatorOuterClass.WindowExpr
+ .newBuilder()
+ .setBuiltInWindowFunction(builtinFunc.get)
+ .setSpec(spec)
+ .build())
+ } else if (aggExpr.isDefined) {
+ Some(
+ OperatorOuterClass.WindowExpr
+ .newBuilder()
+ .setAggFunc(aggExpr.get)
+ .setSpec(spec)
+ .build())
+ } else {
+ None
+ }
+ }
+
override def createExec(nativeOp: Operator, op: WindowExec): CometNativeExec
= {
CometWindowExec(
nativeOp,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]