This is an automated email from the ASF dual-hosted git repository.
zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new 60f6f878 [CELEBORN-11] ShuffleClient supports MapPartition shuffle
write:pushdata (#1036)
60f6f878 is described below
commit 60f6f8783271508f64a4007193eeef6c76faa2cd
Author: zhongqiangczq <[email protected]>
AuthorDate: Thu Dec 8 12:31:47 2022 +0800
[CELEBORN-11] ShuffleClient supports MapPartition shuffle write:pushdata
(#1036)
---
.../org/apache/celeborn/client/ShuffleClient.java | 16 +++
.../apache/celeborn/client/ShuffleClientImpl.java | 131 +++++++++++++++++-
.../apache/celeborn/client/DummyShuffleClient.java | 15 +++
.../celeborn/client/ShuffleClientBaseSuiteJ.java | 108 +++++++++++++++
.../celeborn/client/ShuffleClientImplSuiteJ.java | 147 +++++++++++++++++++++
5 files changed, 416 insertions(+), 1 deletion(-)
diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
index d4a471e9..4beaf00c 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
@@ -19,7 +19,9 @@ package org.apache.celeborn.client;
import java.io.IOException;
import java.util.Optional;
+import java.util.function.BooleanSupplier;
+import io.netty.buffer.ByteBuf;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
@@ -185,6 +187,20 @@ public abstract class ShuffleClient implements Cloneable {
public abstract void shutDown();
+ // Write data to a specific map partition, input data's type is Bytebuf.
+ // data's type is Bytebuf to avoid copy between application and netty
+ // closecallback will do some clean opertions like memory release.
+ public abstract int pushDataToLocation(
+ String applicationId,
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int partitionId,
+ ByteBuf data,
+ PartitionLocation location,
+ BooleanSupplier closeCallBack)
+ throws IOException;;
+
public abstract Optional<PartitionLocation> regionStart(
String applicationId,
int shuffleId,
diff --git
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index a3d0c6b1..86f0c5b4 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -25,11 +25,13 @@ import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
+import java.util.function.BooleanSupplier;
import scala.reflect.ClassTag$;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.util.concurrent.Uninterruptibles;
+import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
@@ -90,6 +92,8 @@ public class ShuffleClientImpl extends ShuffleClient {
protected TransportClientFactory dataClientFactory;
+ final int BATCH_HEADER_SIZE = 4 * 4;
+
// key: shuffleId, value: (partitionId, PartitionLocation)
private final Map<Integer, ConcurrentHashMap<Integer, PartitionLocation>>
reducePartitionMap =
new ConcurrentHashMap<>();
@@ -565,7 +569,7 @@ public class ShuffleClientImpl extends ShuffleClient {
compressor.compress(data, offset, length);
final int compressedTotalSize = compressor.getCompressedTotalSize();
- final int BATCH_HEADER_SIZE = 4 * 4;
+
final byte[] body = new byte[BATCH_HEADER_SIZE + compressedTotalSize];
Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET, mapId);
Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 4, attemptId);
@@ -1249,6 +1253,131 @@ public class ShuffleClientImpl extends ShuffleClient {
|| (message.startsWith("Failed to send RPC "));
}
+ public int pushDataToLocation(
+ String applicationId,
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int partitionId,
+ ByteBuf data,
+ PartitionLocation location,
+ BooleanSupplier closeCallBack)
+ throws IOException {
+ // mapKey
+ final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
+ final String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId);
+ // return if shuffle stage already ended
+ if (mapperEnded(shuffleId, mapId, attemptId)) {
+ logger.debug(
+ "The mapper(shuffle {} map {} attempt {}) has already ended while"
+ + " pushing data byteBuf.",
+ shuffleId,
+ mapId,
+ attemptId);
+ PushState pushState = pushStates.get(mapKey);
+ if (pushState != null) {
+ pushState.cancelFutures();
+ }
+ return 0;
+ }
+
+ PushState pushState = pushStates.computeIfAbsent(mapKey, (s) -> new
PushState(conf));
+
+ // increment batchId
+ final int nextBatchId = pushState.batchId.addAndGet(1);
+ int totalLength = data.readableBytes();
+ data.markWriterIndex();
+ data.writerIndex(0);
+ data.writeInt(mapId);
+ data.writeInt(attemptId);
+ data.writeInt(nextBatchId);
+ data.writeInt(totalLength - BATCH_HEADER_SIZE);
+ data.resetWriterIndex();
+ logger.debug(
+ "Do push data byteBuf for app {} shuffle {} map {} attempt {} reduce
{} batch {}.",
+ applicationId,
+ shuffleId,
+ mapId,
+ attemptId,
+ partitionId,
+ nextBatchId);
+ // check limit
+ limitMaxInFlight(mapKey, pushState, maxInFlight);
+
+ // add inFlight requests
+ pushState.inFlightBatches.put(nextBatchId, location);
+
+ // build PushData request
+ NettyManagedBuffer buffer = new NettyManagedBuffer(data);
+ PushData pushData = new PushData(MASTER_MODE, shuffleKey,
location.getUniqueId(), buffer);
+
+ // build callback
+ RpcResponseCallback callback =
+ new RpcResponseCallback() {
+ @Override
+ public void onSuccess(ByteBuffer response) {
+ closeCallBack.getAsBoolean();
+ pushState.inFlightBatches.remove(nextBatchId);
+ pushState.removeFuture(nextBatchId);
+ if (response.remaining() > 0) {
+ byte reason = response.get();
+ if (reason == StatusCode.STAGE_ENDED.getValue()) {
+ mapperEndMap
+ .computeIfAbsent(shuffleId, (id) ->
ConcurrentHashMap.newKeySet())
+ .add(mapKey);
+ }
+ }
+ logger.debug(
+ "Push data byteBuf to {}:{} success for map {} attempt {}
batch {}.",
+ location.getHost(),
+ location.getPushPort(),
+ mapId,
+ attemptId,
+ nextBatchId);
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ closeCallBack.getAsBoolean();
+ pushState.inFlightBatches.remove(nextBatchId);
+ pushState.removeFuture(nextBatchId);
+ if (pushState.exception.get() != null) {
+ return;
+ }
+ if (!mapperEnded(shuffleId, mapId, attemptId)) {
+ pushState.exception.compareAndSet(
+ null, new IOException("PushData byteBuf failed!", e));
+ logger.error(
+ "Push data byteBuf to {}:{} failed for map {} attempt {}
batch {}.",
+ location.getHost(),
+ location.getPushPort(),
+ mapId,
+ attemptId,
+ nextBatchId,
+ e);
+ } else {
+ logger.warn(
+ "Mapper shuffleId:{} mapId:{} attempt:{} already ended,
remove batchId:{}.",
+ shuffleId,
+ mapId,
+ attemptId,
+ nextBatchId);
+ }
+ }
+ };
+ // do push data
+ try {
+ TransportClient client =
+ dataClientFactory.createClient(location.getHost(),
location.getPushPort(), partitionId);
+ ChannelFuture future = client.pushData(pushData, callback);
+ pushState.addFuture(nextBatchId, future);
+ } catch (Exception e) {
+ logger.warn("PushData byteBuf failed", e);
+ callback.onFailure(new
Exception(getPushDataFailCause(e.getMessage()).toString(), e));
+ }
+ return totalLength;
+ }
+
@Override
public void pushDataHandShake(
String applicationId,
diff --git
a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
index cccfd20b..d1911fa3 100644
--- a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
+++ b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
@@ -23,7 +23,9 @@ import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.util.Optional;
+import java.util.function.BooleanSupplier;
+import io.netty.buffer.ByteBuf;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -135,6 +137,19 @@ public class DummyShuffleClient extends ShuffleClient {
}
}
+ @Override
+ public int pushDataToLocation(
+ String applicationId,
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int partitionId,
+ ByteBuf data,
+ PartitionLocation location,
+ BooleanSupplier closeCallBack) {
+ return 0;
+ }
+
@Override
public Optional<PartitionLocation> regionStart(
String applicationId,
diff --git
a/client/src/test/java/org/apache/celeborn/client/ShuffleClientBaseSuiteJ.java
b/client/src/test/java/org/apache/celeborn/client/ShuffleClientBaseSuiteJ.java
new file mode 100644
index 00000000..ead6a010
--- /dev/null
+++
b/client/src/test/java/org/apache/celeborn/client/ShuffleClientBaseSuiteJ.java
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.client;
+
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import java.io.IOException;
+
+import scala.reflect.ClassTag$;
+
+import io.netty.channel.ChannelFuture;
+
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.identity.UserIdentifier;
+import org.apache.celeborn.common.network.client.TransportClient;
+import org.apache.celeborn.common.network.client.TransportClientFactory;
+import org.apache.celeborn.common.protocol.CompressionCodec;
+import org.apache.celeborn.common.protocol.PartitionLocation;
+import org.apache.celeborn.common.protocol.PbRegisterShuffleResponse;
+import org.apache.celeborn.common.protocol.message.ControlMessages;
+import org.apache.celeborn.common.protocol.message.StatusCode;
+import org.apache.celeborn.common.rpc.RpcEndpointRef;
+
+public abstract class ShuffleClientBaseSuiteJ {
+ protected ShuffleClientImpl shuffleClient = null;
+ protected static final RpcEndpointRef endpointRef =
mock(RpcEndpointRef.class);
+ protected static final TransportClientFactory clientFactory =
mock(TransportClientFactory.class);
+ protected final TransportClient client = mock(TransportClient.class);
+
+ protected static final String TEST_APPLICATION_ID = "testapp1";
+ protected static final int TEST_SHUFFLE_ID = 1;
+ protected static final int TEST_ATTEMPT_ID = 0;
+ protected static final int TEST_REDUCRE_ID = 0;
+
+ protected static final int MASTER_RPC_PORT = 1234;
+ protected static final int MASTER_PUSH_PORT = 1235;
+ protected static final int MASTER_FETCH_PORT = 1236;
+ protected static final int MASTER_REPLICATE_PORT = 1237;
+ protected static final int SLAVE_RPC_PORT = 4321;
+ protected static final int SLAVE_PUSH_PORT = 4322;
+ protected static final int SLAVE_FETCH_PORT = 4323;
+ protected static final int SLAVE_REPLICATE_PORT = 4324;
+ protected static final PartitionLocation masterLocation =
+ new PartitionLocation(
+ 0,
+ 1,
+ "localhost",
+ MASTER_RPC_PORT,
+ MASTER_PUSH_PORT,
+ MASTER_FETCH_PORT,
+ MASTER_REPLICATE_PORT,
+ PartitionLocation.Mode.MASTER);
+ protected static final PartitionLocation slaveLocation =
+ new PartitionLocation(
+ 0,
+ 1,
+ "localhost",
+ SLAVE_RPC_PORT,
+ SLAVE_PUSH_PORT,
+ SLAVE_FETCH_PORT,
+ SLAVE_REPLICATE_PORT,
+ PartitionLocation.Mode.SLAVE);
+
+ protected final int BATCH_HEADER_SIZE = 4 * 4;
+ protected ChannelFuture mockedFuture = mock(ChannelFuture.class);
+
+ protected CelebornConf setupEnv(CompressionCodec codec) throws IOException,
InterruptedException {
+ CelebornConf conf = new CelebornConf();
+ conf.set("celeborn.shuffle.compression.codec", codec.name());
+ conf.set("celeborn.push.retry.threads", "1");
+ conf.set("celeborn.push.buffer.size", "1K");
+ shuffleClient = new ShuffleClientImpl(conf, new UserIdentifier("mock",
"mock"));
+ masterLocation.setPeer(slaveLocation);
+
+ when(endpointRef.askSync(
+ ControlMessages.RegisterShuffle$.MODULE$.apply(
+ TEST_APPLICATION_ID, TEST_SHUFFLE_ID, 1, 1),
+ ClassTag$.MODULE$.apply(PbRegisterShuffleResponse.class)))
+ .thenAnswer(
+ t ->
+ ControlMessages.RegisterShuffleResponse$.MODULE$.apply(
+ StatusCode.SUCCESS, new PartitionLocation[]
{masterLocation}));
+
+ shuffleClient.setupMetaServiceRef(endpointRef);
+ when(clientFactory.createClient(
+ masterLocation.getHost(), masterLocation.getPushPort(),
TEST_REDUCRE_ID))
+ .thenAnswer(t -> client);
+
+ shuffleClient.dataClientFactory = clientFactory;
+ return conf;
+ }
+}
diff --git
a/client/src/test/java/org/apache/celeborn/client/ShuffleClientImplSuiteJ.java
b/client/src/test/java/org/apache/celeborn/client/ShuffleClientImplSuiteJ.java
new file mode 100644
index 00000000..ae368d2f
--- /dev/null
+++
b/client/src/test/java/org/apache/celeborn/client/ShuffleClientImplSuiteJ.java
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.client;
+
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.when;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.network.client.RpcResponseCallback;
+import org.apache.celeborn.common.protocol.CompressionCodec;
+import org.apache.celeborn.common.protocol.message.StatusCode;
+
+public class ShuffleClientImplSuiteJ extends ShuffleClientBaseSuiteJ {
+ static int BufferSize = 64;
+ static byte[] TEST_BUF1 = new byte[BufferSize];
+ static CelebornConf conf;
+
+ @Before
+ public void setup() throws IOException, InterruptedException {
+ conf = setupEnv(CompressionCodec.LZ4);
+ }
+
+ public ByteBuf createByteBuf() {
+ for (int i = BATCH_HEADER_SIZE; i < BufferSize; i++) {
+ TEST_BUF1[i] = 1;
+ }
+ ByteBuf byteBuf = Unpooled.wrappedBuffer(TEST_BUF1);
+ byteBuf.writerIndex(BufferSize);
+ return byteBuf;
+ }
+
+ @Test
+ public void testPushDataByteBufSuccess() throws IOException {
+ ByteBuf byteBuf = createByteBuf();
+ when(client.pushData(any(), any()))
+ .thenAnswer(
+ t -> {
+ RpcResponseCallback rpcResponseCallback =
+ t.getArgumentAt(1, RpcResponseCallback.class);
+ ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[0]);
+ rpcResponseCallback.onSuccess(byteBuffer);
+ return mockedFuture;
+ });
+
+ int pushDataLen =
+ shuffleClient.pushDataToLocation(
+ TEST_APPLICATION_ID,
+ TEST_SHUFFLE_ID,
+ TEST_ATTEMPT_ID,
+ TEST_ATTEMPT_ID,
+ TEST_REDUCRE_ID,
+ byteBuf,
+ masterLocation,
+ () -> true);
+ Assert.assertEquals(BufferSize, pushDataLen);
+ }
+
+ @Test
+ public void testPushDataByteBufHardSplit() throws IOException {
+ ByteBuf byteBuf = Unpooled.wrappedBuffer(TEST_BUF1);
+ when(client.pushData(any(), any()))
+ .thenAnswer(
+ t -> {
+ RpcResponseCallback rpcResponseCallback =
+ t.getArgumentAt(1, RpcResponseCallback.class);
+ ByteBuffer byteBuffer =
+ ByteBuffer.wrap(new byte[]
{StatusCode.HARD_SPLIT.getValue()});
+ rpcResponseCallback.onSuccess(byteBuffer);
+ return mockedFuture;
+ });
+ int pushDataLen =
+ shuffleClient.pushDataToLocation(
+ TEST_APPLICATION_ID,
+ TEST_SHUFFLE_ID,
+ TEST_ATTEMPT_ID,
+ TEST_ATTEMPT_ID,
+ TEST_REDUCRE_ID,
+ byteBuf,
+ masterLocation,
+ () -> true);
+ }
+
+ @Test
+ public void testPushDataByteBufFail() throws IOException {
+ ByteBuf byteBuf = Unpooled.wrappedBuffer(TEST_BUF1);
+ when(client.pushData(any(), any()))
+ .thenAnswer(
+ t -> {
+ RpcResponseCallback rpcResponseCallback =
+ t.getArgumentAt(1, RpcResponseCallback.class);
+ rpcResponseCallback.onFailure(new Exception("pushDataFailed"));
+ return mockedFuture;
+ });
+ // first push just set pushdata.exception
+ shuffleClient.pushDataToLocation(
+ TEST_APPLICATION_ID,
+ TEST_SHUFFLE_ID,
+ TEST_ATTEMPT_ID,
+ TEST_ATTEMPT_ID,
+ TEST_REDUCRE_ID,
+ byteBuf,
+ masterLocation,
+ () -> true);
+
+ boolean isFailed = false;
+ // second push will throw exception
+ try {
+ shuffleClient.pushDataToLocation(
+ TEST_APPLICATION_ID,
+ TEST_SHUFFLE_ID,
+ TEST_ATTEMPT_ID,
+ TEST_ATTEMPT_ID,
+ TEST_REDUCRE_ID,
+ byteBuf,
+ masterLocation,
+ () -> true);
+ } catch (IOException e) {
+ isFailed = true;
+ } finally {
+ Assert.assertTrue("should failed", isFailed);
+ }
+ }
+}