This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new ad17997cd chore: Refactor some of the scan and sink handling in
`CometExecRule` to reduce duplicate code (#2844)
ad17997cd is described below
commit ad17997cd788c107f71309fd026143c3c97290f0
Author: Andy Grove <[email protected]>
AuthorDate: Thu Dec 4 11:56:08 2025 -0700
chore: Refactor some of the scan and sink handling in `CometExecRule` to
reduce duplicate code (#2844)
---
.../org/apache/comet/rules/CometExecRule.scala | 274 +++++++--------------
.../apache/comet/serde/operator/CometSink.scala | 18 ++
.../spark/sql/comet/CometSparkToColumnarExec.scala | 10 +-
.../org/apache/spark/sql/comet/operators.scala | 6 +
4 files changed, 122 insertions(+), 186 deletions(-)
diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
index 9152b9f78..8e8098fd0 100644
--- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
+++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
@@ -19,8 +19,6 @@
package org.apache.comet.rules
-import scala.jdk.CollectionConverters._
-
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{Divide, DoubleLiteral,
EqualNullSafe, EqualTo, Expression, FloatLiteral, GreaterThan,
GreaterThanOrEqual, KnownFloatingPointNormalized, LessThan, LessThanOrEqual,
NamedExpression, Remainder}
import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
@@ -46,7 +44,6 @@ import org.apache.comet.CometSparkSessionExtensions._
import org.apache.comet.rules.CometExecRule.allExecs
import org.apache.comet.serde.{CometOperatorSerde, Compatible, Incompatible,
OperatorOuterClass, QueryPlanSerde, Unsupported}
import org.apache.comet.serde.OperatorOuterClass.Operator
-import org.apache.comet.serde.QueryPlanSerde.{serializeDataType,
supportedDataType}
import org.apache.comet.serde.operator._
import org.apache.comet.serde.operator.CometDataWritingCommand
@@ -71,13 +68,6 @@ object CometExecRule {
classOf[LocalTableScanExec] -> CometLocalTableScanExec,
classOf[WindowExec] -> CometWindowExec)
- /**
- * DataWritingCommandExec is handled separately in convertNode since it
doesn't follow the
- * standard pattern of having CometNativeExec children.
- */
- val writeExecs: Map[Class[_ <: SparkPlan], CometOperatorSerde[_]] =
- Map(classOf[DataWritingCommandExec] -> CometDataWritingCommand)
-
/**
* Sinks that have a native plan of ScanExec.
*/
@@ -186,57 +176,33 @@ case class CometExecRule(session: SparkSession) extends
Rule[SparkPlan] {
def convertNode(op: SparkPlan): SparkPlan = op match {
// Fully native scan for V1
case scan: CometScanExec if scan.scanImpl ==
CometConf.SCAN_NATIVE_DATAFUSION =>
- val nativeOp = operator2Proto(scan).get
- CometNativeScan.createExec(nativeOp, scan)
+ convertToComet(scan, CometNativeScan).getOrElse(scan)
// Fully native Iceberg scan for V2 (iceberg-rust path)
// Only handle scans with native metadata; SupportsComet scans fall
through to isCometScan
// Config checks (COMET_ICEBERG_NATIVE_ENABLED, COMET_EXEC_ENABLED) are
done in CometScanRule
case scan: CometBatchScanExec if
scan.nativeIcebergScanMetadata.isDefined =>
- operator2Proto(scan) match {
- case Some(nativeOp) =>
- CometIcebergNativeScan.createExec(nativeOp, scan)
- case None =>
- // Serialization failed, fall back to CometBatchScanExec
- scan
- }
+ convertToComet(scan, CometIcebergNativeScan).getOrElse(scan)
// Comet JVM + native scan for V1 and V2
case op if isCometScan(op) =>
- val nativeOp = operator2Proto(op)
- CometScanWrapper(nativeOp.get, op)
+ convertToComet(op, CometScanWrapper).getOrElse(op)
case op if shouldApplySparkToColumnar(conf, op) =>
- val cometOp = CometSparkToColumnarExec(op)
- val nativeOp = operator2Proto(cometOp)
- CometScanWrapper(nativeOp.get, cometOp)
-
- // Handle DataWritingCommandExec specially since it doesn't follow the
standard pattern
- case exec: DataWritingCommandExec =>
- CometExecRule.writeExecs.get(classOf[DataWritingCommandExec]) match {
- case Some(handler) if isOperatorEnabled(handler, exec) =>
- val builder =
OperatorOuterClass.Operator.newBuilder().setPlanId(exec.id)
- handler
- .asInstanceOf[CometOperatorSerde[DataWritingCommandExec]]
- .convert(exec, builder)
- .map(nativeOp =>
- handler
- .asInstanceOf[CometOperatorSerde[DataWritingCommandExec]]
- .createExec(nativeOp, exec))
- .getOrElse(exec)
- case _ =>
- exec
- }
+ convertToComet(op, CometSparkToColumnarExec).getOrElse(op)
+
+ case op: DataWritingCommandExec =>
+ convertToComet(op, CometDataWritingCommand).getOrElse(op)
// For AQE broadcast stage on a Comet broadcast exchange
case s @ BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) =>
- newPlanWithProto(s, CometSinkPlaceHolder(_, s, s))
+ convertToCometIfAllChildrenAreNative(s, CometExchangeSink).getOrElse(s)
case s @ BroadcastQueryStageExec(
_,
ReusedExchangeExec(_, _: CometBroadcastExchangeExec),
_) =>
- newPlanWithProto(s, CometSinkPlaceHolder(_, s, s))
+ convertToCometIfAllChildrenAreNative(s, CometExchangeSink).getOrElse(s)
// `CometBroadcastExchangeExec`'s broadcast output is not compatible
with Spark's broadcast
// exchange. It is only used for Comet native execution. We only
transform Spark broadcast
@@ -273,13 +239,13 @@ case class CometExecRule(session: SparkSession) extends
Rule[SparkPlan] {
// For AQE shuffle stage on a Comet shuffle exchange
case s @ ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) =>
- newPlanWithProto(s, CometSinkPlaceHolder(_, s, s))
+ convertToCometIfAllChildrenAreNative(s, CometExchangeSink).getOrElse(s)
// For AQE shuffle stage on a reused Comet shuffle exchange
// Note that we don't need to handle `ReusedExchangeExec` for non-AQE
case, because
// the query plan won't be re-optimized/planned in non-AQE mode.
case s @ ShuffleQueryStageExec(_, ReusedExchangeExec(_, _:
CometShuffleExchangeExec), _) =>
- newPlanWithProto(s, CometSinkPlaceHolder(_, s, s))
+ convertToCometIfAllChildrenAreNative(s, CometExchangeSink).getOrElse(s)
case s: ShuffleExchangeExec =>
// try native shuffle first, then columnar shuffle, then fall back to
Spark
@@ -287,23 +253,12 @@ case class CometExecRule(session: SparkSession) extends
Rule[SparkPlan] {
tryNativeShuffle(s).orElse(tryColumnarShuffle(s)).getOrElse(s)
case op =>
- allExecs
+ val handler = allExecs
.get(op.getClass)
- .map(_.asInstanceOf[CometOperatorSerde[SparkPlan]]) match {
+ .map(_.asInstanceOf[CometOperatorSerde[SparkPlan]])
+ handler match {
case Some(handler) =>
- if (op.children.forall(isCometNative)) {
- if (isOperatorEnabled(handler, op)) {
- val builder =
OperatorOuterClass.Operator.newBuilder().setPlanId(op.id)
- val childOp =
op.children.map(_.asInstanceOf[CometNativeExec].nativeOp)
- childOp.foreach(builder.addChildren)
- return handler
- .convert(op, builder, childOp: _*)
- .map(handler.createExec(_, op))
- .getOrElse(op)
- }
- } else {
- return op
- }
+ return convertToCometIfAllChildrenAreNative(op,
handler).getOrElse(op)
case _ =>
}
@@ -332,25 +287,11 @@ case class CometExecRule(session: SparkSession) extends
Rule[SparkPlan] {
}
}
- private def operator2ProtoIfAllChildrenAreNative(op: SparkPlan):
Option[Operator] = {
- if (op.children.forall(_.isInstanceOf[CometNativeExec])) {
- operator2Proto(op,
op.children.map(_.asInstanceOf[CometNativeExec].nativeOp): _*)
- } else {
- None
- }
- }
-
- /**
- * Convert operator to proto and then apply a transformation to wrap the
proto in a new plan.
- */
- private def newPlanWithProto(op: SparkPlan, fun: Operator => SparkPlan):
SparkPlan = {
- operator2ProtoIfAllChildrenAreNative(op).map(fun).getOrElse(op)
- }
-
private def tryNativeShuffle(s: ShuffleExchangeExec): Option[SparkPlan] = {
Some(s)
- .filter(_ => nativeShuffleSupported(s))
- .flatMap(_ => operator2ProtoIfAllChildrenAreNative(s))
+ .filter(nativeShuffleSupported)
+ .filter(_.children.forall(_.isInstanceOf[CometNativeExec]))
+ .flatMap(_ => operator2Proto(s))
.map { nativeOp =>
// Switch to use Decimal128 regardless of precision, since Arrow
native execution
// doesn't support Decimal32 and Decimal64 yet.
@@ -366,7 +307,7 @@ case class CometExecRule(session: SparkSession) extends
Rule[SparkPlan] {
// If the child of ShuffleExchangeExec is also a ShuffleExchangeExec, we
should not
// convert it to CometColumnarShuffle,
Some(s)
- .filter(_ => columnarShuffleSupported(s))
+ .filter(columnarShuffleSupported)
.flatMap(_ => operator2Proto(s))
.flatMap { nativeOp =>
s.child match {
@@ -819,84 +760,45 @@ case class CometExecRule(session: SparkSession) extends
Rule[SparkPlan] {
}
/**
- * Convert a Spark plan operator to a protobuf Comet operator.
- *
- * @param op
- * Spark plan operator
- * @param childOp
- * previously converted protobuf Comet operators, which will be consumed
by the Spark plan
- * operator as its children
- * @return
- * The converted Comet native operator for the input `op`, or `None` if
the `op` cannot be
- * converted to a native operator.
+ * Fallback for handling sinks that have not been handled explicitly. This
method should
+ * eventually be removed once CometExecRule fully uses the operator serde
framework.
*/
private def operator2Proto(op: SparkPlan, childOp: Operator*):
Option[Operator] = {
+
+ def isCometSink(op: SparkPlan): Boolean = {
+ op match {
+ case _: CometSparkToColumnarExec => true
+ case _: CometSinkPlaceHolder => true
+ case _ => false
+ }
+ }
+
+ def isExchangeSink(op: SparkPlan): Boolean = {
+ op match {
+ case _: ShuffleExchangeExec => true
+ case ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) => true
+ case ShuffleQueryStageExec(_, ReusedExchangeExec(_, _:
CometShuffleExchangeExec), _) =>
+ true
+ case BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) =>
true
+ case BroadcastQueryStageExec(
+ _,
+ ReusedExchangeExec(_, _: CometBroadcastExchangeExec),
+ _) =>
+ true
+ case _: BroadcastExchangeExec => true
+ case _ => false
+ }
+ }
+
val builder = OperatorOuterClass.Operator.newBuilder().setPlanId(op.id)
childOp.foreach(builder.addChildren)
op match {
-
- // Fully native scan for V1
- case scan: CometScanExec if scan.scanImpl ==
CometConf.SCAN_NATIVE_DATAFUSION =>
- CometNativeScan.convert(scan, builder, childOp: _*)
-
- // Fully native Iceberg scan for V2 (iceberg-rust path)
- case scan: CometBatchScanExec if
scan.nativeIcebergScanMetadata.isDefined =>
- CometIcebergNativeScan.convert(scan, builder, childOp: _*)
+ case op if isExchangeSink(op) =>
+ CometExchangeSink.convert(op, builder, childOp: _*)
case op if isCometSink(op) =>
- val supportedTypes =
- op.output.forall(a => supportedDataType(a.dataType, allowComplex =
true))
-
- if (!supportedTypes) {
- withInfo(op, "Unsupported data type")
- return None
- }
-
- // These operators are source of Comet native execution chain
- val scanBuilder = OperatorOuterClass.Scan.newBuilder()
- val source = op.simpleStringWithNodeId()
- if (source.isEmpty) {
- scanBuilder.setSource(op.getClass.getSimpleName)
- } else {
- scanBuilder.setSource(source)
- }
-
- val ffiSafe = op match {
- case _ if isExchangeSink(op) =>
- // Source of broadcast exchange batches is ArrowStreamReader
- // Source of shuffle exchange batches is NativeBatchDecoderIterator
- true
- case scan: CometScanExec if scan.scanImpl ==
CometConf.SCAN_NATIVE_COMET =>
- // native_comet scan reuses mutable buffers
- false
- case scan: CometScanExec if scan.scanImpl ==
CometConf.SCAN_NATIVE_ICEBERG_COMPAT =>
- // native_iceberg_compat scan reuses mutable buffers for constant
columns
- // https://github.com/apache/datafusion-comet/issues/2152
- false
- case _ =>
- false
- }
- scanBuilder.setArrowFfiSafe(ffiSafe)
-
- val scanTypes = op.output.flatten { attr =>
- serializeDataType(attr.dataType)
- }
-
- if (scanTypes.length == op.output.length) {
- scanBuilder.addAllFields(scanTypes.asJava)
-
- // Sink operators don't have children
- builder.clearChildren()
-
- Some(builder.setScan(scanBuilder).build())
- } else {
- // There are unsupported scan type
- withInfo(
- op,
- s"unsupported Comet operator: ${op.nodeName}, due to unsupported
data types above")
- None
- }
+ CometScanWrapper.convert(op, builder, childOp: _*)
case _ =>
// Emit warning if:
@@ -910,12 +812,46 @@ case class CometExecRule(session: SparkSession) extends
Rule[SparkPlan] {
}
}
- private def isOperatorEnabled(handler: CometOperatorSerde[_], op:
SparkPlan): Boolean = {
- val enabled = handler.enabledConfig.forall(_.get(op.conf))
+ /**
+ * Convert a Spark plan to a Comet plan using the specified serde handler,
but only if all
+ * children are native.
+ */
+ private def convertToCometIfAllChildrenAreNative(
+ op: SparkPlan,
+ handler: CometOperatorSerde[_]): Option[SparkPlan] = {
+ if (op.children.forall(_.isInstanceOf[CometNativeExec])) {
+ convertToComet(op, handler)
+ } else {
+ None
+ }
+ }
+
+ /** Convert a Spark plan to a Comet plan using the specified serde handler */
+ private def convertToComet(op: SparkPlan, handler: CometOperatorSerde[_]):
Option[SparkPlan] = {
+ val serde = handler.asInstanceOf[CometOperatorSerde[SparkPlan]]
+ if (isOperatorEnabled(serde, op)) {
+ val builder = OperatorOuterClass.Operator.newBuilder().setPlanId(op.id)
+ if (op.children.forall(_.isInstanceOf[CometNativeExec])) {
+ val childOp = op.children.map(_.asInstanceOf[CometNativeExec].nativeOp)
+ childOp.foreach(builder.addChildren)
+ return serde
+ .convert(op, builder, childOp: _*)
+ .map(nativeOp => serde.createExec(nativeOp, op))
+ } else {
+ return serde
+ .convert(op, builder)
+ .map(nativeOp => serde.createExec(nativeOp, op))
+ }
+ }
+ None
+ }
+
+ private def isOperatorEnabled(
+ handler: CometOperatorSerde[SparkPlan],
+ op: SparkPlan): Boolean = {
val opName = op.getClass.getSimpleName
- if (enabled) {
- val opSerde = handler.asInstanceOf[CometOperatorSerde[SparkPlan]]
- opSerde.getSupportLevel(op) match {
+ if (handler.enabledConfig.forall(_.get(op.conf))) {
+ handler.getSupportLevel(op) match {
case Unsupported(notes) =>
withInfo(op, notes.getOrElse(""))
false
@@ -952,36 +888,4 @@ case class CometExecRule(session: SparkSession) extends
Rule[SparkPlan] {
false
}
}
-
- /**
- * Whether the input Spark operator `op` can be considered as a Comet sink,
i.e., the start of
- * native execution. If it is true, we'll wrap `op` with `CometScanWrapper`
or
- * `CometSinkPlaceHolder` later in `CometSparkSessionExtensions` after
`operator2proto` is
- * called.
- */
- private def isCometSink(op: SparkPlan): Boolean = {
- if (isExchangeSink(op)) {
- return true
- }
- op match {
- case s if isCometScan(s) => true
- case _: CometSparkToColumnarExec => true
- case _: CometSinkPlaceHolder => true
- case _ => false
- }
- }
-
- private def isExchangeSink(op: SparkPlan): Boolean = {
- op match {
- case _: ShuffleExchangeExec => true
- case ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) => true
- case ShuffleQueryStageExec(_, ReusedExchangeExec(_, _:
CometShuffleExchangeExec), _) => true
- case BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => true
- case BroadcastQueryStageExec(_, ReusedExchangeExec(_, _:
CometBroadcastExchangeExec), _) =>
- true
- case _: BroadcastExchangeExec => true
- case _ => false
- }
- }
-
}
diff --git
a/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala
b/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala
index 8dfc323bc..ca9dbdad7 100644
--- a/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala
@@ -21,9 +21,11 @@ package org.apache.comet.serde.operator
import scala.jdk.CollectionConverters._
+import org.apache.spark.sql.comet.{CometNativeExec, CometSinkPlaceHolder}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.comet.CometSparkSessionExtensions.withInfo
+import org.apache.comet.ConfigEntry
import org.apache.comet.serde.{CometOperatorSerde, OperatorOuterClass}
import org.apache.comet.serde.OperatorOuterClass.Operator
import org.apache.comet.serde.QueryPlanSerde.{serializeDataType,
supportedDataType}
@@ -37,6 +39,8 @@ abstract class CometSink[T <: SparkPlan] extends
CometOperatorSerde[T] {
/** Whether the data produced by the Comet operator is FFI safe */
def isFfiSafe: Boolean = false
+ override def enabledConfig: Option[ConfigEntry[Boolean]] = None
+
override def convert(
op: T,
builder: Operator.Builder,
@@ -78,5 +82,19 @@ abstract class CometSink[T <: SparkPlan] extends
CometOperatorSerde[T] {
None
}
}
+}
+
+object CometExchangeSink extends CometSink[SparkPlan] {
+
+ /**
+ * Exchange data is FFI safe because there is no use of mutable buffers
involved.
+ *
+ * Source of broadcast exchange batches is ArrowStreamReader.
+ *
+ * Source of shuffle exchange batches is NativeBatchDecoderIterator.
+ */
+ override def isFfiSafe: Boolean = true
+ override def createExec(nativeOp: Operator, op: SparkPlan): CometNativeExec =
+ CometSinkPlaceHolder(nativeOp, op, op)
}
diff --git
a/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala
index bcf891857..a8a61e7a7 100644
---
a/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala
+++
b/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala
@@ -34,6 +34,8 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.comet.{CometConf, DataTypeSupport}
+import org.apache.comet.serde.OperatorOuterClass
+import org.apache.comet.serde.operator.CometSink
case class CometSparkToColumnarExec(child: SparkPlan)
extends RowToColumnarTransition
@@ -136,7 +138,13 @@ case class CometSparkToColumnarExec(child: SparkPlan)
}
-object CometSparkToColumnarExec extends DataTypeSupport {
+object CometSparkToColumnarExec extends CometSink[SparkPlan] with
DataTypeSupport {
+ override def createExec(
+ nativeOp: OperatorOuterClass.Operator,
+ op: SparkPlan): CometNativeExec = {
+ CometScanWrapper(nativeOp, CometSparkToColumnarExec(op))
+ }
+
override def isTypeSupported(
dt: DataType,
name: String,
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
index c955f79d9..6eebc53d5 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
@@ -1728,6 +1728,12 @@ case class CometSortMergeJoinExec(
CometMetricNode.sortMergeJoinMetrics(sparkContext)
}
+object CometScanWrapper extends CometSink[SparkPlan] {
+ override def createExec(nativeOp: Operator, op: SparkPlan): CometNativeExec
= {
+ CometScanWrapper(nativeOp, op)
+ }
+}
+
case class CometScanWrapper(override val nativeOp: Operator, override val
originalPlan: SparkPlan)
extends CometNativeExec
with LeafExecNode {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]