This is an automated email from the ASF dual-hosted git repository.
hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 9059ff685f [GLUTEN-8060][CORE] GlutenShuffleManager as a registry of
shuffle managers (#8084)
9059ff685f is described below
commit 9059ff685f37367af8d84c6cc26bae05ec2a225f
Author: Hongze Zhang <[email protected]>
AuthorDate: Tue Dec 3 09:36:15 2024 +0800
[GLUTEN-8060][CORE] GlutenShuffleManager as a registry of shuffle managers
(#8084)
---
.../spark/shuffle/GlutenShuffleManager.scala | 71 +++++
.../scala/org/apache/spark/shuffle/LookupKey.scala | 27 ++
.../spark/shuffle/ShuffleManagerLookup.scala | 38 +++
.../spark/shuffle/ShuffleManagerRegistry.scala | 94 ++++++
.../spark/shuffle/ShuffleManagerRouter.scala | 137 +++++++++
.../spark/shuffle/GlutenShuffleManagerSuite.scala | 315 +++++++++++++++++++++
6 files changed, 682 insertions(+)
diff --git
a/gluten-core/src/main/scala/org/apache/spark/shuffle/GlutenShuffleManager.scala
b/gluten-core/src/main/scala/org/apache/spark/shuffle/GlutenShuffleManager.scala
new file mode 100644
index 0000000000..d38781675b
--- /dev/null
+++
b/gluten-core/src/main/scala/org/apache/spark/shuffle/GlutenShuffleManager.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.shuffle
+
+import org.apache.spark.{ShuffleDependency, SparkConf, TaskContext}
+import org.apache.spark.annotation.Experimental
+
+/**
+ * Shuffle manager that routes shuffle API calls to different shuffle managers
registered by
+ * different backends.
+ *
+ * A SPIP may cause refactoring of this class in the future:
+ * https://issues.apache.org/jira/browse/SPARK-45792
+ */
+@Experimental
+class GlutenShuffleManager(conf: SparkConf, isDriver: Boolean) extends
ShuffleManager {
+ private val routerBuilder =
ShuffleManagerRegistry.get().newRouterBuilder(conf, isDriver)
+
+ override def registerShuffle[K, V, C](
+ shuffleId: Int,
+ dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
+ routerBuilder.getOrBuild().registerShuffle(shuffleId, dependency)
+ }
+
+ override def getWriter[K, V](
+ handle: ShuffleHandle,
+ mapId: Long,
+ context: TaskContext,
+ metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
+ routerBuilder.getOrBuild().getWriter(handle, mapId, context, metrics)
+ }
+
+ override def getReader[K, C](
+ handle: ShuffleHandle,
+ startMapIndex: Int,
+ endMapIndex: Int,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext,
+ metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
+ routerBuilder
+ .getOrBuild()
+ .getReader(handle, startMapIndex, endMapIndex, startPartition,
endPartition, context, metrics)
+ }
+
+ override def unregisterShuffle(shuffleId: Int): Boolean = {
+ routerBuilder.getOrBuild().unregisterShuffle(shuffleId)
+ }
+
+ override def shuffleBlockResolver: ShuffleBlockResolver = {
+ routerBuilder.getOrBuild().shuffleBlockResolver
+ }
+
+ override def stop(): Unit = {
+ routerBuilder.getOrBuild().stop()
+ }
+}
diff --git
a/gluten-core/src/main/scala/org/apache/spark/shuffle/LookupKey.scala
b/gluten-core/src/main/scala/org/apache/spark/shuffle/LookupKey.scala
new file mode 100644
index 0000000000..502dd92efe
--- /dev/null
+++ b/gluten-core/src/main/scala/org/apache/spark/shuffle/LookupKey.scala
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.shuffle
+
+import org.apache.spark.ShuffleDependency
+
+/**
+ * Required during shuffle manager registration to determine whether the
shuffle manager should be
+ * used for the particular shuffle dependency.
+ */
+trait LookupKey {
+ def accepts[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean
+}
diff --git
a/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerLookup.scala
b/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerLookup.scala
new file mode 100644
index 0000000000..8b060c9818
--- /dev/null
+++
b/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerLookup.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.shuffle
+
+import org.apache.spark.ShuffleDependency
+
+private class ShuffleManagerLookup(all: Seq[(LookupKey, ShuffleManager)]) {
+ private val allReversed = all.reverse
+
+ def findShuffleManager[K, V, C](dependency: ShuffleDependency[K, V, C]):
ShuffleManager = {
+ this.synchronized {
+ // The latest shuffle manager registered will be looked up earlier.
+ allReversed.find(_._1.accepts(dependency)).map(_._2).getOrElse {
+ throw new IllegalStateException(s"No ShuffleManager found for
$dependency")
+ }
+ }
+ }
+
+ def all(): Seq[ShuffleManager] = {
+ this.synchronized {
+ all.map(_._2)
+ }
+ }
+}
diff --git
a/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerRegistry.scala
b/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerRegistry.scala
new file mode 100644
index 0000000000..4310054caa
--- /dev/null
+++
b/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerRegistry.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.shuffle
+
+import org.apache.spark.SparkConf
+import org.apache.spark.util.Utils
+
+import scala.collection.mutable
+
+class ShuffleManagerRegistry private[ShuffleManagerRegistry] {
+ import ShuffleManagerRegistry._
+ private val all: mutable.Buffer[(LookupKey, String)] = mutable.Buffer()
+ private val routerBuilders: mutable.Buffer[RouterBuilder] = mutable.Buffer()
+ private val classDeDup: mutable.Set[String] = mutable.Set()
+
+ def register(lookupKey: LookupKey, shuffleManagerClass: String): Unit = {
+ val clazz = Utils.classForName(shuffleManagerClass)
+ require(
+ !clazz.isAssignableFrom(classOf[GlutenShuffleManager]),
+ "It's not allowed to register GlutenShuffleManager recursively")
+ require(
+ classOf[ShuffleManager].isAssignableFrom(clazz),
+ s"Shuffle manager class to register is not an implementation of Spark
ShuffleManager: " +
+ s"$shuffleManagerClass"
+ )
+ require(
+ !classDeDup.contains(shuffleManagerClass),
+ s"Shuffle manager class already registered: $shuffleManagerClass")
+ this.synchronized {
+ classDeDup += shuffleManagerClass
+ all += lookupKey -> shuffleManagerClass
+ // Invalidate all shuffle managers cached in each alive router builder
instances.
+ // Then, once the router builder is accessed, a new router will be
forced to create.
+ routerBuilders.foreach(_.invalidateCache())
+ }
+ }
+
+ // Visible for testing
+ private[shuffle] def clear(): Unit = {
+ this.synchronized {
+ classDeDup.clear()
+ all.clear()
+ routerBuilders.foreach(_.invalidateCache())
+ }
+ }
+
+ private[shuffle] def newRouterBuilder(conf: SparkConf, isDriver: Boolean):
RouterBuilder =
+ this.synchronized {
+ val out = new RouterBuilder(this, conf, isDriver)
+ routerBuilders += out
+ out
+ }
+}
+
+object ShuffleManagerRegistry {
+ private val instance = new ShuffleManagerRegistry()
+
+ def get(): ShuffleManagerRegistry = instance
+
+ class RouterBuilder(registry: ShuffleManagerRegistry, conf: SparkConf,
isDriver: Boolean) {
+ private var router: Option[ShuffleManagerRouter] = None
+
+ private[ShuffleManagerRegistry] def invalidateCache(): Unit = synchronized
{
+ router = None
+ }
+
+ private[shuffle] def getOrBuild(): ShuffleManagerRouter = synchronized {
+ if (router.isEmpty) {
+ val instances = registry.all.map(key => key._1 -> instantiate(key._2,
conf, isDriver))
+ router = Some(new ShuffleManagerRouter(new
ShuffleManagerLookup(instances.toSeq)))
+ }
+ router.get
+ }
+
+ private def instantiate(clazz: String, conf: SparkConf, isDriver:
Boolean): ShuffleManager = {
+ Utils
+ .instantiateSerializerOrShuffleManager[ShuffleManager](clazz, conf,
isDriver)
+ }
+ }
+}
diff --git
a/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerRouter.scala
b/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerRouter.scala
new file mode 100644
index 0000000000..80aa9d8047
--- /dev/null
+++
b/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerRouter.scala
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.shuffle
+import org.apache.spark.{ShuffleDependency, TaskContext}
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.MergedBlockMeta
+import org.apache.spark.storage.{BlockId, ShuffleBlockBatchId, ShuffleBlockId,
ShuffleMergedBlockId}
+
+/** The internal shuffle manager instance used by GlutenShuffleManager. */
+private class ShuffleManagerRouter(lookup: ShuffleManagerLookup) extends
ShuffleManager {
+ import ShuffleManagerRouter._
+ private val cache = new Cache()
+ private val resolver = new BlockResolver(cache)
+
+ override def registerShuffle[K, V, C](
+ shuffleId: Int,
+ dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
+ val manager = lookup.findShuffleManager(dependency)
+ cache.store(shuffleId, manager).registerShuffle(shuffleId, dependency)
+ }
+
+ override def getWriter[K, V](
+ handle: ShuffleHandle,
+ mapId: Long,
+ context: TaskContext,
+ metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
+ cache.get(handle.shuffleId).getWriter(handle, mapId, context, metrics)
+ }
+
+ override def getReader[K, C](
+ handle: ShuffleHandle,
+ startMapIndex: Int,
+ endMapIndex: Int,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext,
+ metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
+ cache
+ .get(handle.shuffleId)
+ .getReader(handle, startMapIndex, endMapIndex, startPartition,
endPartition, context, metrics)
+ }
+
+ override def unregisterShuffle(shuffleId: Int): Boolean = {
+ cache.remove(shuffleId).unregisterShuffle(shuffleId)
+ }
+
+ override def shuffleBlockResolver: ShuffleBlockResolver = resolver
+
+ override def stop(): Unit = {
+ assert(cache.size() == 0)
+ lookup.all().reverse.foreach(_.stop())
+ }
+}
+
+private object ShuffleManagerRouter {
+ private class Cache {
+ private val cache: java.util.Map[Int, ShuffleManager] =
+ new java.util.concurrent.ConcurrentHashMap()
+
+ def store(shuffleId: Int, manager: ShuffleManager): ShuffleManager = {
+ cache.compute(
+ shuffleId,
+ (id, m) => {
+ assert(m == null, s"Shuffle manager was already cached for shuffle
id: $id")
+ manager
+ })
+ }
+
+ def get(shuffleId: Int): ShuffleManager = {
+ val manager = cache.get(shuffleId)
+ assert(manager != null, s"Shuffle manager not registered for shuffle id:
$shuffleId")
+ manager
+ }
+
+ def remove(shuffleId: Int): ShuffleManager = {
+ val manager = cache.remove(shuffleId)
+ assert(manager != null, s"Shuffle manager not registered for shuffle id:
$shuffleId")
+ manager
+ }
+
+ def size(): Int = {
+ cache.size()
+ }
+
+ def clear(): Unit = {
+ cache.clear()
+ }
+ }
+
+ private class BlockResolver(cache: Cache) extends ShuffleBlockResolver {
+ override def getBlockData(blockId: BlockId, dirs: Option[Array[String]]):
ManagedBuffer = {
+ val shuffleId = blockId match {
+ case id: ShuffleBlockId =>
+ id.shuffleId
+ case batchId: ShuffleBlockBatchId =>
+ batchId.shuffleId
+ case _ =>
+ throw new IllegalArgumentException(
+ "GlutenShuffleManager: Unsupported shuffle block id: " + blockId)
+ }
+ cache.get(shuffleId).shuffleBlockResolver.getBlockData(blockId, dirs)
+ }
+
+ override def getMergedBlockData(
+ blockId: ShuffleMergedBlockId,
+ dirs: Option[Array[String]]): Seq[ManagedBuffer] = {
+ val shuffleId = blockId.shuffleId
+ cache.get(shuffleId).shuffleBlockResolver.getMergedBlockData(blockId,
dirs)
+ }
+
+ override def getMergedBlockMeta(
+ blockId: ShuffleMergedBlockId,
+ dirs: Option[Array[String]]): MergedBlockMeta = {
+ val shuffleId = blockId.shuffleId
+ cache.get(shuffleId).shuffleBlockResolver.getMergedBlockMeta(blockId,
dirs)
+ }
+
+ override def stop(): Unit = {
+ throw new UnsupportedOperationException(
+ s"BlockResolver ${getClass.getSimpleName} doesn't need to be
explicitly stopped")
+ }
+ }
+}
diff --git
a/gluten-core/src/test/scala/org/apache/spark/shuffle/GlutenShuffleManagerSuite.scala
b/gluten-core/src/test/scala/org/apache/spark/shuffle/GlutenShuffleManagerSuite.scala
new file mode 100644
index 0000000000..640fc0ab07
--- /dev/null
+++
b/gluten-core/src/test/scala/org/apache/spark/shuffle/GlutenShuffleManagerSuite.scala
@@ -0,0 +1,315 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.shuffle
+
+import org.apache.spark.{Partitioner, ShuffleDependency, SparkConf,
TaskContext}
+import org.apache.spark.internal.config.SHUFFLE_MANAGER
+import org.apache.spark.rdd.EmptyRDD
+import org.apache.spark.shuffle.sort.SortShuffleManager
+import org.apache.spark.sql.test.SharedSparkSession
+
+import java.util.concurrent.atomic.AtomicInteger
+
+import scala.collection.mutable
+
+class GlutenShuffleManagerSuite extends SharedSparkSession {
+ import GlutenShuffleManagerSuite._
+ override protected def sparkConf: SparkConf = {
+ super.sparkConf
+ .set(SHUFFLE_MANAGER.key, classOf[GlutenShuffleManager].getName)
+ }
+
+ override protected def afterEach(): Unit = {
+ val registry = ShuffleManagerRegistry.get()
+ registry.clear()
+ counter1.clear()
+ counter2.clear()
+ }
+
+ test("register one") {
+ val registry = ShuffleManagerRegistry.get()
+
+ registry.register(
+ new LookupKey {
+ override def accepts[K, V, C](dependency: ShuffleDependency[K, V, C]):
Boolean = true
+ },
+ classOf[ShuffleManager1].getName)
+
+ val gm = spark.sparkContext.env.shuffleManager
+ assert(counter1.count("stop") == 0)
+ gm.stop()
+ assert(counter1.count("stop") == 1)
+ gm.stop()
+ gm.stop()
+ assert(counter1.count("stop") == 3)
+ }
+
+ test("register two") {
+ val registry = ShuffleManagerRegistry.get()
+
+ registry.register(
+ new LookupKey {
+ override def accepts[K, V, C](dependency: ShuffleDependency[K, V, C]):
Boolean = true
+ },
+ classOf[ShuffleManager1].getName)
+ registry.register(
+ new LookupKey {
+ override def accepts[K, V, C](dependency: ShuffleDependency[K, V, C]):
Boolean = true
+ },
+ classOf[ShuffleManager2].getName)
+
+ val gm = spark.sparkContext.env.shuffleManager
+ assert(counter1.count("registerShuffle") == 0)
+ assert(counter2.count("registerShuffle") == 0)
+ // The statement calls #registerShuffle internally.
+ val dep =
+ new ShuffleDependency(new EmptyRDD[Product2[Any,
Any]](spark.sparkContext), DummyPartitioner)
+ gm.unregisterShuffle(dep.shuffleId)
+ assert(counter1.count("registerShuffle") == 0)
+ assert(counter2.count("registerShuffle") == 1)
+
+ assert(counter1.count("stop") == 0)
+ assert(counter2.count("stop") == 0)
+ gm.stop()
+ assert(counter1.count("stop") == 1)
+ assert(counter2.count("stop") == 1)
+ }
+
+ test("register two - disordered registration") {
+ val registry = ShuffleManagerRegistry.get()
+
+ registry.register(
+ new LookupKey {
+ override def accepts[K, V, C](dependency: ShuffleDependency[K, V, C]):
Boolean = true
+ },
+ classOf[ShuffleManager1].getName)
+
+ val gm = spark.sparkContext.env.shuffleManager
+ assert(counter1.count("registerShuffle") == 0)
+ assert(counter2.count("registerShuffle") == 0)
+ val dep1 =
+ new ShuffleDependency(new EmptyRDD[Product2[Any,
Any]](spark.sparkContext), DummyPartitioner)
+ gm.unregisterShuffle(dep1.shuffleId)
+ assert(counter1.count("registerShuffle") == 1)
+ assert(counter2.count("registerShuffle") == 0)
+
+ registry.register(
+ new LookupKey {
+ override def accepts[K, V, C](dependency: ShuffleDependency[K, V, C]):
Boolean = true
+ },
+ classOf[ShuffleManager2].getName)
+
+ // The statement calls #registerShuffle internally.
+ val dep2 =
+ new ShuffleDependency(new EmptyRDD[Product2[Any,
Any]](spark.sparkContext), DummyPartitioner)
+ gm.unregisterShuffle(dep2.shuffleId)
+ assert(counter1.count("registerShuffle") == 1)
+ assert(counter2.count("registerShuffle") == 1)
+
+ assert(counter1.count("stop") == 0)
+ assert(counter2.count("stop") == 0)
+ gm.stop()
+ assert(counter1.count("stop") == 1)
+ assert(counter2.count("stop") == 1)
+ }
+
+ test("register two - with empty key") {
+ val registry = ShuffleManagerRegistry.get()
+
+ registry.register(
+ new LookupKey {
+ override def accepts[K, V, C](dependency: ShuffleDependency[K, V, C]):
Boolean = true
+ },
+ classOf[ShuffleManager1].getName)
+ registry.register(
+ new LookupKey {
+ override def accepts[K, V, C](dependency: ShuffleDependency[K, V, C]):
Boolean = false
+ },
+ classOf[ShuffleManager2].getName)
+
+ val gm = spark.sparkContext.env.shuffleManager
+ assert(counter1.count("registerShuffle") == 0)
+ assert(counter2.count("registerShuffle") == 0)
+ // The statement calls #registerShuffle internally.
+ val dep =
+ new ShuffleDependency(new EmptyRDD[Product2[Any,
Any]](spark.sparkContext), DummyPartitioner)
+ gm.unregisterShuffle(dep.shuffleId)
+ assert(counter1.count("registerShuffle") == 1)
+ assert(counter2.count("registerShuffle") == 0)
+ }
+
+ test("register recursively") {
+ val registry = ShuffleManagerRegistry.get()
+
+ assertThrows[IllegalArgumentException](
+ registry.register(
+ new LookupKey {
+ override def accepts[K, V, C](dependency: ShuffleDependency[K, V,
C]): Boolean = true
+ },
+ classOf[GlutenShuffleManager].getName))
+ }
+
+ test("register duplicated") {
+ val registry = ShuffleManagerRegistry.get()
+
+ registry.register(
+ new LookupKey {
+ override def accepts[K, V, C](dependency: ShuffleDependency[K, V, C]):
Boolean = true
+ },
+ classOf[ShuffleManager1].getName)
+ assertThrows[IllegalArgumentException](
+ registry.register(
+ new LookupKey {
+ override def accepts[K, V, C](dependency: ShuffleDependency[K, V,
C]): Boolean = true
+ },
+ classOf[ShuffleManager1].getName))
+ }
+}
+
+object GlutenShuffleManagerSuite {
+ private val counter1 = new InvocationCounter
+ private val counter2 = new InvocationCounter
+
+ class ShuffleManager1(conf: SparkConf) extends ShuffleManager {
+ private val delegate = new SortShuffleManager(conf)
+ private val counter = counter1
+ override def registerShuffle[K, V, C](
+ shuffleId: Int,
+ dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
+ counter.increment("registerShuffle")
+ delegate.registerShuffle(shuffleId, dependency)
+ }
+
+ override def getWriter[K, V](
+ handle: ShuffleHandle,
+ mapId: Long,
+ context: TaskContext,
+ metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
+ counter.increment("getWriter")
+ delegate.getWriter(handle, mapId, context, metrics)
+ }
+
+ override def getReader[K, C](
+ handle: ShuffleHandle,
+ startMapIndex: Int,
+ endMapIndex: Int,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext,
+ metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
+ counter.increment("getReader")
+ delegate.getReader(
+ handle,
+ startMapIndex,
+ endMapIndex,
+ startPartition,
+ endPartition,
+ context,
+ metrics)
+ }
+
+ override def unregisterShuffle(shuffleId: Int): Boolean = {
+ counter.increment("unregisterShuffle")
+ delegate.unregisterShuffle(shuffleId)
+ }
+
+ override def shuffleBlockResolver: ShuffleBlockResolver = {
+ counter.increment("shuffleBlockResolver")
+ delegate.shuffleBlockResolver
+ }
+
+ override def stop(): Unit = {
+ counter.increment("stop")
+ delegate.stop()
+ }
+ }
+
+ class ShuffleManager2(conf: SparkConf, isDriver: Boolean) extends
ShuffleManager {
+ private val delegate = new SortShuffleManager(conf)
+ private val counter = counter2
+ override def registerShuffle[K, V, C](
+ shuffleId: Int,
+ dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
+ counter.increment("registerShuffle")
+ delegate.registerShuffle(shuffleId, dependency)
+ }
+
+ override def getWriter[K, V](
+ handle: ShuffleHandle,
+ mapId: Long,
+ context: TaskContext,
+ metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
+ counter.increment("getWriter")
+ delegate.getWriter(handle, mapId, context, metrics)
+ }
+
+ override def getReader[K, C](
+ handle: ShuffleHandle,
+ startMapIndex: Int,
+ endMapIndex: Int,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext,
+ metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
+ counter.increment("getReader")
+ delegate.getReader(
+ handle,
+ startMapIndex,
+ endMapIndex,
+ startPartition,
+ endPartition,
+ context,
+ metrics)
+ }
+
+ override def unregisterShuffle(shuffleId: Int): Boolean = {
+ counter.increment("unregisterShuffle")
+ delegate.unregisterShuffle(shuffleId)
+ }
+
+ override def shuffleBlockResolver: ShuffleBlockResolver = {
+ counter.increment("shuffleBlockResolver")
+ delegate.shuffleBlockResolver
+ }
+
+ override def stop(): Unit = {
+ counter.increment("stop")
+ delegate.stop()
+ }
+ }
+
+ private class InvocationCounter {
+ private val counter: mutable.Map[String, AtomicInteger] = mutable.Map()
+
+ def increment(name: String): Unit = synchronized {
+ counter.getOrElseUpdate(name, new AtomicInteger()).incrementAndGet()
+ }
+
+ def count(name: String): Int = {
+ counter.getOrElse(name, new AtomicInteger()).get()
+ }
+
+ def clear(): Unit = {
+ counter.clear()
+ }
+ }
+
+ private object DummyPartitioner extends Partitioner {
+ override def numPartitions: Int = 0
+ override def getPartition(key: Any): Int = 0
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]