This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 9d3d7b8 fix: Remove redundant data copy in columnar shuffle (#233)
9d3d7b8 is described below
commit 9d3d7b8fbfcad8c857463fe8d81bd98b2f45d067
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Wed Mar 27 15:23:21 2024 -0700
fix: Remove redundant data copy in columnar shuffle (#233)
* fix: Remove redundant data copy in columnar shuffle
* Fix flaky test
---
.../shuffle/CometShuffleExchangeExec.scala | 214 +++++++++++++++++++--
.../comet/exec/CometColumnarShuffleSuite.scala | 1 +
2 files changed, 204 insertions(+), 11 deletions(-)
diff --git
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
index 47e6dc7..232b6bf 100644
---
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
+++
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
@@ -21,17 +21,21 @@ package org.apache.spark.sql.comet.execution.shuffle
import java.nio.{ByteBuffer, ByteOrder}
import java.nio.file.{Files, Paths}
+import java.util.function.Supplier
import scala.collection.JavaConverters.asJavaIterableConverter
import scala.concurrent.Future
import org.apache.spark._
+import org.apache.spark.internal.config
import org.apache.spark.rdd.{MapPartitionsRDD, RDD}
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{IndexShuffleBlockResolver,
ShuffleWriteMetricsReporter, ShuffleWriteProcessor}
+import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference,
UnsafeProjection, UnsafeRow}
+import
org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.comet.{CometExec, CometMetricNode, CometPlan}
@@ -39,8 +43,12 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS,
ShuffleExchangeLike, ShuffleOrigin}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics,
SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.ColumnarBatch
+import org.apache.spark.util.MutablePair
+import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators,
RecordComparator}
+import org.apache.spark.util.random.XORShiftRandom
import org.apache.comet.serde.{OperatorOuterClass, PartitioningOuterClass,
QueryPlanSerde}
import org.apache.comet.serde.OperatorOuterClass.Operator
@@ -208,6 +216,50 @@ object CometShuffleExchangeExec extends
ShimCometShuffleExchangeExec {
dependency
}
+ /**
+ * This is copied from Spark
`ShuffleExchangeExec.needToCopyObjectsBeforeShuffle`. The only
+ * difference is that we use `BosonShuffleManager` instead of
`SortShuffleManager`.
+ */
+ private def needToCopyObjectsBeforeShuffle(partitioner: Partitioner):
Boolean = {
+ // Note: even though we only use the partitioner's `numPartitions` field,
we require it to be
+ // passed instead of directly passing the number of partitions in order to
guard against
+ // corner-cases where a partitioner constructed with `numPartitions`
partitions may output
+ // fewer partitions (like RangePartitioner, for example).
+ val conf = SparkEnv.get.conf
+ val shuffleManager = SparkEnv.get.shuffleManager
+ val sortBasedShuffleOn = shuffleManager.isInstanceOf[CometShuffleManager]
+ val bypassMergeThreshold =
conf.get(config.SHUFFLE_SORT_BYPASS_MERGE_THRESHOLD)
+ val numParts = partitioner.numPartitions
+ if (sortBasedShuffleOn) {
+ if (numParts <= bypassMergeThreshold) {
+ // If we're using the original SortShuffleManager and the number of
output partitions is
+ // sufficiently small, then Spark will fall back to the hash-based
shuffle write path, which
+ // doesn't buffer deserialized records.
+ // Note that we'll have to remove this case if we fix SPARK-6026 and
remove this bypass.
+ false
+ } else if (numParts <=
SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) {
+ // SPARK-4550 and SPARK-7081 extended sort-based shuffle to serialize
individual records
+ // prior to sorting them. This optimization is only applied in cases
where shuffle
+ // dependency does not specify an aggregator or ordering and the
record serializer has
+ // certain properties and the number of partitions doesn't exceed the
limitation. If this
+ // optimization is enabled, we can safely avoid the copy.
+ //
+ // Exchange never configures its ShuffledRDDs with aggregators or key
orderings, and the
+ // serializer in Spark SQL always satisfy the properties, so we only
need to check whether
+ // the number of partitions exceeds the limitation.
+ false
+ } else {
+ // This different to Spark `SortShuffleManager`.
+ // Comet doesn't use Spark `ExternalSorter` to buffer records in
memory, so we don't need to
+ // copy.
+ false
+ }
+ } else {
+ // Catch-all case to safely handle any future ShuffleManager
implementations.
+ true
+ }
+ }
+
/**
* Returns a [[ShuffleDependency]] that will partition rows of its child
based on the
* partitioning scheme defined in `newPartitioning`. Those partitions of the
returned
@@ -219,21 +271,146 @@ object CometShuffleExchangeExec extends
ShimCometShuffleExchangeExec {
newPartitioning: Partitioning,
serializer: Serializer,
writeMetrics: Map[String, SQLMetric]): ShuffleDependency[Int,
InternalRow, InternalRow] = {
- val sparkShuffleDep = ShuffleExchangeExec.prepareShuffleDependency(
- rdd,
- outputAttributes,
- newPartitioning,
- serializer,
- writeMetrics)
+ val part: Partitioner = newPartitioning match {
+ case RoundRobinPartitioning(numPartitions) => new
HashPartitioner(numPartitions)
+ case HashPartitioning(_, n) =>
+ // For HashPartitioning, the partitioning key is already a valid
partition ID, as we use
+ // `HashPartitioning.partitionIdExpression` to produce partitioning
key.
+ new PartitionIdPassthrough(n)
+ case RangePartitioning(sortingExpressions, numPartitions) =>
+ // Extract only fields used for sorting to avoid collecting large
fields that does not
+ // affect sorting result when deciding partition bounds in
RangePartitioner
+ val rddForSampling = rdd.mapPartitionsInternal { iter =>
+ val projection =
+ UnsafeProjection.create(sortingExpressions.map(_.child),
outputAttributes)
+ val mutablePair = new MutablePair[InternalRow, Null]()
+ // Internally, RangePartitioner runs a job on the RDD that samples
keys to compute
+ // partition bounds. To get accurate samples, we need to copy the
mutable keys.
+ iter.map(row => mutablePair.update(projection(row).copy(), null))
+ }
+ // Construct ordering on extracted sort key.
+ val orderingAttributes = sortingExpressions.zipWithIndex.map { case
(ord, i) =>
+ ord.copy(child = BoundReference(i, ord.dataType, ord.nullable))
+ }
+ implicit val ordering = new LazilyGeneratedOrdering(orderingAttributes)
+ new RangePartitioner(
+ numPartitions,
+ rddForSampling,
+ ascending = true,
+ samplePointsPerPartitionHint =
SQLConf.get.rangeExchangeSampleSizePerPartition)
+ case SinglePartition => new ConstantPartitioner
+ case _ => throw new IllegalStateException(s"Exchange not implemented for
$newPartitioning")
+ // TODO: Handle BroadcastPartitioning.
+ }
+ def getPartitionKeyExtractor(): InternalRow => Any = newPartitioning match
{
+ case RoundRobinPartitioning(numPartitions) =>
+ // Distributes elements evenly across output partitions, starting from
a random partition.
+ // nextInt(numPartitions) implementation has a special case when bound
is a power of 2,
+ // which is basically taking several highest bits from the initial
seed, with only a
+ // minimal scrambling. Due to deterministic seed, using the generator
only once,
+ // and lack of scrambling, the position values for power-of-two
numPartitions always
+ // end up being almost the same regardless of the index. substantially
scrambling the
+ // seed by hashing will help. Refer to SPARK-21782 for more details.
+ val partitionId = TaskContext.get().partitionId()
+ var position = new XORShiftRandom(partitionId).nextInt(numPartitions)
+ (row: InternalRow) => {
+ // The HashPartitioner will handle the `mod` by the number of
partitions
+ position += 1
+ position
+ }
+ case h: HashPartitioning =>
+ val projection = UnsafeProjection.create(h.partitionIdExpression ::
Nil, outputAttributes)
+ row => projection(row).getInt(0)
+ case RangePartitioning(sortingExpressions, _) =>
+ val projection =
+ UnsafeProjection.create(sortingExpressions.map(_.child),
outputAttributes)
+ row => projection(row)
+ case SinglePartition => identity
+ case _ => throw new IllegalStateException(s"Exchange not implemented for
$newPartitioning")
+ }
+
+ val isRoundRobin = newPartitioning.isInstanceOf[RoundRobinPartitioning] &&
+ newPartitioning.numPartitions > 1
+
+ val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = {
+ // [SPARK-23207] Have to make sure the generated RoundRobinPartitioning
is deterministic,
+ // otherwise a retry task may output different rows and thus lead to
data loss.
+ //
+ // Currently we following the most straight-forward way that perform a
local sort before
+ // partitioning.
+ //
+ // Note that we don't perform local sort if the new partitioning has
only 1 partition, under
+ // that case all output rows go to the same partition.
+ val newRdd = if (isRoundRobin && SQLConf.get.sortBeforeRepartition) {
+ rdd.mapPartitionsInternal { iter =>
+ val recordComparatorSupplier = new Supplier[RecordComparator] {
+ override def get: RecordComparator = new RecordBinaryComparator()
+ }
+ // The comparator for comparing row hashcode, which should always be
Integer.
+ val prefixComparator = PrefixComparators.LONG
+
+ // The prefix computer generates row hashcode as the prefix, so we
may decrease the
+ // probability that the prefixes are equal when input rows choose
column values from a
+ // limited range.
+ val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer {
+ private val result = new
UnsafeExternalRowSorter.PrefixComputer.Prefix
+ override def computePrefix(
+ row: InternalRow):
UnsafeExternalRowSorter.PrefixComputer.Prefix = {
+ // The hashcode generated from the binary form of a
[[UnsafeRow]] should not be null.
+ result.isNull = false
+ result.value = row.hashCode()
+ result
+ }
+ }
+ val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
+
+ val sorter = UnsafeExternalRowSorter.createWithRecordComparator(
+ StructType.fromAttributes(outputAttributes),
+ recordComparatorSupplier,
+ prefixComparator,
+ prefixComputer,
+ pageSize,
+ // We are comparing binary here, which does not support radix sort.
+ // See more details in SPARK-28699.
+ false)
+ sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]])
+ }
+ } else {
+ rdd
+ }
+ // round-robin function is order sensitive if we don't sort the input.
+ val isOrderSensitive = isRoundRobin && !SQLConf.get.sortBeforeRepartition
+ if (CometShuffleExchangeExec.needToCopyObjectsBeforeShuffle(part)) {
+ newRdd.mapPartitionsWithIndexInternal(
+ (_, iter) => {
+ val getPartitionKey = getPartitionKeyExtractor()
+ iter.map { row => (part.getPartition(getPartitionKey(row)),
row.copy()) }
+ },
+ isOrderSensitive = isOrderSensitive)
+ } else {
+ newRdd.mapPartitionsWithIndexInternal(
+ (_, iter) => {
+ val getPartitionKey = getPartitionKeyExtractor()
+ val mutablePair = new MutablePair[Int, InternalRow]()
+ iter.map { row =>
mutablePair.update(part.getPartition(getPartitionKey(row)), row) }
+ },
+ isOrderSensitive = isOrderSensitive)
+ }
+ }
+
+ // Now, we manually create a ShuffleDependency. Because pairs in
rddWithPartitionIds
+ // are in the form of (partitionId, row) and every partitionId is in the
expected range
+ // [0, part.numPartitions - 1]. The partitioner of this is a
PartitionIdPassthrough.
val dependency =
new CometShuffleDependency[Int, InternalRow, InternalRow](
- sparkShuffleDep.rdd,
- sparkShuffleDep.partitioner,
- sparkShuffleDep.serializer,
- shuffleWriterProcessor = sparkShuffleDep.shuffleWriterProcessor,
+ rddWithPartitionIds,
+ new PartitionIdPassthrough(part.numPartitions),
+ serializer,
+ shuffleWriterProcessor =
ShuffleExchangeExec.createShuffleWriteProcessor(writeMetrics),
shuffleType = CometColumnarShuffle,
schema = Some(StructType.fromAttributes(outputAttributes)))
+
dependency
}
}
@@ -379,3 +556,18 @@ class CometShuffleWriteProcessor(
}
}
}
+
+/**
+ * Copied from Spark `PartitionIdPassthrough` as it is private in Spark 3.2.
+ */
+private[spark] class PartitionIdPassthrough(override val numPartitions: Int)
extends Partitioner {
+ override def getPartition(key: Any): Int = key.asInstanceOf[Int]
+}
+
+/**
+ * Copied from Spark `ConstantPartitioner` as it doesn't exist in Spark 3.2.
+ */
+private[spark] class ConstantPartitioner extends Partitioner {
+ override def numPartitions: Int = 1
+ override def getPartition(key: Any): Int = 0
+}
diff --git
a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala
b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala
index 216b690..4b4f60a 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala
@@ -950,6 +950,7 @@ class CometShuffleSuite extends CometColumnarShuffleSuite {
.filter($"a" > 4)
.repartition(10)
.sortWithinPartitions($"a")
+ .filter($"a" >= 10)
checkSparkAnswerAndOperator(shuffled, classOf[ShuffleQueryStageExec])
}
}