This is an automated email from the ASF dual-hosted git repository.

hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 9059ff685f [GLUTEN-8060][CORE] GlutenShuffleManager as a registry of 
shuffle managers (#8084)
9059ff685f is described below

commit 9059ff685f37367af8d84c6cc26bae05ec2a225f
Author: Hongze Zhang <[email protected]>
AuthorDate: Tue Dec 3 09:36:15 2024 +0800

    [GLUTEN-8060][CORE] GlutenShuffleManager as a registry of shuffle managers 
(#8084)
---
 .../spark/shuffle/GlutenShuffleManager.scala       |  71 +++++
 .../scala/org/apache/spark/shuffle/LookupKey.scala |  27 ++
 .../spark/shuffle/ShuffleManagerLookup.scala       |  38 +++
 .../spark/shuffle/ShuffleManagerRegistry.scala     |  94 ++++++
 .../spark/shuffle/ShuffleManagerRouter.scala       | 137 +++++++++
 .../spark/shuffle/GlutenShuffleManagerSuite.scala  | 315 +++++++++++++++++++++
 6 files changed, 682 insertions(+)

diff --git 
a/gluten-core/src/main/scala/org/apache/spark/shuffle/GlutenShuffleManager.scala
 
b/gluten-core/src/main/scala/org/apache/spark/shuffle/GlutenShuffleManager.scala
new file mode 100644
index 0000000000..d38781675b
--- /dev/null
+++ 
b/gluten-core/src/main/scala/org/apache/spark/shuffle/GlutenShuffleManager.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.shuffle
+
+import org.apache.spark.{ShuffleDependency, SparkConf, TaskContext}
+import org.apache.spark.annotation.Experimental
+
+/**
+ * Shuffle manager that routes shuffle API calls to different shuffle managers 
registered by
+ * different backends.
+ *
+ * A SPIP may cause refactoring of this class in the future:
+ * https://issues.apache.org/jira/browse/SPARK-45792
+ */
+@Experimental
+class GlutenShuffleManager(conf: SparkConf, isDriver: Boolean) extends 
ShuffleManager {
+  private val routerBuilder = 
ShuffleManagerRegistry.get().newRouterBuilder(conf, isDriver)
+
+  override def registerShuffle[K, V, C](
+      shuffleId: Int,
+      dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
+    routerBuilder.getOrBuild().registerShuffle(shuffleId, dependency)
+  }
+
+  override def getWriter[K, V](
+      handle: ShuffleHandle,
+      mapId: Long,
+      context: TaskContext,
+      metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
+    routerBuilder.getOrBuild().getWriter(handle, mapId, context, metrics)
+  }
+
+  override def getReader[K, C](
+      handle: ShuffleHandle,
+      startMapIndex: Int,
+      endMapIndex: Int,
+      startPartition: Int,
+      endPartition: Int,
+      context: TaskContext,
+      metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
+    routerBuilder
+      .getOrBuild()
+      .getReader(handle, startMapIndex, endMapIndex, startPartition, 
endPartition, context, metrics)
+  }
+
+  override def unregisterShuffle(shuffleId: Int): Boolean = {
+    routerBuilder.getOrBuild().unregisterShuffle(shuffleId)
+  }
+
+  override def shuffleBlockResolver: ShuffleBlockResolver = {
+    routerBuilder.getOrBuild().shuffleBlockResolver
+  }
+
+  override def stop(): Unit = {
+    routerBuilder.getOrBuild().stop()
+  }
+}
diff --git 
a/gluten-core/src/main/scala/org/apache/spark/shuffle/LookupKey.scala 
b/gluten-core/src/main/scala/org/apache/spark/shuffle/LookupKey.scala
new file mode 100644
index 0000000000..502dd92efe
--- /dev/null
+++ b/gluten-core/src/main/scala/org/apache/spark/shuffle/LookupKey.scala
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.shuffle
+
+import org.apache.spark.ShuffleDependency
+
+/**
+ * Required during shuffle manager registration to determine whether the 
shuffle manager should be
+ * used for the particular shuffle dependency.
+ */
+trait LookupKey {
+  def accepts[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean
+}
diff --git 
a/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerLookup.scala
 
b/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerLookup.scala
new file mode 100644
index 0000000000..8b060c9818
--- /dev/null
+++ 
b/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerLookup.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.shuffle
+
+import org.apache.spark.ShuffleDependency
+
+private class ShuffleManagerLookup(all: Seq[(LookupKey, ShuffleManager)]) {
+  private val allReversed = all.reverse
+
+  def findShuffleManager[K, V, C](dependency: ShuffleDependency[K, V, C]): 
ShuffleManager = {
+    this.synchronized {
+      // The latest shuffle manager registered will be looked up earlier.
+      allReversed.find(_._1.accepts(dependency)).map(_._2).getOrElse {
+        throw new IllegalStateException(s"No ShuffleManager found for 
$dependency")
+      }
+    }
+  }
+
+  def all(): Seq[ShuffleManager] = {
+    this.synchronized {
+      all.map(_._2)
+    }
+  }
+}
diff --git 
a/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerRegistry.scala
 
b/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerRegistry.scala
new file mode 100644
index 0000000000..4310054caa
--- /dev/null
+++ 
b/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerRegistry.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.shuffle
+
+import org.apache.spark.SparkConf
+import org.apache.spark.util.Utils
+
+import scala.collection.mutable
+
+class ShuffleManagerRegistry private[ShuffleManagerRegistry] {
+  import ShuffleManagerRegistry._
+  private val all: mutable.Buffer[(LookupKey, String)] = mutable.Buffer()
+  private val routerBuilders: mutable.Buffer[RouterBuilder] = mutable.Buffer()
+  private val classDeDup: mutable.Set[String] = mutable.Set()
+
+  def register(lookupKey: LookupKey, shuffleManagerClass: String): Unit = {
+    val clazz = Utils.classForName(shuffleManagerClass)
+    require(
+      !clazz.isAssignableFrom(classOf[GlutenShuffleManager]),
+      "It's not allowed to register GlutenShuffleManager recursively")
+    require(
+      classOf[ShuffleManager].isAssignableFrom(clazz),
+      s"Shuffle manager class to register is not an implementation of Spark 
ShuffleManager: " +
+        s"$shuffleManagerClass"
+    )
+    require(
+      !classDeDup.contains(shuffleManagerClass),
+      s"Shuffle manager class already registered: $shuffleManagerClass")
+    this.synchronized {
+      classDeDup += shuffleManagerClass
+      all += lookupKey -> shuffleManagerClass
+      // Invalidate all shuffle managers cached in each alive router builder 
instances.
+      // Then, once the router builder is accessed, a new router will be 
forced to create.
+      routerBuilders.foreach(_.invalidateCache())
+    }
+  }
+
+  // Visible for testing
+  private[shuffle] def clear(): Unit = {
+    this.synchronized {
+      classDeDup.clear()
+      all.clear()
+      routerBuilders.foreach(_.invalidateCache())
+    }
+  }
+
+  private[shuffle] def newRouterBuilder(conf: SparkConf, isDriver: Boolean): 
RouterBuilder =
+    this.synchronized {
+      val out = new RouterBuilder(this, conf, isDriver)
+      routerBuilders += out
+      out
+    }
+}
+
+object ShuffleManagerRegistry {
+  private val instance = new ShuffleManagerRegistry()
+
+  def get(): ShuffleManagerRegistry = instance
+
+  class RouterBuilder(registry: ShuffleManagerRegistry, conf: SparkConf, 
isDriver: Boolean) {
+    private var router: Option[ShuffleManagerRouter] = None
+
+    private[ShuffleManagerRegistry] def invalidateCache(): Unit = synchronized 
{
+      router = None
+    }
+
+    private[shuffle] def getOrBuild(): ShuffleManagerRouter = synchronized {
+      if (router.isEmpty) {
+        val instances = registry.all.map(key => key._1 -> instantiate(key._2, 
conf, isDriver))
+        router = Some(new ShuffleManagerRouter(new 
ShuffleManagerLookup(instances.toSeq)))
+      }
+      router.get
+    }
+
+    private def instantiate(clazz: String, conf: SparkConf, isDriver: 
Boolean): ShuffleManager = {
+      Utils
+        .instantiateSerializerOrShuffleManager[ShuffleManager](clazz, conf, 
isDriver)
+    }
+  }
+}
diff --git 
a/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerRouter.scala
 
b/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerRouter.scala
new file mode 100644
index 0000000000..80aa9d8047
--- /dev/null
+++ 
b/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerRouter.scala
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.shuffle
+import org.apache.spark.{ShuffleDependency, TaskContext}
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.MergedBlockMeta
+import org.apache.spark.storage.{BlockId, ShuffleBlockBatchId, ShuffleBlockId, 
ShuffleMergedBlockId}
+
+/** The internal shuffle manager instance used by GlutenShuffleManager. */
+private class ShuffleManagerRouter(lookup: ShuffleManagerLookup) extends 
ShuffleManager {
+  import ShuffleManagerRouter._
+  private val cache = new Cache()
+  private val resolver = new BlockResolver(cache)
+
+  override def registerShuffle[K, V, C](
+      shuffleId: Int,
+      dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
+    val manager = lookup.findShuffleManager(dependency)
+    cache.store(shuffleId, manager).registerShuffle(shuffleId, dependency)
+  }
+
+  override def getWriter[K, V](
+      handle: ShuffleHandle,
+      mapId: Long,
+      context: TaskContext,
+      metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
+    cache.get(handle.shuffleId).getWriter(handle, mapId, context, metrics)
+  }
+
+  override def getReader[K, C](
+      handle: ShuffleHandle,
+      startMapIndex: Int,
+      endMapIndex: Int,
+      startPartition: Int,
+      endPartition: Int,
+      context: TaskContext,
+      metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
+    cache
+      .get(handle.shuffleId)
+      .getReader(handle, startMapIndex, endMapIndex, startPartition, 
endPartition, context, metrics)
+  }
+
+  override def unregisterShuffle(shuffleId: Int): Boolean = {
+    cache.remove(shuffleId).unregisterShuffle(shuffleId)
+  }
+
+  override def shuffleBlockResolver: ShuffleBlockResolver = resolver
+
+  override def stop(): Unit = {
+    assert(cache.size() == 0)
+    lookup.all().reverse.foreach(_.stop())
+  }
+}
+
+private object ShuffleManagerRouter {
+  private class Cache {
+    private val cache: java.util.Map[Int, ShuffleManager] =
+      new java.util.concurrent.ConcurrentHashMap()
+
+    def store(shuffleId: Int, manager: ShuffleManager): ShuffleManager = {
+      cache.compute(
+        shuffleId,
+        (id, m) => {
+          assert(m == null, s"Shuffle manager was already cached for shuffle 
id: $id")
+          manager
+        })
+    }
+
+    def get(shuffleId: Int): ShuffleManager = {
+      val manager = cache.get(shuffleId)
+      assert(manager != null, s"Shuffle manager not registered for shuffle id: 
$shuffleId")
+      manager
+    }
+
+    def remove(shuffleId: Int): ShuffleManager = {
+      val manager = cache.remove(shuffleId)
+      assert(manager != null, s"Shuffle manager not registered for shuffle id: 
$shuffleId")
+      manager
+    }
+
+    def size(): Int = {
+      cache.size()
+    }
+
+    def clear(): Unit = {
+      cache.clear()
+    }
+  }
+
+  private class BlockResolver(cache: Cache) extends ShuffleBlockResolver {
+    override def getBlockData(blockId: BlockId, dirs: Option[Array[String]]): 
ManagedBuffer = {
+      val shuffleId = blockId match {
+        case id: ShuffleBlockId =>
+          id.shuffleId
+        case batchId: ShuffleBlockBatchId =>
+          batchId.shuffleId
+        case _ =>
+          throw new IllegalArgumentException(
+            "GlutenShuffleManager: Unsupported shuffle block id: " + blockId)
+      }
+      cache.get(shuffleId).shuffleBlockResolver.getBlockData(blockId, dirs)
+    }
+
+    override def getMergedBlockData(
+        blockId: ShuffleMergedBlockId,
+        dirs: Option[Array[String]]): Seq[ManagedBuffer] = {
+      val shuffleId = blockId.shuffleId
+      cache.get(shuffleId).shuffleBlockResolver.getMergedBlockData(blockId, 
dirs)
+    }
+
+    override def getMergedBlockMeta(
+        blockId: ShuffleMergedBlockId,
+        dirs: Option[Array[String]]): MergedBlockMeta = {
+      val shuffleId = blockId.shuffleId
+      cache.get(shuffleId).shuffleBlockResolver.getMergedBlockMeta(blockId, 
dirs)
+    }
+
+    override def stop(): Unit = {
+      throw new UnsupportedOperationException(
+        s"BlockResolver ${getClass.getSimpleName} doesn't need to be 
explicitly stopped")
+    }
+  }
+}
diff --git 
a/gluten-core/src/test/scala/org/apache/spark/shuffle/GlutenShuffleManagerSuite.scala
 
b/gluten-core/src/test/scala/org/apache/spark/shuffle/GlutenShuffleManagerSuite.scala
new file mode 100644
index 0000000000..640fc0ab07
--- /dev/null
+++ 
b/gluten-core/src/test/scala/org/apache/spark/shuffle/GlutenShuffleManagerSuite.scala
@@ -0,0 +1,315 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.shuffle
+
+import org.apache.spark.{Partitioner, ShuffleDependency, SparkConf, 
TaskContext}
+import org.apache.spark.internal.config.SHUFFLE_MANAGER
+import org.apache.spark.rdd.EmptyRDD
+import org.apache.spark.shuffle.sort.SortShuffleManager
+import org.apache.spark.sql.test.SharedSparkSession
+
+import java.util.concurrent.atomic.AtomicInteger
+
+import scala.collection.mutable
+
+class GlutenShuffleManagerSuite extends SharedSparkSession {
+  import GlutenShuffleManagerSuite._
+  override protected def sparkConf: SparkConf = {
+    super.sparkConf
+      .set(SHUFFLE_MANAGER.key, classOf[GlutenShuffleManager].getName)
+  }
+
+  override protected def afterEach(): Unit = {
+    val registry = ShuffleManagerRegistry.get()
+    registry.clear()
+    counter1.clear()
+    counter2.clear()
+  }
+
+  test("register one") {
+    val registry = ShuffleManagerRegistry.get()
+
+    registry.register(
+      new LookupKey {
+        override def accepts[K, V, C](dependency: ShuffleDependency[K, V, C]): 
Boolean = true
+      },
+      classOf[ShuffleManager1].getName)
+
+    val gm = spark.sparkContext.env.shuffleManager
+    assert(counter1.count("stop") == 0)
+    gm.stop()
+    assert(counter1.count("stop") == 1)
+    gm.stop()
+    gm.stop()
+    assert(counter1.count("stop") == 3)
+  }
+
+  test("register two") {
+    val registry = ShuffleManagerRegistry.get()
+
+    registry.register(
+      new LookupKey {
+        override def accepts[K, V, C](dependency: ShuffleDependency[K, V, C]): 
Boolean = true
+      },
+      classOf[ShuffleManager1].getName)
+    registry.register(
+      new LookupKey {
+        override def accepts[K, V, C](dependency: ShuffleDependency[K, V, C]): 
Boolean = true
+      },
+      classOf[ShuffleManager2].getName)
+
+    val gm = spark.sparkContext.env.shuffleManager
+    assert(counter1.count("registerShuffle") == 0)
+    assert(counter2.count("registerShuffle") == 0)
+    // The statement calls #registerShuffle internally.
+    val dep =
+      new ShuffleDependency(new EmptyRDD[Product2[Any, 
Any]](spark.sparkContext), DummyPartitioner)
+    gm.unregisterShuffle(dep.shuffleId)
+    assert(counter1.count("registerShuffle") == 0)
+    assert(counter2.count("registerShuffle") == 1)
+
+    assert(counter1.count("stop") == 0)
+    assert(counter2.count("stop") == 0)
+    gm.stop()
+    assert(counter1.count("stop") == 1)
+    assert(counter2.count("stop") == 1)
+  }
+
+  test("register two - disordered registration") {
+    val registry = ShuffleManagerRegistry.get()
+
+    registry.register(
+      new LookupKey {
+        override def accepts[K, V, C](dependency: ShuffleDependency[K, V, C]): 
Boolean = true
+      },
+      classOf[ShuffleManager1].getName)
+
+    val gm = spark.sparkContext.env.shuffleManager
+    assert(counter1.count("registerShuffle") == 0)
+    assert(counter2.count("registerShuffle") == 0)
+    val dep1 =
+      new ShuffleDependency(new EmptyRDD[Product2[Any, 
Any]](spark.sparkContext), DummyPartitioner)
+    gm.unregisterShuffle(dep1.shuffleId)
+    assert(counter1.count("registerShuffle") == 1)
+    assert(counter2.count("registerShuffle") == 0)
+
+    registry.register(
+      new LookupKey {
+        override def accepts[K, V, C](dependency: ShuffleDependency[K, V, C]): 
Boolean = true
+      },
+      classOf[ShuffleManager2].getName)
+
+    // The statement calls #registerShuffle internally.
+    val dep2 =
+      new ShuffleDependency(new EmptyRDD[Product2[Any, 
Any]](spark.sparkContext), DummyPartitioner)
+    gm.unregisterShuffle(dep2.shuffleId)
+    assert(counter1.count("registerShuffle") == 1)
+    assert(counter2.count("registerShuffle") == 1)
+
+    assert(counter1.count("stop") == 0)
+    assert(counter2.count("stop") == 0)
+    gm.stop()
+    assert(counter1.count("stop") == 1)
+    assert(counter2.count("stop") == 1)
+  }
+
+  test("register two - with empty key") {
+    val registry = ShuffleManagerRegistry.get()
+
+    registry.register(
+      new LookupKey {
+        override def accepts[K, V, C](dependency: ShuffleDependency[K, V, C]): 
Boolean = true
+      },
+      classOf[ShuffleManager1].getName)
+    registry.register(
+      new LookupKey {
+        override def accepts[K, V, C](dependency: ShuffleDependency[K, V, C]): 
Boolean = false
+      },
+      classOf[ShuffleManager2].getName)
+
+    val gm = spark.sparkContext.env.shuffleManager
+    assert(counter1.count("registerShuffle") == 0)
+    assert(counter2.count("registerShuffle") == 0)
+    // The statement calls #registerShuffle internally.
+    val dep =
+      new ShuffleDependency(new EmptyRDD[Product2[Any, 
Any]](spark.sparkContext), DummyPartitioner)
+    gm.unregisterShuffle(dep.shuffleId)
+    assert(counter1.count("registerShuffle") == 1)
+    assert(counter2.count("registerShuffle") == 0)
+  }
+
+  test("register recursively") {
+    val registry = ShuffleManagerRegistry.get()
+
+    assertThrows[IllegalArgumentException](
+      registry.register(
+        new LookupKey {
+          override def accepts[K, V, C](dependency: ShuffleDependency[K, V, 
C]): Boolean = true
+        },
+        classOf[GlutenShuffleManager].getName))
+  }
+
+  test("register duplicated") {
+    val registry = ShuffleManagerRegistry.get()
+
+    registry.register(
+      new LookupKey {
+        override def accepts[K, V, C](dependency: ShuffleDependency[K, V, C]): 
Boolean = true
+      },
+      classOf[ShuffleManager1].getName)
+    assertThrows[IllegalArgumentException](
+      registry.register(
+        new LookupKey {
+          override def accepts[K, V, C](dependency: ShuffleDependency[K, V, 
C]): Boolean = true
+        },
+        classOf[ShuffleManager1].getName))
+  }
+}
+
+object GlutenShuffleManagerSuite {
+  private val counter1 = new InvocationCounter
+  private val counter2 = new InvocationCounter
+
+  class ShuffleManager1(conf: SparkConf) extends ShuffleManager {
+    private val delegate = new SortShuffleManager(conf)
+    private val counter = counter1
+    override def registerShuffle[K, V, C](
+        shuffleId: Int,
+        dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
+      counter.increment("registerShuffle")
+      delegate.registerShuffle(shuffleId, dependency)
+    }
+
+    override def getWriter[K, V](
+        handle: ShuffleHandle,
+        mapId: Long,
+        context: TaskContext,
+        metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
+      counter.increment("getWriter")
+      delegate.getWriter(handle, mapId, context, metrics)
+    }
+
+    override def getReader[K, C](
+        handle: ShuffleHandle,
+        startMapIndex: Int,
+        endMapIndex: Int,
+        startPartition: Int,
+        endPartition: Int,
+        context: TaskContext,
+        metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
+      counter.increment("getReader")
+      delegate.getReader(
+        handle,
+        startMapIndex,
+        endMapIndex,
+        startPartition,
+        endPartition,
+        context,
+        metrics)
+    }
+
+    override def unregisterShuffle(shuffleId: Int): Boolean = {
+      counter.increment("unregisterShuffle")
+      delegate.unregisterShuffle(shuffleId)
+    }
+
+    override def shuffleBlockResolver: ShuffleBlockResolver = {
+      counter.increment("shuffleBlockResolver")
+      delegate.shuffleBlockResolver
+    }
+
+    override def stop(): Unit = {
+      counter.increment("stop")
+      delegate.stop()
+    }
+  }
+
+  class ShuffleManager2(conf: SparkConf, isDriver: Boolean) extends 
ShuffleManager {
+    private val delegate = new SortShuffleManager(conf)
+    private val counter = counter2
+    override def registerShuffle[K, V, C](
+        shuffleId: Int,
+        dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
+      counter.increment("registerShuffle")
+      delegate.registerShuffle(shuffleId, dependency)
+    }
+
+    override def getWriter[K, V](
+        handle: ShuffleHandle,
+        mapId: Long,
+        context: TaskContext,
+        metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
+      counter.increment("getWriter")
+      delegate.getWriter(handle, mapId, context, metrics)
+    }
+
+    override def getReader[K, C](
+        handle: ShuffleHandle,
+        startMapIndex: Int,
+        endMapIndex: Int,
+        startPartition: Int,
+        endPartition: Int,
+        context: TaskContext,
+        metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
+      counter.increment("getReader")
+      delegate.getReader(
+        handle,
+        startMapIndex,
+        endMapIndex,
+        startPartition,
+        endPartition,
+        context,
+        metrics)
+    }
+
+    override def unregisterShuffle(shuffleId: Int): Boolean = {
+      counter.increment("unregisterShuffle")
+      delegate.unregisterShuffle(shuffleId)
+    }
+
+    override def shuffleBlockResolver: ShuffleBlockResolver = {
+      counter.increment("shuffleBlockResolver")
+      delegate.shuffleBlockResolver
+    }
+
+    override def stop(): Unit = {
+      counter.increment("stop")
+      delegate.stop()
+    }
+  }
+
+  private class InvocationCounter {
+    private val counter: mutable.Map[String, AtomicInteger] = mutable.Map()
+
+    def increment(name: String): Unit = synchronized {
+      counter.getOrElseUpdate(name, new AtomicInteger()).incrementAndGet()
+    }
+
+    def count(name: String): Int = {
+      counter.getOrElse(name, new AtomicInteger()).get()
+    }
+
+    def clear(): Unit = {
+      counter.clear()
+    }
+  }
+
+  private object DummyPartitioner extends Partitioner {
+    override def numPartitions: Int = 0
+    override def getPartition(key: Any): Int = 0
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to