This is an automated email from the ASF dual-hosted git repository.

roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 4d2db5be [Refactor] Optimize creating shuffle handlers (#259)
4d2db5be is described below

commit 4d2db5befede8b94199c73622cc8cbc27ba36f8a
Author: Junfan Zhang <[email protected]>
AuthorDate: Wed Oct 12 10:10:55 2022 +0800

    [Refactor] Optimize creating shuffle handlers (#259)
    
    ### What changes were proposed in this pull request?
    [Refactor] Optimize creating shuffle handlers
    
    ### Why are the changes needed?
    When creating shuffle handler, the code is too duplicate and complex.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    UTs
---
 .../hadoop/mapred/RssMapOutputCollector.java       |   4 +-
 .../spark/shuffle/writer/RssShuffleWriter.java     |   4 +-
 .../spark/shuffle/writer/RssShuffleWriter.java     |   4 +-
 .../storage/factory/ShuffleHandlerFactory.java     | 218 ++++++++-------------
 .../apache/uniffle/storage/util/StorageType.java   |  35 +++-
 .../uniffle/storage/util/StorageTypeTest.java      |  68 +++++++
 6 files changed, 184 insertions(+), 149 deletions(-)

diff --git 
a/client-mr/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java 
b/client-mr/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java
index 64d08e8e..308a560c 100644
--- 
a/client-mr/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java
+++ 
b/client-mr/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java
@@ -196,8 +196,6 @@ public class RssMapOutputCollector<K extends Object, V 
extends Object>
   }
 
   private boolean isMemoryShuffleEnabled(String storageType) {
-    return StorageType.MEMORY_LOCALFILE.name().equals(storageType)
-        || StorageType.MEMORY_HDFS.name().equals(storageType)
-        || StorageType.MEMORY_LOCALFILE_HDFS.name().equals(storageType);
+    return StorageType.withMemory(StorageType.valueOf(storageType));
   }
 }
