This is an automated email from the ASF dual-hosted git repository.
rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new bd29da836 [CELEBORN-1490][CIP-6] Introduce tier consumer for hybrid
shuffle
bd29da836 is described below
commit bd29da83635bd97b690174d1d0551fd114dbde50
Author: Weijie Guo <[email protected]>
AuthorDate: Thu Oct 17 10:46:35 2024 +0800
[CELEBORN-1490][CIP-6] Introduce tier consumer for hybrid shuffle
### What changes were proposed in this pull request?
Introduce tier consumer for hybrid shuffle
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
unit test
Closes #2786 from reswqa/cip6-7-pr-new.
Authored-by: Weijie Guo <[email protected]>
Signed-off-by: Shuang <[email protected]>
---
.../plugin/flink/network/ReadClientHandler.java | 2 -
.../celeborn/plugin/flink/protocol/ReadData.java | 6 +-
.../flink/protocol/SubPartitionReadData.java | 36 +-
.../flink/readclient/CelebornBufferStream.java | 59 ++-
.../flink/readclient/FlinkShuffleClientImpl.java | 14 +-
...nsportFrameDecoderWithBufferSupplierSuiteJ.java | 39 +-
.../flink/tiered/CelebornChannelBufferManager.java | 164 ++++++
.../flink/tiered/CelebornChannelBufferReader.java | 329 ++++++++++++
.../flink/tiered/CelebornTierConsumerAgent.java | 550 +++++++++++++++++++++
.../plugin/flink/tiered/CelebornTierFactory.java | 3 +-
.../celeborn/common/network/protocol/ReadData.java | 2 +-
.../network/protocol/SubPartitionReadData.java | 20 +-
12 files changed, 1157 insertions(+), 67 deletions(-)
diff --git
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java
index c1d0982fb..1ef2b7781 100644
---
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java
+++
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java
@@ -66,8 +66,6 @@ public class ReadClientHandler extends BaseMessageHandler {
} else {
if (msg != null && msg instanceof ReadData) {
((ReadData) msg).getFlinkBuffer().release();
- } else if (msg != null && msg instanceof SubPartitionReadData) {
- ((SubPartitionReadData) msg).getFlinkBuffer().release();
}
logger.warn("Unexpected streamId received: {}", streamId);
diff --git
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/protocol/ReadData.java
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/protocol/ReadData.java
index a8f7fcf84..ba507e88e 100644
---
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/protocol/ReadData.java
+++
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/protocol/ReadData.java
@@ -23,9 +23,9 @@ import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
import org.apache.celeborn.common.network.protocol.RequestMessage;
-public final class ReadData extends RequestMessage {
- private final long streamId;
- private ByteBuf flinkBuffer;
+public class ReadData extends RequestMessage {
+ protected final long streamId;
+ protected ByteBuf flinkBuffer;
@Override
public boolean needCopyOut() {
diff --git
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/protocol/SubPartitionReadData.java
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/protocol/SubPartitionReadData.java
index b12f24d9b..fc08de75f 100644
---
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/protocol/SubPartitionReadData.java
+++
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/protocol/SubPartitionReadData.java
@@ -20,46 +20,30 @@ package org.apache.celeborn.plugin.flink.protocol;
import java.util.Objects;
-import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
-
-import org.apache.celeborn.common.network.protocol.ReadData;
-import org.apache.celeborn.common.network.protocol.RequestMessage;
-
/**
* Comparing {@link ReadData}, this class has an additional field of
subpartitionId. This class is
* added to keep the backward compatibility.
*/
-public class SubPartitionReadData extends RequestMessage {
- private final long streamId;
+public class SubPartitionReadData extends ReadData {
private final int subPartitionId;
- private ByteBuf flinkBuffer;
-
- @Override
- public boolean needCopyOut() {
- return true;
- }
public SubPartitionReadData(long streamId, int subPartitionId) {
+ super(streamId);
this.subPartitionId = subPartitionId;
- this.streamId = streamId;
}
@Override
public int encodedLength() {
- return 8 + 4;
+ return super.encodedLength() + 4;
}
// This method will not be called because ReadData won't be created at flink
client.
@Override
public void encode(io.netty.buffer.ByteBuf buf) {
- buf.writeLong(streamId);
+ super.encode(buf);
buf.writeInt(subPartitionId);
}
- public long getStreamId() {
- return streamId;
- }
-
public int getSubPartitionId() {
return subPartitionId;
}
@@ -74,8 +58,8 @@ public class SubPartitionReadData extends RequestMessage {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
SubPartitionReadData readData = (SubPartitionReadData) o;
- return streamId == readData.streamId
- && subPartitionId == readData.subPartitionId
+ return streamId == readData.getStreamId()
+ && subPartitionId == readData.getSubPartitionId()
&& Objects.equals(flinkBuffer, readData.flinkBuffer);
}
@@ -95,12 +79,4 @@ public class SubPartitionReadData extends RequestMessage {
+ flinkBuffer
+ '}';
}
-
- public ByteBuf getFlinkBuffer() {
- return flinkBuffer;
- }
-
- public void setFlinkBuffer(ByteBuf flinkBuffer) {
- this.flinkBuffer = flinkBuffer;
- }
}
diff --git
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/CelebornBufferStream.java
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/CelebornBufferStream.java
index 7f478143d..849895fef 100644
---
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/CelebornBufferStream.java
+++
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/CelebornBufferStream.java
@@ -20,9 +20,12 @@ package org.apache.celeborn.plugin.flink.readclient;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Supplier;
+import javax.annotation.Nullable;
+
import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -36,6 +39,7 @@ import org.apache.celeborn.common.network.util.NettyUtils;
import org.apache.celeborn.common.protocol.MessageType;
import org.apache.celeborn.common.protocol.PartitionLocation;
import org.apache.celeborn.common.protocol.PbBufferStreamEnd;
+import org.apache.celeborn.common.protocol.PbNotifyRequiredSegment;
import org.apache.celeborn.common.protocol.PbOpenStream;
import org.apache.celeborn.common.protocol.PbReadAddCredit;
import org.apache.celeborn.common.protocol.PbStreamHandler;
@@ -110,10 +114,38 @@ public class CelebornBufferStream {
});
}
+ public void notifyRequiredSegment(PbNotifyRequiredSegment
pbNotifyRequiredSegment) {
+ this.client.sendRpc(
+ new TransportMessage(
+ MessageType.NOTIFY_REQUIRED_SEGMENT,
pbNotifyRequiredSegment.toByteArray())
+ .toByteBuffer(),
+ new RpcResponseCallback() {
+
+ @Override
+ public void onSuccess(ByteBuffer response) {
+ // Send PbNotifyRequiredSegment do not expect response.
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ logger.error(
+ "Send PbNotifyRequiredSegment to {} failed, streamId {},
detail {}",
+ NettyUtils.getRemoteAddress(client.getChannel()),
+ streamId,
+ e.getCause());
+ messageConsumer.accept(new TransportableError(streamId, e));
+ }
+ });
+ }
+
public static CelebornBufferStream empty() {
return EMPTY_CELEBORN_BUFFER_STREAM;
}
+ public static boolean isEmptyStream(CelebornBufferStream stream) {
+ return stream == null || stream == EMPTY_CELEBORN_BUFFER_STREAM;
+ }
+
public long getStreamId() {
return streamId;
}
@@ -167,6 +199,11 @@ public class CelebornBufferStream {
}
public void moveToNextPartitionIfPossible(long endedStreamId) {
+ moveToNextPartitionIfPossible(endedStreamId, null);
+ }
+
+ public void moveToNextPartitionIfPossible(
+ long endedStreamId, @Nullable BiConsumer<Long, Integer>
requiredSegmentIdConsumer) {
logger.debug(
"MoveToNextPartitionIfPossible in this:{}, endedStreamId: {},
currentLocationIndex: {}, currentSteamId:{}, locationsLength:{}",
this,
@@ -178,9 +215,10 @@ public class CelebornBufferStream {
logger.debug("Get end streamId {}", endedStreamId);
cleanStream(endedStreamId);
}
+
if (currentLocationIndex.get() < locations.length) {
try {
- openStreamInternal();
+ openStreamInternal(requiredSegmentIdConsumer);
logger.debug(
"MoveToNextPartitionIfPossible after openStream this:{},
endedStreamId: {}, currentLocationIndex: {}, currentSteamId:{},
locationsLength:{}",
this,
@@ -195,7 +233,12 @@ public class CelebornBufferStream {
}
}
- private void openStreamInternal() throws IOException, InterruptedException {
+ /**
+ * Open the stream, note that if the openReaderFuture is not null,
requiredSegmentIdConsumer will
+ * be invoked for every subPartition when open stream success.
+ */
+ private void openStreamInternal(@Nullable BiConsumer<Long, Integer>
requiredSegmentIdConsumer)
+ throws IOException, InterruptedException {
this.client =
clientFactory.createClientWithRetry(
locations[currentLocationIndex.get()].getHost(),
@@ -210,6 +253,7 @@ public class CelebornBufferStream {
.setStartIndex(subIndexStart)
.setEndIndex(subIndexEnd)
.setInitialCredit(initialCredit)
+ .setRequireSubpartitionId(true)
.build()
.toByteArray());
client.sendRpc(
@@ -230,6 +274,13 @@ public class CelebornBufferStream {
.getReadClientHandler()
.registerHandler(streamId, messageConsumer, client);
isOpenSuccess = true;
+ if (requiredSegmentIdConsumer != null) {
+ for (int subPartitionId = subIndexStart;
+ subPartitionId <= subIndexEnd;
+ subPartitionId++) {
+ requiredSegmentIdConsumer.accept(streamId,
subPartitionId);
+ }
+ }
logger.debug(
"open stream success from remote:{}, stream id:{},
fileName: {}",
client.getSocketAddress(),
@@ -269,4 +320,8 @@ public class CelebornBufferStream {
public TransportClient getClient() {
return client;
}
+
+ public boolean isOpened() {
+ return isOpenSuccess;
+ }
}
diff --git
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java
index 4fb65777e..a3f8a1a9b 100644
---
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java
+++
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java
@@ -181,15 +181,11 @@ public class FlinkShuffleClientImpl extends
ShuffleClientImpl {
shuffleId,
partitionId,
isSegmentGranularityVisible);
- if (isSegmentGranularityVisible) {
- // When the downstream reduce tasks start early than upstream map
tasks, the shuffle
- // partition locations may be found empty, should retry until the
upstream task started
- return CelebornBufferStream.empty();
- } else {
- throw new PartitionUnRetryAbleException(
- String.format(
- "Shuffle data lost for shuffle %d partition %d.", shuffleId,
partitionId));
- }
+ // TODO: in segment granularity visible senarios, when the downstream
reduce tasks start early
+ // than upstream map tasks, the shuffle
+ // partition locations may be found empty, should retry until the
upstream task started
+ throw new PartitionUnRetryAbleException(
+ String.format("Shuffle data lost for shuffle %d partition %d.",
shuffleId, partitionId));
} else {
Arrays.sort(partitionLocations,
Comparator.comparingInt(PartitionLocation::getEpoch));
logger.debug(
diff --git
a/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplierSuiteJ.java
b/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplierSuiteJ.java
index 64696abec..c7c8440c8 100644
---
a/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplierSuiteJ.java
+++
b/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplierSuiteJ.java
@@ -21,6 +21,8 @@ import static
org.apache.celeborn.common.network.client.TransportClient.requestI
import java.io.IOException;
import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
@@ -31,19 +33,40 @@ import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import org.junit.Assert;
import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
import org.mockito.Mockito;
import org.apache.celeborn.common.network.buffer.NioManagedBuffer;
import org.apache.celeborn.common.network.protocol.Message;
import org.apache.celeborn.common.network.protocol.ReadData;
import org.apache.celeborn.common.network.protocol.RpcRequest;
+import org.apache.celeborn.common.network.protocol.SubPartitionReadData;
import org.apache.celeborn.common.network.protocol.TransportMessage;
import org.apache.celeborn.common.protocol.MessageType;
import org.apache.celeborn.common.protocol.PbBacklogAnnouncement;
import org.apache.celeborn.common.util.JavaUtils;
+@RunWith(Parameterized.class)
public class TransportFrameDecoderWithBufferSupplierSuiteJ {
+ enum TestReadDataType {
+ READ_DATA,
+ SUBPARTITION_READ_DATA,
+ }
+
+ private TestReadDataType testReadDataType;
+
+ public TransportFrameDecoderWithBufferSupplierSuiteJ(TestReadDataType
testReadDataType) {
+ this.testReadDataType = testReadDataType;
+ }
+
+ @Parameterized.Parameters
+ public static Collection prepareData() {
+ Object[][] object = {{TestReadDataType.READ_DATA},
{TestReadDataType.SUBPARTITION_READ_DATA}};
+ return Arrays.asList(object);
+ }
+
@Test
public void testDropUnusedBytes() throws IOException {
ConcurrentHashMap<Long,
Supplier<org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf>>
@@ -64,11 +87,11 @@ public class TransportFrameDecoderWithBufferSupplierSuiteJ {
ChannelHandlerContext context = Mockito.mock(ChannelHandlerContext.class);
RpcRequest announcement = createBacklogAnnouncement(0, 0);
- ReadData unUsedReadData = new ReadData(1, generateData(1024));
- ReadData readData = new ReadData(2, generateData(1024));
+ ReadData unUsedReadData = generateReadDataMessage(1, 0,
generateData(1024));
+ ReadData readData = generateReadDataMessage(2, 0, generateData(1024));
RpcRequest announcement1 = createBacklogAnnouncement(0, 0);
- ReadData unUsedReadData1 = new ReadData(1, generateData(1024));
- ReadData readData1 = new ReadData(2, generateData(8));
+ ReadData unUsedReadData1 = generateReadDataMessage(1, 0,
generateData(1024));
+ ReadData readData1 = generateReadDataMessage(2, 0, generateData(8));
ByteBuf buffer = Unpooled.buffer(5000);
encodeMessage(announcement, buffer);
@@ -145,4 +168,12 @@ public class TransportFrameDecoderWithBufferSupplierSuiteJ
{
return data;
}
+
+ private ReadData generateReadDataMessage(long streamId, int subPartitionId,
ByteBuf buf) {
+ if (testReadDataType == TestReadDataType.READ_DATA) {
+ return new ReadData(streamId, buf);
+ } else {
+ return new SubPartitionReadData(streamId, subPartitionId, buf);
+ }
+ }
}
diff --git
a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornChannelBufferManager.java
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornChannelBufferManager.java
new file mode 100644
index 000000000..8fe9e4bbb
--- /dev/null
+++
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornChannelBufferManager.java
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.celeborn.plugin.flink.tiered;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+import java.util.LinkedList;
+import java.util.Queue;
+
+import javax.annotation.concurrent.GuardedBy;
+
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferListener;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.BufferRecycler;
+import
org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.TieredStorageMemoryManager;
+import org.apache.flink.util.ExceptionUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class CelebornChannelBufferManager implements BufferListener,
BufferRecycler {
+
+ private static Logger logger =
LoggerFactory.getLogger(CelebornChannelBufferManager.class);
+
+ /** The queue to hold the available buffer when the reader is waiting for
buffers. */
+ private final Queue<Buffer> bufferQueue;
+
+ private final TieredStorageMemoryManager memoryManager;
+
+ private final CelebornChannelBufferReader bufferReader;
+
+ /** The tag indicates whether it is waiting for buffers from the buffer
pool. */
+ @GuardedBy("bufferQueue")
+ private boolean isWaitingForFloatingBuffers;
+
+ /** The total number of required buffers for the respective input channel. */
+ @GuardedBy("bufferQueue")
+ private int numRequiredBuffers = 0;
+
+ public CelebornChannelBufferManager(
+ TieredStorageMemoryManager memoryManager, CelebornChannelBufferReader
bufferReader) {
+ this.memoryManager = checkNotNull(memoryManager);
+ this.bufferReader = checkNotNull(bufferReader);
+ this.bufferQueue = new LinkedList<>();
+ }
+
+ @Override
+ public boolean notifyBufferAvailable(Buffer buffer) {
+ if (bufferReader.isClosed()) {
+ return false;
+ }
+ int numBuffers = 0;
+ boolean isBufferUsed = false;
+ try {
+ synchronized (bufferQueue) {
+ if (!isWaitingForFloatingBuffers) {
+ logger.warn("This channel should be waiting for floating buffers.");
+ return false;
+ }
+ isWaitingForFloatingBuffers = false;
+ if (bufferReader.isClosed() || bufferQueue.size() >=
numRequiredBuffers) {
+ return false;
+ }
+ bufferQueue.add(buffer);
+ isBufferUsed = true;
+ numBuffers = 1 + tryRequestBuffers();
+ }
+ bufferReader.notifyAvailableCredits(numBuffers);
+ } catch (Throwable t) {
+ bufferReader.errorReceived(t.getLocalizedMessage());
+ }
+ return isBufferUsed;
+ }
+
+ public void decreaseRequiredCredits(int numCredits) {
+ synchronized (bufferQueue) {
+ numRequiredBuffers -= numCredits;
+ }
+ }
+
+ @Override
+ public void notifyBufferDestroyed() {
+ // noop
+ }
+
+ @Override
+ public void recycle(MemorySegment segment) {
+ try {
+ memoryManager.getBufferPool().recycle(segment);
+ } catch (Throwable t) {
+ ExceptionUtils.rethrow(t);
+ }
+ }
+
+ Buffer requestBuffer() {
+ synchronized (bufferQueue) {
+ return bufferQueue.poll();
+ }
+ }
+
+ int requestBuffers(int numRequired) {
+ int numRequestedBuffers = 0;
+ synchronized (bufferQueue) {
+ if (bufferReader.isClosed()) {
+ return numRequestedBuffers;
+ }
+ numRequiredBuffers += numRequired;
+ numRequestedBuffers = tryRequestBuffers();
+ }
+ return numRequestedBuffers;
+ }
+
+ int tryRequestBuffersIfNeeded() {
+ synchronized (bufferQueue) {
+ if (numRequiredBuffers > 0 && !isWaitingForFloatingBuffers &&
bufferQueue.isEmpty()) {
+ return tryRequestBuffers();
+ }
+ return 0;
+ }
+ }
+
+ void close() {
+ synchronized (bufferQueue) {
+ for (Buffer buffer : bufferQueue) {
+ buffer.recycleBuffer();
+ }
+ bufferQueue.clear();
+ }
+ }
+
+ @GuardedBy("bufferQueue")
+ private int tryRequestBuffers() {
+ assert Thread.holdsLock(bufferQueue);
+ int numRequestedBuffers = 0;
+ while (bufferQueue.size() < numRequiredBuffers &&
!isWaitingForFloatingBuffers) {
+ BufferPool bufferPool = memoryManager.getBufferPool();
+ Buffer buffer = bufferPool.requestBuffer();
+ if (buffer != null) {
+ bufferQueue.add(buffer);
+ numRequestedBuffers++;
+ } else if (bufferPool.addBufferListener(this)) {
+ isWaitingForFloatingBuffers = true;
+ break;
+ }
+ }
+ return numRequestedBuffers;
+ }
+}
diff --git
a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornChannelBufferReader.java
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornChannelBufferReader.java
new file mode 100644
index 000000000..527617c96
--- /dev/null
+++
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornChannelBufferReader.java
@@ -0,0 +1,329 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink.tiered;
+
+import static org.apache.celeborn.plugin.flink.utils.Utils.checkState;
+
+import java.io.IOException;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.function.BiConsumer;
+import java.util.function.Consumer;
+
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageInputChannelId;
+import
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageSubpartitionId;
+import
org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.TieredStorageMemoryManager;
+import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.common.network.protocol.BacklogAnnouncement;
+import org.apache.celeborn.common.network.protocol.BufferStreamEnd;
+import org.apache.celeborn.common.network.protocol.RequestMessage;
+import org.apache.celeborn.common.network.protocol.TransportableError;
+import org.apache.celeborn.common.network.util.NettyUtils;
+import org.apache.celeborn.common.protocol.PbNotifyRequiredSegment;
+import org.apache.celeborn.common.protocol.PbReadAddCredit;
+import org.apache.celeborn.common.util.JavaUtils;
+import org.apache.celeborn.plugin.flink.ShuffleResourceDescriptor;
+import org.apache.celeborn.plugin.flink.protocol.SubPartitionReadData;
+import org.apache.celeborn.plugin.flink.readclient.CelebornBufferStream;
+import org.apache.celeborn.plugin.flink.readclient.FlinkShuffleClientImpl;
+
+/**
+ * Wrap the {@link CelebornBufferStream}, utilize in flink hybrid shuffle
integration strategy now.
+ */
+public class CelebornChannelBufferReader {
+ private static final Logger LOG =
LoggerFactory.getLogger(CelebornChannelBufferReader.class);
+
+ private CelebornChannelBufferManager bufferManager;
+
+ private final FlinkShuffleClientImpl client;
+
+ private final int shuffleId;
+
+ private final int partitionId;
+
+ private final TieredStorageInputChannelId inputChannelId;
+
+ private final int subPartitionIndexStart;
+
+ private final int subPartitionIndexEnd;
+
+ private final BiConsumer<ByteBuf, TieredStorageSubpartitionId> dataListener;
+
+ private final BiConsumer<Throwable, TieredStorageSubpartitionId>
failureListener;
+
+ private final Consumer<RequestMessage> messageConsumer;
+
+ private CelebornBufferStream bufferStream;
+
+ private boolean isOpened;
+
+ private volatile boolean closed = false;
+
+ private volatile ConcurrentHashMap<Integer, Integer>
subPartitionRequiredSegmentIds;
+
+ /** Note this field is to record the number of backlog before the read is
set up. */
+ private int numBackLog = 0;
+
+ public CelebornChannelBufferReader(
+ FlinkShuffleClientImpl client,
+ ShuffleResourceDescriptor shuffleDescriptor,
+ TieredStorageInputChannelId inputChannelId,
+ int startSubIdx,
+ int endSubIdx,
+ BiConsumer<ByteBuf, TieredStorageSubpartitionId> dataListener,
+ BiConsumer<Throwable, TieredStorageSubpartitionId> failureListener) {
+ this.client = client;
+ this.shuffleId = shuffleDescriptor.getShuffleId();
+ this.partitionId = shuffleDescriptor.getPartitionId();
+ this.inputChannelId = inputChannelId;
+ this.subPartitionIndexStart = startSubIdx;
+ this.subPartitionIndexEnd = endSubIdx;
+ this.dataListener = dataListener;
+ this.failureListener = failureListener;
+ this.subPartitionRequiredSegmentIds = JavaUtils.newConcurrentHashMap();
+ for (int subPartitionId = subPartitionIndexStart;
+ subPartitionId <= subPartitionIndexEnd;
+ subPartitionId++) {
+ subPartitionRequiredSegmentIds.put(subPartitionId, -1);
+ }
+ this.messageConsumer =
+ requestMessage -> {
+ // Note that we need to use SubPartitionReadData because the
isSegmentGranularityVisible
+ // is set as true when opening stream
+ if (requestMessage instanceof SubPartitionReadData) {
+ dataReceived((SubPartitionReadData) requestMessage);
+ } else if (requestMessage instanceof BacklogAnnouncement) {
+ backlogReceived(((BacklogAnnouncement)
requestMessage).getBacklog());
+ } else if (requestMessage instanceof TransportableError) {
+ errorReceived(((TransportableError)
requestMessage).getErrorMessage());
+ } else if (requestMessage instanceof BufferStreamEnd) {
+ onStreamEnd((BufferStreamEnd) requestMessage);
+ }
+ };
+ }
+
+ public void setup(TieredStorageMemoryManager memoryManager) {
+ this.bufferManager = new CelebornChannelBufferManager(memoryManager, this);
+ if (numBackLog > 0) {
+ notifyAvailableCredits(bufferManager.requestBuffers(numBackLog));
+ numBackLog = 0;
+ }
+ }
+
+ public void open(int initialCredit) {
+ try {
+ bufferStream =
+ client.readBufferedPartition(
+ shuffleId, partitionId, subPartitionIndexStart,
subPartitionIndexEnd, true);
+ bufferStream.open(this::requestBuffer, initialCredit, messageConsumer);
+ this.isOpened = bufferStream.isOpened();
+ } catch (Exception e) {
+ messageConsumer.accept(new TransportableError(0L, e));
+ LOG.error("Failed to open reader", e);
+ }
+ }
+
+ public void close() {
+ // It may be call multiple times because subPartitions can share the same
reader, as a single
+ // reader can consume multiple subPartitions
+ if (closed) {
+ return;
+ }
+
+ // need set closed first before remove Handler
+ closed = true;
+ if (!CelebornBufferStream.isEmptyStream(bufferStream)) {
+ bufferStream.close();
+ bufferStream = null;
+ } else {
+ LOG.warn(
+ "bufferStream is null when closed, shuffleId: {}, partitionId: {}",
+ shuffleId,
+ partitionId);
+ }
+
+ try {
+ if (bufferManager != null) {
+ bufferManager.close();
+ bufferManager = null;
+ }
+ } catch (Throwable throwable) {
+ LOG.warn("Failed to close buffer manager.", throwable);
+ }
+
+ subPartitionRequiredSegmentIds.clear();
+ subPartitionRequiredSegmentIds = null;
+ }
+
+ public boolean isOpened() {
+ return isOpened;
+ }
+
+ boolean isClosed() {
+ return closed;
+ }
+
+ public void notifyAvailableCredits(int numCredits) {
+ if (numCredits <= 0) {
+ return;
+ }
+ if (!closed && !CelebornBufferStream.isEmptyStream(bufferStream)) {
+ bufferStream.addCredit(
+ PbReadAddCredit.newBuilder()
+ .setStreamId(bufferStream.getStreamId())
+ .setCredit(numCredits)
+ .build());
+ bufferManager.decreaseRequiredCredits(numCredits);
+ return;
+ }
+ LOG.warn(
+ "The buffer stream is null or closed, ignore the credits for
shuffleId: {}, partitionId: {}",
+ shuffleId,
+ partitionId);
+ }
+
+ public void notifyRequiredSegmentIfNeeded(int requiredSegmentId, int
subPartitionId) {
+ Integer lastRequiredSegmentId =
+ subPartitionRequiredSegmentIds.computeIfAbsent(subPartitionId, id ->
-1);
+ if (requiredSegmentId >= 0 && requiredSegmentId != lastRequiredSegmentId) {
+ LOG.debug(
+ "Notify required segment id {} for {} {}, the last segment id is {}",
+ requiredSegmentId,
+ partitionId,
+ subPartitionId,
+ lastRequiredSegmentId);
+ subPartitionRequiredSegmentIds.put(subPartitionId, requiredSegmentId);
+ if (!this.notifyRequiredSegment(requiredSegmentId, subPartitionId)) {
+ // if fail to notify reader segment, restore the last required segment
id
+ subPartitionRequiredSegmentIds.putIfAbsent(subPartitionId,
lastRequiredSegmentId);
+ }
+ }
+ }
+
+ public boolean notifyRequiredSegment(int requiredSegmentId, int
subPartitionId) {
+ this.subPartitionRequiredSegmentIds.put(subPartitionId, requiredSegmentId);
+ if (!closed && !CelebornBufferStream.isEmptyStream(bufferStream)) {
+ LOG.debug(
+ "Notify required segmentId {} for {} {} {}",
+ requiredSegmentId,
+ partitionId,
+ subPartitionId,
+ shuffleId);
+ PbNotifyRequiredSegment notifyRequiredSegment =
+ PbNotifyRequiredSegment.newBuilder()
+ .setStreamId(bufferStream.getStreamId())
+ .setRequiredSegmentId(requiredSegmentId)
+ .setSubPartitionId(subPartitionId)
+ .build();
+ bufferStream.notifyRequiredSegment(notifyRequiredSegment);
+ return true;
+ }
+ return false;
+ }
+
+ public ByteBuf requestBuffer() {
+ Buffer buffer = bufferManager.requestBuffer();
+ return buffer == null ? null : buffer.asByteBuf();
+ }
+
+ public void backlogReceived(int backlog) {
+ if (!closed) {
+ if (bufferManager == null) {
+ numBackLog += backlog;
+ return;
+ }
+ int numRequestedBuffers = bufferManager.requestBuffers(backlog);
+ if (numRequestedBuffers > 0) {
+ notifyAvailableCredits(numRequestedBuffers);
+ }
+ numBackLog = 0;
+ return;
+ }
+ LOG.warn(
+ "The buffer stream {} is null or closed, ignore the backlog for
shuffleId: {}, partitionId: {}",
+ bufferStream.getStreamId(),
+ shuffleId,
+ partitionId);
+ }
+
+ public void errorReceived(String errorMsg) {
+ if (!closed) {
+ closed = true;
+ LOG.debug("Error received, " + errorMsg);
+ if (!CelebornBufferStream.isEmptyStream(bufferStream) &&
bufferStream.getClient() != null) {
+ LOG.error(
+ "Received error from {} message {}",
+ NettyUtils.getRemoteAddress(bufferStream.getClient().getChannel()),
+ errorMsg);
+ }
+ for (int subPartitionId = subPartitionIndexStart;
+ subPartitionId <= subPartitionIndexEnd;
+ subPartitionId++) {
+ failureListener.accept(
+ new IOException(errorMsg), new
TieredStorageSubpartitionId(subPartitionId));
+ }
+ }
+ }
+
+ public void dataReceived(SubPartitionReadData readData) {
+ LOG.debug(
+ "Remote buffer stream reader get stream id {} subPartitionId {}
received readable bytes {}.",
+ readData.getStreamId(),
+ readData.getSubPartitionId(),
+ readData.getFlinkBuffer().readableBytes());
+ checkState(
+ readData.getSubPartitionId() >= subPartitionIndexStart
+ && readData.getSubPartitionId() <= subPartitionIndexEnd,
+ "Wrong sub partition id: " + readData.getSubPartitionId());
+ dataListener.accept(
+ readData.getFlinkBuffer(), new
TieredStorageSubpartitionId(readData.getSubPartitionId()));
+ int numRequested = bufferManager.tryRequestBuffersIfNeeded();
+ notifyAvailableCredits(numRequested);
+ }
+
+ public void onStreamEnd(BufferStreamEnd streamEnd) {
+ long streamId = streamEnd.getStreamId();
+ LOG.debug("Buffer stream reader get stream end for {}", streamId);
+ if (!closed && !CelebornBufferStream.isEmptyStream(bufferStream)) {
+ // TOOD: Update the partition locations here if support reading and
writing shuffle data
+ // simultaneously
+ bufferStream.moveToNextPartitionIfPossible(streamId,
this::sendRequireSegmentId);
+ }
+ }
+
+ public TieredStorageInputChannelId getInputChannelId() {
+ return inputChannelId;
+ }
+
+ private void sendRequireSegmentId(long streamId, int subPartitionId) {
+ if (subPartitionRequiredSegmentIds.containsKey(subPartitionId)) {
+ int currentSegmentId =
subPartitionRequiredSegmentIds.get(subPartitionId);
+ if (bufferStream.isOpened() && currentSegmentId >= 0) {
+ LOG.debug(
+ "Buffer stream {} is opened, notify required segment id {} ",
+ streamId,
+ currentSegmentId);
+ notifyRequiredSegment(currentSegmentId, subPartitionId);
+ }
+ }
+ }
+}
diff --git
a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierConsumerAgent.java
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierConsumerAgent.java
new file mode 100644
index 000000000..d858ae891
--- /dev/null
+++
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierConsumerAgent.java
@@ -0,0 +1,550 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink.tiered;
+
+import static org.apache.celeborn.plugin.flink.utils.Utils.checkArgument;
+import static org.apache.celeborn.plugin.flink.utils.Utils.checkNotNull;
+import static org.apache.celeborn.plugin.flink.utils.Utils.checkState;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Queue;
+import java.util.Set;
+import java.util.function.BiConsumer;
+import java.util.stream.Collectors;
+
+import javax.annotation.concurrent.GuardedBy;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import
org.apache.flink.runtime.io.network.partition.PartitionNotFoundException;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import
org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet;
+import
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageIdMappingUtils;
+import
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageInputChannelId;
+import
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStoragePartitionId;
+import
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageSubpartitionId;
+import
org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.AvailabilityNotifier;
+import
org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.TieredStorageConsumerSpec;
+import
org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.TieredStorageMemoryManager;
+import
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierConsumerAgent;
+import
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierShuffleDescriptor;
+import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
+import org.apache.flink.util.ExceptionUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.exception.DriverChangedException;
+import org.apache.celeborn.common.exception.PartitionUnRetryAbleException;
+import org.apache.celeborn.common.identity.UserIdentifier;
+import org.apache.celeborn.plugin.flink.RemoteShuffleResource;
+import org.apache.celeborn.plugin.flink.ShuffleResourceDescriptor;
+import org.apache.celeborn.plugin.flink.buffer.ReceivedNoHeaderBufferPacker;
+import org.apache.celeborn.plugin.flink.readclient.FlinkShuffleClientImpl;
+
+public class CelebornTierConsumerAgent implements TierConsumerAgent {
+
+ private static final Logger LOG =
LoggerFactory.getLogger(CelebornTierConsumerAgent.class);
+
+ private final CelebornConf conf;
+
+ private final int gateIndex;
+
+ private final List<TieredStorageConsumerSpec> consumerSpecs;
+
+ private final List<TierShuffleDescriptor> shuffleDescriptors;
+
+ /**
+ * partitionId -> subPartitionId -> reader, note that subPartitions may
share the same reader, as
+ * a single reader can consume multiple subPartitions to improvement
performance.
+ */
+ private final Map<
+ TieredStoragePartitionId, Map<TieredStorageSubpartitionId,
CelebornChannelBufferReader>>
+ bufferReaders;
+
+ /** Lock to protect {@link #receivedBuffers} and {@link #cause} and {@link
#closed}, etc. */
+ private final Object lock = new Object();
+
+ /** Received buffers from remote shuffle worker. It's consumed by upper
computing task. */
+ @GuardedBy("lock")
+ private final Map<TieredStoragePartitionId, Map<TieredStorageSubpartitionId,
Queue<Buffer>>>
+ receivedBuffers;
+
+ @GuardedBy("lock")
+ private final Set<Tuple2<TieredStoragePartitionId,
TieredStorageSubpartitionId>>
+ subPartitionsNeedNotifyAvailable;
+
+ @GuardedBy("lock")
+ private boolean started = false;
+
+ @GuardedBy("lock")
+ private Throwable cause;
+
+ /** Whether this remote input gate has been closed or not. */
+ @GuardedBy("lock")
+ private boolean closed;
+
+ private FlinkShuffleClientImpl shuffleClient;
+
+ /**
+ * The notify target is flink inputGate, used in notify input gate which
subPartition contain
+ * shuffle data that can to be read.
+ */
+ private AvailabilityNotifier availabilityNotifier;
+
+ private TieredStorageMemoryManager memoryManager;
+
+ public CelebornTierConsumerAgent(
+ CelebornConf conf,
+ List<TieredStorageConsumerSpec> tieredStorageConsumerSpecs,
+ List<TierShuffleDescriptor> shuffleDescriptors) {
+ checkArgument(!shuffleDescriptors.isEmpty(), "Wrong shuffle descriptors
size.");
+ checkArgument(
+ tieredStorageConsumerSpecs.size() == shuffleDescriptors.size(),
+ "Wrong consumer spec size.");
+ this.conf = conf;
+ this.gateIndex = tieredStorageConsumerSpecs.get(0).getGateIndex();
+ this.consumerSpecs = tieredStorageConsumerSpecs;
+ this.shuffleDescriptors = shuffleDescriptors;
+ this.bufferReaders = new HashMap<>();
+ this.receivedBuffers = new HashMap<>();
+ this.subPartitionsNeedNotifyAvailable = new HashSet<>();
+ for (TierShuffleDescriptor shuffleDescriptor : shuffleDescriptors) {
+ if (shuffleDescriptor instanceof TierShuffleDescriptorImpl) {
+ initShuffleClient((TierShuffleDescriptorImpl) shuffleDescriptor);
+ break;
+ }
+ }
+ checkNotNull(this.shuffleClient);
+ initBufferReaders();
+ }
+
+ @Override
+ public void setup(TieredStorageMemoryManager memoryManager) {
+ this.memoryManager = memoryManager;
+ for (Map<TieredStorageSubpartitionId, CelebornChannelBufferReader>
subPartitionReaders :
+ bufferReaders.values()) {
+ subPartitionReaders.forEach((partitionId, reader) ->
reader.setup(memoryManager));
+ }
+ }
+
+ @Override
+ public void start() {
+ // notify input gate that some sub partitions are available
+ Set<Tuple2<TieredStoragePartitionId, TieredStorageSubpartitionId>>
needNotifyAvailable;
+ synchronized (lock) {
+ needNotifyAvailable = new HashSet<>(subPartitionsNeedNotifyAvailable);
+ subPartitionsNeedNotifyAvailable.clear();
+ started = true;
+ }
+ try {
+ needNotifyAvailable.forEach(
+ partitionIdTuple -> notifyAvailable(partitionIdTuple.f0,
partitionIdTuple.f1));
+ } catch (Throwable t) {
+ LOG.error("Error occurred when notifying sub partitions available", t);
+ recycleAllResources();
+ ExceptionUtils.rethrow(t);
+ }
+ needNotifyAvailable.clear();
+
+ // Require segment 0 when starting the client
+ for (TieredStorageConsumerSpec spec : consumerSpecs) {
+ for (int subpartitionId : spec.getSubpartitionIds().values()) {
+ CelebornChannelBufferReader bufferReader =
+ getBufferReader(spec.getPartitionId(), new
TieredStorageSubpartitionId(subpartitionId));
+ if (bufferReader == null) {
+ continue;
+ }
+ // TODO: if fail to open reader, may the downstream task start before
than upstream task,
+ // should retry open reader, rather than throw exception
+ boolean openReaderSuccess = openReader(bufferReader);
+ if (!openReaderSuccess) {
+ LOG.error("Failed to open reader.");
+ recycleAllResources();
+ ExceptionUtils.rethrow(new IOException("Failed to open reader."));
+ }
+ bufferReader.notifyRequiredSegmentIfNeeded(0, subpartitionId);
+ }
+ }
+ }
+
+ @Override
+ public int peekNextBufferSubpartitionId(
+ TieredStoragePartitionId tieredStoragePartitionId,
+ ResultSubpartitionIndexSet resultSubpartitionIndexSet) {
+ synchronized (lock) {
+ // check health
+ healthCheck();
+
+ // return the subPartitionId if already receive buffer from
corresponding subpartition
+ Map<TieredStorageSubpartitionId, Queue<Buffer>>
subPartitionReceivedBuffers =
+ receivedBuffers.get(tieredStoragePartitionId);
+ if (subPartitionReceivedBuffers == null) {
+ return -1;
+ }
+ for (int subPartitionIndex = resultSubpartitionIndexSet.getStartIndex();
+ subPartitionIndex <= resultSubpartitionIndexSet.getEndIndex();
+ subPartitionIndex++) {
+ Queue<Buffer> buffers =
+ subPartitionReceivedBuffers.get(new
TieredStorageSubpartitionId(subPartitionIndex));
+ if (buffers != null && !buffers.isEmpty()) {
+ return subPartitionIndex;
+ }
+ }
+ }
+ return -1;
+ }
+
+ @Override
+ public Optional<Buffer> getNextBuffer(
+ TieredStoragePartitionId tieredStoragePartitionId,
+ TieredStorageSubpartitionId tieredStorageSubpartitionId,
+ int segmentId) {
+ synchronized (lock) {
+ // check health
+ healthCheck();
+ }
+
+ // check reader status
+ if (!bufferReaders.containsKey(tieredStoragePartitionId)
+ ||
!bufferReaders.get(tieredStoragePartitionId).containsKey(tieredStorageSubpartitionId))
{
+ return Optional.empty();
+ }
+ try {
+ boolean openReaderSuccess = openReader(tieredStoragePartitionId,
tieredStorageSubpartitionId);
+ if (!openReaderSuccess) {
+ return Optional.empty();
+ }
+ } catch (Throwable throwable) {
+ LOG.error("Failed to open reader.", throwable);
+ recycleAllResources();
+ ExceptionUtils.rethrow(throwable);
+ }
+
+ synchronized (lock) {
+ CelebornChannelBufferReader bufferReader =
+ getBufferReader(tieredStoragePartitionId,
tieredStorageSubpartitionId);
+ bufferReader.notifyRequiredSegmentIfNeeded(
+ segmentId, tieredStorageSubpartitionId.getSubpartitionId());
+ Map<TieredStorageSubpartitionId, Queue<Buffer>> partitionBuffers =
+ receivedBuffers.get(tieredStoragePartitionId);
+ if (partitionBuffers == null || partitionBuffers.isEmpty()) {
+ return Optional.empty();
+ }
+ Queue<Buffer> subPartitionBuffers =
partitionBuffers.get(tieredStorageSubpartitionId);
+ if (subPartitionBuffers == null || subPartitionBuffers.isEmpty()) {
+ return Optional.empty();
+ }
+ return Optional.ofNullable(subPartitionBuffers.poll());
+ }
+ }
+
+ @Override
+ public void registerAvailabilityNotifier(AvailabilityNotifier
availabilityNotifier) {
+ this.availabilityNotifier = availabilityNotifier;
+ LOG.info("Registered availability notifier for gate {}.", gateIndex);
+ }
+
+ @Override
+ public void updateTierShuffleDescriptor(
+ TieredStoragePartitionId tieredStoragePartitionId,
+ TieredStorageInputChannelId tieredStorageInputChannelId,
+ TieredStorageSubpartitionId subpartitionId,
+ TierShuffleDescriptor tierShuffleDescriptor) {
+ if (!(tierShuffleDescriptor instanceof TierShuffleDescriptorImpl)) {
+ return;
+ }
+ TierShuffleDescriptorImpl shuffleDescriptor = (TierShuffleDescriptorImpl)
tierShuffleDescriptor;
+ checkState(
+
shuffleDescriptor.getResultPartitionID().equals(tieredStoragePartitionId.getPartitionID()),
+ "Wrong result partition id: " +
shuffleDescriptor.getResultPartitionID());
+ ResultSubpartitionIndexSet subpartitionIndexSet =
+ new ResultSubpartitionIndexSet(subpartitionId.getSubpartitionId());
+ if (!bufferReaders.containsKey(tieredStoragePartitionId)
+ ||
!bufferReaders.get(tieredStoragePartitionId).containsKey(subpartitionId)) {
+ ShuffleResourceDescriptor shuffleResourceDescriptor =
+
shuffleDescriptor.getShuffleResource().getMapPartitionShuffleDescriptor();
+ createBufferReader(
+ shuffleResourceDescriptor,
+ tieredStoragePartitionId,
+ tieredStorageInputChannelId,
+ subpartitionIndexSet);
+ CelebornChannelBufferReader bufferReader =
+ checkNotNull(getBufferReader(tieredStoragePartitionId,
subpartitionId));
+ bufferReader.setup(checkNotNull(memoryManager));
+ openReader(bufferReader);
+ }
+ }
+
+ @Override
+ public void close() {
+ Throwable closeException = null;
+ // Do not check closed flag, thus to allow calling this method from both
task thread and
+ // cancel thread.
+ try {
+ recycleAllResources();
+ } catch (Throwable throwable) {
+ closeException = throwable;
+ LOG.error("Failed to recycle all resources.", throwable);
+ }
+ if (closeException != null) {
+ ExceptionUtils.rethrow(closeException);
+ }
+ }
+
+ private void initShuffleClient(TierShuffleDescriptorImpl
remoteShuffleDescriptor) {
+ RemoteShuffleResource shuffleResource =
remoteShuffleDescriptor.getShuffleResource();
+ try {
+ String appUniqueId = remoteShuffleDescriptor.getCelebornAppId();
+ this.shuffleClient =
+ FlinkShuffleClientImpl.get(
+ appUniqueId,
+ shuffleResource.getLifecycleManagerHost(),
+ shuffleResource.getLifecycleManagerPort(),
+ shuffleResource.getLifecycleManagerTimestamp(),
+ conf,
+ new UserIdentifier("default", "default"));
+ } catch (DriverChangedException e) {
+ throw new RuntimeException(e.getMessage());
+ }
+ }
+
+ private CelebornChannelBufferReader getBufferReader(
+ TieredStoragePartitionId tieredStoragePartitionId,
+ TieredStorageSubpartitionId tieredStorageSubpartitionId) {
+ return
bufferReaders.get(tieredStoragePartitionId).get(tieredStorageSubpartitionId);
+ }
+
+ private void recycleAllResources() {
+ List<Buffer> buffersToRecycle = new ArrayList<>();
+ for (Map<TieredStorageSubpartitionId, CelebornChannelBufferReader>
subPartitionReaders :
+ bufferReaders.values()) {
+ subPartitionReaders.values().forEach(CelebornChannelBufferReader::close);
+ }
+ synchronized (lock) {
+ for (Map<TieredStorageSubpartitionId, Queue<Buffer>> subPartitionMap :
+ receivedBuffers.values()) {
+ buffersToRecycle.addAll(
+ subPartitionMap.values().stream()
+ .flatMap(Queue::stream)
+ .collect(Collectors.toCollection(LinkedList::new)));
+ }
+ receivedBuffers.clear();
+ bufferReaders.clear();
+ availabilityNotifier = null;
+ closed = true;
+ }
+ try {
+ buffersToRecycle.forEach(Buffer::recycleBuffer);
+ } catch (Throwable throwable) {
+ LOG.error("Failed to recycle buffers.", throwable);
+ throw throwable;
+ }
+ }
+
+ private boolean openReader(
+ TieredStoragePartitionId partitionId, TieredStorageSubpartitionId
subPartitionId) {
+ CelebornChannelBufferReader bufferReader =
+
checkNotNull(checkNotNull(bufferReaders.get(partitionId)).get(subPartitionId));
+ return openReader(bufferReader);
+ }
+
+ private boolean openReader(CelebornChannelBufferReader bufferReader) {
+ if (!bufferReader.isOpened()) {
+ try {
+ bufferReader.open(0);
+ } catch (Exception e) {
+ // may throw PartitionUnRetryAbleException
+ recycleAllResources();
+ ExceptionUtils.rethrow(e);
+ }
+ }
+
+ return bufferReader.isOpened();
+ }
+
+ private void initBufferReaders() {
+ for (int i = 0; i < shuffleDescriptors.size(); i++) {
+ if (!(shuffleDescriptors.get(i) instanceof TierShuffleDescriptorImpl)) {
+ continue;
+ }
+ TierShuffleDescriptorImpl shuffleDescriptor =
+ (TierShuffleDescriptorImpl) shuffleDescriptors.get(i);
+ ResultPartitionID resultPartitionID =
shuffleDescriptor.getResultPartitionID();
+ ShuffleResourceDescriptor shuffleResourceDescriptor =
+
shuffleDescriptor.getShuffleResource().getMapPartitionShuffleDescriptor();
+ TieredStoragePartitionId partitionId = new
TieredStoragePartitionId(resultPartitionID);
+ checkState(consumerSpecs.get(i).getPartitionId().equals(partitionId),
"Wrong partition id.");
+ ResultSubpartitionIndexSet subPartitionIdSet =
consumerSpecs.get(i).getSubpartitionIds();
+ LOG.debug(
+ "create shuffle reader for gate {} descriptor {} partitionId {},
subPartitionId start {} and end {}",
+ gateIndex,
+ shuffleResourceDescriptor,
+ partitionId,
+ subPartitionIdSet.getStartIndex(),
+ subPartitionIdSet.getEndIndex());
+ createBufferReader(
+ shuffleResourceDescriptor,
+ partitionId,
+ consumerSpecs.get(i).getInputChannelId(),
+ subPartitionIdSet);
+ }
+ }
+
+ private void createBufferReader(
+ ShuffleResourceDescriptor shuffleDescriptor,
+ TieredStoragePartitionId partitionId,
+ TieredStorageInputChannelId inputChannelId,
+ ResultSubpartitionIndexSet subPartitionIdSet) {
+ // create a single reader for multiple subPartitions to improvement
performance
+ CelebornChannelBufferReader reader =
+ new CelebornChannelBufferReader(
+ shuffleClient,
+ shuffleDescriptor,
+ inputChannelId,
+ subPartitionIdSet.getStartIndex(),
+ subPartitionIdSet.getEndIndex(),
+ getDataListener(partitionId),
+ getFailureListener(partitionId));
+
+ for (int id = subPartitionIdSet.getStartIndex(); id <=
subPartitionIdSet.getEndIndex(); id++) {
+ TieredStorageSubpartitionId subPartitionId = new
TieredStorageSubpartitionId(id);
+ checkState(
+ !bufferReaders.containsKey(partitionId)
+ || !bufferReaders.get(partitionId).containsKey(subPartitionId),
+ "Duplicate shuffle reader.");
+ bufferReaders
+ .computeIfAbsent(partitionId, partition -> new HashMap<>())
+ .put(subPartitionId, reader);
+ }
+ }
+
+ @GuardedBy("lock")
+ private void healthCheck() {
+ if (closed || cause != null) {
+ Exception e;
+ if (closed) {
+ e = new IOException("Celeborn consumer agent already closed.");
+ } else {
+ e = new IOException(cause);
+ }
+ recycleAllResources();
+ LOG.error("Failed to check health.", e);
+ ExceptionUtils.rethrow(e);
+ }
+ }
+
+ private void onBuffer(
+ TieredStoragePartitionId partitionId,
+ TieredStorageSubpartitionId subPartitionId,
+ Buffer buffer) {
+ boolean wasEmpty;
+ synchronized (lock) {
+ if (closed || cause != null) {
+ buffer.recycleBuffer();
+ recycleAllResources();
+ throw new IllegalStateException("Input gate already closed or
failed.");
+ }
+ Queue<Buffer> buffers =
+ receivedBuffers
+ .computeIfAbsent(partitionId, partition -> new HashMap<>())
+ .computeIfAbsent(subPartitionId, subpartition -> new
LinkedList<>());
+ wasEmpty = buffers.isEmpty();
+ buffers.add(buffer);
+ if (wasEmpty && !started) {
+ subPartitionsNeedNotifyAvailable.add(Tuple2.of(partitionId,
subPartitionId));
+ return;
+ }
+ }
+ if (wasEmpty) {
+ notifyAvailable(partitionId, subPartitionId);
+ }
+ }
+
+ private BiConsumer<ByteBuf, TieredStorageSubpartitionId> getDataListener(
+ TieredStoragePartitionId partitionId) {
+ return (byteBuf, subPartitionId) -> {
+ Queue<Buffer> unpackedBuffers = null;
+ try {
+ unpackedBuffers = ReceivedNoHeaderBufferPacker.unpack(byteBuf);
+ while (!unpackedBuffers.isEmpty()) {
+ onBuffer(partitionId, subPartitionId, unpackedBuffers.poll());
+ }
+ } catch (Throwable throwable) {
+ synchronized (lock) {
+ LOG.error(
+ "Failed to process the received buffer, cause: {} throwable {}.",
+ cause == null ? "" : cause,
+ throwable);
+ if (cause == null) {
+ cause = throwable;
+ }
+ }
+ notifyAvailable(partitionId, subPartitionId);
+ if (unpackedBuffers != null) {
+ unpackedBuffers.forEach(Buffer::recycleBuffer);
+ }
+ recycleAllResources();
+ }
+ };
+ }
+
+ private BiConsumer<Throwable, TieredStorageSubpartitionId>
getFailureListener(
+ TieredStoragePartitionId partitionId) {
+ return (throwable, subPartitionId) -> {
+ synchronized (lock) {
+ // only record and process the first exception
+ if (cause != null) {
+ return;
+ }
+ Class<?> clazz = PartitionUnRetryAbleException.class;
+ if (throwable.getMessage() != null &&
throwable.getMessage().contains(clazz.getName())) {
+ cause =
+ new
PartitionNotFoundException(TieredStorageIdMappingUtils.convertId(partitionId));
+ LOG.error("The consumer agent received an
PartitionUnRetryAbleException.", throwable);
+ } else {
+ LOG.error("The consumer agent received an exception.", throwable);
+ cause = throwable;
+ }
+ }
+ // notify input gate, the input gate will call
peekNextBufferSubpartitionId or getNextBufer,
+ // and process exception
+ notifyAvailable(partitionId, subPartitionId);
+ };
+ }
+
+ private void notifyAvailable(
+ TieredStoragePartitionId partitionId, TieredStorageSubpartitionId
subPartitionId) {
+ Map<TieredStorageSubpartitionId, CelebornChannelBufferReader>
subPartitionReaders =
+ bufferReaders.get(partitionId);
+ if (subPartitionReaders != null) {
+ CelebornChannelBufferReader channelBufferReader =
subPartitionReaders.get(subPartitionId);
+ if (channelBufferReader != null) {
+ availabilityNotifier.notifyAvailable(partitionId,
channelBufferReader.getInputChannelId());
+ }
+ }
+ }
+}
diff --git
a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierFactory.java
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierFactory.java
index 02306a5ad..1a86130e4 100644
---
a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierFactory.java
+++
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierFactory.java
@@ -118,8 +118,7 @@ public class CelebornTierFactory implements TierFactory {
List<TieredStorageConsumerSpec> tieredStorageConsumerSpecs,
List<TierShuffleDescriptor> shuffleDescriptors,
TieredStorageNettyService nettyService) {
- // TODO impl this in the follow-up PR.
- return null;
+ return new CelebornTierConsumerAgent(conf, tieredStorageConsumerSpecs,
shuffleDescriptors);
}
public static String getCelebornTierName() {
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/protocol/ReadData.java
b/common/src/main/java/org/apache/celeborn/common/network/protocol/ReadData.java
index 6e465ccc5..1be7b46ff 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/protocol/ReadData.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/protocol/ReadData.java
@@ -26,7 +26,7 @@ import
org.apache.celeborn.common.network.buffer.NettyManagedBuffer;
// This is buffer wrapper used in celeborn worker only
// It doesn't need decode in worker.
public class ReadData extends RequestMessage {
- private long streamId;
+ protected long streamId;
public ReadData(long streamId, ByteBuf buf) {
super(new NettyManagedBuffer(buf));
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/protocol/SubPartitionReadData.java
b/common/src/main/java/org/apache/celeborn/common/network/protocol/SubPartitionReadData.java
index bdea84989..11a13118d 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/protocol/SubPartitionReadData.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/protocol/SubPartitionReadData.java
@@ -22,38 +22,30 @@ import java.util.Objects;
import io.netty.buffer.ByteBuf;
-import org.apache.celeborn.common.network.buffer.NettyManagedBuffer;
-
/**
* Comparing {@link ReadData}, this class has an additional field of
subpartitionId. This class is
* added to keep the backward compatibility.
*/
-public class SubPartitionReadData extends RequestMessage {
- private long streamId;
+public class SubPartitionReadData extends ReadData {
private int subPartitionId;
public SubPartitionReadData(long streamId, int subPartitionId, ByteBuf buf) {
- super(new NettyManagedBuffer(buf));
- this.streamId = streamId;
+ super(streamId, buf);
this.subPartitionId = subPartitionId;
}
@Override
public int encodedLength() {
- return 8 + 4;
+ return super.encodedLength() + 4;
}
@Override
public void encode(ByteBuf buf) {
- buf.writeLong(streamId);
+ super.encode(buf);
buf.writeInt(subPartitionId);
}
- public long getStreamId() {
- return streamId;
- }
-
public int getSubPartitionId() {
return subPartitionId;
}
@@ -68,8 +60,8 @@ public class SubPartitionReadData extends RequestMessage {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
SubPartitionReadData readData = (SubPartitionReadData) o;
- return streamId == readData.streamId
- && subPartitionId == readData.subPartitionId
+ return streamId == readData.getStreamId()
+ && subPartitionId == readData.getSubPartitionId()
&& super.equals(o);
}