This is an automated email from the ASF dual-hosted git repository.
roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new 4d2db5be [Refactor] Optimize creating shuffle handlers (#259)
4d2db5be is described below
commit 4d2db5befede8b94199c73622cc8cbc27ba36f8a
Author: Junfan Zhang <[email protected]>
AuthorDate: Wed Oct 12 10:10:55 2022 +0800
[Refactor] Optimize creating shuffle handlers (#259)
### What changes were proposed in this pull request?
[Refactor] Optimize creating shuffle handlers
### Why are the changes needed?
When creating shuffle handler, the code is too duplicate and complex.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
UTs
---
.../hadoop/mapred/RssMapOutputCollector.java | 4 +-
.../spark/shuffle/writer/RssShuffleWriter.java | 4 +-
.../spark/shuffle/writer/RssShuffleWriter.java | 4 +-
.../storage/factory/ShuffleHandlerFactory.java | 218 ++++++++-------------
.../apache/uniffle/storage/util/StorageType.java | 35 +++-
.../uniffle/storage/util/StorageTypeTest.java | 68 +++++++
6 files changed, 184 insertions(+), 149 deletions(-)
diff --git
a/client-mr/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java
b/client-mr/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java
index 64d08e8e..308a560c 100644
---
a/client-mr/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java
+++
b/client-mr/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java
@@ -196,8 +196,6 @@ public class RssMapOutputCollector<K extends Object, V
extends Object>
}
private boolean isMemoryShuffleEnabled(String storageType) {
- return StorageType.MEMORY_LOCALFILE.name().equals(storageType)
- || StorageType.MEMORY_HDFS.name().equals(storageType)
- || StorageType.MEMORY_LOCALFILE_HDFS.name().equals(storageType);
+ return StorageType.withMemory(StorageType.valueOf(storageType));
}
}
diff --git
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 982b4eea..8313ebb2 100644
---
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -119,9 +119,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
}
private boolean isMemoryShuffleEnabled(String storageType) {
- return StorageType.MEMORY_LOCALFILE.name().equals(storageType)
- || StorageType.MEMORY_HDFS.name().equals(storageType)
- || StorageType.MEMORY_LOCALFILE_HDFS.name().equals(storageType);
+ return StorageType.withMemory(StorageType.valueOf(storageType));
}
/**
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index dbebc338..0a6cc324 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -122,9 +122,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
}
private boolean isMemoryShuffleEnabled(String storageType) {
- return StorageType.MEMORY_LOCALFILE.name().equals(storageType)
- || StorageType.MEMORY_HDFS.name().equals(storageType)
- || StorageType.MEMORY_LOCALFILE_HDFS.name().equals(storageType);
+ return StorageType.withMemory(StorageType.valueOf(storageType));
}
@Override
diff --git
a/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
b/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
index c885cca9..9dcb280c 100644
---
a/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
+++
b/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
@@ -17,13 +17,16 @@
package org.apache.uniffle.storage.factory;
+import java.util.ArrayList;
import java.util.List;
+import java.util.concurrent.Callable;
import java.util.stream.Collectors;
import org.apache.uniffle.client.api.ShuffleServerClient;
import org.apache.uniffle.client.factory.ShuffleServerClientFactory;
import org.apache.uniffle.client.util.ClientType;
import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.storage.handler.api.ClientReadHandler;
import org.apache.uniffle.storage.handler.api.ShuffleDeleteHandler;
import org.apache.uniffle.storage.handler.impl.ComposedClientReadHandler;
@@ -51,142 +54,89 @@ public class ShuffleHandlerFactory {
}
public ClientReadHandler
createShuffleReadHandler(CreateShuffleReadHandlerRequest request) {
- if (StorageType.HDFS.name().equals(request.getStorageType())) {
- return new HdfsClientReadHandler(
- request.getAppId(),
- request.getShuffleId(),
- request.getPartitionId(),
- request.getIndexReadLimit(),
- request.getPartitionNumPerRange(),
- request.getPartitionNum(),
- request.getReadBufferSize(),
- request.getExpectBlockIds(),
- request.getProcessBlockIds(),
- request.getStorageBasePath(),
- request.getHadoopConf());
- } else if (StorageType.LOCALFILE.name().equals(request.getStorageType())) {
- List<ShuffleServerInfo> shuffleServerInfoList =
request.getShuffleServerInfoList();
- List<ShuffleServerClient> shuffleServerClients =
shuffleServerInfoList.stream().map(
- ssi ->
ShuffleServerClientFactory.getInstance().getShuffleServerClient(ClientType.GRPC.name(),
ssi)).collect(
- Collectors.toList());
- return new LocalFileQuorumClientReadHandler(request.getAppId(),
request.getShuffleId(), request.getPartitionId(),
- request.getIndexReadLimit(), request.getPartitionNumPerRange(),
request.getPartitionNum(),
- request.getReadBufferSize(), request.getExpectBlockIds(),
request.getProcessBlockIds(),
- shuffleServerClients);
- } else if
(StorageType.LOCALFILE_HDFS.name().equals(request.getStorageType())) {
- List<ShuffleServerInfo> shuffleServerInfoList =
request.getShuffleServerInfoList();
- List<ShuffleServerClient> shuffleServerClients =
shuffleServerInfoList.stream().map(
- ssi ->
ShuffleServerClientFactory.getInstance().getShuffleServerClient(
- ClientType.GRPC.name(), ssi)).collect(
- Collectors.toList());
- return new ComposedClientReadHandler(() -> {
- return new LocalFileQuorumClientReadHandler(
- request.getAppId(),
- request.getShuffleId(),
- request.getPartitionId(),
- request.getIndexReadLimit(),
- request.getPartitionNumPerRange(),
- request.getPartitionNum(),
- request.getReadBufferSize(),
- request.getExpectBlockIds(),
- request.getProcessBlockIds(),
- shuffleServerClients);
- }, () -> {
- return new HdfsClientReadHandler(
- request.getAppId(),
- request.getShuffleId(),
- request.getPartitionId(),
- request.getIndexReadLimit(),
- request.getPartitionNumPerRange(),
- request.getPartitionNum(),
- request.getReadBufferSize(),
- request.getExpectBlockIds(),
- request.getProcessBlockIds(),
- request.getStorageBasePath(),
- request.getHadoopConf());
- });
- } else if
(StorageType.MEMORY_LOCALFILE.name().equals(request.getStorageType())) {
- List<ShuffleServerInfo> shuffleServerInfoList =
request.getShuffleServerInfoList();
- List<ShuffleServerClient> shuffleServerClients =
shuffleServerInfoList.stream().map(
- ssi ->
ShuffleServerClientFactory.getInstance().getShuffleServerClient(
- ClientType.GRPC.name(), ssi)).collect(
- Collectors.toList());
- ClientReadHandler memoryClientReadHandler = new
MemoryQuorumClientReadHandler(
- request.getAppId(),
- request.getShuffleId(),
- request.getPartitionId(),
- request.getReadBufferSize(),
- shuffleServerClients);
- ClientReadHandler localClientReadHandler = new
LocalFileQuorumClientReadHandler(request.getAppId(),
- request.getShuffleId(), request.getPartitionId(),
request.getIndexReadLimit(),
- request.getPartitionNumPerRange(), request.getPartitionNum(),
- request.getReadBufferSize(), request.getExpectBlockIds(),
request.getProcessBlockIds(),
- shuffleServerClients);
- return new ComposedClientReadHandler(memoryClientReadHandler,
localClientReadHandler);
- } else if
(StorageType.MEMORY_HDFS.name().equals(request.getStorageType())) {
- List<ShuffleServerInfo> shuffleServerInfoList =
request.getShuffleServerInfoList();
- List<ShuffleServerClient> shuffleServerClients =
shuffleServerInfoList.stream().map(
- ssi ->
ShuffleServerClientFactory.getInstance().getShuffleServerClient(
- ClientType.GRPC.name(), ssi)).collect(
- Collectors.toList());
- return new ComposedClientReadHandler(() -> {
- return new MemoryQuorumClientReadHandler(
- request.getAppId(),
- request.getShuffleId(),
- request.getPartitionId(),
- request.getReadBufferSize(),
- shuffleServerClients);
- }, () -> {
- return new HdfsClientReadHandler(
- request.getAppId(),
- request.getShuffleId(),
- request.getPartitionId(),
- request.getIndexReadLimit(),
- request.getPartitionNumPerRange(),
- request.getPartitionNum(),
- request.getReadBufferSize(),
- request.getExpectBlockIds(),
- request.getProcessBlockIds(),
- request.getStorageBasePath(),
- request.getHadoopConf());
- });
- } else if
(StorageType.MEMORY_LOCALFILE_HDFS.name().equals(request.getStorageType())) {
- List<ShuffleServerInfo> shuffleServerInfoList =
request.getShuffleServerInfoList();
- List<ShuffleServerClient> shuffleServerClients =
shuffleServerInfoList.stream().map(
- ssi ->
ShuffleServerClientFactory.getInstance().getShuffleServerClient(
- ClientType.GRPC.name(), ssi)).collect(
- Collectors.toList());
- return new ComposedClientReadHandler(() -> {
- return new MemoryQuorumClientReadHandler(
- request.getAppId(),
- request.getShuffleId(),
- request.getPartitionId(),
- request.getReadBufferSize(),
- shuffleServerClients);
- }, () -> {
- return new LocalFileQuorumClientReadHandler(request.getAppId(),
- request.getShuffleId(), request.getPartitionId(),
request.getIndexReadLimit(),
- request.getPartitionNumPerRange(), request.getPartitionNum(),
- request.getReadBufferSize(), request.getExpectBlockIds(),
request.getProcessBlockIds(),
- shuffleServerClients);
- }, () -> {
- return new HdfsClientReadHandler(
- request.getAppId(),
- request.getShuffleId(),
- request.getPartitionId(),
- request.getIndexReadLimit(),
- request.getPartitionNumPerRange(),
- request.getPartitionNum(),
- request.getReadBufferSize(),
- request.getExpectBlockIds(),
- request.getProcessBlockIds(),
- request.getStorageBasePath(),
- request.getHadoopConf());
- });
- } else {
+ String storageType = request.getStorageType();
+ StorageType type = StorageType.valueOf(storageType);
+
+ if (StorageType.MEMORY == type) {
throw new UnsupportedOperationException(
- "Doesn't support storage type for client read handler:" +
request.getStorageType());
+ "Doesn't support storage type for client read handler:" +
storageType);
+ }
+
+ if (StorageType.HDFS == type) {
+ return getHdfsClientReadHandler(request);
+ }
+ if (StorageType.LOCALFILE == type) {
+ return getLocalfileClientReaderHandler(request);
+ }
+
+ List<ClientReadHandler> handlers = new ArrayList<>();
+ if (StorageType.withMemory(type)) {
+ handlers.add(
+ getMemoryClientReadHandler(request)
+ );
+ }
+ if (StorageType.withLocalfile(type)) {
+ handlers.add(
+ getLocalfileClientReaderHandler(request)
+ );
}
+ if (StorageType.withHDFS(type)) {
+ handlers.add(
+ getHdfsClientReadHandler(request)
+ );
+ }
+ if (handlers.isEmpty()) {
+ throw new RssException("This should not happen due to the unknown
storage type: " + storageType);
+ }
+
+ Callable<ClientReadHandler>[] callables =
+ handlers
+ .stream()
+ .map(x -> (Callable<ClientReadHandler>) () -> x)
+ .collect(Collectors.toList())
+ .toArray(new Callable[handlers.size()]);
+ return new ComposedClientReadHandler(callables);
+ }
+
+ private ClientReadHandler
getMemoryClientReadHandler(CreateShuffleReadHandlerRequest request) {
+ List<ShuffleServerInfo> shuffleServerInfoList =
request.getShuffleServerInfoList();
+ List<ShuffleServerClient> shuffleServerClients =
shuffleServerInfoList.stream().map(
+ ssi -> ShuffleServerClientFactory.getInstance().getShuffleServerClient(
+ ClientType.GRPC.name(), ssi)).collect(
+ Collectors.toList());
+ ClientReadHandler memoryClientReadHandler = new
MemoryQuorumClientReadHandler(
+ request.getAppId(),
+ request.getShuffleId(),
+ request.getPartitionId(),
+ request.getReadBufferSize(),
+ shuffleServerClients);
+ return memoryClientReadHandler;
+ }
+
+ private ClientReadHandler
getLocalfileClientReaderHandler(CreateShuffleReadHandlerRequest request) {
+ List<ShuffleServerInfo> shuffleServerInfoList =
request.getShuffleServerInfoList();
+ List<ShuffleServerClient> shuffleServerClients =
shuffleServerInfoList.stream().map(
+ ssi ->
ShuffleServerClientFactory.getInstance().getShuffleServerClient(ClientType.GRPC.name(),
ssi)).collect(
+ Collectors.toList());
+ return new LocalFileQuorumClientReadHandler(request.getAppId(),
request.getShuffleId(), request.getPartitionId(),
+ request.getIndexReadLimit(), request.getPartitionNumPerRange(),
request.getPartitionNum(),
+ request.getReadBufferSize(), request.getExpectBlockIds(),
request.getProcessBlockIds(),
+ shuffleServerClients);
+ }
+
+ private ClientReadHandler
getHdfsClientReadHandler(CreateShuffleReadHandlerRequest request) {
+ return new HdfsClientReadHandler(
+ request.getAppId(),
+ request.getShuffleId(),
+ request.getPartitionId(),
+ request.getIndexReadLimit(),
+ request.getPartitionNumPerRange(),
+ request.getPartitionNum(),
+ request.getReadBufferSize(),
+ request.getExpectBlockIds(),
+ request.getProcessBlockIds(),
+ request.getStorageBasePath(),
+ request.getHadoopConf());
}
public ShuffleDeleteHandler
createShuffleDeleteHandler(CreateShuffleDeleteHandlerRequest request) {
diff --git
a/storage/src/main/java/org/apache/uniffle/storage/util/StorageType.java
b/storage/src/main/java/org/apache/uniffle/storage/util/StorageType.java
index 80c155e8..37493548 100644
--- a/storage/src/main/java/org/apache/uniffle/storage/util/StorageType.java
+++ b/storage/src/main/java/org/apache/uniffle/storage/util/StorageType.java
@@ -18,10 +18,33 @@
package org.apache.uniffle.storage.util;
public enum StorageType {
- HDFS,
- LOCALFILE,
- LOCALFILE_HDFS,
- MEMORY_LOCALFILE,
- MEMORY_HDFS,
- MEMORY_LOCALFILE_HDFS
+ MEMORY(1),
+ LOCALFILE(2),
+ HDFS(4),
+ LOCALFILE_HDFS(6),
+ MEMORY_LOCALFILE(3),
+ MEMORY_HDFS(5),
+ MEMORY_LOCALFILE_HDFS(7);
+
+ private int val;
+
+ StorageType(int val) {
+ this.val = val;
+ }
+
+ private int getVal() {
+ return val;
+ }
+
+ public static boolean withMemory(StorageType storageType) {
+ return (storageType.getVal() & MEMORY.getVal()) != 0;
+ }
+
+ public static boolean withLocalfile(StorageType storageType) {
+ return (storageType.getVal() & LOCALFILE.getVal()) != 0;
+ }
+
+ public static boolean withHDFS(StorageType storageType) {
+ return (storageType.getVal() & HDFS.getVal()) != 0;
+ }
}
diff --git
a/storage/src/test/java/org/apache/uniffle/storage/util/StorageTypeTest.java
b/storage/src/test/java/org/apache/uniffle/storage/util/StorageTypeTest.java
new file mode 100644
index 00000000..afaf4ab0
--- /dev/null
+++ b/storage/src/test/java/org/apache/uniffle/storage/util/StorageTypeTest.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.storage.util;
+
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class StorageTypeTest {
+
+ @Test
+ public void commonTest() {
+ String type = "HDFS";
+ assertEquals(StorageType.valueOf(type), StorageType.HDFS);
+
+ StorageType storageType = StorageType.MEMORY;
+ assertTrue(StorageType.withMemory(storageType));
+ assertFalse(StorageType.withLocalfile(storageType));
+ assertFalse(StorageType.withHDFS(storageType));
+
+ storageType = StorageType.LOCALFILE;
+ assertFalse(StorageType.withMemory(storageType));
+ assertTrue(StorageType.withLocalfile(storageType));
+ assertFalse(StorageType.withHDFS(storageType));
+
+ storageType = StorageType.HDFS;
+ assertFalse(StorageType.withMemory(storageType));
+ assertFalse(StorageType.withLocalfile(storageType));
+ assertTrue(StorageType.withHDFS(storageType));
+
+ storageType = StorageType.MEMORY_HDFS;
+ assertTrue(StorageType.withMemory(storageType));
+ assertFalse(StorageType.withLocalfile(storageType));
+ assertTrue(StorageType.withHDFS(storageType));
+
+ storageType = StorageType.MEMORY_LOCALFILE;
+ assertTrue(StorageType.withMemory(storageType));
+ assertTrue(StorageType.withLocalfile(storageType));
+ assertFalse(StorageType.withHDFS(storageType));
+
+ storageType = StorageType.MEMORY_LOCALFILE_HDFS;
+ assertTrue(StorageType.withMemory(storageType));
+ assertTrue(StorageType.withLocalfile(storageType));
+ assertTrue(StorageType.withHDFS(storageType));
+
+ storageType = StorageType.LOCALFILE_HDFS;
+ assertFalse(StorageType.withMemory(storageType));
+ assertTrue(StorageType.withLocalfile(storageType));
+ assertTrue(StorageType.withHDFS(storageType));
+ }
+}