This is an automated email from the ASF dual-hosted git repository.
marong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 2ae80afbc [CORE] Use SortShuffleManager instance in
ColumnarShuffleManager (#6022)
2ae80afbc is described below
commit 2ae80afbca24e92fca5c9c5d0849a37a5b5c15fd
Author: Ankita Victor <[email protected]>
AuthorDate: Thu Jun 13 06:27:10 2024 +0530
[CORE] Use SortShuffleManager instance in ColumnarShuffleManager (#6022)
---
.../apache/gluten/execution/FallbackSuite.scala | 26 ++++-
.../shuffle/sort/ColumnarShuffleManager.scala | 121 +++++++--------------
2 files changed, 67 insertions(+), 80 deletions(-)
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/execution/FallbackSuite.scala
b/backends-velox/src/test/scala/org/apache/gluten/execution/FallbackSuite.scala
index 15a71ceb5..27d191b9e 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/execution/FallbackSuite.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/execution/FallbackSuite.scala
@@ -20,8 +20,9 @@ import org.apache.gluten.GlutenConfig
import org.apache.gluten.extension.GlutenPlan
import org.apache.spark.SparkConf
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.{ColumnarShuffleExchangeExec, SparkPlan}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper,
AQEShuffleReadExec}
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
class FallbackSuite extends VeloxWholeStageTransformerSuite with
AdaptiveSparkPlanHelper {
protected val rootPath: String = getClass.getResource("/").getPath
@@ -71,6 +72,29 @@ class FallbackSuite extends VeloxWholeStageTransformerSuite
with AdaptiveSparkPl
collect(plan) { case v: VeloxColumnarToRowExec => v }.size
}
+ private def collectColumnarShuffleExchange(plan: SparkPlan): Int = {
+ collect(plan) { case c: ColumnarShuffleExchangeExec => c }.size
+ }
+
+ private def collectShuffleExchange(plan: SparkPlan): Int = {
+ collect(plan) { case c: ShuffleExchangeExec => c }.size
+ }
+
+ test("fallback with shuffle manager") {
+ withSQLConf(GlutenConfig.COLUMNAR_SHUFFLE_ENABLED.key -> "false") {
+ runQueryAndCompare("select c1, count(*) from tmp1 group by c1") {
+ df =>
+ val plan = df.queryExecution.executedPlan
+
+ assert(collectColumnarShuffleExchange(plan) == 0)
+ assert(collectShuffleExchange(plan) == 1)
+
+ val wholeQueryColumnarToRow = collectColumnarToRow(plan)
+ assert(wholeQueryColumnarToRow == 2)
+ }
+ }
+ }
+
test("fallback with collect") {
withSQLConf(GlutenConfig.COLUMNAR_WHOLESTAGE_FALLBACK_THRESHOLD.key ->
"1") {
runQueryAndCompare("SELECT count(*) FROM tmp1") {
diff --git
a/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala
b/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala
index d8ba78cb9..06c6e6c0e 100644
---
a/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala
+++
b/gluten-core/src/main/scala/org/apache/spark/shuffle/sort/ColumnarShuffleManager.scala
@@ -20,7 +20,6 @@ import org.apache.spark.{ShuffleDependency, SparkConf,
SparkEnv, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.serializer.SerializerManager
import org.apache.spark.shuffle._
-import org.apache.spark.shuffle.api.ShuffleExecutorComponents
import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch
import org.apache.spark.storage.BlockId
import org.apache.spark.util.collection.OpenHashSet
@@ -28,13 +27,12 @@ import org.apache.spark.util.collection.OpenHashSet
import java.io.InputStream
import java.util.concurrent.ConcurrentHashMap
-import scala.collection.JavaConverters._
-
class ColumnarShuffleManager(conf: SparkConf) extends ShuffleManager with
Logging {
import ColumnarShuffleManager._
- private lazy val shuffleExecutorComponents =
loadShuffleExecutorComponents(conf)
+ private[this] lazy val sortShuffleManager: SortShuffleManager = new
SortShuffleManager(conf)
+
override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf)
/** A mapping from shuffle ids to the number of mappers producing output for
those shuffles. */
@@ -49,23 +47,9 @@ class ColumnarShuffleManager(conf: SparkConf) extends
ShuffleManager with Loggin
new ColumnarShuffleHandle[K, V](
shuffleId,
dependency.asInstanceOf[ColumnarShuffleDependency[K, V, V]])
- } else if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) {
- // If there are fewer than spark.shuffle.sort.bypassMergeThreshold
partitions and we don't
- // need map-side aggregation, then write numPartitions files directly
and just concatenate
- // them at the end. This avoids doing serialization and deserialization
twice to merge
- // together the spilled files, which would happen with the normal code
path. The downside is
- // having multiple files open at a time and thus more memory allocated
to buffers.
- new BypassMergeSortShuffleHandle[K, V](
- shuffleId,
- dependency.asInstanceOf[ShuffleDependency[K, V, V]])
- } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) {
- // Otherwise, try to buffer map outputs in a serialized form, since this
is more efficient:
- new SerializedShuffleHandle[K, V](
- shuffleId,
- dependency.asInstanceOf[ShuffleDependency[K, V, V]])
} else {
- // Otherwise, buffer map outputs in a deserialized form:
- new BaseShuffleHandle(shuffleId, dependency)
+ // Otherwise call default SortShuffleManager
+ sortShuffleManager.registerShuffle(shuffleId, dependency)
}
}
@@ -75,39 +59,19 @@ class ColumnarShuffleManager(conf: SparkConf) extends
ShuffleManager with Loggin
mapId: Long,
context: TaskContext,
metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
- val mapTaskIds =
- taskIdMapsForShuffle.computeIfAbsent(handle.shuffleId, _ => new
OpenHashSet[Long](16))
- mapTaskIds.synchronized {
- mapTaskIds.add(context.taskAttemptId())
- }
- val env = SparkEnv.get
handle match {
case columnarShuffleHandle: ColumnarShuffleHandle[K @unchecked, V
@unchecked] =>
+ val mapTaskIds =
+ taskIdMapsForShuffle.computeIfAbsent(handle.shuffleId, _ => new
OpenHashSet[Long](16))
+ mapTaskIds.synchronized {
+ mapTaskIds.add(context.taskAttemptId())
+ }
GlutenShuffleWriterWrapper.genColumnarShuffleWriter(
shuffleBlockResolver,
columnarShuffleHandle,
mapId,
metrics)
- case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V
@unchecked] =>
- new UnsafeShuffleWriter(
- env.blockManager,
- context.taskMemoryManager(),
- unsafeShuffleHandle,
- mapId,
- context,
- env.conf,
- metrics,
- shuffleExecutorComponents)
- case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V
@unchecked] =>
- new BypassMergeSortShuffleWriter(
- env.blockManager,
- bypassMergeSortHandle,
- mapId,
- env.conf,
- metrics,
- shuffleExecutorComponents)
- case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
- new SortShuffleWriter(other, mapId, context, shuffleExecutorComponents)
+ case _ => sortShuffleManager.getWriter(handle, mapId, context, metrics)
}
}
@@ -123,17 +87,17 @@ class ColumnarShuffleManager(conf: SparkConf) extends
ShuffleManager with Loggin
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
- val (blocksByAddress, canEnableBatchFetch) = {
- GlutenShuffleUtils.getReaderParam(
- handle,
- startMapIndex,
- endMapIndex,
- startPartition,
- endPartition)
- }
- val shouldBatchFetch =
- canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition,
context)
if (handle.isInstanceOf[ColumnarShuffleHandle[_, _]]) {
+ val (blocksByAddress, canEnableBatchFetch) = {
+ GlutenShuffleUtils.getReaderParam(
+ handle,
+ startMapIndex,
+ endMapIndex,
+ startPartition,
+ endPartition)
+ }
+ val shouldBatchFetch =
+ canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition,
context)
new BlockStoreShuffleReader(
handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
blocksByAddress,
@@ -143,44 +107,43 @@ class ColumnarShuffleManager(conf: SparkConf) extends
ShuffleManager with Loggin
shouldBatchFetch = shouldBatchFetch
)
} else {
- new BlockStoreShuffleReader(
- handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
- blocksByAddress,
+ sortShuffleManager.getReader(
+ handle,
+ startMapIndex,
+ endMapIndex,
+ startPartition,
+ endPartition,
context,
- metrics,
- shouldBatchFetch = shouldBatchFetch
- )
+ metrics)
}
}
/** Remove a shuffle's metadata from the ShuffleManager. */
override def unregisterShuffle(shuffleId: Int): Boolean = {
- Option(taskIdMapsForShuffle.remove(shuffleId)).foreach {
- mapTaskIds =>
- mapTaskIds.iterator.foreach {
- mapId => shuffleBlockResolver.removeDataByMap(shuffleId, mapId)
- }
+ if (taskIdMapsForShuffle.contains(shuffleId)) {
+ Option(taskIdMapsForShuffle.remove(shuffleId)).foreach {
+ mapTaskIds =>
+ mapTaskIds.iterator.foreach {
+ mapId => shuffleBlockResolver.removeDataByMap(shuffleId, mapId)
+ }
+ }
+ true
+ } else {
+ sortShuffleManager.unregisterShuffle(shuffleId)
}
- true
}
/** Shut down this ShuffleManager. */
override def stop(): Unit = {
- shuffleBlockResolver.stop()
+ if (!taskIdMapsForShuffle.isEmpty) {
+ shuffleBlockResolver.stop()
+ } else {
+ sortShuffleManager.stop
+ }
}
}
object ColumnarShuffleManager extends Logging {
- private def loadShuffleExecutorComponents(conf: SparkConf):
ShuffleExecutorComponents = {
- val executorComponents =
ShuffleDataIOUtils.loadShuffleDataIO(conf).executor()
- val extraConfigs =
conf.getAllWithPrefix(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX).toMap
- executorComponents.initializeExecutor(
- conf.getAppId,
- SparkEnv.get.executorId,
- extraConfigs.asJava)
- executorComponents
- }
-
private def bypassDecompressionSerializerManger =
new SerializerManager(
SparkEnv.get.serializer,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]