This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 9d3d7b8  fix: Remove redundant data copy in columnar shuffle (#233)
9d3d7b8 is described below

commit 9d3d7b8fbfcad8c857463fe8d81bd98b2f45d067
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Wed Mar 27 15:23:21 2024 -0700

    fix: Remove redundant data copy in columnar shuffle (#233)
    
    * fix: Remove redundant data copy in columnar shuffle
    
    * Fix flaky test
---
 .../shuffle/CometShuffleExchangeExec.scala         | 214 +++++++++++++++++++--
 .../comet/exec/CometColumnarShuffleSuite.scala     |   1 +
 2 files changed, 204 insertions(+), 11 deletions(-)

diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
 
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
index 47e6dc7..232b6bf 100644
--- 
a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
+++ 
b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala
@@ -21,17 +21,21 @@ package org.apache.spark.sql.comet.execution.shuffle
 
 import java.nio.{ByteBuffer, ByteOrder}
 import java.nio.file.{Files, Paths}
+import java.util.function.Supplier
 
 import scala.collection.JavaConverters.asJavaIterableConverter
 import scala.concurrent.Future
 
 import org.apache.spark._
+import org.apache.spark.internal.config
 import org.apache.spark.rdd.{MapPartitionsRDD, RDD}
 import org.apache.spark.scheduler.MapStatus
 import org.apache.spark.serializer.Serializer
 import org.apache.spark.shuffle.{IndexShuffleBlockResolver, 
ShuffleWriteMetricsReporter, ShuffleWriteProcessor}
+import org.apache.spark.shuffle.sort.SortShuffleManager
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, 
UnsafeProjection, UnsafeRow}
+import 
org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
 import org.apache.spark.sql.catalyst.plans.logical.Statistics
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.comet.{CometExec, CometMetricNode, CometPlan}
@@ -39,8 +43,12 @@ import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, 
ShuffleExchangeLike, ShuffleOrigin}
 import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, 
SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter}
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.vectorized.ColumnarBatch
+import org.apache.spark.util.MutablePair
+import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, 
RecordComparator}
+import org.apache.spark.util.random.XORShiftRandom
 
 import org.apache.comet.serde.{OperatorOuterClass, PartitioningOuterClass, 
QueryPlanSerde}
 import org.apache.comet.serde.OperatorOuterClass.Operator
@@ -208,6 +216,50 @@ object CometShuffleExchangeExec extends 
ShimCometShuffleExchangeExec {
     dependency
   }
 
