This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 612cfd6be Refactor (#2780)
612cfd6be is described below

commit 612cfd6beaf9cadb803d2e0ce56d65ec10b701de
Author: Andy Grove <[email protected]>
AuthorDate: Fri Nov 14 11:33:45 2025 -0700

    Refactor (#2780)
---
 .../org/apache/comet/serde/QueryPlanSerde.scala    | 172 --------------------
 .../apache/spark/sql/comet/CometWindowExec.scala   | 181 ++++++++++++++++++++-
 2 files changed, 178 insertions(+), 175 deletions(-)

diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 98d9cf9e3..2ef2a1ae2 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -367,178 +367,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
     Some(dataType)
   }
 
-  def windowExprToProto(
-      windowExpr: WindowExpression,
-      output: Seq[Attribute],
-      conf: SQLConf): Option[OperatorOuterClass.WindowExpr] = {
-
-    val aggregateExpressions: Array[AggregateExpression] = windowExpr.flatMap 
{ expr =>
-      expr match {
-        case agg: AggregateExpression =>
-          agg.aggregateFunction match {
-            case _: Count =>
-              Some(agg)
-            case min: Min =>
-              if (AggSerde.minMaxDataTypeSupported(min.dataType)) {
-                Some(agg)
-              } else {
-                withInfo(windowExpr, s"datatype ${min.dataType} is not 
supported", expr)
-                None
-              }
-            case max: Max =>
-              if (AggSerde.minMaxDataTypeSupported(max.dataType)) {
-                Some(agg)
-              } else {
-                withInfo(windowExpr, s"datatype ${max.dataType} is not 
supported", expr)
-                None
-              }
-            case s: Sum =>
-              if (AggSerde.sumDataTypeSupported(s.dataType) && !s.dataType
-                  .isInstanceOf[DecimalType]) {
-                Some(agg)
-              } else {
-                withInfo(windowExpr, s"datatype ${s.dataType} is not 
supported", expr)
-                None
-              }
-            case _ =>
-              withInfo(
-                windowExpr,
-                s"aggregate ${agg.aggregateFunction}" +
-                  " is not supported for window function",
-                expr)
-              None
-          }
-        case _ =>
-          None
-      }
-    }.toArray
-
-    val (aggExpr, builtinFunc) = if (aggregateExpressions.nonEmpty) {
-      val modes = aggregateExpressions.map(_.mode).distinct
-      assert(modes.size == 1 && modes.head == Complete)
-      (aggExprToProto(aggregateExpressions.head, output, true, conf), None)
-    } else {
-      (None, exprToProto(windowExpr.windowFunction, output))
-    }
-
-    if (aggExpr.isEmpty && builtinFunc.isEmpty) {
-      return None
-    }
-
-    val f = windowExpr.windowSpec.frameSpecification
-
-    val (frameType, lowerBound, upperBound) = f match {
-      case SpecifiedWindowFrame(frameType, lBound, uBound) =>
-        val frameProto = frameType match {
-          case RowFrame => OperatorOuterClass.WindowFrameType.Rows
-          case RangeFrame => OperatorOuterClass.WindowFrameType.Range
-        }
-
-        val lBoundProto = lBound match {
-          case UnboundedPreceding =>
-            OperatorOuterClass.LowerWindowFrameBound
-              .newBuilder()
-              
.setUnboundedPreceding(OperatorOuterClass.UnboundedPreceding.newBuilder().build())
-              .build()
-          case CurrentRow =>
-            OperatorOuterClass.LowerWindowFrameBound
-              .newBuilder()
-              
.setCurrentRow(OperatorOuterClass.CurrentRow.newBuilder().build())
-              .build()
-          case e if frameType == RowFrame =>
-            val offset = e.eval() match {
-              case i: Integer => i.toLong
-              case l: Long => l
-              case _ => return None
-            }
-            OperatorOuterClass.LowerWindowFrameBound
-              .newBuilder()
-              .setPreceding(
-                OperatorOuterClass.Preceding
-                  .newBuilder()
-                  .setOffset(offset)
-                  .build())
-              .build()
-          case _ =>
-            // TODO add support for numeric and temporal RANGE BETWEEN 
expressions
-            // see https://github.com/apache/datafusion-comet/issues/1246
-            return None
-        }
-
-        val uBoundProto = uBound match {
-          case UnboundedFollowing =>
-            OperatorOuterClass.UpperWindowFrameBound
-              .newBuilder()
-              
.setUnboundedFollowing(OperatorOuterClass.UnboundedFollowing.newBuilder().build())
-              .build()
-          case CurrentRow =>
-            OperatorOuterClass.UpperWindowFrameBound
-              .newBuilder()
-              
.setCurrentRow(OperatorOuterClass.CurrentRow.newBuilder().build())
-              .build()
-          case e if frameType == RowFrame =>
-            val offset = e.eval() match {
-              case i: Integer => i.toLong
-              case l: Long => l
-              case _ => return None
-            }
-            OperatorOuterClass.UpperWindowFrameBound
-              .newBuilder()
-              .setFollowing(
-                OperatorOuterClass.Following
-                  .newBuilder()
-                  .setOffset(offset)
-                  .build())
-              .build()
-          case _ =>
-            // TODO add support for numeric and temporal RANGE BETWEEN 
expressions
-            // see https://github.com/apache/datafusion-comet/issues/1246
-            return None
-        }
-
-        (frameProto, lBoundProto, uBoundProto)
-      case _ =>
-        (
-          OperatorOuterClass.WindowFrameType.Rows,
-          OperatorOuterClass.LowerWindowFrameBound
-            .newBuilder()
-            
.setUnboundedPreceding(OperatorOuterClass.UnboundedPreceding.newBuilder().build())
-            .build(),
-          OperatorOuterClass.UpperWindowFrameBound
-            .newBuilder()
-            
.setUnboundedFollowing(OperatorOuterClass.UnboundedFollowing.newBuilder().build())
-            .build())
-    }
-
-    val frame = OperatorOuterClass.WindowFrame
-      .newBuilder()
-      .setFrameType(frameType)
-      .setLowerBound(lowerBound)
-      .setUpperBound(upperBound)
-      .build()
-
-    val spec =
-      
OperatorOuterClass.WindowSpecDefinition.newBuilder().setFrameSpecification(frame).build()
-
-    if (builtinFunc.isDefined) {
-      Some(
-        OperatorOuterClass.WindowExpr
-          .newBuilder()
-          .setBuiltInWindowFunction(builtinFunc.get)
-          .setSpec(spec)
-          .build())
-    } else if (aggExpr.isDefined) {
-      Some(
-        OperatorOuterClass.WindowExpr
-          .newBuilder()
-          .setAggFunc(aggExpr.get)
-          .setSpec(spec)
-          .build())
-    } else {
-      None
-    }
-  }
-
   def aggExprToProto(
       aggExpr: AggregateExpression,
       inputs: Seq[Attribute],
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala 
b/spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala
index 159e4df6b..0a783b922 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala
@@ -21,19 +21,22 @@ package org.apache.spark.sql.comet
 
 import scala.jdk.CollectionConverters._
 
-import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
AttributeReference, Expression, NamedExpression, SortOrder, WindowExpression}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
AttributeReference, CurrentRow, Expression, NamedExpression, RangeFrame, 
RowFrame, SortOrder, SpecifiedWindowFrame, UnboundedFollowing, 
UnboundedPreceding, WindowExpression}
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Complete, Count, Max, Min, Sum}
 import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
 import org.apache.spark.sql.execution.window.WindowExec
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.DecimalType
 
 import com.google.common.base.Objects
 
 import org.apache.comet.{CometConf, ConfigEntry}
 import org.apache.comet.CometSparkSessionExtensions.withInfo
-import org.apache.comet.serde.{CometOperatorSerde, Incompatible, 
OperatorOuterClass, SupportLevel}
+import org.apache.comet.serde.{AggSerde, CometOperatorSerde, Incompatible, 
OperatorOuterClass, SupportLevel}
 import org.apache.comet.serde.OperatorOuterClass.Operator
-import org.apache.comet.serde.QueryPlanSerde.{exprToProto, windowExprToProto}
+import org.apache.comet.serde.QueryPlanSerde.{aggExprToProto, exprToProto}
 
 object CometWindowExec extends CometOperatorSerde[WindowExec] {
 
@@ -92,6 +95,178 @@ object CometWindowExec extends 
CometOperatorSerde[WindowExec] {
 
   }
 
+  private def windowExprToProto(
+      windowExpr: WindowExpression,
+      output: Seq[Attribute],
+      conf: SQLConf): Option[OperatorOuterClass.WindowExpr] = {
+
+    val aggregateExpressions: Array[AggregateExpression] = windowExpr.flatMap 
{ expr =>
+      expr match {
+        case agg: AggregateExpression =>
+          agg.aggregateFunction match {
+            case _: Count =>
+              Some(agg)
+            case min: Min =>
+              if (AggSerde.minMaxDataTypeSupported(min.dataType)) {
+                Some(agg)
+              } else {
+                withInfo(windowExpr, s"datatype ${min.dataType} is not 
supported", expr)
+                None
+              }
+            case max: Max =>
+              if (AggSerde.minMaxDataTypeSupported(max.dataType)) {
+                Some(agg)
+              } else {
+                withInfo(windowExpr, s"datatype ${max.dataType} is not 
supported", expr)
+                None
+              }
+            case s: Sum =>
+              if (AggSerde.sumDataTypeSupported(s.dataType) && !s.dataType
+                  .isInstanceOf[DecimalType]) {
+                Some(agg)
+              } else {
+                withInfo(windowExpr, s"datatype ${s.dataType} is not 
supported", expr)
+                None
+              }
+            case _ =>
+              withInfo(
+                windowExpr,
+                s"aggregate ${agg.aggregateFunction}" +
+                  " is not supported for window function",
+                expr)
+              None
+          }
+        case _ =>
+          None
+      }
+    }.toArray
+
+    val (aggExpr, builtinFunc) = if (aggregateExpressions.nonEmpty) {
+      val modes = aggregateExpressions.map(_.mode).distinct
+      assert(modes.size == 1 && modes.head == Complete)
+      (aggExprToProto(aggregateExpressions.head, output, true, conf), None)
+    } else {
+      (None, exprToProto(windowExpr.windowFunction, output))
+    }
+
+    if (aggExpr.isEmpty && builtinFunc.isEmpty) {
+      return None
+    }
+
+    val f = windowExpr.windowSpec.frameSpecification
+
+    val (frameType, lowerBound, upperBound) = f match {
+      case SpecifiedWindowFrame(frameType, lBound, uBound) =>
+        val frameProto = frameType match {
+          case RowFrame => OperatorOuterClass.WindowFrameType.Rows
+          case RangeFrame => OperatorOuterClass.WindowFrameType.Range
+        }
+
+        val lBoundProto = lBound match {
+          case UnboundedPreceding =>
+            OperatorOuterClass.LowerWindowFrameBound
+              .newBuilder()
+              
.setUnboundedPreceding(OperatorOuterClass.UnboundedPreceding.newBuilder().build())
+              .build()
+          case CurrentRow =>
+            OperatorOuterClass.LowerWindowFrameBound
+              .newBuilder()
+              
.setCurrentRow(OperatorOuterClass.CurrentRow.newBuilder().build())
+              .build()
+          case e if frameType == RowFrame =>
+            val offset = e.eval() match {
+              case i: Integer => i.toLong
+              case l: Long => l
+              case _ => return None
+            }
+            OperatorOuterClass.LowerWindowFrameBound
+              .newBuilder()
+              .setPreceding(
+                OperatorOuterClass.Preceding
+                  .newBuilder()
+                  .setOffset(offset)
+                  .build())
+              .build()
+          case _ =>
+            // TODO add support for numeric and temporal RANGE BETWEEN 
expressions
+            // see https://github.com/apache/datafusion-comet/issues/1246
+            return None
+        }
+
+        val uBoundProto = uBound match {
+          case UnboundedFollowing =>
+            OperatorOuterClass.UpperWindowFrameBound
+              .newBuilder()
+              
.setUnboundedFollowing(OperatorOuterClass.UnboundedFollowing.newBuilder().build())
+              .build()
+          case CurrentRow =>
+            OperatorOuterClass.UpperWindowFrameBound
+              .newBuilder()
+              
.setCurrentRow(OperatorOuterClass.CurrentRow.newBuilder().build())
+              .build()
+          case e if frameType == RowFrame =>
+            val offset = e.eval() match {
+              case i: Integer => i.toLong
+              case l: Long => l
+              case _ => return None
+            }
+            OperatorOuterClass.UpperWindowFrameBound
+              .newBuilder()
+              .setFollowing(
+                OperatorOuterClass.Following
+                  .newBuilder()
+                  .setOffset(offset)
+                  .build())
+              .build()
+          case _ =>
+            // TODO add support for numeric and temporal RANGE BETWEEN 
expressions
+            // see https://github.com/apache/datafusion-comet/issues/1246
+            return None
+        }
+
+        (frameProto, lBoundProto, uBoundProto)
+      case _ =>
+        (
+          OperatorOuterClass.WindowFrameType.Rows,
+          OperatorOuterClass.LowerWindowFrameBound
+            .newBuilder()
+            
.setUnboundedPreceding(OperatorOuterClass.UnboundedPreceding.newBuilder().build())
+            .build(),
+          OperatorOuterClass.UpperWindowFrameBound
+            .newBuilder()
+            
.setUnboundedFollowing(OperatorOuterClass.UnboundedFollowing.newBuilder().build())
+            .build())
+    }
+
+    val frame = OperatorOuterClass.WindowFrame
+      .newBuilder()
+      .setFrameType(frameType)
+      .setLowerBound(lowerBound)
+      .setUpperBound(upperBound)
+      .build()
+
+    val spec =
+      
OperatorOuterClass.WindowSpecDefinition.newBuilder().setFrameSpecification(frame).build()
+
+    if (builtinFunc.isDefined) {
+      Some(
+        OperatorOuterClass.WindowExpr
+          .newBuilder()
+          .setBuiltInWindowFunction(builtinFunc.get)
+          .setSpec(spec)
+          .build())
+    } else if (aggExpr.isDefined) {
+      Some(
+        OperatorOuterClass.WindowExpr
+          .newBuilder()
+          .setAggFunc(aggExpr.get)
+          .setSpec(spec)
+          .build())
+    } else {
+      None
+    }
+  }
+
   override def createExec(nativeOp: Operator, op: WindowExec): CometNativeExec 
= {
     CometWindowExec(
       nativeOp,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to