diff --git 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 982b4eea..8313ebb2 100644
--- 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -119,9 +119,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   }
 
   private boolean isMemoryShuffleEnabled(String storageType) {
-    return StorageType.MEMORY_LOCALFILE.name().equals(storageType)
-        || StorageType.MEMORY_HDFS.name().equals(storageType)
-        || StorageType.MEMORY_LOCALFILE_HDFS.name().equals(storageType);
+    return StorageType.withMemory(StorageType.valueOf(storageType));
   }
 
   /**
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index dbebc338..0a6cc324 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -122,9 +122,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   }
 
   private boolean isMemoryShuffleEnabled(String storageType) {
-    return StorageType.MEMORY_LOCALFILE.name().equals(storageType)
-        || StorageType.MEMORY_HDFS.name().equals(storageType)
-        || StorageType.MEMORY_LOCALFILE_HDFS.name().equals(storageType);
+    return StorageType.withMemory(StorageType.valueOf(storageType));
   }
 
   @Override
diff --git 
a/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
 
b/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
index c885cca9..9dcb280c 100644
--- 
a/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
+++ 
b/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
@@ -17,13 +17,16 @@
 
 package org.apache.uniffle.storage.factory;
 
+import java.util.ArrayList;
 import java.util.List;
+import java.util.concurrent.Callable;
 import java.util.stream.Collectors;
 
 import org.apache.uniffle.client.api.ShuffleServerClient;
 import org.apache.uniffle.client.factory.ShuffleServerClientFactory;
 import org.apache.uniffle.client.util.ClientType;
 import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.storage.handler.api.ClientReadHandler;
 import org.apache.uniffle.storage.handler.api.ShuffleDeleteHandler;
 import org.apache.uniffle.storage.handler.impl.ComposedClientReadHandler;
@@ -51,142 +54,89 @@ public class ShuffleHandlerFactory {
   }
 
   public ClientReadHandler 
createShuffleReadHandler(CreateShuffleReadHandlerRequest request) {
-    if (StorageType.HDFS.name().equals(request.getStorageType())) {
-      return new HdfsClientReadHandler(
-          request.getAppId(),
-          request.getShuffleId(),
-          request.getPartitionId(),
-          request.getIndexReadLimit(),
-          request.getPartitionNumPerRange(),
-          request.getPartitionNum(),
-          request.getReadBufferSize(),
-          request.getExpectBlockIds(),
-          request.getProcessBlockIds(),
-          request.getStorageBasePath(),
-          request.getHadoopConf());
-    } else if (StorageType.LOCALFILE.name().equals(request.getStorageType())) {
-      List<ShuffleServerInfo> shuffleServerInfoList = 
request.getShuffleServerInfoList();
-      List<ShuffleServerClient> shuffleServerClients = 
shuffleServerInfoList.stream().map(
-          ssi -> 
ShuffleServerClientFactory.getInstance().getShuffleServerClient(ClientType.GRPC.name(),
 ssi)).collect(
-          Collectors.toList());
-      return new LocalFileQuorumClientReadHandler(request.getAppId(), 
request.getShuffleId(), request.getPartitionId(),
-          request.getIndexReadLimit(), request.getPartitionNumPerRange(), 
request.getPartitionNum(),
-          request.getReadBufferSize(), request.getExpectBlockIds(), 
request.getProcessBlockIds(),
-          shuffleServerClients);
-    } else if 
(StorageType.LOCALFILE_HDFS.name().equals(request.getStorageType())) {
-      List<ShuffleServerInfo> shuffleServerInfoList = 
request.getShuffleServerInfoList();
-      List<ShuffleServerClient> shuffleServerClients = 
shuffleServerInfoList.stream().map(
-          ssi -> 
ShuffleServerClientFactory.getInstance().getShuffleServerClient(
-              ClientType.GRPC.name(), ssi)).collect(
-          Collectors.toList());
-      return new ComposedClientReadHandler(() -> {
-        return new LocalFileQuorumClientReadHandler(
-            request.getAppId(),
-            request.getShuffleId(),
-            request.getPartitionId(),
-            request.getIndexReadLimit(),
-            request.getPartitionNumPerRange(),
-            request.getPartitionNum(),
-            request.getReadBufferSize(),
-            request.getExpectBlockIds(),
-            request.getProcessBlockIds(),
-            shuffleServerClients);
-      }, () -> {
-        return new HdfsClientReadHandler(
-            request.getAppId(),
-            request.getShuffleId(),
-            request.getPartitionId(),
-            request.getIndexReadLimit(),
-            request.getPartitionNumPerRange(),
-            request.getPartitionNum(),
-            request.getReadBufferSize(),
-            request.getExpectBlockIds(),
-            request.getProcessBlockIds(),
-            request.getStorageBasePath(),
-            request.getHadoopConf());
-      });
-    } else if 
(StorageType.MEMORY_LOCALFILE.name().equals(request.getStorageType())) {
-      List<ShuffleServerInfo> shuffleServerInfoList = 
request.getShuffleServerInfoList();
-      List<ShuffleServerClient> shuffleServerClients = 
shuffleServerInfoList.stream().map(
-          ssi -> 
ShuffleServerClientFactory.getInstance().getShuffleServerClient(
-              ClientType.GRPC.name(), ssi)).collect(
-          Collectors.toList());
-      ClientReadHandler memoryClientReadHandler = new 
MemoryQuorumClientReadHandler(
-          request.getAppId(),
-          request.getShuffleId(),
-          request.getPartitionId(),
-          request.getReadBufferSize(),
-          shuffleServerClients);
-      ClientReadHandler localClientReadHandler = new 
LocalFileQuorumClientReadHandler(request.getAppId(),
-          request.getShuffleId(), request.getPartitionId(), 
request.getIndexReadLimit(),
-          request.getPartitionNumPerRange(), request.getPartitionNum(),
-          request.getReadBufferSize(), request.getExpectBlockIds(), 
request.getProcessBlockIds(),
-          shuffleServerClients);
-      return new ComposedClientReadHandler(memoryClientReadHandler, 
localClientReadHandler);
-    } else if 
(StorageType.MEMORY_HDFS.name().equals(request.getStorageType())) {
-      List<ShuffleServerInfo> shuffleServerInfoList = 
request.getShuffleServerInfoList();
-      List<ShuffleServerClient> shuffleServerClients = 
shuffleServerInfoList.stream().map(
-          ssi -> 
ShuffleServerClientFactory.getInstance().getShuffleServerClient(
-              ClientType.GRPC.name(), ssi)).collect(
-          Collectors.toList());
-      return new ComposedClientReadHandler(() -> {
-        return new MemoryQuorumClientReadHandler(
-          request.getAppId(),
-          request.getShuffleId(),
-          request.getPartitionId(),
-          request.getReadBufferSize(),
-          shuffleServerClients);
-      }, () -> {
-        return new HdfsClientReadHandler(
-            request.getAppId(),
-            request.getShuffleId(),
-            request.getPartitionId(),
-            request.getIndexReadLimit(),
-            request.getPartitionNumPerRange(),
-            request.getPartitionNum(),
-            request.getReadBufferSize(),
-            request.getExpectBlockIds(),
-            request.getProcessBlockIds(),
-            request.getStorageBasePath(),
-            request.getHadoopConf());
-      });
-    } else if 
(StorageType.MEMORY_LOCALFILE_HDFS.name().equals(request.getStorageType())) {
-      List<ShuffleServerInfo> shuffleServerInfoList = 
request.getShuffleServerInfoList();
-      List<ShuffleServerClient> shuffleServerClients = 
shuffleServerInfoList.stream().map(
-          ssi -> 
ShuffleServerClientFactory.getInstance().getShuffleServerClient(
-              ClientType.GRPC.name(), ssi)).collect(
-          Collectors.toList());
-      return new ComposedClientReadHandler(() -> {
-        return new MemoryQuorumClientReadHandler(
-            request.getAppId(),
-            request.getShuffleId(),
-            request.getPartitionId(),
-            request.getReadBufferSize(),
-            shuffleServerClients);
-      }, () -> {
-        return new LocalFileQuorumClientReadHandler(request.getAppId(),
-            request.getShuffleId(), request.getPartitionId(), 
request.getIndexReadLimit(),
-            request.getPartitionNumPerRange(), request.getPartitionNum(),
-            request.getReadBufferSize(), request.getExpectBlockIds(), 
request.getProcessBlockIds(),
-            shuffleServerClients);
-      }, () -> {
-        return  new HdfsClientReadHandler(
-            request.getAppId(),
-            request.getShuffleId(),
-            request.getPartitionId(),
-            request.getIndexReadLimit(),
-            request.getPartitionNumPerRange(),
-            request.getPartitionNum(),
-            request.getReadBufferSize(),
-            request.getExpectBlockIds(),
-            request.getProcessBlockIds(),
-            request.getStorageBasePath(),
-            request.getHadoopConf());
-      });
-    } else {
+    String storageType = request.getStorageType();
+    StorageType type = StorageType.valueOf(storageType);
+
+    if (StorageType.MEMORY == type) {
       throw new UnsupportedOperationException(
-          "Doesn't support storage type for client read handler:" + 
request.getStorageType());
+          "Doesn't support storage type for client read handler:" + 
storageType);
+    }
+
+    if (StorageType.HDFS == type) {
+      return getHdfsClientReadHandler(request);
+    }
+    if (StorageType.LOCALFILE == type) {
+      return getLocalfileClientReaderHandler(request);
+    }
+
+    List<ClientReadHandler> handlers = new ArrayList<>();
+    if (StorageType.withMemory(type)) {
+      handlers.add(
+          getMemoryClientReadHandler(request)
+      );
+    }
+    if (StorageType.withLocalfile(type)) {
+      handlers.add(
+          getLocalfileClientReaderHandler(request)
+      );
     }
+    if (StorageType.withHDFS(type)) {
+      handlers.add(
+          getHdfsClientReadHandler(request)
+      );
+    }
+    if (handlers.isEmpty()) {
+      throw new RssException("This should not happen due to the unknown 
storage type: " + storageType);
+    }
+
+    Callable<ClientReadHandler>[] callables =
+        handlers
+            .stream()
+            .map(x -> (Callable<ClientReadHandler>) () -> x)
+            .collect(Collectors.toList())
+            .toArray(new Callable[handlers.size()]);
+    return new ComposedClientReadHandler(callables);
+  }
+
+  private ClientReadHandler 
getMemoryClientReadHandler(CreateShuffleReadHandlerRequest request) {
+    List<ShuffleServerInfo> shuffleServerInfoList = 
request.getShuffleServerInfoList();
+    List<ShuffleServerClient> shuffleServerClients = 
shuffleServerInfoList.stream().map(
+        ssi -> ShuffleServerClientFactory.getInstance().getShuffleServerClient(
+            ClientType.GRPC.name(), ssi)).collect(
+        Collectors.toList());
+    ClientReadHandler memoryClientReadHandler = new 
MemoryQuorumClientReadHandler(
+        request.getAppId(),
+        request.getShuffleId(),
+        request.getPartitionId(),
+        request.getReadBufferSize(),
+        shuffleServerClients);
+    return memoryClientReadHandler;
+  }
+
+  private ClientReadHandler 
getLocalfileClientReaderHandler(CreateShuffleReadHandlerRequest request) {
+    List<ShuffleServerInfo> shuffleServerInfoList = 
request.getShuffleServerInfoList();
+    List<ShuffleServerClient> shuffleServerClients = 
shuffleServerInfoList.stream().map(
+        ssi -> 
ShuffleServerClientFactory.getInstance().getShuffleServerClient(ClientType.GRPC.name(),
 ssi)).collect(
+        Collectors.toList());
+    return new LocalFileQuorumClientReadHandler(request.getAppId(), 
request.getShuffleId(), request.getPartitionId(),
+        request.getIndexReadLimit(), request.getPartitionNumPerRange(), 
request.getPartitionNum(),
+        request.getReadBufferSize(), request.getExpectBlockIds(), 
request.getProcessBlockIds(),
+        shuffleServerClients);
+  }
+
+  private ClientReadHandler 
getHdfsClientReadHandler(CreateShuffleReadHandlerRequest request) {
+    return new HdfsClientReadHandler(
+        request.getAppId(),
+        request.getShuffleId(),
+        request.getPartitionId(),
+        request.getIndexReadLimit(),
+        request.getPartitionNumPerRange(),
+        request.getPartitionNum(),
+        request.getReadBufferSize(),
+        request.getExpectBlockIds(),
+        request.getProcessBlockIds(),
+        request.getStorageBasePath(),
+        request.getHadoopConf());
   }
 
   public ShuffleDeleteHandler 
createShuffleDeleteHandler(CreateShuffleDeleteHandlerRequest request) {
diff --git 
a/storage/src/main/java/org/apache/uniffle/storage/util/StorageType.java 
b/storage/src/main/java/org/apache/uniffle/storage/util/StorageType.java
index 80c155e8..37493548 100644
--- a/storage/src/main/java/org/apache/uniffle/storage/util/StorageType.java
+++ b/storage/src/main/java/org/apache/uniffle/storage/util/StorageType.java
@@ -18,10 +18,33 @@
 package org.apache.uniffle.storage.util;
 
 public enum StorageType {
-  HDFS,
-  LOCALFILE,
-  LOCALFILE_HDFS,
-  MEMORY_LOCALFILE,
-  MEMORY_HDFS,
-  MEMORY_LOCALFILE_HDFS
+  MEMORY(1),
+  LOCALFILE(2),
+  HDFS(4),
+  LOCALFILE_HDFS(6),
+  MEMORY_LOCALFILE(3),
+  MEMORY_HDFS(5),
+  MEMORY_LOCALFILE_HDFS(7);
+
+  private int val;
+
+  StorageType(int val) {
+    this.val = val;
+  }
+
+  private int getVal() {
+    return val;
+  }
+
+  public static boolean withMemory(StorageType storageType) {
+    return (storageType.getVal() & MEMORY.getVal()) != 0;
+  }
+
+  public static boolean withLocalfile(StorageType storageType) {
+    return (storageType.getVal() & LOCALFILE.getVal()) != 0;
+  }
+
+  public static boolean withHDFS(StorageType storageType) {
+    return (storageType.getVal() & HDFS.getVal()) != 0;
+  }
 }
diff --git 
a/storage/src/test/java/org/apache/uniffle/storage/util/StorageTypeTest.java 
b/storage/src/test/java/org/apache/uniffle/storage/util/StorageTypeTest.java
new file mode 100644
index 00000000..afaf4ab0
--- /dev/null
+++ b/storage/src/test/java/org/apache/uniffle/storage/util/StorageTypeTest.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.storage.util;
+
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class StorageTypeTest {
+
+  @Test
+  public void commonTest() {
+    String type = "HDFS";
+    assertEquals(StorageType.valueOf(type), StorageType.HDFS);
+
+    StorageType storageType = StorageType.MEMORY;
+    assertTrue(StorageType.withMemory(storageType));
+    assertFalse(StorageType.withLocalfile(storageType));
+    assertFalse(StorageType.withHDFS(storageType));
+
+    storageType = StorageType.LOCALFILE;
+    assertFalse(StorageType.withMemory(storageType));
+    assertTrue(StorageType.withLocalfile(storageType));
+    assertFalse(StorageType.withHDFS(storageType));
+
+    storageType = StorageType.HDFS;
+    assertFalse(StorageType.withMemory(storageType));
+    assertFalse(StorageType.withLocalfile(storageType));
+    assertTrue(StorageType.withHDFS(storageType));
+
+    storageType = StorageType.MEMORY_HDFS;
+    assertTrue(StorageType.withMemory(storageType));
+    assertFalse(StorageType.withLocalfile(storageType));
+    assertTrue(StorageType.withHDFS(storageType));
+
+    storageType = StorageType.MEMORY_LOCALFILE;
+    assertTrue(StorageType.withMemory(storageType));
+    assertTrue(StorageType.withLocalfile(storageType));
+    assertFalse(StorageType.withHDFS(storageType));
+
+    storageType = StorageType.MEMORY_LOCALFILE_HDFS;
+    assertTrue(StorageType.withMemory(storageType));
+    assertTrue(StorageType.withLocalfile(storageType));
+    assertTrue(StorageType.withHDFS(storageType));
+
+    storageType = StorageType.LOCALFILE_HDFS;
+    assertFalse(StorageType.withMemory(storageType));
+    assertTrue(StorageType.withLocalfile(storageType));
+    assertTrue(StorageType.withHDFS(storageType));
+  }
+}

Reply via email to