This is an automated email from the ASF dual-hosted git repository.

angerszhuuuu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new c4103922 [CELEBORN-265] Integration with Spark3.0 cast class exception 
of ShuffleHandler (#1197)
c4103922 is described below

commit c410392284f556360c0cb1492cb23d2101e0a396
Author: Angerszhuuuu <[email protected]>
AuthorDate: Thu Feb 2 11:52:51 2023 +0800

    [CELEBORN-265] Integration with Spark3.0 cast class exception of 
ShuffleHandler (#1197)
    
    * [CELEBORN-265] Integration with Spark3.0 cast class exception of 
ShuffleHandler
---
 .../spark/shuffle/celeborn/RssShuffleManager.java  | 25 ++++++---
 .../tests/spark/ShuffleFallbackSuite.scala         | 62 ++++++++++++++++++++++
 2 files changed, 80 insertions(+), 7 deletions(-)

diff --git 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/RssShuffleManager.java
 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/RssShuffleManager.java
index d2b1769e..d3d2b0bc 100644
--- 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/RssShuffleManager.java
+++ 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/RssShuffleManager.java
@@ -244,16 +244,27 @@ public class RssShuffleManager implements ShuffleManager {
       int endPartition,
       TaskContext context,
       ShuffleReadMetricsReporter metrics) {
-    @SuppressWarnings("unchecked")
-    RssShuffleHandle<K, ?, C> h = (RssShuffleHandle<K, ?, C>) handle;
-    return new RssShuffleReader<>(
-        h,
-        startPartition,
-        endPartition,
+    if (handle instanceof RssShuffleHandle) {
+      @SuppressWarnings("unchecked")
+      RssShuffleHandle<K, ?, C> h = (RssShuffleHandle<K, ?, C>) handle;
+      return new RssShuffleReader<>(
+          h,
+          startPartition,
+          endPartition,
+          startMapIndex,
+          endMapIndex,
+          context,
+          celebornConf,
+          metrics);
+    }
+    return SparkUtils.getReader(
+        sortShuffleManager(),
+        handle,
         startMapIndex,
         endMapIndex,
+        startPartition,
+        endPartition,
         context,
-        celebornConf,
         metrics);
   }
 }
diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/ShuffleFallbackSuite.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/ShuffleFallbackSuite.scala
new file mode 100644
index 00000000..6aa529df
--- /dev/null
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/ShuffleFallbackSuite.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.tests.spark
+
+import scala.util.Random
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.SparkSession
+import org.scalatest.BeforeAndAfterEach
+import org.scalatest.funsuite.AnyFunSuite
+
+import org.apache.celeborn.client.ShuffleClient
+import org.apache.celeborn.common.protocol.CompressionCodec
+
+class ShuffleFallbackSuite extends AnyFunSuite
+  with SparkTestBase
+  with BeforeAndAfterEach {
+
+  override def beforeEach(): Unit = {
+    ShuffleClient.reset()
+  }
+
+  override def afterEach(): Unit = {
+    System.gc()
+  }
+
+  private def enableRss(conf: SparkConf) = {
+    conf.set("spark.shuffle.manager", 
"org.apache.spark.shuffle.celeborn.RssShuffleManager")
+      .set("spark.rss.master.address", masterInfo._1.rpcEnv.address.toString)
+      .set("spark.rss.shuffle.split.threshold", "10MB")
+  }
+
+  test(s"celeborn spark integration test - fallback") {
+    val sparkConf = new SparkConf().setAppName("rss-demo")
+      .setMaster("local[4]")
+      .set("spark.celeborn.shuffle.forceFallback.enabled", "true")
+
+    enableRss(sparkConf)
+
+    val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
+    import sparkSession.implicits._
+    val df = sparkSession.sparkContext.parallelize(1 to 120000, 8)
+      .repartition(100)
+    df.collect()
+    sparkSession.stop()
+  }
+}

Reply via email to