This is an automated email from the ASF dual-hosted git repository.
zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new d3d40f73 [CELEBORN-106] flink-plugin supports shufflewrite:OutputGate
(#1051)
d3d40f73 is described below
commit d3d40f730ce64618169be662febd7b3d140cecb8
Author: zhongqiangczq <[email protected]>
AuthorDate: Thu Dec 8 11:24:37 2022 +0800
[CELEBORN-106] flink-plugin supports shufflewrite:OutputGate (#1051)
---
client-flink/flink-1.14/pom.xml | 6 +
.../plugin/flink/RemoteShuffleOutputGate.java | 226 +++++++++++++++++++++
client-flink/{flink-1.14 => flink-common}/pom.xml | 9 +-
.../apache/celeborn/plugin/flink/utils/Utils.java | 73 +++++++
.../org/apache/celeborn/client/ShuffleClient.java | 3 +
.../apache/celeborn/client/DummyShuffleClient.java | 6 +
pom.xml | 1 +
7 files changed, 316 insertions(+), 8 deletions(-)
diff --git a/client-flink/flink-1.14/pom.xml b/client-flink/flink-1.14/pom.xml
index 4aa5b01e..30681778 100644
--- a/client-flink/flink-1.14/pom.xml
+++ b/client-flink/flink-1.14/pom.xml
@@ -49,5 +49,11 @@
<artifactId>junit</artifactId>
<scope>test</scope>
</dependency>
+ <dependency>
+ <groupId>org.apache.celeborn</groupId>
+
<artifactId>celeborn-client-flink-common-${flink.version}_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ <scope>compile</scope>
+ </dependency>
</dependencies>
</project>
diff --git
a/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGate.java
b/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGate.java
new file mode 100644
index 00000000..f30dfadc
--- /dev/null
+++
b/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGate.java
@@ -0,0 +1,226 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink;
+
+import java.io.IOException;
+import java.util.Optional;
+
+import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
+import org.apache.flink.util.function.SupplierWithException;
+
+import org.apache.celeborn.client.ShuffleClient;
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.identity.UserIdentifier;
+import org.apache.celeborn.common.protocol.PartitionLocation;
+import org.apache.celeborn.plugin.flink.utils.Utils;
+
+/**
+ * A transportation gate used to spill buffers from {@link
ResultPartitionWriter} to remote shuffle
+ * worker.
+ */
+public class RemoteShuffleOutputGate {
+
+ private final RemoteShuffleDescriptor shuffleDesc;
+ protected final int numSubs;
+ private final ShuffleClient shuffleWriteClient;
+ protected final SupplierWithException<BufferPool, IOException>
bufferPoolFactory;
+ protected BufferPool bufferPool;
+ private CelebornConf celebornConf;
+ private final int numMappers;
+ private PartitionLocation partitionLocation;
+
+ private int currentRegionIndex = 0;
+
+ private int bufferSize;
+ private String applicationId;
+ private int shuffleId;
+ private int mapId;
+ private int attemptId;
+ private String rssMetaServiceHost;
+ private int rssMetaServicePort;
+ private UserIdentifier userIdentifier;
+
+ /**
+ * @param shuffleDesc Describes shuffle meta and shuffle worker address.
+ * @param numSubs Number of subpartitions of the corresponding {@link
ResultPartitionWriter}.
+ * @param bufferPoolFactory {@link BufferPool} provider.
+ */
+ public RemoteShuffleOutputGate(
+ RemoteShuffleDescriptor shuffleDesc,
+ int numSubs,
+ int bufferSize,
+ SupplierWithException<BufferPool, IOException> bufferPoolFactory,
+ CelebornConf celebornConf,
+ int numMappers) {
+
+ this.shuffleDesc = shuffleDesc;
+ this.numSubs = numSubs;
+ this.bufferPoolFactory = bufferPoolFactory;
+ this.shuffleWriteClient = createWriteClient();
+ // this.bufferPacker = new BufferPacker(this::write);
+ this.celebornConf = celebornConf;
+ this.numMappers = numMappers;
+ this.bufferSize = bufferSize;
+ this.applicationId = shuffleDesc.getJobID().toString();
+ this.shuffleId =
+
shuffleDesc.getShuffleResource().getMapPartitionShuffleDescriptor().getShuffleId();
+ this.mapId =
shuffleDesc.getShuffleResource().getMapPartitionShuffleDescriptor().getMapId();
+ this.attemptId =
+
shuffleDesc.getShuffleResource().getMapPartitionShuffleDescriptor().getAttemptId();
+ this.rssMetaServiceHost =
+ ((RemoteShuffleResource)
shuffleDesc.getShuffleResource()).getRssMetaServiceHost();
+ this.rssMetaServicePort =
+ ((RemoteShuffleResource)
shuffleDesc.getShuffleResource()).getRssMetaServicePort();
+ }
+
+ /** Initialize transportation gate. */
+ public void setup() throws IOException, InterruptedException {
+ bufferPool = Utils.checkNotNull(bufferPoolFactory.get());
+ Utils.checkArgument(
+ bufferPool.getNumberOfRequiredMemorySegments() >= 2,
+ "Too few buffers for transfer, the minimum valid required size is 2.");
+
+ // guarantee that we have at least one buffer
+ // BufferUtils.reserveNumRequiredBuffers(bufferPool, 1);
+
+ // handshake
+ handshake();
+ }
+
+ /** Get transportation buffer pool. */
+ public BufferPool getBufferPool() {
+ return bufferPool;
+ }
+
+ /** Writes a {@link Buffer} to a subpartition. */
+ public void write(Buffer buffer, int subIdx) throws InterruptedException {
+ // bufferPacker.process(buffer, subIdx);
+ }
+
+ /**
+ * Indicates the start of a region. A region of buffers guarantees the
records inside are
+ * completed.
+ *
+ * @param isBroadcast Whether it's a broadcast region.
+ */
+ public void regionStart(boolean isBroadcast) {
+ Optional<PartitionLocation> newPartitionLoc = null;
+ try {
+ newPartitionLoc =
+ shuffleWriteClient.regionStart(
+ applicationId,
+ shuffleId,
+ mapId,
+ attemptId,
+ partitionLocation,
+ currentRegionIndex,
+ isBroadcast);
+ // revived
+ if (newPartitionLoc.isPresent()) {
+ partitionLocation = newPartitionLoc.get();
+ // send handshake again
+ handshake();
+ // send regionstart again
+ shuffleWriteClient.regionStart(
+ applicationId,
+ shuffleId,
+ mapId,
+ attemptId,
+ newPartitionLoc.get(),
+ currentRegionIndex,
+ isBroadcast);
+ }
+ } catch (IOException e) {
+ Utils.rethrowAsRuntimeException(e);
+ }
+ }
+
+ /**
+ * Indicates the finish of a region. A region is always bounded by a pair of
region-start and
+ * region-finish.
+ */
+ public void regionFinish() throws InterruptedException {
+ // bufferPacker.drain();
+ try {
+ shuffleWriteClient.regionFinish(
+ applicationId, shuffleId, mapId, attemptId, partitionLocation);
+ currentRegionIndex++;
+ } catch (IOException e) {
+ Utils.rethrowAsRuntimeException(e);
+ }
+ }
+
+ /** Indicates the writing/spilling is finished. */
+ public void finish() throws InterruptedException, IOException {
+ shuffleWriteClient.mapperEnd(applicationId, shuffleId, mapId, attemptId,
numMappers);
+ }
+
+ /** Close the transportation gate. */
+ public void close() throws IOException {
+ if (bufferPool != null) {
+ bufferPool.lazyDestroy();
+ }
+ // bufferPacker.close();
+ shuffleWriteClient.shutDown();
+ }
+
+ /** Returns shuffle descriptor. */
+ public RemoteShuffleDescriptor getShuffleDesc() {
+ return shuffleDesc;
+ }
+
+ private ShuffleClient createWriteClient() {
+ return ShuffleClient.get(rssMetaServiceHost, rssMetaServicePort,
celebornConf, userIdentifier);
+ }
+
+ /** Writes a piece of data to a subpartition. */
+ public void write(ByteBuf byteBuf, int subIdx) throws InterruptedException {
+ // try {
+ // byteBuf.retain();
+ // shuffleWriteClient.pushData(
+ // applicationId,
+ // shuffleId,
+ // mapId,
+ // attemptId,
+ // subIdx,
+ // io.netty.buffer.Unpooled.wrappedBuffer(byteBuf.nioBuffer()),
+ // partitionLocation,
+ // () -> byteBuf.release());
+ // } catch (IOException e) {
+ // Utils.rethrowAsRuntimeException(e);
+ // }
+ }
+
+ public void handshake() {
+ if (partitionLocation == null) {
+ partitionLocation =
+ shuffleWriteClient.registerMapPartitionTask(
+ applicationId, shuffleId, numMappers, mapId, attemptId);
+ }
+ currentRegionIndex = 0;
+ try {
+ shuffleWriteClient.pushDataHandShake(
+ applicationId, shuffleId, mapId, attemptId, numSubs, bufferSize,
partitionLocation);
+ } catch (IOException e) {
+ Utils.rethrowAsRuntimeException(e);
+ }
+ }
+}
diff --git a/client-flink/flink-1.14/pom.xml b/client-flink/flink-common/pom.xml
similarity index 86%
copy from client-flink/flink-1.14/pom.xml
copy to client-flink/flink-common/pom.xml
index 4aa5b01e..3b37b482 100644
--- a/client-flink/flink-1.14/pom.xml
+++ b/client-flink/flink-common/pom.xml
@@ -24,9 +24,7 @@
<relativePath>../../pom.xml</relativePath>
</parent>
-
<artifactId>celeborn-client-flink-${flink.version}_${scala.binary.version}</artifactId>
- <packaging>jar</packaging>
- <name>Celeborn Client for flink</name>
+
<artifactId>celeborn-client-flink-common-${flink.version}_${scala.binary.version}</artifactId>
<dependencies>
<dependency>
@@ -44,10 +42,5 @@
<artifactId>celeborn-client_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
- <dependency>
- <groupId>junit</groupId>
- <artifactId>junit</artifactId>
- <scope>test</scope>
- </dependency>
</dependencies>
</project>
diff --git
a/client-flink/flink-common/src/main/java/org/apache/celeborn/plugin/flink/utils/Utils.java
b/client-flink/flink-common/src/main/java/org/apache/celeborn/plugin/flink/utils/Utils.java
new file mode 100644
index 00000000..f4ebbb8b
--- /dev/null
+++
b/client-flink/flink-common/src/main/java/org/apache/celeborn/plugin/flink/utils/Utils.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink.utils;
+
+/** Utility methods can be used by all modules. */
+public class Utils {
+ /**
+ * Ensures that the target object is not null and returns it. It will throw
{@link
+ * NullPointerException} if the target object is null.
+ */
+ public static <T> T checkNotNull(T object) {
+ if (object == null) {
+ throw new NullPointerException("Must be not null.");
+ }
+ return object;
+ }
+
+ /**
+ * Check the legality of method arguments. It will throw {@link
IllegalArgumentException} if the
+ * given condition is not true.
+ */
+ public static void checkArgument(boolean condition, String message) {
+ if (!condition) {
+ throw new IllegalArgumentException(message);
+ }
+ }
+
+ /**
+ * Checks the legality of program state. It will throw {@link
IllegalStateException} if the given
+ * condition is not true.
+ */
+ public static void checkState(boolean condition, String message) {
+ if (!condition) {
+ throw new IllegalStateException(message);
+ }
+ }
+
+ /** Casts the given long value to int and ensures there is no loss. */
+ public static int checkedDownCast(long value) {
+ int downCast = (int) value;
+ if ((long) downCast != value) {
+ throw new IllegalArgumentException("Cannot downcast long value " + value
+ " to integer.");
+ }
+
+ return downCast;
+ }
+
+ /** Rethrows the target {@link Throwable} as {@link Error} or {@link
RuntimeException}. */
+ public static void rethrowAsRuntimeException(Throwable t) {
+ if (t instanceof Error) {
+ throw (Error) t;
+ } else if (t instanceof RuntimeException) {
+ throw (RuntimeException) t;
+ } else {
+ throw new RuntimeException(t);
+ }
+ }
+}
diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
index 3b76333f..d4a471e9 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
@@ -208,4 +208,7 @@ public abstract class ShuffleClient implements Cloneable {
int bufferSize,
PartitionLocation location)
throws IOException;
+
+ public abstract PartitionLocation registerMapPartitionTask(
+ String appId, int shuffleId, int numMappers, int mapId, int attemptId);
}
diff --git
a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
index 95470f6c..cccfd20b 100644
--- a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
+++ b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
@@ -163,4 +163,10 @@ public class DummyShuffleClient extends ShuffleClient {
int bufferSize,
PartitionLocation location)
throws IOException {}
+
+ @Override
+ public PartitionLocation registerMapPartitionTask(
+ String appId, int shuffleId, int numMappers, int mapId, int attemptId) {
+ return null;
+ }
}
diff --git a/pom.xml b/pom.xml
index 829a2469..982f1187 100644
--- a/pom.xml
+++ b/pom.xml
@@ -977,6 +977,7 @@
<id>flink-1.14</id>
<modules>
<module>client-flink/flink-1.14</module>
+ <module>client-flink/flink-common</module>
<module>client-flink/flink-shaded</module>
</modules>
<properties>