This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 60f6f878 [CELEBORN-11] ShuffleClient supports MapPartition shuffle 
write:pushdata (#1036)
60f6f878 is described below

commit 60f6f8783271508f64a4007193eeef6c76faa2cd
Author: zhongqiangczq <[email protected]>
AuthorDate: Thu Dec 8 12:31:47 2022 +0800

    [CELEBORN-11] ShuffleClient supports MapPartition shuffle write:pushdata 
(#1036)
---
 .../org/apache/celeborn/client/ShuffleClient.java  |  16 +++
 .../apache/celeborn/client/ShuffleClientImpl.java  | 131 +++++++++++++++++-
 .../apache/celeborn/client/DummyShuffleClient.java |  15 +++
 .../celeborn/client/ShuffleClientBaseSuiteJ.java   | 108 +++++++++++++++
 .../celeborn/client/ShuffleClientImplSuiteJ.java   | 147 +++++++++++++++++++++
 5 files changed, 416 insertions(+), 1 deletion(-)

diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java 
b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
index d4a471e9..4beaf00c 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
@@ -19,7 +19,9 @@ package org.apache.celeborn.client;
 
 import java.io.IOException;
 import java.util.Optional;
+import java.util.function.BooleanSupplier;
 
+import io.netty.buffer.ByteBuf;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.FileSystem;
 
@@ -185,6 +187,20 @@ public abstract class ShuffleClient implements Cloneable {
 
   public abstract void shutDown();
 
+  // Write data to a specific map partition, input data's type is Bytebuf.
+  // data's type is Bytebuf to avoid copy between application and netty
+  // closecallback will do some clean opertions like memory release.
+  public abstract int pushDataToLocation(
+      String applicationId,
+      int shuffleId,
+      int mapId,
+      int attemptId,
+      int partitionId,
+      ByteBuf data,
+      PartitionLocation location,
+      BooleanSupplier closeCallBack)
+      throws IOException;;
+
   public abstract Optional<PartitionLocation> regionStart(
       String applicationId,
       int shuffleId,
diff --git 
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java 
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index a3d0c6b1..86f0c5b4 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -25,11 +25,13 @@ import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
+import java.util.function.BooleanSupplier;
 
 import scala.reflect.ClassTag$;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.util.concurrent.Uninterruptibles;
+import io.netty.buffer.ByteBuf;
 import io.netty.buffer.CompositeByteBuf;
 import io.netty.buffer.Unpooled;
 import io.netty.channel.ChannelFuture;
@@ -90,6 +92,8 @@ public class ShuffleClientImpl extends ShuffleClient {
 
   protected TransportClientFactory dataClientFactory;
 
+  final int BATCH_HEADER_SIZE = 4 * 4;
+
   // key: shuffleId, value: (partitionId, PartitionLocation)
   private final Map<Integer, ConcurrentHashMap<Integer, PartitionLocation>> 
reducePartitionMap =
       new ConcurrentHashMap<>();
@@ -565,7 +569,7 @@ public class ShuffleClientImpl extends ShuffleClient {
     compressor.compress(data, offset, length);
 
     final int compressedTotalSize = compressor.getCompressedTotalSize();
-    final int BATCH_HEADER_SIZE = 4 * 4;
+
     final byte[] body = new byte[BATCH_HEADER_SIZE + compressedTotalSize];
     Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET, mapId);
     Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 4, attemptId);
@@ -1249,6 +1253,131 @@ public class ShuffleClientImpl extends ShuffleClient {
         || (message.startsWith("Failed to send RPC "));
   }
 
+  public int pushDataToLocation(
+      String applicationId,
+      int shuffleId,
+      int mapId,
+      int attemptId,
+      int partitionId,
+      ByteBuf data,
+      PartitionLocation location,
+      BooleanSupplier closeCallBack)
+      throws IOException {
+    // mapKey
+    final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
+    final String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId);
+    // return if shuffle stage already ended
+    if (mapperEnded(shuffleId, mapId, attemptId)) {
+      logger.debug(
+          "The mapper(shuffle {} map {} attempt {}) has already ended while"
+              + " pushing data byteBuf.",
+          shuffleId,
+          mapId,
+          attemptId);
+      PushState pushState = pushStates.get(mapKey);
+      if (pushState != null) {
+        pushState.cancelFutures();
+      }
+      return 0;
+    }
+
+    PushState pushState = pushStates.computeIfAbsent(mapKey, (s) -> new 
PushState(conf));
+
+    // increment batchId
+    final int nextBatchId = pushState.batchId.addAndGet(1);
+    int totalLength = data.readableBytes();
+    data.markWriterIndex();
+    data.writerIndex(0);
+    data.writeInt(mapId);
+    data.writeInt(attemptId);
+    data.writeInt(nextBatchId);
+    data.writeInt(totalLength - BATCH_HEADER_SIZE);
+    data.resetWriterIndex();
+    logger.debug(
+        "Do push data byteBuf for app {} shuffle {} map {} attempt {} reduce 
{} batch {}.",
+        applicationId,
+        shuffleId,
+        mapId,
+        attemptId,
+        partitionId,
+        nextBatchId);
+    // check limit
+    limitMaxInFlight(mapKey, pushState, maxInFlight);
+
+    // add inFlight requests
+    pushState.inFlightBatches.put(nextBatchId, location);
+
+    // build PushData request
+    NettyManagedBuffer buffer = new NettyManagedBuffer(data);
+    PushData pushData = new PushData(MASTER_MODE, shuffleKey, 
location.getUniqueId(), buffer);
+
+    // build callback
+    RpcResponseCallback callback =
+        new RpcResponseCallback() {
+          @Override
+          public void onSuccess(ByteBuffer response) {
+            closeCallBack.getAsBoolean();
+            pushState.inFlightBatches.remove(nextBatchId);
+            pushState.removeFuture(nextBatchId);
+            if (response.remaining() > 0) {
+              byte reason = response.get();
+              if (reason == StatusCode.STAGE_ENDED.getValue()) {
+                mapperEndMap
+                    .computeIfAbsent(shuffleId, (id) -> 
ConcurrentHashMap.newKeySet())
+                    .add(mapKey);
+              }
+            }
+            logger.debug(
+                "Push data byteBuf to {}:{} success for map {} attempt {} 
batch {}.",
+                location.getHost(),
+                location.getPushPort(),
+                mapId,
+                attemptId,
+                nextBatchId);
+          }
+
+          @Override
+          public void onFailure(Throwable e) {
+            closeCallBack.getAsBoolean();
+            pushState.inFlightBatches.remove(nextBatchId);
+            pushState.removeFuture(nextBatchId);
+            if (pushState.exception.get() != null) {
+              return;
+            }
+            if (!mapperEnded(shuffleId, mapId, attemptId)) {
+              pushState.exception.compareAndSet(
+                  null, new IOException("PushData byteBuf failed!", e));
+              logger.error(
+                  "Push data byteBuf to {}:{} failed for map {} attempt {} 
batch {}.",
+                  location.getHost(),
+                  location.getPushPort(),
+                  mapId,
+                  attemptId,
+                  nextBatchId,
+                  e);
+            } else {
+              logger.warn(
+                  "Mapper shuffleId:{} mapId:{} attempt:{} already ended, 
remove batchId:{}.",
+                  shuffleId,
+                  mapId,
+                  attemptId,
+                  nextBatchId);
+            }
+          }
+        };
+    // do push data
+    try {
+      TransportClient client =
+          dataClientFactory.createClient(location.getHost(), 
location.getPushPort(), partitionId);
+      ChannelFuture future = client.pushData(pushData, callback);
+      pushState.addFuture(nextBatchId, future);
+    } catch (Exception e) {
+      logger.warn("PushData byteBuf failed", e);
+      callback.onFailure(new 
Exception(getPushDataFailCause(e.getMessage()).toString(), e));
+    }
+    return totalLength;
+  }
+
   @Override
   public void pushDataHandShake(
       String applicationId,
diff --git 
a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java 
b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
index cccfd20b..d1911fa3 100644
--- a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
+++ b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
@@ -23,7 +23,9 @@ import java.io.FileOutputStream;
 import java.io.IOException;
 import java.io.OutputStream;
 import java.util.Optional;
+import java.util.function.BooleanSupplier;
 
+import io.netty.buffer.ByteBuf;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -135,6 +137,19 @@ public class DummyShuffleClient extends ShuffleClient {
     }
   }
 
+  @Override
+  public int pushDataToLocation(
+      String applicationId,
+      int shuffleId,
+      int mapId,
+      int attemptId,
+      int partitionId,
+      ByteBuf data,
+      PartitionLocation location,
+      BooleanSupplier closeCallBack) {
+    return 0;
+  }
+
   @Override
   public Optional<PartitionLocation> regionStart(
       String applicationId,
diff --git 
a/client/src/test/java/org/apache/celeborn/client/ShuffleClientBaseSuiteJ.java 
b/client/src/test/java/org/apache/celeborn/client/ShuffleClientBaseSuiteJ.java
new file mode 100644
index 00000000..ead6a010
--- /dev/null
+++ 
b/client/src/test/java/org/apache/celeborn/client/ShuffleClientBaseSuiteJ.java
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.client;
+
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import java.io.IOException;
+
+import scala.reflect.ClassTag$;
+
+import io.netty.channel.ChannelFuture;
+
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.identity.UserIdentifier;
+import org.apache.celeborn.common.network.client.TransportClient;
+import org.apache.celeborn.common.network.client.TransportClientFactory;
+import org.apache.celeborn.common.protocol.CompressionCodec;
+import org.apache.celeborn.common.protocol.PartitionLocation;
+import org.apache.celeborn.common.protocol.PbRegisterShuffleResponse;
+import org.apache.celeborn.common.protocol.message.ControlMessages;
+import org.apache.celeborn.common.protocol.message.StatusCode;
+import org.apache.celeborn.common.rpc.RpcEndpointRef;
+
+public abstract class ShuffleClientBaseSuiteJ {
+  protected ShuffleClientImpl shuffleClient = null;
+  protected static final RpcEndpointRef endpointRef = 
mock(RpcEndpointRef.class);
+  protected static final TransportClientFactory clientFactory = 
mock(TransportClientFactory.class);
+  protected final TransportClient client = mock(TransportClient.class);
+
+  protected static final String TEST_APPLICATION_ID = "testapp1";
+  protected static final int TEST_SHUFFLE_ID = 1;
+  protected static final int TEST_ATTEMPT_ID = 0;
+  protected static final int TEST_REDUCRE_ID = 0;
+
+  protected static final int MASTER_RPC_PORT = 1234;
+  protected static final int MASTER_PUSH_PORT = 1235;
+  protected static final int MASTER_FETCH_PORT = 1236;
+  protected static final int MASTER_REPLICATE_PORT = 1237;
+  protected static final int SLAVE_RPC_PORT = 4321;
+  protected static final int SLAVE_PUSH_PORT = 4322;
+  protected static final int SLAVE_FETCH_PORT = 4323;
+  protected static final int SLAVE_REPLICATE_PORT = 4324;
+  protected static final PartitionLocation masterLocation =
+      new PartitionLocation(
+          0,
+          1,
+          "localhost",
+          MASTER_RPC_PORT,
+          MASTER_PUSH_PORT,
+          MASTER_FETCH_PORT,
+          MASTER_REPLICATE_PORT,
+          PartitionLocation.Mode.MASTER);
+  protected static final PartitionLocation slaveLocation =
+      new PartitionLocation(
+          0,
+          1,
+          "localhost",
+          SLAVE_RPC_PORT,
+          SLAVE_PUSH_PORT,
+          SLAVE_FETCH_PORT,
+          SLAVE_REPLICATE_PORT,
+          PartitionLocation.Mode.SLAVE);
+
+  protected final int BATCH_HEADER_SIZE = 4 * 4;
+  protected ChannelFuture mockedFuture = mock(ChannelFuture.class);
+
+  protected CelebornConf setupEnv(CompressionCodec codec) throws IOException, 
InterruptedException {
+    CelebornConf conf = new CelebornConf();
+    conf.set("celeborn.shuffle.compression.codec", codec.name());
+    conf.set("celeborn.push.retry.threads", "1");
+    conf.set("celeborn.push.buffer.size", "1K");
+    shuffleClient = new ShuffleClientImpl(conf, new UserIdentifier("mock", 
"mock"));
+    masterLocation.setPeer(slaveLocation);
+
+    when(endpointRef.askSync(
+            ControlMessages.RegisterShuffle$.MODULE$.apply(
+                TEST_APPLICATION_ID, TEST_SHUFFLE_ID, 1, 1),
+            ClassTag$.MODULE$.apply(PbRegisterShuffleResponse.class)))
+        .thenAnswer(
+            t ->
+                ControlMessages.RegisterShuffleResponse$.MODULE$.apply(
+                    StatusCode.SUCCESS, new PartitionLocation[] 
{masterLocation}));
+
+    shuffleClient.setupMetaServiceRef(endpointRef);
+    when(clientFactory.createClient(
+            masterLocation.getHost(), masterLocation.getPushPort(), 
TEST_REDUCRE_ID))
+        .thenAnswer(t -> client);
+
+    shuffleClient.dataClientFactory = clientFactory;
+    return conf;
+  }
+}
diff --git 
a/client/src/test/java/org/apache/celeborn/client/ShuffleClientImplSuiteJ.java 
b/client/src/test/java/org/apache/celeborn/client/ShuffleClientImplSuiteJ.java
new file mode 100644
index 00000000..ae368d2f
--- /dev/null
+++ 
b/client/src/test/java/org/apache/celeborn/client/ShuffleClientImplSuiteJ.java
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.client;
+
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.when;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.network.client.RpcResponseCallback;
+import org.apache.celeborn.common.protocol.CompressionCodec;
+import org.apache.celeborn.common.protocol.message.StatusCode;
+
+public class ShuffleClientImplSuiteJ extends ShuffleClientBaseSuiteJ {
+  static int BufferSize = 64;
+  static byte[] TEST_BUF1 = new byte[BufferSize];
+  static CelebornConf conf;
+
+  @Before
+  public void setup() throws IOException, InterruptedException {
+    conf = setupEnv(CompressionCodec.LZ4);
+  }
+
+  public ByteBuf createByteBuf() {
+    for (int i = BATCH_HEADER_SIZE; i < BufferSize; i++) {
+      TEST_BUF1[i] = 1;
+    }
+    ByteBuf byteBuf = Unpooled.wrappedBuffer(TEST_BUF1);
+    byteBuf.writerIndex(BufferSize);
+    return byteBuf;
+  }
+
+  @Test
+  public void testPushDataByteBufSuccess() throws IOException {
+    ByteBuf byteBuf = createByteBuf();
+    when(client.pushData(any(), any()))
+        .thenAnswer(
+            t -> {
+              RpcResponseCallback rpcResponseCallback =
+                  t.getArgumentAt(1, RpcResponseCallback.class);
+              ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[0]);
+              rpcResponseCallback.onSuccess(byteBuffer);
+              return mockedFuture;
+            });
+
+    int pushDataLen =
+        shuffleClient.pushDataToLocation(
+            TEST_APPLICATION_ID,
+            TEST_SHUFFLE_ID,
+            TEST_ATTEMPT_ID,
+            TEST_ATTEMPT_ID,
+            TEST_REDUCRE_ID,
+            byteBuf,
+            masterLocation,
+            () -> true);
+    Assert.assertEquals(BufferSize, pushDataLen);
+  }
+
+  @Test
+  public void testPushDataByteBufHardSplit() throws IOException {
+    ByteBuf byteBuf = Unpooled.wrappedBuffer(TEST_BUF1);
+    when(client.pushData(any(), any()))
+        .thenAnswer(
+            t -> {
+              RpcResponseCallback rpcResponseCallback =
+                  t.getArgumentAt(1, RpcResponseCallback.class);
+              ByteBuffer byteBuffer =
+                  ByteBuffer.wrap(new byte[] 
{StatusCode.HARD_SPLIT.getValue()});
+              rpcResponseCallback.onSuccess(byteBuffer);
+              return mockedFuture;
+            });
+    int pushDataLen =
+        shuffleClient.pushDataToLocation(
+            TEST_APPLICATION_ID,
+            TEST_SHUFFLE_ID,
+            TEST_ATTEMPT_ID,
+            TEST_ATTEMPT_ID,
+            TEST_REDUCRE_ID,
+            byteBuf,
+            masterLocation,
+            () -> true);
+  }
+
+  @Test
+  public void testPushDataByteBufFail() throws IOException {
+    ByteBuf byteBuf = Unpooled.wrappedBuffer(TEST_BUF1);
+    when(client.pushData(any(), any()))
+        .thenAnswer(
+            t -> {
+              RpcResponseCallback rpcResponseCallback =
+                  t.getArgumentAt(1, RpcResponseCallback.class);
+              rpcResponseCallback.onFailure(new Exception("pushDataFailed"));
+              return mockedFuture;
+            });
+    // first push just  set pushdata.exception
+    shuffleClient.pushDataToLocation(
+        TEST_APPLICATION_ID,
+        TEST_SHUFFLE_ID,
+        TEST_ATTEMPT_ID,
+        TEST_ATTEMPT_ID,
+        TEST_REDUCRE_ID,
+        byteBuf,
+        masterLocation,
+        () -> true);
+
+    boolean isFailed = false;
+    // second push will throw exception
+    try {
+      shuffleClient.pushDataToLocation(
+          TEST_APPLICATION_ID,
+          TEST_SHUFFLE_ID,
+          TEST_ATTEMPT_ID,
+          TEST_ATTEMPT_ID,
+          TEST_REDUCRE_ID,
+          byteBuf,
+          masterLocation,
+          () -> true);
+    } catch (IOException e) {
+      isFailed = true;
+    } finally {
+      Assert.assertTrue("should failed", isFailed);
+    }
+  }
+}

Reply via email to