+  /**
+   * This is copied from Spark 
`ShuffleExchangeExec.needToCopyObjectsBeforeShuffle`. The only
+   * difference is that we use `BosonShuffleManager` instead of 
`SortShuffleManager`.
+   */
+  private def needToCopyObjectsBeforeShuffle(partitioner: Partitioner): 
Boolean = {
+    // Note: even though we only use the partitioner's `numPartitions` field, 
we require it to be
+    // passed instead of directly passing the number of partitions in order to 
guard against
+    // corner-cases where a partitioner constructed with `numPartitions` 
partitions may output
+    // fewer partitions (like RangePartitioner, for example).
+    val conf = SparkEnv.get.conf
+    val shuffleManager = SparkEnv.get.shuffleManager
+    val sortBasedShuffleOn = shuffleManager.isInstanceOf[CometShuffleManager]
+    val bypassMergeThreshold = 
conf.get(config.SHUFFLE_SORT_BYPASS_MERGE_THRESHOLD)
+    val numParts = partitioner.numPartitions
+    if (sortBasedShuffleOn) {
+      if (numParts <= bypassMergeThreshold) {
+        // If we're using the original SortShuffleManager and the number of 
output partitions is
+        // sufficiently small, then Spark will fall back to the hash-based 
shuffle write path, which
+        // doesn't buffer deserialized records.
+        // Note that we'll have to remove this case if we fix SPARK-6026 and 
remove this bypass.
+        false
+      } else if (numParts <= 
SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) {
+        // SPARK-4550 and  SPARK-7081 extended sort-based shuffle to serialize 
individual records
+        // prior to sorting them. This optimization is only applied in cases 
where shuffle
+        // dependency does not specify an aggregator or ordering and the 
record serializer has
+        // certain properties and the number of partitions doesn't exceed the 
limitation. If this
+        // optimization is enabled, we can safely avoid the copy.
+        //
+        // Exchange never configures its ShuffledRDDs with aggregators or key 
orderings, and the
+        // serializer in Spark SQL always satisfy the properties, so we only 
need to check whether
+        // the number of partitions exceeds the limitation.
+        false
+      } else {
+        // This different to Spark `SortShuffleManager`.
+        // Comet doesn't use Spark `ExternalSorter` to buffer records in 
memory, so we don't need to
+        // copy.
+        false
+      }
+    } else {
+      // Catch-all case to safely handle any future ShuffleManager 
implementations.
+      true
+    }
+  }
+
   /**
    * Returns a [[ShuffleDependency]] that will partition rows of its child 
based on the
    * partitioning scheme defined in `newPartitioning`. Those partitions of the 
returned
@@ -219,21 +271,146 @@ object CometShuffleExchangeExec extends 
ShimCometShuffleExchangeExec {
       newPartitioning: Partitioning,
       serializer: Serializer,
       writeMetrics: Map[String, SQLMetric]): ShuffleDependency[Int, 
InternalRow, InternalRow] = {
-    val sparkShuffleDep = ShuffleExchangeExec.prepareShuffleDependency(
-      rdd,
-      outputAttributes,
-      newPartitioning,
-      serializer,
-      writeMetrics)
+    val part: Partitioner = newPartitioning match {
+      case RoundRobinPartitioning(numPartitions) => new 
HashPartitioner(numPartitions)
+      case HashPartitioning(_, n) =>
+        // For HashPartitioning, the partitioning key is already a valid 
partition ID, as we use
+        // `HashPartitioning.partitionIdExpression` to produce partitioning 
key.
+        new PartitionIdPassthrough(n)
+      case RangePartitioning(sortingExpressions, numPartitions) =>
+        // Extract only fields used for sorting to avoid collecting large 
fields that does not
+        // affect sorting result when deciding partition bounds in 
RangePartitioner
+        val rddForSampling = rdd.mapPartitionsInternal { iter =>
+          val projection =
+            UnsafeProjection.create(sortingExpressions.map(_.child), 
outputAttributes)
+          val mutablePair = new MutablePair[InternalRow, Null]()
+          // Internally, RangePartitioner runs a job on the RDD that samples 
keys to compute
+          // partition bounds. To get accurate samples, we need to copy the 
mutable keys.
+          iter.map(row => mutablePair.update(projection(row).copy(), null))
+        }
+        // Construct ordering on extracted sort key.
+        val orderingAttributes = sortingExpressions.zipWithIndex.map { case 
(ord, i) =>
+          ord.copy(child = BoundReference(i, ord.dataType, ord.nullable))
+        }
+        implicit val ordering = new LazilyGeneratedOrdering(orderingAttributes)
+        new RangePartitioner(
+          numPartitions,
+          rddForSampling,
+          ascending = true,
+          samplePointsPerPartitionHint = 
SQLConf.get.rangeExchangeSampleSizePerPartition)
+      case SinglePartition => new ConstantPartitioner
+      case _ => throw new IllegalStateException(s"Exchange not implemented for 
$newPartitioning")
+      // TODO: Handle BroadcastPartitioning.
+    }
+    def getPartitionKeyExtractor(): InternalRow => Any = newPartitioning match 
{
+      case RoundRobinPartitioning(numPartitions) =>
+        // Distributes elements evenly across output partitions, starting from 
a random partition.
+        // nextInt(numPartitions) implementation has a special case when bound 
is a power of 2,
+        // which is basically taking several highest bits from the initial 
seed, with only a
+        // minimal scrambling. Due to deterministic seed, using the generator 
only once,
+        // and lack of scrambling, the position values for power-of-two 
numPartitions always
+        // end up being almost the same regardless of the index. substantially 
scrambling the
+        // seed by hashing will help. Refer to SPARK-21782 for more details.
+        val partitionId = TaskContext.get().partitionId()
+        var position = new XORShiftRandom(partitionId).nextInt(numPartitions)
+        (row: InternalRow) => {
+          // The HashPartitioner will handle the `mod` by the number of 
partitions
+          position += 1
+          position
+        }
+      case h: HashPartitioning =>
+        val projection = UnsafeProjection.create(h.partitionIdExpression :: 
Nil, outputAttributes)
+        row => projection(row).getInt(0)
+      case RangePartitioning(sortingExpressions, _) =>
+        val projection =
+          UnsafeProjection.create(sortingExpressions.map(_.child), 
outputAttributes)
+        row => projection(row)
+      case SinglePartition => identity
+      case _ => throw new IllegalStateException(s"Exchange not implemented for 
$newPartitioning")
+    }
+
+    val isRoundRobin = newPartitioning.isInstanceOf[RoundRobinPartitioning] &&
+      newPartitioning.numPartitions > 1
+
+    val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = {
+      // [SPARK-23207] Have to make sure the generated RoundRobinPartitioning 
is deterministic,
+      // otherwise a retry task may output different rows and thus lead to 
data loss.
+      //
+      // Currently we following the most straight-forward way that perform a 
local sort before
+      // partitioning.
+      //
+      // Note that we don't perform local sort if the new partitioning has 
only 1 partition, under
+      // that case all output rows go to the same partition.
+      val newRdd = if (isRoundRobin && SQLConf.get.sortBeforeRepartition) {
+        rdd.mapPartitionsInternal { iter =>
+          val recordComparatorSupplier = new Supplier[RecordComparator] {
+            override def get: RecordComparator = new RecordBinaryComparator()
+          }
+          // The comparator for comparing row hashcode, which should always be 
Integer.
+          val prefixComparator = PrefixComparators.LONG
+
+          // The prefix computer generates row hashcode as the prefix, so we 
may decrease the
+          // probability that the prefixes are equal when input rows choose 
column values from a
+          // limited range.
+          val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer {
+            private val result = new 
UnsafeExternalRowSorter.PrefixComputer.Prefix
+            override def computePrefix(
+                row: InternalRow): 
UnsafeExternalRowSorter.PrefixComputer.Prefix = {
+              // The hashcode generated from the binary form of a 
[[UnsafeRow]] should not be null.
+              result.isNull = false
+              result.value = row.hashCode()
+              result
+            }
+          }
+          val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
+
+          val sorter = UnsafeExternalRowSorter.createWithRecordComparator(
+            StructType.fromAttributes(outputAttributes),
+            recordComparatorSupplier,
+            prefixComparator,
+            prefixComputer,
+            pageSize,
+            // We are comparing binary here, which does not support radix sort.
+            // See more details in SPARK-28699.
+            false)
+          sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]])
+        }
+      } else {
+        rdd
+      }
 
+      // round-robin function is order sensitive if we don't sort the input.
+      val isOrderSensitive = isRoundRobin && !SQLConf.get.sortBeforeRepartition
+      if (CometShuffleExchangeExec.needToCopyObjectsBeforeShuffle(part)) {
+        newRdd.mapPartitionsWithIndexInternal(
+          (_, iter) => {
+            val getPartitionKey = getPartitionKeyExtractor()
+            iter.map { row => (part.getPartition(getPartitionKey(row)), 
row.copy()) }
+          },
+          isOrderSensitive = isOrderSensitive)
+      } else {
+        newRdd.mapPartitionsWithIndexInternal(
+          (_, iter) => {
+            val getPartitionKey = getPartitionKeyExtractor()
+            val mutablePair = new MutablePair[Int, InternalRow]()
+            iter.map { row => 
mutablePair.update(part.getPartition(getPartitionKey(row)), row) }
+          },
+          isOrderSensitive = isOrderSensitive)
+      }
+    }
+
+    // Now, we manually create a ShuffleDependency. Because pairs in 
rddWithPartitionIds
+    // are in the form of (partitionId, row) and every partitionId is in the 
expected range
+    // [0, part.numPartitions - 1]. The partitioner of this is a 
PartitionIdPassthrough.
     val dependency =
       new CometShuffleDependency[Int, InternalRow, InternalRow](
-        sparkShuffleDep.rdd,
-        sparkShuffleDep.partitioner,
-        sparkShuffleDep.serializer,
-        shuffleWriterProcessor = sparkShuffleDep.shuffleWriterProcessor,
+        rddWithPartitionIds,
+        new PartitionIdPassthrough(part.numPartitions),
+        serializer,
+        shuffleWriterProcessor = 
ShuffleExchangeExec.createShuffleWriteProcessor(writeMetrics),
         shuffleType = CometColumnarShuffle,
         schema = Some(StructType.fromAttributes(outputAttributes)))
+
     dependency
   }
 }
@@ -379,3 +556,18 @@ class CometShuffleWriteProcessor(
     }
   }
 }
+
+/**
+ * Copied from Spark `PartitionIdPassthrough` as it is private in Spark 3.2.
+ */
+private[spark] class PartitionIdPassthrough(override val numPartitions: Int) 
extends Partitioner {
+  override def getPartition(key: Any): Int = key.asInstanceOf[Int]
+}
+
+/**
+ * Copied from Spark `ConstantPartitioner` as it doesn't exist in Spark 3.2.
+ */
+private[spark] class ConstantPartitioner extends Partitioner {
+  override def numPartitions: Int = 1
+  override def getPartition(key: Any): Int = 0
+}
diff --git 
a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala
index 216b690..4b4f60a 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala
@@ -950,6 +950,7 @@ class CometShuffleSuite extends CometColumnarShuffleSuite {
           .filter($"a" > 4)
           .repartition(10)
           .sortWithinPartitions($"a")
+          .filter($"a" >= 10)
         checkSparkAnswerAndOperator(shuffled, classOf[ShuffleQueryStageExec])
       }
     }

Reply via email to