This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new ad17997cd chore: Refactor some of the scan and sink handling in 
`CometExecRule` to reduce duplicate code (#2844)
ad17997cd is described below

commit ad17997cd788c107f71309fd026143c3c97290f0
Author: Andy Grove <[email protected]>
AuthorDate: Thu Dec 4 11:56:08 2025 -0700

    chore: Refactor some of the scan and sink handling in `CometExecRule` to 
reduce duplicate code (#2844)
---
 .../org/apache/comet/rules/CometExecRule.scala     | 274 +++++++--------------
 .../apache/comet/serde/operator/CometSink.scala    |  18 ++
 .../spark/sql/comet/CometSparkToColumnarExec.scala |  10 +-
 .../org/apache/spark/sql/comet/operators.scala     |   6 +
 4 files changed, 122 insertions(+), 186 deletions(-)

diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala 
b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
index 9152b9f78..8e8098fd0 100644
--- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
+++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
@@ -19,8 +19,6 @@
 
 package org.apache.comet.rules
 
-import scala.jdk.CollectionConverters._
-
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.expressions.{Divide, DoubleLiteral, 
EqualNullSafe, EqualTo, Expression, FloatLiteral, GreaterThan, 
GreaterThanOrEqual, KnownFloatingPointNormalized, LessThan, LessThanOrEqual, 
NamedExpression, Remainder}
 import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
@@ -46,7 +44,6 @@ import org.apache.comet.CometSparkSessionExtensions._
 import org.apache.comet.rules.CometExecRule.allExecs
 import org.apache.comet.serde.{CometOperatorSerde, Compatible, Incompatible, 
OperatorOuterClass, QueryPlanSerde, Unsupported}
 import org.apache.comet.serde.OperatorOuterClass.Operator
-import org.apache.comet.serde.QueryPlanSerde.{serializeDataType, 
supportedDataType}
 import org.apache.comet.serde.operator._
 import org.apache.comet.serde.operator.CometDataWritingCommand
 
@@ -71,13 +68,6 @@ object CometExecRule {
       classOf[LocalTableScanExec] -> CometLocalTableScanExec,
       classOf[WindowExec] -> CometWindowExec)
 
-  /**
-   * DataWritingCommandExec is handled separately in convertNode since it 
doesn't follow the
-   * standard pattern of having CometNativeExec children.
-   */
-  val writeExecs: Map[Class[_ <: SparkPlan], CometOperatorSerde[_]] =
-    Map(classOf[DataWritingCommandExec] -> CometDataWritingCommand)
-
   /**
    * Sinks that have a native plan of ScanExec.
    */
@@ -186,57 +176,33 @@ case class CometExecRule(session: SparkSession) extends 
Rule[SparkPlan] {
     def convertNode(op: SparkPlan): SparkPlan = op match {
       // Fully native scan for V1
       case scan: CometScanExec if scan.scanImpl == 
CometConf.SCAN_NATIVE_DATAFUSION =>
-        val nativeOp = operator2Proto(scan).get
-        CometNativeScan.createExec(nativeOp, scan)
+        convertToComet(scan, CometNativeScan).getOrElse(scan)
 
       // Fully native Iceberg scan for V2 (iceberg-rust path)
       // Only handle scans with native metadata; SupportsComet scans fall 
through to isCometScan
       // Config checks (COMET_ICEBERG_NATIVE_ENABLED, COMET_EXEC_ENABLED) are 
done in CometScanRule
       case scan: CometBatchScanExec if 
scan.nativeIcebergScanMetadata.isDefined =>
-        operator2Proto(scan) match {
-          case Some(nativeOp) =>
-            CometIcebergNativeScan.createExec(nativeOp, scan)
-          case None =>
-            // Serialization failed, fall back to CometBatchScanExec
-            scan
-        }
+        convertToComet(scan, CometIcebergNativeScan).getOrElse(scan)
 
       // Comet JVM + native scan for V1 and V2
       case op if isCometScan(op) =>
-        val nativeOp = operator2Proto(op)
-        CometScanWrapper(nativeOp.get, op)
+        convertToComet(op, CometScanWrapper).getOrElse(op)
 
       case op if shouldApplySparkToColumnar(conf, op) =>
-        val cometOp = CometSparkToColumnarExec(op)
-        val nativeOp = operator2Proto(cometOp)
-        CometScanWrapper(nativeOp.get, cometOp)
-
-      // Handle DataWritingCommandExec specially since it doesn't follow the 
standard pattern
-      case exec: DataWritingCommandExec =>
-        CometExecRule.writeExecs.get(classOf[DataWritingCommandExec]) match {
-          case Some(handler) if isOperatorEnabled(handler, exec) =>
-            val builder = 
OperatorOuterClass.Operator.newBuilder().setPlanId(exec.id)
-            handler
-              .asInstanceOf[CometOperatorSerde[DataWritingCommandExec]]
-              .convert(exec, builder)
-              .map(nativeOp =>
-                handler
-                  .asInstanceOf[CometOperatorSerde[DataWritingCommandExec]]
-                  .createExec(nativeOp, exec))
-              .getOrElse(exec)
-          case _ =>
-            exec
-        }
+        convertToComet(op, CometSparkToColumnarExec).getOrElse(op)
+
+      case op: DataWritingCommandExec =>
+        convertToComet(op, CometDataWritingCommand).getOrElse(op)
 
       // For AQE broadcast stage on a Comet broadcast exchange
       case s @ BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) =>
-        newPlanWithProto(s, CometSinkPlaceHolder(_, s, s))
+        convertToCometIfAllChildrenAreNative(s, CometExchangeSink).getOrElse(s)
 
       case s @ BroadcastQueryStageExec(
             _,
             ReusedExchangeExec(_, _: CometBroadcastExchangeExec),
             _) =>
-        newPlanWithProto(s, CometSinkPlaceHolder(_, s, s))
+        convertToCometIfAllChildrenAreNative(s, CometExchangeSink).getOrElse(s)
 
       // `CometBroadcastExchangeExec`'s broadcast output is not compatible 
with Spark's broadcast
       // exchange. It is only used for Comet native execution. We only 
transform Spark broadcast
@@ -273,13 +239,13 @@ case class CometExecRule(session: SparkSession) extends 
Rule[SparkPlan] {
 
       // For AQE shuffle stage on a Comet shuffle exchange
       case s @ ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) =>
-        newPlanWithProto(s, CometSinkPlaceHolder(_, s, s))
+        convertToCometIfAllChildrenAreNative(s, CometExchangeSink).getOrElse(s)
 
       // For AQE shuffle stage on a reused Comet shuffle exchange
       // Note that we don't need to handle `ReusedExchangeExec` for non-AQE 
case, because
       // the query plan won't be re-optimized/planned in non-AQE mode.
       case s @ ShuffleQueryStageExec(_, ReusedExchangeExec(_, _: 
CometShuffleExchangeExec), _) =>
-        newPlanWithProto(s, CometSinkPlaceHolder(_, s, s))
+        convertToCometIfAllChildrenAreNative(s, CometExchangeSink).getOrElse(s)
 
       case s: ShuffleExchangeExec =>
         // try native shuffle first, then columnar shuffle, then fall back to 
Spark
@@ -287,23 +253,12 @@ case class CometExecRule(session: SparkSession) extends 
Rule[SparkPlan] {
         tryNativeShuffle(s).orElse(tryColumnarShuffle(s)).getOrElse(s)
 
       case op =>
-        allExecs
+        val handler = allExecs
           .get(op.getClass)
-          .map(_.asInstanceOf[CometOperatorSerde[SparkPlan]]) match {
+          .map(_.asInstanceOf[CometOperatorSerde[SparkPlan]])
+        handler match {
           case Some(handler) =>
-            if (op.children.forall(isCometNative)) {
-              if (isOperatorEnabled(handler, op)) {
-                val builder = 
OperatorOuterClass.Operator.newBuilder().setPlanId(op.id)
-                val childOp = 
op.children.map(_.asInstanceOf[CometNativeExec].nativeOp)
-                childOp.foreach(builder.addChildren)
-                return handler
-                  .convert(op, builder, childOp: _*)
-                  .map(handler.createExec(_, op))
-                  .getOrElse(op)
-              }
-            } else {
-              return op
-            }
+            return convertToCometIfAllChildrenAreNative(op, 
handler).getOrElse(op)
           case _ =>
         }
 
@@ -332,25 +287,11 @@ case class CometExecRule(session: SparkSession) extends 
Rule[SparkPlan] {
     }
   }
 
-  private def operator2ProtoIfAllChildrenAreNative(op: SparkPlan): 
Option[Operator] = {
-    if (op.children.forall(_.isInstanceOf[CometNativeExec])) {
-      operator2Proto(op, 
op.children.map(_.asInstanceOf[CometNativeExec].nativeOp): _*)
-    } else {
-      None
-    }
-  }
-
-  /**
-   * Convert operator to proto and then apply a transformation to wrap the 
proto in a new plan.
-   */
-  private def newPlanWithProto(op: SparkPlan, fun: Operator => SparkPlan): 
SparkPlan = {
-    operator2ProtoIfAllChildrenAreNative(op).map(fun).getOrElse(op)
-  }
-
   private def tryNativeShuffle(s: ShuffleExchangeExec): Option[SparkPlan] = {
     Some(s)
-      .filter(_ => nativeShuffleSupported(s))
-      .flatMap(_ => operator2ProtoIfAllChildrenAreNative(s))
+      .filter(nativeShuffleSupported)
+      .filter(_.children.forall(_.isInstanceOf[CometNativeExec]))
+      .flatMap(_ => operator2Proto(s))
       .map { nativeOp =>
         // Switch to use Decimal128 regardless of precision, since Arrow 
native execution
         // doesn't support Decimal32 and Decimal64 yet.
@@ -366,7 +307,7 @@ case class CometExecRule(session: SparkSession) extends 
Rule[SparkPlan] {
     // If the child of ShuffleExchangeExec is also a ShuffleExchangeExec, we 
should not
     // convert it to CometColumnarShuffle,
     Some(s)
-      .filter(_ => columnarShuffleSupported(s))
+      .filter(columnarShuffleSupported)
       .flatMap(_ => operator2Proto(s))
       .flatMap { nativeOp =>
         s.child match {
@@ -819,84 +760,45 @@ case class CometExecRule(session: SparkSession) extends 
Rule[SparkPlan] {
   }
 
   /**
-   * Convert a Spark plan operator to a protobuf Comet operator.
-   *
-   * @param op
-   *   Spark plan operator
-   * @param childOp
-   *   previously converted protobuf Comet operators, which will be consumed 
by the Spark plan
-   *   operator as its children
-   * @return
-   *   The converted Comet native operator for the input `op`, or `None` if 
the `op` cannot be
-   *   converted to a native operator.
+   * Fallback for handling sinks that have not been handled explicitly. This 
method should
+   * eventually be removed once CometExecRule fully uses the operator serde 
framework.
    */
   private def operator2Proto(op: SparkPlan, childOp: Operator*): 
Option[Operator] = {
+
+    def isCometSink(op: SparkPlan): Boolean = {
+      op match {
+        case _: CometSparkToColumnarExec => true
+        case _: CometSinkPlaceHolder => true
+        case _ => false
+      }
+    }
+
+    def isExchangeSink(op: SparkPlan): Boolean = {
+      op match {
+        case _: ShuffleExchangeExec => true
+        case ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) => true
+        case ShuffleQueryStageExec(_, ReusedExchangeExec(_, _: 
CometShuffleExchangeExec), _) =>
+          true
+        case BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => 
true
+        case BroadcastQueryStageExec(
+              _,
+              ReusedExchangeExec(_, _: CometBroadcastExchangeExec),
+              _) =>
+          true
+        case _: BroadcastExchangeExec => true
+        case _ => false
+      }
+    }
+
     val builder = OperatorOuterClass.Operator.newBuilder().setPlanId(op.id)
     childOp.foreach(builder.addChildren)
 
     op match {
-
-      // Fully native scan for V1
-      case scan: CometScanExec if scan.scanImpl == 
CometConf.SCAN_NATIVE_DATAFUSION =>
-        CometNativeScan.convert(scan, builder, childOp: _*)
-
-      // Fully native Iceberg scan for V2 (iceberg-rust path)
-      case scan: CometBatchScanExec if 
scan.nativeIcebergScanMetadata.isDefined =>
-        CometIcebergNativeScan.convert(scan, builder, childOp: _*)
+      case op if isExchangeSink(op) =>
+        CometExchangeSink.convert(op, builder, childOp: _*)
 
       case op if isCometSink(op) =>
-        val supportedTypes =
-          op.output.forall(a => supportedDataType(a.dataType, allowComplex = 
true))
-
-        if (!supportedTypes) {
-          withInfo(op, "Unsupported data type")
-          return None
-        }
-
-        // These operators are source of Comet native execution chain
-        val scanBuilder = OperatorOuterClass.Scan.newBuilder()
-        val source = op.simpleStringWithNodeId()
-        if (source.isEmpty) {
-          scanBuilder.setSource(op.getClass.getSimpleName)
-        } else {
-          scanBuilder.setSource(source)
-        }
-
-        val ffiSafe = op match {
-          case _ if isExchangeSink(op) =>
-            // Source of broadcast exchange batches is ArrowStreamReader
-            // Source of shuffle exchange batches is NativeBatchDecoderIterator
-            true
-          case scan: CometScanExec if scan.scanImpl == 
CometConf.SCAN_NATIVE_COMET =>
-            // native_comet scan reuses mutable buffers
-            false
-          case scan: CometScanExec if scan.scanImpl == 
CometConf.SCAN_NATIVE_ICEBERG_COMPAT =>
-            // native_iceberg_compat scan reuses mutable buffers for constant 
columns
-            // https://github.com/apache/datafusion-comet/issues/2152
-            false
-          case _ =>
-            false
-        }
-        scanBuilder.setArrowFfiSafe(ffiSafe)
-
-        val scanTypes = op.output.flatten { attr =>
-          serializeDataType(attr.dataType)
-        }
-
-        if (scanTypes.length == op.output.length) {
-          scanBuilder.addAllFields(scanTypes.asJava)
-
-          // Sink operators don't have children
-          builder.clearChildren()
-
-          Some(builder.setScan(scanBuilder).build())
-        } else {
-          // There are unsupported scan type
-          withInfo(
-            op,
-            s"unsupported Comet operator: ${op.nodeName}, due to unsupported 
data types above")
-          None
-        }
+        CometScanWrapper.convert(op, builder, childOp: _*)
 
       case _ =>
         // Emit warning if:
@@ -910,12 +812,46 @@ case class CometExecRule(session: SparkSession) extends 
Rule[SparkPlan] {
     }
   }
 
-  private def isOperatorEnabled(handler: CometOperatorSerde[_], op: 
SparkPlan): Boolean = {
-    val enabled = handler.enabledConfig.forall(_.get(op.conf))
+  /**
+   * Convert a Spark plan to a Comet plan using the specified serde handler, 
but only if all
+   * children are native.
+   */
+  private def convertToCometIfAllChildrenAreNative(
+      op: SparkPlan,
+      handler: CometOperatorSerde[_]): Option[SparkPlan] = {
+    if (op.children.forall(_.isInstanceOf[CometNativeExec])) {
+      convertToComet(op, handler)
+    } else {
+      None
+    }
+  }
+
+  /** Convert a Spark plan to a Comet plan using the specified serde handler */
+  private def convertToComet(op: SparkPlan, handler: CometOperatorSerde[_]): 
Option[SparkPlan] = {
+    val serde = handler.asInstanceOf[CometOperatorSerde[SparkPlan]]
+    if (isOperatorEnabled(serde, op)) {
+      val builder = OperatorOuterClass.Operator.newBuilder().setPlanId(op.id)
+      if (op.children.forall(_.isInstanceOf[CometNativeExec])) {
+        val childOp = op.children.map(_.asInstanceOf[CometNativeExec].nativeOp)
+        childOp.foreach(builder.addChildren)
+        return serde
+          .convert(op, builder, childOp: _*)
+          .map(nativeOp => serde.createExec(nativeOp, op))
+      } else {
+        return serde
+          .convert(op, builder)
+          .map(nativeOp => serde.createExec(nativeOp, op))
+      }
+    }
+    None
+  }
+
+  private def isOperatorEnabled(
+      handler: CometOperatorSerde[SparkPlan],
+      op: SparkPlan): Boolean = {
     val opName = op.getClass.getSimpleName
-    if (enabled) {
-      val opSerde = handler.asInstanceOf[CometOperatorSerde[SparkPlan]]
-      opSerde.getSupportLevel(op) match {
+    if (handler.enabledConfig.forall(_.get(op.conf))) {
+      handler.getSupportLevel(op) match {
         case Unsupported(notes) =>
           withInfo(op, notes.getOrElse(""))
           false
@@ -952,36 +888,4 @@ case class CometExecRule(session: SparkSession) extends 
Rule[SparkPlan] {
       false
     }
   }
-
-  /**
-   * Whether the input Spark operator `op` can be considered as a Comet sink, 
i.e., the start of
-   * native execution. If it is true, we'll wrap `op` with `CometScanWrapper` 
or
-   * `CometSinkPlaceHolder` later in `CometSparkSessionExtensions` after 
`operator2proto` is
-   * called.
-   */
-  private def isCometSink(op: SparkPlan): Boolean = {
-    if (isExchangeSink(op)) {
-      return true
-    }
-    op match {
-      case s if isCometScan(s) => true
-      case _: CometSparkToColumnarExec => true
-      case _: CometSinkPlaceHolder => true
-      case _ => false
-    }
-  }
-
-  private def isExchangeSink(op: SparkPlan): Boolean = {
-    op match {
-      case _: ShuffleExchangeExec => true
-      case ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) => true
-      case ShuffleQueryStageExec(_, ReusedExchangeExec(_, _: 
CometShuffleExchangeExec), _) => true
-      case BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => true
-      case BroadcastQueryStageExec(_, ReusedExchangeExec(_, _: 
CometBroadcastExchangeExec), _) =>
-        true
-      case _: BroadcastExchangeExec => true
-      case _ => false
-    }
-  }
-
 }
diff --git 
a/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala 
b/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala
index 8dfc323bc..ca9dbdad7 100644
--- a/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala
@@ -21,9 +21,11 @@ package org.apache.comet.serde.operator
 
 import scala.jdk.CollectionConverters._
 
+import org.apache.spark.sql.comet.{CometNativeExec, CometSinkPlaceHolder}
 import org.apache.spark.sql.execution.SparkPlan
 
 import org.apache.comet.CometSparkSessionExtensions.withInfo
+import org.apache.comet.ConfigEntry
 import org.apache.comet.serde.{CometOperatorSerde, OperatorOuterClass}
 import org.apache.comet.serde.OperatorOuterClass.Operator
 import org.apache.comet.serde.QueryPlanSerde.{serializeDataType, 
supportedDataType}
@@ -37,6 +39,8 @@ abstract class CometSink[T <: SparkPlan] extends 
CometOperatorSerde[T] {
   /** Whether the data produced by the Comet operator is FFI safe */
   def isFfiSafe: Boolean = false
 
+  override def enabledConfig: Option[ConfigEntry[Boolean]] = None
+
   override def convert(
       op: T,
       builder: Operator.Builder,
@@ -78,5 +82,19 @@ abstract class CometSink[T <: SparkPlan] extends 
CometOperatorSerde[T] {
       None
     }
   }
+}
+
+object CometExchangeSink extends CometSink[SparkPlan] {
+
+  /**
+   * Exchange data is FFI safe because there is no use of mutable buffers 
involved.
+   *
+   * Source of broadcast exchange batches is ArrowStreamReader.
+   *
+   * Source of shuffle exchange batches is NativeBatchDecoderIterator.
+   */
+  override def isFfiSafe: Boolean = true
 
+  override def createExec(nativeOp: Operator, op: SparkPlan): CometNativeExec =
+    CometSinkPlaceHolder(nativeOp, op, op)
 }
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala
 
b/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala
index bcf891857..a8a61e7a7 100644
--- 
a/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala
+++ 
b/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala
@@ -34,6 +34,8 @@ import org.apache.spark.sql.types._
 import org.apache.spark.sql.vectorized.ColumnarBatch
 
 import org.apache.comet.{CometConf, DataTypeSupport}
+import org.apache.comet.serde.OperatorOuterClass
+import org.apache.comet.serde.operator.CometSink
 
 case class CometSparkToColumnarExec(child: SparkPlan)
     extends RowToColumnarTransition
@@ -136,7 +138,13 @@ case class CometSparkToColumnarExec(child: SparkPlan)
 
 }
 
-object CometSparkToColumnarExec extends DataTypeSupport {
+object CometSparkToColumnarExec extends CometSink[SparkPlan] with 
DataTypeSupport {
+  override def createExec(
+      nativeOp: OperatorOuterClass.Operator,
+      op: SparkPlan): CometNativeExec = {
+    CometScanWrapper(nativeOp, CometSparkToColumnarExec(op))
+  }
+
   override def isTypeSupported(
       dt: DataType,
       name: String,
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala 
b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
index c955f79d9..6eebc53d5 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
@@ -1728,6 +1728,12 @@ case class CometSortMergeJoinExec(
     CometMetricNode.sortMergeJoinMetrics(sparkContext)
 }
 
+object CometScanWrapper extends CometSink[SparkPlan] {
+  override def createExec(nativeOp: Operator, op: SparkPlan): CometNativeExec 
= {
+    CometScanWrapper(nativeOp, op)
+  }
+}
+
 case class CometScanWrapper(override val nativeOp: Operator, override val 
originalPlan: SparkPlan)
     extends CometNativeExec
     with LeafExecNode {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to