This is an automated email from the ASF dual-hosted git repository. mbutrovich pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push: new d0ca41471 minor: Refactor to move some shuffle-related logic from `QueryPlanSerde` to `CometExecRule` (#2015) d0ca41471 is described below commit d0ca41471bfda53cf1856b6e3fbd5101b424037e Author: Andy Grove <agr...@apache.org> AuthorDate: Fri Jul 11 13:30:06 2025 -0600 minor: Refactor to move some shuffle-related logic from `QueryPlanSerde` to `CometExecRule` (#2015) --- .../org/apache/comet/rules/CometExecRule.scala | 172 ++++++++++++++++++--- .../org/apache/comet/serde/QueryPlanSerde.scala | 135 ---------------- 2 files changed, 152 insertions(+), 155 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala index 2383dd844..a1464a2e0 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Divide, DoubleLiteral, EqualNullSafe, EqualTo, Expression, FloatLiteral, GreaterThan, GreaterThanOrEqual, KnownFloatingPointNormalized, LessThan, LessThanOrEqual, NamedExpression, Remainder} import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial} import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec} @@ -34,13 +35,15 @@ import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregat import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.window.WindowExec -import org.apache.spark.sql.types.{DoubleType, FloatType} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ShortType, StringType, StructType, TimestampNTZType, TimestampType} import org.apache.comet.{CometConf, ExtendedExplainInfo} import org.apache.comet.CometConf.COMET_ANSI_MODE_ENABLED import org.apache.comet.CometSparkSessionExtensions._ import org.apache.comet.serde.OperatorOuterClass.Operator import org.apache.comet.serde.QueryPlanSerde +import org.apache.comet.serde.QueryPlanSerde.emitWarning /** * Spark physical optimizer rule for replacing Spark operators with Comet operators. @@ -53,7 +56,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { plan.transformUp { case s: ShuffleExchangeExec if isCometPlan(s.child) && isCometNativeShuffleMode(conf) && - QueryPlanSerde.nativeShuffleSupported(s)._1 => + nativeShuffleSupported(s)._1 => logInfo("Comet extension enabled for Native Shuffle") // Switch to use Decimal128 regardless of precision, since Arrow native execution @@ -65,7 +68,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { // (if configured) case s: ShuffleExchangeExec if (!s.child.supportsColumnar || isCometPlan(s.child)) && isCometJVMShuffleMode(conf) && - QueryPlanSerde.columnarShuffleSupported(s)._1 && + columnarShuffleSupported(s)._1 && !isShuffleOperator(s.child) => logInfo("Comet extension enabled for JVM Columnar Shuffle") CometShuffleExchangeExec(s, shuffleType = CometColumnarShuffle) @@ -490,7 +493,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { case s: ShuffleExchangeExec => val nativePrecondition = isCometShuffleEnabled(conf) && isCometNativeShuffleMode(conf) && - QueryPlanSerde.nativeShuffleSupported(s)._1 + nativeShuffleSupported(s)._1 val nativeShuffle: Option[SparkPlan] = if (nativePrecondition) { @@ -517,7 +520,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { // If the child of ShuffleExchangeExec is also a ShuffleExchangeExec, we should not // convert it to CometColumnarShuffle, if (isCometShuffleEnabled(conf) && isCometJVMShuffleMode(conf) && - QueryPlanSerde.columnarShuffleSupported(s)._1 && + columnarShuffleSupported(s)._1 && !isShuffleOperator(s.child)) { val newOp = QueryPlanSerde.operator2Proto(s) @@ -547,18 +550,12 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { val msg1 = createMessage(!isShuffleEnabled, s"Comet shuffle is not enabled: $reason") val columnarShuffleEnabled = isCometJVMShuffleMode(conf) val msg2 = createMessage( - isShuffleEnabled && !columnarShuffleEnabled && !QueryPlanSerde - .nativeShuffleSupported(s) - ._1, + isShuffleEnabled && !columnarShuffleEnabled && !nativeShuffleSupported(s)._1, "Native shuffle: " + - s"${QueryPlanSerde.nativeShuffleSupported(s)._2}") - val typeInfo = QueryPlanSerde - .columnarShuffleSupported(s) - ._2 + s"${nativeShuffleSupported(s)._2}") + val typeInfo = columnarShuffleSupported(s)._2 val msg3 = createMessage( - isShuffleEnabled && columnarShuffleEnabled && !QueryPlanSerde - .columnarShuffleSupported(s) - ._1, + isShuffleEnabled && columnarShuffleEnabled && !columnarShuffleSupported(s)._1, "JVM shuffle: " + s"$typeInfo") withInfo(s, Seq(msg1, msg2, msg3).flatten.mkString(",")) @@ -578,7 +575,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { } } - def normalizePlan(plan: SparkPlan): SparkPlan = { + private def normalizePlan(plan: SparkPlan): SparkPlan = { plan.transformUp { case p: ProjectExec => val newProjectList = p.projectList.map(normalize(_).asInstanceOf[NamedExpression]) @@ -595,7 +592,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { // because they are handled well in Spark (e.g., `SQLOrderingUtil.compareFloats`). But the // comparison functions in arrow-rs do not normalize NaN and zero. So we need to normalize NaN // and zero for comparison operators in Comet. - def normalize(expr: Expression): Expression = { + private def normalize(expr: Expression): Expression = { expr.transformUp { case EqualTo(left, right) => EqualTo(normalizeNaNAndZero(left), normalizeNaNAndZero(right)) @@ -616,7 +613,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { } } - def normalizeNaNAndZero(expr: Expression): Expression = { + private def normalizeNaNAndZero(expr: Expression): Expression = { expr match { case _: KnownFloatingPointNormalized => expr case FloatLiteral(f) if !f.equals(-0.0f) => expr @@ -755,7 +752,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { * Find the first Comet partial aggregate in the plan. If it reaches a Spark HashAggregate with * partial mode, it will return None. */ - def findCometPartialAgg(plan: SparkPlan): Option[CometHashAggregateExec] = { + private def findCometPartialAgg(plan: SparkPlan): Option[CometHashAggregateExec] = { plan.collectFirst { case agg: CometHashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial) => Some(agg) @@ -770,7 +767,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { /** * Returns true if a given spark plan is Comet shuffle operator. */ - def isShuffleOperator(op: SparkPlan): Boolean = { + private def isShuffleOperator(op: SparkPlan): Boolean = { op match { case op: ShuffleQueryStageExec if op.plan.isInstanceOf[CometShuffleExchangeExec] => true case _: CometShuffleExchangeExec => true @@ -778,4 +775,139 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { case _ => false } } + + /** + * Whether the given Spark partitioning is supported by Comet native shuffle. + */ + private def nativeShuffleSupported(s: ShuffleExchangeExec): (Boolean, String) = { + + /** + * Determine which data types are supported as hash-partition keys in native shuffle. + * + * Hash Partition Key determines how data should be collocated for operations like + * `groupByKey`, `reduceByKey` or `join`. + */ + def supportedHashPartitionKeyDataType(dt: DataType): Boolean = dt match { + case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: LongType | + _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType | + _: TimestampNTZType | _: DecimalType | _: DateType => + true + case _ => + false + } + + val inputs = s.child.output + val partitioning = s.outputPartitioning + val conf = SQLConf.get + var msg = "" + val supported = partitioning match { + case HashPartitioning(expressions, _) => + // native shuffle currently does not support complex types as partition keys + // due to lack of hashing support for those types + val supported = + expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) && + expressions.forall(e => supportedHashPartitionKeyDataType(e.dataType)) && + inputs.forall(attr => supportedShuffleDataType(attr.dataType)) && + CometConf.COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED.get(conf) + if (!supported) { + msg = s"unsupported Spark partitioning: $expressions" + } + supported + case SinglePartition => + inputs.forall(attr => supportedShuffleDataType(attr.dataType)) + case RangePartitioning(ordering, _) => + val supported = ordering.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) && + inputs.forall(attr => supportedShuffleDataType(attr.dataType)) && + CometConf.COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED.get(conf) + if (!supported) { + msg = s"unsupported Spark partitioning: $ordering" + } + supported + case _ => + msg = s"unsupported Spark partitioning: ${partitioning.getClass.getName}" + false + } + + if (!supported) { + emitWarning(msg) + (false, msg) + } else { + (true, null) + } + } + + /** + * Check if the datatypes of shuffle input are supported. This is used for Columnar shuffle + * which supports struct/array. + */ + private def columnarShuffleSupported(s: ShuffleExchangeExec): (Boolean, String) = { + val inputs = s.child.output + val partitioning = s.outputPartitioning + var msg = "" + val supported = partitioning match { + case HashPartitioning(expressions, _) => + // columnar shuffle supports the same data types (including complex types) both for + // partition keys and for other columns + val supported = + expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) && + expressions.forall(e => supportedShuffleDataType(e.dataType)) && + inputs.forall(attr => supportedShuffleDataType(attr.dataType)) + if (!supported) { + msg = s"unsupported Spark partitioning expressions: $expressions" + } + supported + case SinglePartition => + inputs.forall(attr => supportedShuffleDataType(attr.dataType)) + case RoundRobinPartitioning(_) => + inputs.forall(attr => supportedShuffleDataType(attr.dataType)) + case RangePartitioning(orderings, _) => + val supported = + orderings.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) && + orderings.forall(e => supportedShuffleDataType(e.dataType)) && + inputs.forall(attr => supportedShuffleDataType(attr.dataType)) + if (!supported) { + msg = s"unsupported Spark partitioning expressions: $orderings" + } + supported + case _ => + msg = s"unsupported Spark partitioning: ${partitioning.getClass.getName}" + false + } + + if (!supported) { + emitWarning(msg) + (false, msg) + } else { + (true, null) + } + } + + /** + * Determine which data types are supported in a shuffle. + */ + private def supportedShuffleDataType(dt: DataType): Boolean = dt match { + case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: LongType | + _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType | + _: TimestampNTZType | _: DecimalType | _: DateType => + true + case StructType(fields) => + fields.forall(f => supportedShuffleDataType(f.dataType)) && + // Java Arrow stream reader cannot work on duplicate field name + fields.map(f => f.name).distinct.length == fields.length + case ArrayType(ArrayType(_, _), _) => false // TODO: nested array is not supported + case ArrayType(MapType(_, _, _), _) => false // TODO: map array element is not supported + case ArrayType(elementType, _) => + supportedShuffleDataType(elementType) + case MapType(MapType(_, _, _), _, _) => false // TODO: nested map is not supported + case MapType(_, MapType(_, _, _), _) => false + case MapType(StructType(_), _, _) => false // TODO: struct map key/value is not supported + case MapType(_, StructType(_), _) => false + case MapType(ArrayType(_, _), _, _) => false // TODO: array map key/value is not supported + case MapType(_, ArrayType(_, _), _) => false + case MapType(keyType, valueType, _) => + supportedShuffleDataType(keyType) && supportedShuffleDataType(valueType) + case _ => + false + } + } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 077faeb41..970329b28 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, NormalizeNaNAndZero} import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getExistenceDefaultValues import org.apache.spark.sql.comet._ @@ -2725,140 +2724,6 @@ object QueryPlanSerde extends Logging with CometExprShim { } } - /** - * Check if the datatypes of shuffle input are supported. This is used for Columnar shuffle - * which supports struct/array. - */ - def columnarShuffleSupported(s: ShuffleExchangeExec): (Boolean, String) = { - val inputs = s.child.output - val partitioning = s.outputPartitioning - var msg = "" - val supported = partitioning match { - case HashPartitioning(expressions, _) => - // columnar shuffle supports the same data types (including complex types) both for - // partition keys and for other columns - val supported = - expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) && - expressions.forall(e => supportedShuffleDataType(e.dataType)) && - inputs.forall(attr => supportedShuffleDataType(attr.dataType)) - if (!supported) { - msg = s"unsupported Spark partitioning expressions: $expressions" - } - supported - case SinglePartition => - inputs.forall(attr => supportedShuffleDataType(attr.dataType)) - case RoundRobinPartitioning(_) => - inputs.forall(attr => supportedShuffleDataType(attr.dataType)) - case RangePartitioning(orderings, _) => - val supported = - orderings.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) && - orderings.forall(e => supportedShuffleDataType(e.dataType)) && - inputs.forall(attr => supportedShuffleDataType(attr.dataType)) - if (!supported) { - msg = s"unsupported Spark partitioning expressions: $orderings" - } - supported - case _ => - msg = s"unsupported Spark partitioning: ${partitioning.getClass.getName}" - false - } - - if (!supported) { - emitWarning(msg) - (false, msg) - } else { - (true, null) - } - } - - /** - * Whether the given Spark partitioning is supported by Comet native shuffle. - */ - def nativeShuffleSupported(s: ShuffleExchangeExec): (Boolean, String) = { - - /** - * Determine which data types are supported as hash-partition keys in native shuffle. - * - * Hash Partition Key determines how data should be collocated for operations like - * `groupByKey`, `reduceByKey` or `join`. - */ - def supportedHashPartitionKeyDataType(dt: DataType): Boolean = dt match { - case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: LongType | - _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType | - _: TimestampNTZType | _: DecimalType | _: DateType => - true - case _ => - false - } - - val inputs = s.child.output - val partitioning = s.outputPartitioning - val conf = SQLConf.get - var msg = "" - val supported = partitioning match { - case HashPartitioning(expressions, _) => - // native shuffle currently does not support complex types as partition keys - // due to lack of hashing support for those types - val supported = - expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) && - expressions.forall(e => supportedHashPartitionKeyDataType(e.dataType)) && - inputs.forall(attr => supportedShuffleDataType(attr.dataType)) && - CometConf.COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED.get(conf) - if (!supported) { - msg = s"unsupported Spark partitioning: $expressions" - } - supported - case SinglePartition => - inputs.forall(attr => supportedShuffleDataType(attr.dataType)) - case RangePartitioning(ordering, _) => - val supported = ordering.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) && - inputs.forall(attr => supportedShuffleDataType(attr.dataType)) && - CometConf.COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED.get(conf) - if (!supported) { - msg = s"unsupported Spark partitioning: $ordering" - } - supported - case _ => - msg = s"unsupported Spark partitioning: ${partitioning.getClass.getName}" - false - } - - if (!supported) { - emitWarning(msg) - (false, msg) - } else { - (true, null) - } - } - - /** - * Determine which data types are supported in a shuffle. - */ - def supportedShuffleDataType(dt: DataType): Boolean = dt match { - case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: LongType | - _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType | - _: TimestampNTZType | _: DecimalType | _: DateType => - true - case StructType(fields) => - fields.forall(f => supportedShuffleDataType(f.dataType)) && - // Java Arrow stream reader cannot work on duplicate field name - fields.map(f => f.name).distinct.length == fields.length - case ArrayType(ArrayType(_, _), _) => false // TODO: nested array is not supported - case ArrayType(MapType(_, _, _), _) => false // TODO: map array element is not supported - case ArrayType(elementType, _) => - supportedShuffleDataType(elementType) - case MapType(MapType(_, _, _), _, _) => false // TODO: nested map is not supported - case MapType(_, MapType(_, _, _), _) => false - case MapType(StructType(_), _, _) => false // TODO: struct map key/value is not supported - case MapType(_, StructType(_), _) => false - case MapType(ArrayType(_, _), _, _) => false // TODO: array map key/value is not supported - case MapType(_, ArrayType(_, _), _) => false - case MapType(keyType, valueType, _) => - supportedShuffleDataType(keyType) && supportedShuffleDataType(valueType) - case _ => - false - } - // Utility method. Adds explain info if the result of calling exprToProto is None def optExprWithInfo( optExpr: Option[Expr], --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org For additional commands, e-mail: commits-h...@datafusion.apache.org