This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 5c6a0271e minor: Move shuffle logic from `CometExecRule` to
`CometShuffleExchangeExec` serde implementation (#2853)
5c6a0271e is described below
commit 5c6a0271e55acc647b984f9feead993fec83d616
Author: Andy Grove <[email protected]>
AuthorDate: Fri Dec 5 16:57:04 2025 -0700
minor: Move shuffle logic from `CometExecRule` to
`CometShuffleExchangeExec` serde implementation (#2853)
---
.../apache/comet/CometSparkSessionExtensions.scala | 16 -
.../org/apache/comet/rules/CometExecRule.scala | 313 +-------------------
.../shuffle/CometShuffleExchangeExec.scala | 326 ++++++++++++++++++++-
3 files changed, 328 insertions(+), 327 deletions(-)
diff --git
a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
index 71ce8d311..01a11bf0d 100644
--- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
+++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
@@ -157,22 +157,6 @@ object CometSparkSessionExtensions extends Logging {
COMET_EXEC_ENABLED.get(conf)
}
- private[comet] def isCometNativeShuffleMode(conf: SQLConf): Boolean = {
- COMET_SHUFFLE_MODE.get(conf) match {
- case "native" => true
- case "auto" => true
- case _ => false
- }
- }
-
- private[comet] def isCometJVMShuffleMode(conf: SQLConf): Boolean = {
- COMET_SHUFFLE_MODE.get(conf) match {
- case "jvm" => true
- case "auto" => true
- case _ => false
- }
- }
-
def isCometScan(op: SparkPlan): Boolean = {
op.isInstanceOf[CometBatchScanExec] || op.isInstanceOf[CometScanExec]
}
diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
index 0ef10ec09..6879fba4e 100644
--- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
+++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
@@ -22,11 +22,10 @@ package org.apache.comet.rules
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{Divide, DoubleLiteral,
EqualNullSafe, EqualTo, Expression, FloatLiteral, GreaterThan,
GreaterThanOrEqual, KnownFloatingPointNormalized, LessThan, LessThanOrEqual,
NamedExpression, Remainder}
import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
-import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning,
RangePartitioning, RoundRobinPartitioning, SinglePartition}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.sideBySide
import org.apache.spark.sql.comet._
-import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle,
CometNativeShuffle, CometShuffleExchangeExec, CometShuffleManager}
+import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle,
CometNativeShuffle, CometShuffleExchangeExec}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec,
AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec,
ObjectHashAggregateExec}
@@ -35,14 +34,12 @@ import
org.apache.spark.sql.execution.datasources.v2.V2CommandExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec,
ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec,
ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.window.WindowExec
-import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.comet.{CometConf, ExtendedExplainInfo}
-import org.apache.comet.CometConf.COMET_EXEC_SHUFFLE_ENABLED
import org.apache.comet.CometSparkSessionExtensions._
import org.apache.comet.rules.CometExecRule.allExecs
-import org.apache.comet.serde.{CometOperatorSerde, Compatible, Incompatible,
OperatorOuterClass, QueryPlanSerde, Unsupported}
+import org.apache.comet.serde.{CometOperatorSerde, Compatible, Incompatible,
OperatorOuterClass, Unsupported}
import org.apache.comet.serde.OperatorOuterClass.Operator
import org.apache.comet.serde.operator._
import org.apache.comet.serde.operator.CometDataWritingCommand
@@ -92,21 +89,19 @@ case class CometExecRule(session: SparkSession) extends
Rule[SparkPlan] {
private def applyCometShuffle(plan: SparkPlan): SparkPlan = {
plan.transformUp {
- case s: ShuffleExchangeExec if nativeShuffleSupported(s) =>
+ case s: ShuffleExchangeExec if
CometShuffleExchangeExec.nativeShuffleSupported(s) =>
// Switch to use Decimal128 regardless of precision, since Arrow
native execution
// doesn't support Decimal32 and Decimal64 yet.
conf.setConfString(CometConf.COMET_USE_DECIMAL_128.key, "true")
CometShuffleExchangeExec(s, shuffleType = CometNativeShuffle)
- case s: ShuffleExchangeExec if columnarShuffleSupported(s) =>
+ case s: ShuffleExchangeExec if
CometShuffleExchangeExec.columnarShuffleSupported(s) =>
// Columnar shuffle for regular Spark operators (not Comet) and Comet
operators
// (if configured)
CometShuffleExchangeExec(s, shuffleType = CometColumnarShuffle)
}
}
- private def isCometPlan(op: SparkPlan): Boolean = op.isInstanceOf[CometPlan]
-
private def isCometNative(op: SparkPlan): Boolean =
op.isInstanceOf[CometNativeExec]
// spotless:off
@@ -249,9 +244,7 @@ case class CometExecRule(session: SparkSession) extends
Rule[SparkPlan] {
convertToCometIfAllChildrenAreNative(s, CometExchangeSink).getOrElse(s)
case s: ShuffleExchangeExec =>
- // try native shuffle first, then columnar shuffle, then fall back to
Spark
- // if neither are supported
- tryNativeShuffle(s).orElse(tryColumnarShuffle(s)).getOrElse(s)
+ convertToComet(s, CometShuffleExchangeExec).getOrElse(s)
case op =>
val handler = allExecs
@@ -288,39 +281,6 @@ case class CometExecRule(session: SparkSession) extends
Rule[SparkPlan] {
}
}
- private def tryNativeShuffle(s: ShuffleExchangeExec): Option[SparkPlan] = {
- Some(s)
- .filter(nativeShuffleSupported)
- .filter(_.children.forall(_.isInstanceOf[CometNativeExec]))
- .flatMap(_ => operator2Proto(s))
- .map { nativeOp =>
- // Switch to use Decimal128 regardless of precision, since Arrow
native execution
- // doesn't support Decimal32 and Decimal64 yet.
- conf.setConfString(CometConf.COMET_USE_DECIMAL_128.key, "true")
- val cometOp = CometShuffleExchangeExec(s, shuffleType =
CometNativeShuffle)
- CometSinkPlaceHolder(nativeOp, s, cometOp)
- }
- }
-
- private def tryColumnarShuffle(s: ShuffleExchangeExec): Option[SparkPlan] = {
- // Columnar shuffle for regular Spark operators (not Comet) and Comet
operators
- // (if configured).
- // If the child of ShuffleExchangeExec is also a ShuffleExchangeExec, we
should not
- // convert it to CometColumnarShuffle,
- Some(s)
- .filter(columnarShuffleSupported)
- .flatMap(_ => operator2Proto(s))
- .flatMap { nativeOp =>
- s.child match {
- case n if n.isInstanceOf[CometNativeExec] || !n.supportsColumnar =>
- val cometOp = CometShuffleExchangeExec(s, shuffleType =
CometColumnarShuffle)
- Some(CometSinkPlaceHolder(nativeOp, s, cometOp))
- case _ =>
- None
- }
- }
- }
-
private def normalizePlan(plan: SparkPlan): SparkPlan = {
plan.transformUp {
case p: ProjectExec =>
@@ -497,269 +457,6 @@ case class CometExecRule(session: SparkSession) extends
Rule[SparkPlan] {
}
}
- /**
- * Returns true if a given spark plan is Comet shuffle operator.
- */
- private def isShuffleOperator(op: SparkPlan): Boolean = {
- op match {
- case op: ShuffleQueryStageExec if
op.plan.isInstanceOf[CometShuffleExchangeExec] => true
- case _: CometShuffleExchangeExec => true
- case op: CometSinkPlaceHolder => isShuffleOperator(op.child)
- case _ => false
- }
- }
-
- def isCometShuffleEnabledWithInfo(op: SparkPlan): Boolean = {
- if (!COMET_EXEC_SHUFFLE_ENABLED.get(op.conf)) {
- withInfo(
- op,
- s"Comet shuffle is not enabled: ${COMET_EXEC_SHUFFLE_ENABLED.key} is
not enabled")
- false
- } else if (!isCometShuffleManagerEnabled(op.conf)) {
- withInfo(op, s"spark.shuffle.manager is not set to
${classOf[CometShuffleManager].getName}")
- false
- } else {
- true
- }
- }
-
- /**
- * Whether the given Spark partitioning is supported by Comet native shuffle.
- */
- private def nativeShuffleSupported(s: ShuffleExchangeExec): Boolean = {
-
- /**
- * Determine which data types are supported as partition columns in native
shuffle.
- *
- * For HashPartitioning this defines the key that determines how data
should be collocated for
- * operations like `groupByKey`, `reduceByKey`, or `join`. Native code
does not support
- * hashing complex types, see hash_funcs/utils.rs
- */
- def supportedHashPartitioningDataType(dt: DataType): Boolean = dt match {
- case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _:
LongType |
- _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _:
TimestampType |
- _: TimestampNTZType | _: DecimalType | _: DateType =>
- true
- case _ =>
- false
- }
-
- /**
- * Determine which data types are supported as partition columns in native
shuffle.
- *
- * For RangePartitioning this defines the key that determines how data
should be collocated
- * for operations like `orderBy`, `repartitionByRange`. Native code does
not support sorting
- * complex types.
- */
- def supportedRangePartitioningDataType(dt: DataType): Boolean = dt match {
- case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _:
LongType |
- _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _:
TimestampType |
- _: TimestampNTZType | _: DecimalType | _: DateType =>
- true
- case _ =>
- false
- }
-
- /**
- * Determine which data types are supported as data columns in native
shuffle.
- *
- * Native shuffle relies on the Arrow IPC writer to serialize batches to
disk, so it should
- * support all types that Comet supports.
- */
- def supportedSerializableDataType(dt: DataType): Boolean = dt match {
- case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _:
LongType |
- _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _:
TimestampType |
- _: TimestampNTZType | _: DecimalType | _: DateType =>
- true
- case StructType(fields) =>
- fields.nonEmpty && fields.forall(f =>
supportedSerializableDataType(f.dataType))
- case ArrayType(elementType, _) =>
- supportedSerializableDataType(elementType)
- case MapType(keyType, valueType, _) =>
- supportedSerializableDataType(keyType) &&
supportedSerializableDataType(valueType)
- case _ =>
- false
- }
-
- if (!isCometShuffleEnabledWithInfo(s)) {
- return false
- }
-
- if (!isCometNativeShuffleMode(s.conf)) {
- withInfo(s, "Comet native shuffle not enabled")
- return false
- }
-
- if (!isCometPlan(s.child)) {
- // we do not need to report a fallback reason if the child plan is not a
Comet plan
- return false
- }
-
- val inputs = s.child.output
-
- for (input <- inputs) {
- if (!supportedSerializableDataType(input.dataType)) {
- withInfo(s, s"unsupported shuffle data type ${input.dataType} for
input $input")
- return false
- }
- }
-
- val partitioning = s.outputPartitioning
- val conf = SQLConf.get
- partitioning match {
- case HashPartitioning(expressions, _) =>
- var supported = true
- if
(!CometConf.COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED.get(conf)) {
- withInfo(
- s,
-
s"${CometConf.COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED.key} is
disabled")
- supported = false
- }
- for (expr <- expressions) {
- if (QueryPlanSerde.exprToProto(expr, inputs).isEmpty) {
- withInfo(s, s"unsupported hash partitioning expression: $expr")
- supported = false
- // We don't short-circuit in case there is more than one
unsupported expression
- // to provide info for.
- }
- }
- for (dt <- expressions.map(_.dataType).distinct) {
- if (!supportedHashPartitioningDataType(dt)) {
- withInfo(s, s"unsupported hash partitioning data type for native
shuffle: $dt")
- supported = false
- }
- }
- supported
- case SinglePartition =>
- // we already checked that the input types are supported
- true
- case RangePartitioning(orderings, _) =>
- if
(!CometConf.COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED.get(conf)) {
- withInfo(
- s,
-
s"${CometConf.COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED.key} is
disabled")
- return false
- }
- var supported = true
- for (o <- orderings) {
- if (QueryPlanSerde.exprToProto(o, inputs).isEmpty) {
- withInfo(s, s"unsupported range partitioning sort order: $o", o)
- supported = false
- // We don't short-circuit in case there is more than one
unsupported expression
- // to provide info for.
- }
- }
- for (dt <- orderings.map(_.dataType).distinct) {
- if (!supportedRangePartitioningDataType(dt)) {
- withInfo(s, s"unsupported range partitioning data type for native
shuffle: $dt")
- supported = false
- }
- }
- supported
- case _ =>
- withInfo(
- s,
- s"unsupported Spark partitioning for native shuffle:
${partitioning.getClass.getName}")
- false
- }
- }
-
- /**
- * Check if the datatypes of shuffle input are supported. This is used for
Columnar shuffle
- * which supports struct/array.
- */
- private def columnarShuffleSupported(s: ShuffleExchangeExec): Boolean = {
-
- /**
- * Determine which data types are supported as data columns in columnar
shuffle.
- *
- * Comet columnar shuffle used native code to convert Spark unsafe rows to
Arrow batches, see
- * shuffle/row.rs
- */
- def supportedSerializableDataType(dt: DataType): Boolean = dt match {
- case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _:
LongType |
- _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _:
TimestampType |
- _: TimestampNTZType | _: DecimalType | _: DateType =>
- true
- case StructType(fields) =>
- fields.nonEmpty && fields.forall(f =>
supportedSerializableDataType(f.dataType)) &&
- // Java Arrow stream reader cannot work on duplicate field name
- fields.map(f => f.name).distinct.length == fields.length &&
- fields.nonEmpty
- case ArrayType(elementType, _) =>
- supportedSerializableDataType(elementType)
- case MapType(keyType, valueType, _) =>
- supportedSerializableDataType(keyType) &&
supportedSerializableDataType(valueType)
- case _ =>
- false
- }
-
- if (!isCometShuffleEnabledWithInfo(s)) {
- return false
- }
-
- if (!isCometJVMShuffleMode(s.conf)) {
- withInfo(s, "Comet columnar shuffle not enabled")
- return false
- }
-
- if (isShuffleOperator(s.child)) {
- withInfo(s, s"Child ${s.child.getClass.getName} is a shuffle operator")
- return false
- }
-
- if (!(!s.child.supportsColumnar || isCometPlan(s.child))) {
- withInfo(s, s"Child ${s.child.getClass.getName} is a neither row-based
or a Comet operator")
- return false
- }
-
- val inputs = s.child.output
-
- for (input <- inputs) {
- if (!supportedSerializableDataType(input.dataType)) {
- withInfo(s, s"unsupported shuffle data type ${input.dataType} for
input $input")
- return false
- }
- }
-
- val partitioning = s.outputPartitioning
- partitioning match {
- case HashPartitioning(expressions, _) =>
- var supported = true
- for (expr <- expressions) {
- if (QueryPlanSerde.exprToProto(expr, inputs).isEmpty) {
- withInfo(s, s"unsupported hash partitioning expression: $expr")
- supported = false
- // We don't short-circuit in case there is more than one
unsupported expression
- // to provide info for.
- }
- }
- supported
- case SinglePartition =>
- // we already checked that the input types are supported
- true
- case RoundRobinPartitioning(_) =>
- // we already checked that the input types are supported
- true
- case RangePartitioning(orderings, _) =>
- var supported = true
- for (o <- orderings) {
- if (QueryPlanSerde.exprToProto(o, inputs).isEmpty) {
- withInfo(s, s"unsupported range partitioning sort order: $o")
- supported = false
- // We don't short-circuit in case there is more than one
unsupported expression
- // to provide info for.
- }
- }
- supported
- case _ =>
- withInfo(
- s,
- s"unsupported Spark partitioning for columnar shuffle:
${partitioning.getClass.getName}")
- false
- }
- }
-
/**
* Fallback for handling sinks that have not been handled explicitly. This
method should
* eventually be removed once CometExecRule fully uses the operator serde
framework.
diff --git
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
index 1f7d37a10..2e6ab9aff 100644
---
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
+++
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
@@ -29,16 +29,18 @@ import org.apache.spark.internal.config
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.sort.SortShuffleManager
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper}
import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference,
UnsafeProjection, UnsafeRow}
import
org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.comet.{CometMetricNode, CometPlan}
+import org.apache.spark.sql.comet.{CometMetricNode, CometNativeExec,
CometPlan, CometSinkPlaceHolder}
import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS,
ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics,
SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter}
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType,
ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType,
LongType, MapType, ShortType, StringType, StructType, TimestampNTZType,
TimestampType}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.MutablePair
import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators,
RecordComparator}
@@ -47,6 +49,10 @@ import org.apache.spark.util.random.XORShiftRandom
import com.google.common.base.Objects
import org.apache.comet.CometConf
+import org.apache.comet.CometConf.{COMET_EXEC_SHUFFLE_ENABLED,
COMET_SHUFFLE_MODE}
+import
org.apache.comet.CometSparkSessionExtensions.{isCometShuffleManagerEnabled,
withInfo}
+import org.apache.comet.serde.{Compatible, OperatorOuterClass, QueryPlanSerde,
SupportLevel, Unsupported}
+import org.apache.comet.serde.operator.CometSink
import org.apache.comet.shims.ShimCometShuffleExchangeExec
/**
@@ -210,7 +216,321 @@ case class CometShuffleExchangeExec(
Iterator(outputPartitioning, shuffleOrigin, shuffleType, child) ++
Iterator(s"[plan_id=$id]")
}
-object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec {
+object CometShuffleExchangeExec
+ extends CometSink[ShuffleExchangeExec]
+ with ShimCometShuffleExchangeExec
+ with SQLConfHelper {
+
+ override def getSupportLevel(op: ShuffleExchangeExec): SupportLevel = {
+ if (nativeShuffleSupported(op) || columnarShuffleSupported(op)) {
+ Compatible()
+ } else {
+ Unsupported()
+ }
+ }
+
+ override def createExec(
+ nativeOp: OperatorOuterClass.Operator,
+ op: ShuffleExchangeExec): CometNativeExec = {
+ if (nativeShuffleSupported(op) &&
op.children.forall(_.isInstanceOf[CometNativeExec])) {
+ // Switch to use Decimal128 regardless of precision, since Arrow native
execution
+ // doesn't support Decimal32 and Decimal64 yet.
+ conf.setConfString(CometConf.COMET_USE_DECIMAL_128.key, "true")
+ CometSinkPlaceHolder(
+ nativeOp,
+ op,
+ CometShuffleExchangeExec(op, shuffleType = CometNativeShuffle))
+
+ } else if (columnarShuffleSupported(op)) {
+ CometSinkPlaceHolder(
+ nativeOp,
+ op,
+ CometShuffleExchangeExec(op, shuffleType = CometColumnarShuffle))
+ } else {
+ throw new IllegalStateException()
+ }
+ }
+
+ /**
+ * Whether the given Spark partitioning is supported by Comet native shuffle.
+ */
+ def nativeShuffleSupported(s: ShuffleExchangeExec): Boolean = {
+
+ /**
+ * Determine which data types are supported as partition columns in native
shuffle.
+ *
+ * For HashPartitioning this defines the key that determines how data
should be collocated for
+ * operations like `groupByKey`, `reduceByKey`, or `join`. Native code
does not support
+ * hashing complex types, see hash_funcs/utils.rs
+ */
+ def supportedHashPartitioningDataType(dt: DataType): Boolean = dt match {
+ case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _:
LongType |
+ _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _:
TimestampType |
+ _: TimestampNTZType | _: DecimalType | _: DateType =>
+ true
+ case _ =>
+ false
+ }
+
+ /**
+ * Determine which data types are supported as partition columns in native
shuffle.
+ *
+ * For RangePartitioning this defines the key that determines how data
should be collocated
+ * for operations like `orderBy`, `repartitionByRange`. Native code does
not support sorting
+ * complex types.
+ */
+ def supportedRangePartitioningDataType(dt: DataType): Boolean = dt match {
+ case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _:
LongType |
+ _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _:
TimestampType |
+ _: TimestampNTZType | _: DecimalType | _: DateType =>
+ true
+ case _ =>
+ false
+ }
+
+ /**
+ * Determine which data types are supported as data columns in native
shuffle.
+ *
+ * Native shuffle relies on the Arrow IPC writer to serialize batches to
disk, so it should
+ * support all types that Comet supports.
+ */
+ def supportedSerializableDataType(dt: DataType): Boolean = dt match {
+ case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _:
LongType |
+ _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _:
TimestampType |
+ _: TimestampNTZType | _: DecimalType | _: DateType =>
+ true
+ case StructType(fields) =>
+ fields.nonEmpty && fields.forall(f =>
supportedSerializableDataType(f.dataType))
+ case ArrayType(elementType, _) =>
+ supportedSerializableDataType(elementType)
+ case MapType(keyType, valueType, _) =>
+ supportedSerializableDataType(keyType) &&
supportedSerializableDataType(valueType)
+ case _ =>
+ false
+ }
+
+ if (!isCometShuffleEnabledWithInfo(s)) {
+ return false
+ }
+
+ if (!isCometNativeShuffleMode(s.conf)) {
+ withInfo(s, "Comet native shuffle not enabled")
+ return false
+ }
+
+ if (!isCometPlan(s.child)) {
+ // we do not need to report a fallback reason if the child plan is not a
Comet plan
+ return false
+ }
+
+ val inputs = s.child.output
+
+ for (input <- inputs) {
+ if (!supportedSerializableDataType(input.dataType)) {
+ withInfo(s, s"unsupported shuffle data type ${input.dataType} for
input $input")
+ return false
+ }
+ }
+
+ val partitioning = s.outputPartitioning
+ val conf = SQLConf.get
+ partitioning match {
+ case HashPartitioning(expressions, _) =>
+ var supported = true
+ if
(!CometConf.COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED.get(conf)) {
+ withInfo(
+ s,
+
s"${CometConf.COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED.key} is
disabled")
+ supported = false
+ }
+ for (expr <- expressions) {
+ if (QueryPlanSerde.exprToProto(expr, inputs).isEmpty) {
+ withInfo(s, s"unsupported hash partitioning expression: $expr")
+ supported = false
+ // We don't short-circuit in case there is more than one
unsupported expression
+ // to provide info for.
+ }
+ }
+ for (dt <- expressions.map(_.dataType).distinct) {
+ if (!supportedHashPartitioningDataType(dt)) {
+ withInfo(s, s"unsupported hash partitioning data type for native
shuffle: $dt")
+ supported = false
+ }
+ }
+ supported
+ case SinglePartition =>
+ // we already checked that the input types are supported
+ true
+ case RangePartitioning(orderings, _) =>
+ if
(!CometConf.COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED.get(conf)) {
+ withInfo(
+ s,
+
s"${CometConf.COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED.key} is
disabled")
+ return false
+ }
+ var supported = true
+ for (o <- orderings) {
+ if (QueryPlanSerde.exprToProto(o, inputs).isEmpty) {
+ withInfo(s, s"unsupported range partitioning sort order: $o", o)
+ supported = false
+ // We don't short-circuit in case there is more than one
unsupported expression
+ // to provide info for.
+ }
+ }
+ for (dt <- orderings.map(_.dataType).distinct) {
+ if (!supportedRangePartitioningDataType(dt)) {
+ withInfo(s, s"unsupported range partitioning data type for native
shuffle: $dt")
+ supported = false
+ }
+ }
+ supported
+ case _ =>
+ withInfo(
+ s,
+ s"unsupported Spark partitioning for native shuffle:
${partitioning.getClass.getName}")
+ false
+ }
+ }
+
+ /**
+ * Check if the datatypes of shuffle input are supported. This is used for
Columnar shuffle
+ * which supports struct/array.
+ */
+ def columnarShuffleSupported(s: ShuffleExchangeExec): Boolean = {
+
+ /**
+ * Determine which data types are supported as data columns in columnar
shuffle.
+ *
+ * Comet columnar shuffle used native code to convert Spark unsafe rows to
Arrow batches, see
+ * shuffle/row.rs
+ */
+ def supportedSerializableDataType(dt: DataType): Boolean = dt match {
+ case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _:
LongType |
+ _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _:
TimestampType |
+ _: TimestampNTZType | _: DecimalType | _: DateType =>
+ true
+ case StructType(fields) =>
+ fields.nonEmpty && fields.forall(f =>
supportedSerializableDataType(f.dataType)) &&
+ // Java Arrow stream reader cannot work on duplicate field name
+ fields.map(f => f.name).distinct.length == fields.length &&
+ fields.nonEmpty
+ case ArrayType(elementType, _) =>
+ supportedSerializableDataType(elementType)
+ case MapType(keyType, valueType, _) =>
+ supportedSerializableDataType(keyType) &&
supportedSerializableDataType(valueType)
+ case _ =>
+ false
+ }
+
+ if (!isCometShuffleEnabledWithInfo(s)) {
+ return false
+ }
+
+ if (!isCometJVMShuffleMode(s.conf)) {
+ withInfo(s, "Comet columnar shuffle not enabled")
+ return false
+ }
+
+ if (isShuffleOperator(s.child)) {
+ withInfo(s, s"Child ${s.child.getClass.getName} is a shuffle operator")
+ return false
+ }
+
+ if (!(!s.child.supportsColumnar || isCometPlan(s.child))) {
+ withInfo(s, s"Child ${s.child.getClass.getName} is a neither row-based
or a Comet operator")
+ return false
+ }
+
+ val inputs = s.child.output
+
+ for (input <- inputs) {
+ if (!supportedSerializableDataType(input.dataType)) {
+ withInfo(s, s"unsupported shuffle data type ${input.dataType} for
input $input")
+ return false
+ }
+ }
+
+ val partitioning = s.outputPartitioning
+ partitioning match {
+ case HashPartitioning(expressions, _) =>
+ var supported = true
+ for (expr <- expressions) {
+ if (QueryPlanSerde.exprToProto(expr, inputs).isEmpty) {
+ withInfo(s, s"unsupported hash partitioning expression: $expr")
+ supported = false
+ // We don't short-circuit in case there is more than one
unsupported expression
+ // to provide info for.
+ }
+ }
+ supported
+ case SinglePartition =>
+ // we already checked that the input types are supported
+ true
+ case RoundRobinPartitioning(_) =>
+ // we already checked that the input types are supported
+ true
+ case RangePartitioning(orderings, _) =>
+ var supported = true
+ for (o <- orderings) {
+ if (QueryPlanSerde.exprToProto(o, inputs).isEmpty) {
+ withInfo(s, s"unsupported range partitioning sort order: $o")
+ supported = false
+ // We don't short-circuit in case there is more than one
unsupported expression
+ // to provide info for.
+ }
+ }
+ supported
+ case _ =>
+ withInfo(
+ s,
+ s"unsupported Spark partitioning for columnar shuffle:
${partitioning.getClass.getName}")
+ false
+ }
+ }
+
+ private def isCometNativeShuffleMode(conf: SQLConf): Boolean = {
+ COMET_SHUFFLE_MODE.get(conf) match {
+ case "native" => true
+ case "auto" => true
+ case _ => false
+ }
+ }
+
+ private def isCometJVMShuffleMode(conf: SQLConf): Boolean = {
+ COMET_SHUFFLE_MODE.get(conf) match {
+ case "jvm" => true
+ case "auto" => true
+ case _ => false
+ }
+ }
+
+ private def isCometPlan(op: SparkPlan): Boolean = op.isInstanceOf[CometPlan]
+
+ /**
+ * Returns true if a given spark plan is Comet shuffle operator.
+ */
+ private def isShuffleOperator(op: SparkPlan): Boolean = {
+ op match {
+ case op: ShuffleQueryStageExec if
op.plan.isInstanceOf[CometShuffleExchangeExec] => true
+ case _: CometShuffleExchangeExec => true
+ case op: CometSinkPlaceHolder => isShuffleOperator(op.child)
+ case _ => false
+ }
+ }
+
+ def isCometShuffleEnabledWithInfo(op: SparkPlan): Boolean = {
+ if (!COMET_EXEC_SHUFFLE_ENABLED.get(op.conf)) {
+ withInfo(
+ op,
+ s"Comet shuffle is not enabled: ${COMET_EXEC_SHUFFLE_ENABLED.key} is
not enabled")
+ false
+ } else if (!isCometShuffleManagerEnabled(op.conf)) {
+ withInfo(op, s"spark.shuffle.manager is not set to
${classOf[CometShuffleManager].getName}")
+ false
+ } else {
+ true
+ }
+ }
def prepareShuffleDependency(
rdd: RDD[ColumnarBatch],
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]