This is an automated email from the ASF dual-hosted git repository.
roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new 3b8d6c86 [#133] feat(netty): Implement ShuffleServer interface. (#879)
3b8d6c86 is described below
commit 3b8d6c86b7a2c89ec05ffa67928df640aba15ac3
Author: Xianming Lei <[email protected]>
AuthorDate: Wed May 17 16:29:38 2023 +0800
[#133] feat(netty): Implement ShuffleServer interface. (#879)
### What changes were proposed in this pull request?
Implement ShuffleServer interface.
### Why are the changes needed?
For #133.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
existing UTs.
Co-authored-by: leixianming <[email protected]>
---
.../apache/uniffle/common/ShuffleDataResult.java | 23 +-
.../uniffle/common/ShufflePartitionedBlock.java | 29 +-
.../common/netty/client/TransportClient.java | 4 +
.../netty/client/TransportClientFactory.java | 14 +-
.../common/netty/client/TransportContext.java | 34 +-
.../common/netty/handle/BaseMessageHandler.java | 32 +-
.../common/netty/handle/MessageHandler.java | 36 +-
.../netty/handle/TransportChannelHandler.java | 139 +++++++
.../netty/handle/TransportRequestHandler.java | 61 +++
.../netty/handle/TransportResponseHandler.java | 94 ++++-
.../apache/uniffle/common/util/ByteBufUtils.java | 5 +
.../org/apache/uniffle/common/util/Constants.java | 2 +
.../common/ShufflePartitionedBlockTest.java | 4 +-
.../uniffle/common/util/ByteBufUtilsTest.java | 26 ++
.../uniffle/server/buffer/ShuffleBuffer.java | 16 +-
.../server/netty/ShuffleServerNettyHandler.java | 407 +++++++++++++++++++++
.../apache/uniffle/server/netty/StreamServer.java | 22 +-
.../uniffle/server/ShuffleFlushManagerTest.java | 5 +-
.../server/buffer/ShuffleBufferManagerTest.java | 9 +-
.../uniffle/server/buffer/ShuffleBufferTest.java | 3 +-
.../server/storage/MultiStorageManagerTest.java | 9 +-
.../StorageManagerFallbackStrategyTest.java | 6 +-
.../handler/impl/HdfsShuffleWriteHandler.java | 3 +-
.../handler/impl/LocalFileWriteHandler.java | 4 +-
.../storage/HdfsShuffleHandlerTestBase.java | 3 +-
.../handler/impl/LocalFileHandlerTestBase.java | 5 +-
26 files changed, 864 insertions(+), 131 deletions(-)
diff --git
a/common/src/main/java/org/apache/uniffle/common/ShuffleDataResult.java
b/common/src/main/java/org/apache/uniffle/common/ShuffleDataResult.java
index 19e93dd1..98a8b1d7 100644
--- a/common/src/main/java/org/apache/uniffle/common/ShuffleDataResult.java
+++ b/common/src/main/java/org/apache/uniffle/common/ShuffleDataResult.java
@@ -21,10 +21,14 @@ import java.nio.ByteBuffer;
import java.util.List;
import com.google.common.collect.Lists;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+import org.apache.uniffle.common.util.ByteBufUtils;
public class ShuffleDataResult {
- private final ByteBuffer data;
+ private final ByteBuf data;
private final List<BufferSegment> bufferSegments;
public ShuffleDataResult() {
@@ -36,6 +40,11 @@ public class ShuffleDataResult {
}
public ShuffleDataResult(ByteBuffer data, List<BufferSegment>
bufferSegments) {
+ this.data = data != null ? Unpooled.wrappedBuffer(data) :
Unpooled.EMPTY_BUFFER;
+ this.bufferSegments = bufferSegments;
+ }
+
+ public ShuffleDataResult(ByteBuf data, List<BufferSegment> bufferSegments) {
this.data = data;
this.bufferSegments = bufferSegments;
}
@@ -51,16 +60,17 @@ public class ShuffleDataResult {
if (data.hasArray()) {
return data.array();
}
- ByteBuffer dataBuffer = data.duplicate();
- byte[] byteArray = new byte[dataBuffer.remaining()];
- dataBuffer.get(byteArray);
- return byteArray;
+ return ByteBufUtils.readBytes(data);
}
- public ByteBuffer getDataBuffer() {
+ public ByteBuf getDataBuf() {
return data;
}
+ public ByteBuffer getDataBuffer() {
+ return data.nioBuffer();
+ }
+
public List<BufferSegment> getBufferSegments() {
return bufferSegments;
}
@@ -68,5 +78,4 @@ public class ShuffleDataResult {
public boolean isEmpty() {
return bufferSegments == null || bufferSegments.isEmpty() || data == null
|| data.capacity() == 0;
}
-
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/ShufflePartitionedBlock.java
b/common/src/main/java/org/apache/uniffle/common/ShufflePartitionedBlock.java
index 9bee3fe3..04ba6542 100644
---
a/common/src/main/java/org/apache/uniffle/common/ShufflePartitionedBlock.java
+++
b/common/src/main/java/org/apache/uniffle/common/ShufflePartitionedBlock.java
@@ -17,16 +17,18 @@
package org.apache.uniffle.common;
-import java.util.Arrays;
import java.util.Objects;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
public class ShufflePartitionedBlock {
private int length;
private long crc;
private long blockId;
private int uncompressLength;
- private byte[] data;
+ private ByteBuf data;
private long taskAttemptId;
public ShufflePartitionedBlock(
@@ -41,6 +43,21 @@ public class ShufflePartitionedBlock {
this.blockId = blockId;
this.uncompressLength = uncompressLength;
this.taskAttemptId = taskAttemptId;
+ this.data = data == null ? Unpooled.EMPTY_BUFFER :
Unpooled.wrappedBuffer(data);
+ }
+
+ public ShufflePartitionedBlock(
+ int length,
+ int uncompressLength,
+ long crc,
+ long blockId,
+ long taskAttemptId,
+ ByteBuf data) {
+ this.length = length;
+ this.crc = crc;
+ this.blockId = blockId;
+ this.uncompressLength = uncompressLength;
+ this.taskAttemptId = taskAttemptId;
this.data = data;
}
@@ -62,12 +79,12 @@ public class ShufflePartitionedBlock {
return length == that.length
&& crc == that.crc
&& blockId == that.blockId
- && Arrays.equals(data, that.data);
+ && data.equals(that.data);
}
@Override
public int hashCode() {
- return Objects.hash(length, crc, blockId, Arrays.hashCode(data));
+ return Objects.hash(length, crc, blockId, data);
}
public int getLength() {
@@ -94,11 +111,11 @@ public class ShufflePartitionedBlock {
this.blockId = blockId;
}
- public byte[] getData() {
+ public ByteBuf getData() {
return data;
}
- public void setData(byte[] data) {
+ public void setData(ByteBuf data) {
this.data = data;
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClient.java
b/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClient.java
index c82408af..01846a2f 100644
---
a/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClient.java
+++
b/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClient.java
@@ -165,4 +165,8 @@ public class TransportClient implements Closeable {
channel.close().awaitUninterruptibly(10, TimeUnit.SECONDS);
}
+ public void timeOut() {
+ this.timedOut = true;
+ }
+
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClientFactory.java
b/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClientFactory.java
index c8056151..3e9e706a 100644
---
a/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClientFactory.java
+++
b/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClientFactory.java
@@ -40,7 +40,7 @@ import org.slf4j.LoggerFactory;
import org.apache.uniffle.common.netty.IOMode;
import org.apache.uniffle.common.netty.TransportFrameDecoder;
-import org.apache.uniffle.common.netty.handle.TransportResponseHandler;
+import org.apache.uniffle.common.netty.handle.TransportChannelHandler;
import org.apache.uniffle.common.util.JavaUtils;
import org.apache.uniffle.common.util.NettyUtils;
@@ -121,8 +121,11 @@ public class TransportClientFactory implements Closeable {
// Make sure that the channel will not timeout by updating the last use
time of the
// handler. Then check that the client is still alive, in case it timed
out before
// this code was able to update things.
- TransportResponseHandler handler =
-
cachedClient.getChannel().pipeline().get(TransportResponseHandler.class);
+ TransportChannelHandler handler =
+
cachedClient.getChannel().pipeline().get(TransportChannelHandler.class);
+ synchronized (handler) {
+ handler.getResponseHandler().updateTimeOfLastRequest();
+ }
if (cachedClient.isActive()) {
logger.trace(
@@ -197,9 +200,8 @@ public class TransportClientFactory implements Closeable {
new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel ch) {
- TransportResponseHandler transportResponseHandler =
context.initializePipeline(ch, decoder);
- TransportClient client = new TransportClient(ch,
transportResponseHandler);
- clientRef.set(client);
+ TransportChannelHandler transportResponseHandler =
context.initializePipeline(ch, decoder);
+ clientRef.set(transportResponseHandler.getClient());
channelRef.set(ch);
}
});
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/client/TransportContext.java
b/common/src/main/java/org/apache/uniffle/common/netty/client/TransportContext.java
index 134b633a..5acb46dd 100644
---
a/common/src/main/java/org/apache/uniffle/common/netty/client/TransportContext.java
+++
b/common/src/main/java/org/apache/uniffle/common/netty/client/TransportContext.java
@@ -17,6 +17,7 @@
package org.apache.uniffle.common.netty.client;
+import io.netty.channel.Channel;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.timeout.IdleStateHandler;
@@ -24,34 +25,59 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.uniffle.common.netty.MessageEncoder;
+import org.apache.uniffle.common.netty.handle.BaseMessageHandler;
+import org.apache.uniffle.common.netty.handle.TransportChannelHandler;
+import org.apache.uniffle.common.netty.handle.TransportRequestHandler;
import org.apache.uniffle.common.netty.handle.TransportResponseHandler;
public class TransportContext {
private static final Logger logger =
LoggerFactory.getLogger(TransportContext.class);
private TransportConf transportConf;
+ private final BaseMessageHandler msgHandler;
+ private boolean closeIdleConnections;
private static final MessageEncoder ENCODER = MessageEncoder.INSTANCE;
public TransportContext(TransportConf transportConf) {
+ this(transportConf, true);
+ }
+
+ public TransportContext(TransportConf transportConf, boolean
closeIdleConnections) {
+ this(transportConf, null, closeIdleConnections);
+ }
+
+ public TransportContext(TransportConf transportConf, BaseMessageHandler
msgHandler, boolean closeIdleConnections) {
this.transportConf = transportConf;
+ this.msgHandler = msgHandler;
+ this.closeIdleConnections = closeIdleConnections;
}
public TransportClientFactory createClientFactory() {
return new TransportClientFactory(this);
}
- public TransportResponseHandler initializePipeline(
+ public TransportChannelHandler initializePipeline(
SocketChannel channel, ChannelInboundHandlerAdapter decoder) {
- TransportResponseHandler responseHandler = new
TransportResponseHandler(channel);
+ TransportChannelHandler channelHandler = createChannelHandler(channel,
msgHandler);
channel
.pipeline()
.addLast("encoder", ENCODER) // out
.addLast("decoder", decoder) // in
.addLast(
"idleStateHandler", new IdleStateHandler(0, 0,
transportConf.connectionTimeoutMs() / 1000))
- .addLast("responseHandler", responseHandler);
- return responseHandler;
+ .addLast("responseHandler", channelHandler);
+ return channelHandler;
+ }
+
+ private TransportChannelHandler createChannelHandler(
+ Channel channel, BaseMessageHandler msgHandler) {
+ TransportResponseHandler responseHandler = new
TransportResponseHandler(channel);
+ TransportClient client = new TransportClient(channel, responseHandler);
+ TransportRequestHandler requestHandler =
+ new TransportRequestHandler(client, msgHandler);
+ return new TransportChannelHandler(
+ client, responseHandler, requestHandler,
transportConf.connectionTimeoutMs(), closeIdleConnections);
}
public TransportConf getConf() {
diff --git
a/server/src/main/java/org/apache/uniffle/server/netty/decoder/StreamServerInitDecoder.java
b/common/src/main/java/org/apache/uniffle/common/netty/handle/BaseMessageHandler.java
similarity index 54%
copy from
server/src/main/java/org/apache/uniffle/server/netty/decoder/StreamServerInitDecoder.java
copy to
common/src/main/java/org/apache/uniffle/common/netty/handle/BaseMessageHandler.java
index d2551e55..91ba7a01 100644
---
a/server/src/main/java/org/apache/uniffle/server/netty/decoder/StreamServerInitDecoder.java
+++
b/common/src/main/java/org/apache/uniffle/common/netty/handle/BaseMessageHandler.java
@@ -15,34 +15,14 @@
* limitations under the License.
*/
-package org.apache.uniffle.server.netty.decoder;
+package org.apache.uniffle.common.netty.handle;
-import java.util.List;
+import org.apache.uniffle.common.netty.client.TransportClient;
+import org.apache.uniffle.common.netty.protocol.RequestMessage;
-import io.netty.buffer.ByteBuf;
-import io.netty.channel.ChannelHandlerContext;
-import io.netty.handler.codec.ByteToMessageDecoder;
+public interface BaseMessageHandler {
-public class StreamServerInitDecoder extends ByteToMessageDecoder {
+ void receive(TransportClient client, RequestMessage msg);
- public StreamServerInitDecoder() {
- }
-
- private void addDecoder(ChannelHandlerContext ctx, byte type) {
-
- }
-
- @Override
- protected void decode(ChannelHandlerContext ctx,
- ByteBuf in,
- List<Object> out) {
- if (in.readableBytes() < Byte.BYTES) {
- return;
- }
- in.markReaderIndex();
- byte magicByte = in.readByte();
- in.resetReaderIndex();
-
- addDecoder(ctx, magicByte);
- }
+ void exceptionCaught(Throwable cause, TransportClient client);
}
diff --git
a/server/src/main/java/org/apache/uniffle/server/netty/decoder/StreamServerInitDecoder.java
b/common/src/main/java/org/apache/uniffle/common/netty/handle/MessageHandler.java
similarity index 54%
rename from
server/src/main/java/org/apache/uniffle/server/netty/decoder/StreamServerInitDecoder.java
rename to
common/src/main/java/org/apache/uniffle/common/netty/handle/MessageHandler.java
index d2551e55..6712ddfc 100644
---
a/server/src/main/java/org/apache/uniffle/server/netty/decoder/StreamServerInitDecoder.java
+++
b/common/src/main/java/org/apache/uniffle/common/netty/handle/MessageHandler.java
@@ -15,34 +15,20 @@
* limitations under the License.
*/
-package org.apache.uniffle.server.netty.decoder;
+package org.apache.uniffle.common.netty.handle;
-import java.util.List;
+import org.apache.uniffle.common.netty.protocol.Message;
-import io.netty.buffer.ByteBuf;
-import io.netty.channel.ChannelHandlerContext;
-import io.netty.handler.codec.ByteToMessageDecoder;
+public abstract class MessageHandler<T extends Message> {
+ /** Handles the receipt of a single message. */
+ public abstract void handle(T message) throws Exception;
-public class StreamServerInitDecoder extends ByteToMessageDecoder {
+ /** Invoked when the channel this MessageHandler is on is active. */
+ public abstract void channelActive();
- public StreamServerInitDecoder() {
- }
+ /** Invoked when an exception was caught on the Channel. */
+ public abstract void exceptionCaught(Throwable cause);
- private void addDecoder(ChannelHandlerContext ctx, byte type) {
-
- }
-
- @Override
- protected void decode(ChannelHandlerContext ctx,
- ByteBuf in,
- List<Object> out) {
- if (in.readableBytes() < Byte.BYTES) {
- return;
- }
- in.markReaderIndex();
- byte magicByte = in.readByte();
- in.resetReaderIndex();
-
- addDecoder(ctx, magicByte);
- }
+ /** Invoked when the channel this MessageHandler is on is inactive. */
+ public abstract void channelInactive();
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/handle/TransportChannelHandler.java
b/common/src/main/java/org/apache/uniffle/common/netty/handle/TransportChannelHandler.java
new file mode 100644
index 00000000..f7fb9f33
--- /dev/null
+++
b/common/src/main/java/org/apache/uniffle/common/netty/handle/TransportChannelHandler.java
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.handle;
+
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.handler.timeout.IdleState;
+import io.netty.handler.timeout.IdleStateEvent;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.netty.client.TransportClient;
+import org.apache.uniffle.common.netty.protocol.RequestMessage;
+import org.apache.uniffle.common.netty.protocol.RpcResponse;
+import org.apache.uniffle.common.util.NettyUtils;
+
+public class TransportChannelHandler extends ChannelInboundHandlerAdapter {
+ private static final Logger logger =
LoggerFactory.getLogger(TransportChannelHandler.class);
+
+ private final TransportClient client;
+ private final TransportResponseHandler responseHandler;
+ private final TransportRequestHandler requestHandler;
+ private final long requestTimeoutNs;
+ private final boolean closeIdleConnections;
+
+ public TransportChannelHandler(
+ TransportClient client,
+ TransportResponseHandler responseHandler,
+ TransportRequestHandler requestHandler,
+ long requestTimeoutMs,
+ boolean closeIdleConnections) {
+ this.client = client;
+ this.responseHandler = responseHandler;
+ this.requestHandler = requestHandler;
+ this.requestTimeoutNs = requestTimeoutMs * 1000L * 1000;
+ this.closeIdleConnections = closeIdleConnections;
+ }
+
+ public TransportClient getClient() {
+ return client;
+ }
+
+ @Override
+ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
throws Exception {
+ logger.warn(
+ "Exception in connection from " +
NettyUtils.getRemoteAddress(ctx.channel()), cause);
+ requestHandler.exceptionCaught(cause);
+ responseHandler.exceptionCaught(cause);
+ ctx.close();
+ }
+
+ @Override
+ public void channelActive(ChannelHandlerContext ctx) throws Exception {
+ try {
+ requestHandler.channelActive();
+ } catch (RuntimeException e) {
+ logger.error("Exception from request handler while channel is active",
e);
+ }
+ try {
+ responseHandler.channelActive();
+ } catch (RuntimeException e) {
+ logger.error("Exception from response handler while channel is active",
e);
+ }
+ super.channelActive(ctx);
+ }
+
+ @Override
+ public void channelInactive(ChannelHandlerContext ctx) throws Exception {
+ try {
+ requestHandler.channelInactive();
+ } catch (RuntimeException e) {
+ logger.error("Exception from request handler while channel is inactive",
e);
+ }
+ try {
+ responseHandler.channelInactive();
+ } catch (RuntimeException e) {
+ logger.error("Exception from response handler while channel is
inactive", e);
+ }
+ super.channelInactive(ctx);
+ }
+
+ @Override
+ public void channelRead(ChannelHandlerContext ctx, Object request) throws
Exception {
+ if (request instanceof RequestMessage) {
+ requestHandler.handle((RequestMessage) request);
+ } else if (request instanceof RpcResponse) {
+ responseHandler.handle((RpcResponse) request);
+ } else {
+ ctx.fireChannelRead(request);
+ }
+ }
+
+ /** Triggered based on events from an {@link
io.netty.handler.timeout.IdleStateHandler}. */
+ @Override
+ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
+ if (evt instanceof IdleStateEvent) {
+ IdleStateEvent e = (IdleStateEvent) evt;
+ synchronized (this) {
+ boolean isActuallyOverdue =
+ System.nanoTime() - responseHandler.getTimeOfLastRequestNs() >
requestTimeoutNs;
+ if (e.state() == IdleState.ALL_IDLE && isActuallyOverdue) {
+ if (responseHandler.numOutstandingRequests() > 0) {
+ String address = NettyUtils.getRemoteAddress(ctx.channel());
+ logger.error(
+ "Connection to {} has been quiet for {} ms while there are
outstanding "
+ + "requests.",
+ address,
+ requestTimeoutNs / 1000 / 1000);
+ }
+ if (closeIdleConnections) {
+ // While CloseIdleConnections is enable, we also close idle
connection
+ client.timeOut();
+ ctx.close();
+ }
+ }
+ }
+ }
+ ctx.fireUserEventTriggered(evt);
+ }
+
+ public TransportResponseHandler getResponseHandler() {
+ return responseHandler;
+ }
+}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/handle/TransportRequestHandler.java
b/common/src/main/java/org/apache/uniffle/common/netty/handle/TransportRequestHandler.java
new file mode 100644
index 00000000..795cf270
--- /dev/null
+++
b/common/src/main/java/org/apache/uniffle/common/netty/handle/TransportRequestHandler.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.handle;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.netty.client.TransportClient;
+import org.apache.uniffle.common.netty.protocol.RequestMessage;
+
+public class TransportRequestHandler extends MessageHandler<RequestMessage> {
+
+ private static final Logger logger =
LoggerFactory.getLogger(TransportRequestHandler.class);
+
+ /** Client on the same channel allowing us to talk back to the requester. */
+ private final TransportClient reverseClient;
+
+ /** Handles all RPC messages. */
+ private final BaseMessageHandler msgHandler;
+
+ public TransportRequestHandler(TransportClient reverseClient,
BaseMessageHandler msgHandler) {
+ this.reverseClient = reverseClient;
+ this.msgHandler = msgHandler;
+ }
+
+ @Override
+ public void exceptionCaught(Throwable cause) {
+ msgHandler.exceptionCaught(cause, reverseClient);
+ }
+
+ @Override
+ public void channelActive() {
+ logger.debug("channelActive: {}", reverseClient.getSocketAddress());
+ }
+
+ @Override
+ public void channelInactive() {
+ logger.debug("channelInactive: {}", reverseClient.getSocketAddress());
+ }
+
+ @Override
+ public void handle(RequestMessage request) {
+ msgHandler.receive(reverseClient, request);
+ }
+}
+
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/handle/TransportResponseHandler.java
b/common/src/main/java/org/apache/uniffle/common/netty/handle/TransportResponseHandler.java
index 86dd9953..b0681503 100644
---
a/common/src/main/java/org/apache/uniffle/common/netty/handle/TransportResponseHandler.java
+++
b/common/src/main/java/org/apache/uniffle/common/netty/handle/TransportResponseHandler.java
@@ -17,56 +17,110 @@
package org.apache.uniffle.common.netty.handle;
+import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicLong;
import io.netty.channel.Channel;
-import io.netty.channel.ChannelHandlerContext;
-import io.netty.channel.ChannelInboundHandlerAdapter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.netty.client.RpcResponseCallback;
import org.apache.uniffle.common.netty.protocol.RpcResponse;
import org.apache.uniffle.common.util.NettyUtils;
-public class TransportResponseHandler extends ChannelInboundHandlerAdapter {
+public class TransportResponseHandler extends MessageHandler<RpcResponse> {
private static final Logger logger =
LoggerFactory.getLogger(TransportResponseHandler.class);
private Map<Long, RpcResponseCallback> outstandingRpcRequests;
private Channel channel;
+ /** Records the time (in system nanoseconds) that the last fetch or RPC
request was sent. */
+ private final AtomicLong timeOfLastRequestNs;
+
public TransportResponseHandler(Channel channel) {
this.channel = channel;
this.outstandingRpcRequests = new ConcurrentHashMap<>();
+ this.timeOfLastRequestNs = new AtomicLong(0);
+ }
+
+ public void addResponseCallback(long requestId, RpcResponseCallback
callback) {
+ updateTimeOfLastRequest();
+ if (outstandingRpcRequests.containsKey(requestId)) {
+ logger.warn("[addRpcRequest] requestId {} already exists!", requestId);
+ }
+ outstandingRpcRequests.put(requestId, callback);
+ }
+
+ public void removeRpcRequest(long requestId) {
+ outstandingRpcRequests.remove(requestId);
}
@Override
- public void channelRead(ChannelHandlerContext ctx, Object msg) throws
Exception {
- if (msg instanceof RpcResponse) {
- RpcResponse responseMessage = (RpcResponse) msg;
- RpcResponseCallback listener =
outstandingRpcRequests.get(responseMessage.getRequestId());
- if (listener == null) {
- logger.warn("Ignoring response from {} since it is not outstanding",
- NettyUtils.getRemoteAddress(channel));
- } else {
- listener.onSuccess(responseMessage);
- }
+ public void handle(RpcResponse message) throws Exception {
+ RpcResponseCallback listener =
outstandingRpcRequests.get(message.getRequestId());
+ if (listener == null) {
+ logger.warn("Ignoring response from {} since it is not outstanding",
+ NettyUtils.getRemoteAddress(channel));
} else {
- throw new RssException("receive unexpected message!");
+ listener.onSuccess(message);
}
- super.channelRead(ctx, msg);
}
- public void addResponseCallback(long requestId, RpcResponseCallback
callback) {
- outstandingRpcRequests.put(requestId, callback);
+ @Override
+ public void channelActive() {
+
}
- public void removeRpcRequest(long requestId) {
- outstandingRpcRequests.remove(requestId);
+ @Override
+ public void exceptionCaught(Throwable cause) {
+ if (numOutstandingRequests() > 0) {
+ String remoteAddress = NettyUtils.getRemoteAddress(channel);
+ logger.error(
+ "Still have {} requests outstanding when connection from {} is
closed",
+ numOutstandingRequests(),
+ remoteAddress);
+ failOutstandingRequests(cause);
+ }
+ }
+
+ @Override
+ public void channelInactive() {
+ if (numOutstandingRequests() > 0) {
+ String remoteAddress = NettyUtils.getRemoteAddress(channel);
+ logger.error(
+ "Still have {} requests outstanding when connection from {} is
closed",
+ numOutstandingRequests(),
+ remoteAddress);
+ failOutstandingRequests(new IOException("Connection from " +
remoteAddress + " closed"));
+ }
+ }
+
+ public int numOutstandingRequests() {
+ return outstandingRpcRequests.size();
+ }
+
+ private void failOutstandingRequests(Throwable cause) {
+ for (Map.Entry<Long, RpcResponseCallback> entry :
outstandingRpcRequests.entrySet()) {
+ try {
+ entry.getValue().onFailure(cause);
+ } catch (Exception e) {
+ logger.warn("RpcResponseCallback.onFailure throws exception", e);
+ }
+ }
+
+ outstandingRpcRequests.clear();
}
+ /** Returns the time in nanoseconds of when the last request was sent out. */
+ public long getTimeOfLastRequestNs() {
+ return timeOfLastRequestNs.get();
+ }
+ /** Updates the time of the last request to the current system time. */
+ public void updateTimeOfLastRequest() {
+ timeOfLastRequestNs.set(System.nanoTime());
+ }
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/util/ByteBufUtils.java
b/common/src/main/java/org/apache/uniffle/common/util/ByteBufUtils.java
index afdfae27..8288df8f 100644
--- a/common/src/main/java/org/apache/uniffle/common/util/ByteBufUtils.java
+++ b/common/src/main/java/org/apache/uniffle/common/util/ByteBufUtils.java
@@ -80,4 +80,9 @@ public class ByteBufUtils {
buf.resetReaderIndex();
return bytes;
}
+
+ public static void readBytes(ByteBuf from, byte[] to, int offset, int
length) {
+ from.readBytes(to, offset, length);
+ from.resetReaderIndex();
+ }
}
diff --git a/common/src/main/java/org/apache/uniffle/common/util/Constants.java
b/common/src/main/java/org/apache/uniffle/common/util/Constants.java
index af959c5b..6fdd2754 100644
--- a/common/src/main/java/org/apache/uniffle/common/util/Constants.java
+++ b/common/src/main/java/org/apache/uniffle/common/util/Constants.java
@@ -75,4 +75,6 @@ public final class Constants {
public static final String DEVICE_NO_SPACE_ERROR_MESSAGE = "No space left on
device";
public static final String NETTY_STREAM_SERVICE_NAME = "netty.rpc.server";
public static final String GRPC_SERVICE_NAME = "grpc.server";
+
+ public static final int COMPOSITE_BYTE_BUF_MAX_COMPONENTS = 1024;
}
diff --git
a/common/src/test/java/org/apache/uniffle/common/ShufflePartitionedBlockTest.java
b/common/src/test/java/org/apache/uniffle/common/ShufflePartitionedBlockTest.java
index d8e660d4..569643bd 100644
---
a/common/src/test/java/org/apache/uniffle/common/ShufflePartitionedBlockTest.java
+++
b/common/src/test/java/org/apache/uniffle/common/ShufflePartitionedBlockTest.java
@@ -23,6 +23,8 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
+import org.apache.uniffle.common.util.ByteBufUtils;
+
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
@@ -40,7 +42,7 @@ public class ShufflePartitionedBlockTest {
assertEquals(3, b1.getBlockId());
ShufflePartitionedBlock b3 = new ShufflePartitionedBlock(1, 1, 2, 3, 3,
buf);
- assertArrayEquals(buf, b3.getData());
+ assertArrayEquals(buf, ByteBufUtils.readBytes(b3.getData()));
}
@Test
diff --git
a/common/src/test/java/org/apache/uniffle/common/util/ByteBufUtilsTest.java
b/common/src/test/java/org/apache/uniffle/common/util/ByteBufUtilsTest.java
index 3f60eaef..fa2f6099 100644
--- a/common/src/test/java/org/apache/uniffle/common/util/ByteBufUtilsTest.java
+++ b/common/src/test/java/org/apache/uniffle/common/util/ByteBufUtilsTest.java
@@ -17,7 +17,10 @@
package org.apache.uniffle.common.util;
+import java.nio.charset.StandardCharsets;
+
import io.netty.buffer.ByteBuf;
+import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled;
import org.junit.jupiter.api.Test;
@@ -42,5 +45,28 @@ public class ByteBufUtilsTest {
byteBuf.clear();
ByteBufUtils.writeLengthAndString(byteBuf, null);
assertNull(ByteBufUtils.readLengthAndString(byteBuf));
+
+ byteBuf.clear();
+ ByteBufUtils.writeLengthAndString(byteBuf, expectedString);
+ ByteBuf byteBuf1 = Unpooled.buffer(100);
+ ByteBufUtils.writeLengthAndString(byteBuf1, expectedString);
+ final int expectedLength = byteBuf.readableBytes() +
byteBuf1.readableBytes();
+ CompositeByteBuf compositeByteBuf = Unpooled.compositeBuffer();
+ compositeByteBuf.addComponent(true, byteBuf);
+ compositeByteBuf.addComponent(true, byteBuf1);
+
+ ByteBuf res = Unpooled.buffer(100);
+ ByteBufUtils.copyByteBuf(compositeByteBuf, res);
+ assertEquals(expectedLength, res.readableBytes() - Integer.BYTES);
+
+ res.clear();
+ ByteBufUtils.copyByteBuf(compositeByteBuf, res);
+ assertEquals(expectedLength, res.readableBytes() - Integer.BYTES);
+
+ byteBuf.clear();
+ byte[] bytes = expectedString.getBytes(StandardCharsets.UTF_8);
+ byteBuf.writeBytes(bytes);
+ ByteBufUtils.readBytes(byteBuf, bytes, 1, byteBuf.readableBytes() - 1);
+ assertEquals("ttest_st", new String(bytes, StandardCharsets.UTF_8));
}
}
diff --git
a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBuffer.java
b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBuffer.java
index 78d76b38..fcb6567e 100644
--- a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBuffer.java
+++ b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBuffer.java
@@ -25,6 +25,8 @@ import java.util.function.Supplier;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
+import io.netty.buffer.ByteBufAllocator;
+import io.netty.buffer.CompositeByteBuf;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -161,11 +163,11 @@ public class ShuffleBuffer {
updateBufferSegmentsAndResultBlocks(
lastBlockId, readBufferSize, bufferSegments, readBlocks,
expectedTaskIds);
if (!bufferSegments.isEmpty()) {
- int length = calculateDataLength(bufferSegments);
- byte[] data = new byte[length];
+ CompositeByteBuf byteBuf =
+ new CompositeByteBuf(ByteBufAllocator.DEFAULT, true,
Constants.COMPOSITE_BYTE_BUF_MAX_COMPONENTS);
// copy result data
- updateShuffleData(readBlocks, data);
- return new ShuffleDataResult(data, bufferSegments);
+ updateShuffleData(readBlocks, byteBuf);
+ return new ShuffleDataResult(byteBuf, bufferSegments);
}
} catch (Exception e) {
LOG.error("Exception happened when getShuffleData in buffer", e);
@@ -238,16 +240,16 @@ public class ShuffleBuffer {
return bufferSegment.getOffset() + bufferSegment.getLength();
}
- private void updateShuffleData(List<ShufflePartitionedBlock> readBlocks,
byte[] data) {
+ private void updateShuffleData(List<ShufflePartitionedBlock> readBlocks,
CompositeByteBuf data) {
int offset = 0;
for (ShufflePartitionedBlock block : readBlocks) {
// fill shuffle data
try {
- System.arraycopy(block.getData(), 0, data, offset, block.getLength());
+ data.addComponent(true, block.getData().retain());
} catch (Exception e) {
LOG.error("Unexpected exception for System.arraycopy, length["
+ block.getLength() + "], offset["
- + offset + "], dataLength[" + data.length + "]", e);
+ + offset + "], dataLength[" + data.capacity() + "]", e);
throw e;
}
offset += block.getLength();
diff --git
a/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java
b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java
new file mode 100644
index 00000000..b87d57a3
--- /dev/null
+++
b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.server.netty;
+
+import java.nio.ByteBuffer;
+import java.util.List;
+import java.util.Map;
+
+import com.google.common.collect.Lists;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.BufferSegment;
+import org.apache.uniffle.common.ShuffleBlockInfo;
+import org.apache.uniffle.common.ShuffleDataResult;
+import org.apache.uniffle.common.ShuffleIndexResult;
+import org.apache.uniffle.common.ShufflePartitionedBlock;
+import org.apache.uniffle.common.ShufflePartitionedData;
+import org.apache.uniffle.common.config.RssBaseConf;
+import org.apache.uniffle.common.exception.FileNotFoundException;
+import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.netty.client.TransportClient;
+import org.apache.uniffle.common.netty.handle.BaseMessageHandler;
+import org.apache.uniffle.common.netty.protocol.GetLocalShuffleDataRequest;
+import org.apache.uniffle.common.netty.protocol.GetLocalShuffleDataResponse;
+import org.apache.uniffle.common.netty.protocol.GetLocalShuffleIndexRequest;
+import org.apache.uniffle.common.netty.protocol.GetLocalShuffleIndexResponse;
+import org.apache.uniffle.common.netty.protocol.GetMemoryShuffleDataRequest;
+import org.apache.uniffle.common.netty.protocol.GetMemoryShuffleDataResponse;
+import org.apache.uniffle.common.netty.protocol.RequestMessage;
+import org.apache.uniffle.common.netty.protocol.RpcResponse;
+import org.apache.uniffle.common.netty.protocol.SendShuffleDataRequest;
+import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.server.ShuffleDataReadEvent;
+import org.apache.uniffle.server.ShuffleServer;
+import org.apache.uniffle.server.ShuffleServerConf;
+import org.apache.uniffle.server.ShuffleServerGrpcMetrics;
+import org.apache.uniffle.server.ShuffleServerMetrics;
+import org.apache.uniffle.server.ShuffleTaskManager;
+import org.apache.uniffle.server.buffer.PreAllocatedBufferInfo;
+import org.apache.uniffle.storage.common.Storage;
+import org.apache.uniffle.storage.common.StorageReadMetrics;
+import org.apache.uniffle.storage.util.ShuffleStorageUtils;
+
+public class ShuffleServerNettyHandler implements BaseMessageHandler {
+
+ private static final Logger LOG =
LoggerFactory.getLogger(ShuffleServerNettyHandler.class);
+ private static final int RPC_TIMEOUT = 60000;
+ private final ShuffleServer shuffleServer;
+
+ public ShuffleServerNettyHandler(ShuffleServer shuffleServer) {
+ this.shuffleServer = shuffleServer;
+ }
+
+ @Override
+ public void receive(TransportClient client, RequestMessage msg) {
+ if (msg instanceof SendShuffleDataRequest) {
+ handleSendShuffleDataRequest(client, (SendShuffleDataRequest)msg);
+ } else if (msg instanceof GetLocalShuffleDataRequest) {
+ handleGetLocalShuffleData(client, (GetLocalShuffleDataRequest)msg);
+ } else if (msg instanceof GetLocalShuffleIndexRequest) {
+ handleGetLocalShuffleIndexRequest(client,
(GetLocalShuffleIndexRequest)msg);
+ } else if (msg instanceof GetMemoryShuffleDataRequest) {
+ handleGetMemoryShuffleDataRequest(client,
(GetMemoryShuffleDataRequest)msg);
+ } else {
+ throw new RssException("Can not handle message " + msg.type());
+ }
+ }
+
+ @Override
+ public void exceptionCaught(Throwable cause, TransportClient client) {
+ LOG.error("exception caught {}", client.getSocketAddress(), cause);
+ }
+
+ public void handleSendShuffleDataRequest(
+ TransportClient client, SendShuffleDataRequest req) {
+ RpcResponse rpcResponse;
+ String appId = req.getAppId();
+ int shuffleId = req.getShuffleId();
+ long requireBufferId = req.getRequireId();
+ long timestamp = req.getTimestamp();
+ if (timestamp > 0) {
+ /*
+ * Here we record the transport time, but we don't consider the impact
of data size on transport time.
+ * The amount of data will not cause great fluctuations in latency. For
example, 100K costs 1ms,
+ * and 1M costs 10ms. This seems like a normal fluctuation, but it may
rise to 10s when the server load is high.
+ * In addition, we need to pay attention to that the time of the client
machine and the machine
+ * time of the Shuffle Server should be kept in sync. TransportTime is
accurate only if this condition is met.
+ * */
+ long transportTime = System.currentTimeMillis() - timestamp;
+ if (transportTime > 0) {
+ shuffleServer.getGrpcMetrics().recordTransportTime(
+ ShuffleServerGrpcMetrics.SEND_SHUFFLE_DATA_METHOD, transportTime);
+ }
+ }
+ int requireSize = shuffleServer
+
.getShuffleTaskManager().getRequireBufferSize(requireBufferId);
+
+ StatusCode ret = StatusCode.SUCCESS;
+ String responseMessage = "OK";
+ if (req.getPartitionToBlocks().size() > 0) {
+ ShuffleServerMetrics.counterTotalReceivedDataSize.inc(requireSize);
+ ShuffleTaskManager manager = shuffleServer.getShuffleTaskManager();
+ PreAllocatedBufferInfo info =
manager.getAndRemovePreAllocatedBuffer(requireBufferId);
+ boolean isPreAllocated = info != null;
+ if (!isPreAllocated) {
+ String errorMsg = "Can't find requireBufferId[" + requireBufferId + "]
for appId[" + appId
+ + "], shuffleId[" + shuffleId + "]";
+ LOG.warn(errorMsg);
+ responseMessage = errorMsg;
+ rpcResponse = new RpcResponse(req.getRequestId(),
StatusCode.INTERNAL_ERROR, responseMessage);
+ client.sendRpcSync(rpcResponse, RPC_TIMEOUT);
+ return;
+ }
+ final long start = System.currentTimeMillis();
+ List<ShufflePartitionedData> shufflePartitionedData =
toPartitionedData(req);
+ long alreadyReleasedSize = 0;
+ for (ShufflePartitionedData spd : shufflePartitionedData) {
+ String shuffleDataInfo = "appId[" + appId + "], shuffleId[" + shuffleId
+ + "], partitionId[" +
spd.getPartitionId() + "]";
+ try {
+ ret = manager.cacheShuffleData(appId, shuffleId, isPreAllocated,
spd);
+ if (ret != StatusCode.SUCCESS) {
+ String errorMsg = "Error happened when shuffleEngine.write for "
+ + shuffleDataInfo + ", statusCode=" + ret;
+ LOG.error(errorMsg);
+ responseMessage = errorMsg;
+ break;
+ } else {
+ long toReleasedSize = spd.getTotalBlockSize();
+ // after each cacheShuffleData call, the `preAllocatedSize` is
updated timely.
+ manager.releasePreAllocatedSize(toReleasedSize);
+ alreadyReleasedSize += toReleasedSize;
+ manager.updateCachedBlockIds(appId, shuffleId,
spd.getPartitionId(), spd.getBlockList());
+ }
+ } catch (Exception e) {
+ String errorMsg = "Error happened when shuffleEngine.write for "
+ + shuffleDataInfo + ": " + e.getMessage();
+ ret = StatusCode.INTERNAL_ERROR;
+ responseMessage = errorMsg;
+ LOG.error(errorMsg);
+ break;
+ }
+ }
+ // since the required buffer id is only used once, the shuffle client
would try to require another buffer whether
+ // current connection succeeded or not. Therefore, the
preAllocatedBuffer is first get and removed, then after
+ // cacheShuffleData finishes, the preAllocatedSize should be updated
accordingly.
+ if (info.getRequireSize() > alreadyReleasedSize) {
+ manager.releasePreAllocatedSize(info.getRequireSize() -
alreadyReleasedSize);
+ }
+ rpcResponse = new RpcResponse(req.getRequestId(), ret, responseMessage);
+ long costTime = System.currentTimeMillis() - start;
+
shuffleServer.getGrpcMetrics().recordProcessTime(ShuffleServerGrpcMetrics.SEND_SHUFFLE_DATA_METHOD,
costTime);
+ LOG.debug("Cache Shuffle Data for appId[" + appId + "], shuffleId[" +
shuffleId
+ + "], cost " + costTime
+ + " ms with " + shufflePartitionedData.size() + " blocks
and " + requireSize + " bytes");
+ } else {
+ rpcResponse = new RpcResponse(req.getRequestId(),
StatusCode.INTERNAL_ERROR, "No data in request");
+ }
+
+ client.sendRpcSync(rpcResponse, RPC_TIMEOUT);
+ }
+
+ public void handleGetMemoryShuffleDataRequest(
+ TransportClient client, GetMemoryShuffleDataRequest req) {
+ String appId = req.getAppId();
+ int shuffleId = req.getShuffleId();
+ int partitionId = req.getPartitionId();
+ long blockId = req.getLastBlockId();
+ int readBufferSize = req.getReadBufferSize();
+ long timestamp = req.getTimestamp();
+
+ if (timestamp > 0) {
+ long transportTime = System.currentTimeMillis() - timestamp;
+ if (transportTime > 0) {
+ shuffleServer.getGrpcMetrics().recordTransportTime(
+ ShuffleServerGrpcMetrics.GET_MEMORY_SHUFFLE_DATA_METHOD,
transportTime);
+ }
+ }
+ long start = System.currentTimeMillis();
+ StatusCode status = StatusCode.SUCCESS;
+ String msg = "OK";
+ GetMemoryShuffleDataResponse response;
+ String requestInfo = "appId[" + appId + "], shuffleId[" + shuffleId + "],
partitionId["
+ + partitionId + "]";
+
+ // todo: if can get the exact memory size?
+ if
(shuffleServer.getShuffleBufferManager().requireReadMemoryWithRetry(readBufferSize))
{
+ try {
+ ShuffleDataResult shuffleDataResult = shuffleServer
+ .getShuffleTaskManager()
+ .getInMemoryShuffleData(
+ appId,
+ shuffleId,
+ partitionId,
+ blockId,
+ readBufferSize,
+
req.getExpectedTaskIdsBitmap()
+ );
+ ByteBuf data = Unpooled.EMPTY_BUFFER;
+ List<BufferSegment> bufferSegments = Lists.newArrayList();
+ if (shuffleDataResult != null) {
+ data = Unpooled.wrappedBuffer(shuffleDataResult.getDataBuffer());
+ bufferSegments = shuffleDataResult.getBufferSegments();
+
ShuffleServerMetrics.counterTotalReadDataSize.inc(data.readableBytes());
+
ShuffleServerMetrics.counterTotalReadMemoryDataSize.inc(data.readableBytes());
+ }
+ long costTime = System.currentTimeMillis() - start;
+ shuffleServer.getGrpcMetrics().recordProcessTime(
+ ShuffleServerGrpcMetrics.GET_MEMORY_SHUFFLE_DATA_METHOD, costTime);
+ LOG.info("Successfully getInMemoryShuffleData cost {} ms with {} bytes
shuffle"
+ + " data for {}", costTime, data.readableBytes(),
requestInfo);
+
+ response = new GetMemoryShuffleDataResponse(req.getRequestId(),
status, msg, bufferSegments, data);
+ } catch (Exception e) {
+ status = StatusCode.INTERNAL_ERROR;
+ msg = "Error happened when get in memory shuffle data for "
+ + requestInfo + ", " + e.getMessage();
+ LOG.error(msg, e);
+ response = new GetMemoryShuffleDataResponse(req.getRequestId(),
+ status, msg, Lists.newArrayList(), Unpooled.EMPTY_BUFFER);
+ } finally {
+
shuffleServer.getShuffleBufferManager().releaseReadMemory(readBufferSize);
+ }
+ } else {
+ status = StatusCode.INTERNAL_ERROR;
+ msg = "Can't require memory to get in memory shuffle data";
+ LOG.error(msg + " for " + requestInfo);
+ response = new GetMemoryShuffleDataResponse(req.getRequestId(),
+ status, msg, Lists.newArrayList(), Unpooled.EMPTY_BUFFER);
+ }
+ client.sendRpcSync(response, RPC_TIMEOUT);
+ }
+
+ public void handleGetLocalShuffleIndexRequest(
+ TransportClient client, GetLocalShuffleIndexRequest req) {
+ String appId = req.getAppId();
+ int shuffleId = req.getShuffleId();
+ int partitionId = req.getPartitionId();
+ int partitionNumPerRange = req.getPartitionNumPerRange();
+ int partitionNum = req.getPartitionNum();
+ StatusCode status = StatusCode.SUCCESS;
+ String msg = "OK";
+ GetLocalShuffleIndexResponse response;
+ String requestInfo = "appId[" + appId + "], shuffleId[" + shuffleId + "],
partitionId["
+ + partitionId + "]";
+
+ int[] range = ShuffleStorageUtils.getPartitionRange(partitionId,
partitionNumPerRange, partitionNum);
+ Storage storage = shuffleServer.getStorageManager()
+ .selectStorage(new ShuffleDataReadEvent(appId,
shuffleId, partitionId, range[0]));
+ if (storage != null) {
+ storage.updateReadMetrics(new StorageReadMetrics(appId, shuffleId));
+ }
+ // Index file is expected small size and won't cause oom problem with the
assumed size. An index segment is 40B,
+ // with the default size - 2MB, it can support 50k blocks for shuffle data.
+ long assumedFileSize = shuffleServer
+
.getShuffleServerConf().getLong(ShuffleServerConf.SERVER_SHUFFLE_INDEX_SIZE_HINT);
+ if
(shuffleServer.getShuffleBufferManager().requireReadMemoryWithRetry(assumedFileSize))
{
+ try {
+ final long start = System.currentTimeMillis();
+ ShuffleIndexResult shuffleIndexResult =
shuffleServer.getShuffleTaskManager().getShuffleIndex(
+ appId, shuffleId, partitionId, partitionNumPerRange, partitionNum);
+
+ ByteBuffer data = shuffleIndexResult.getIndexData();
+ ShuffleServerMetrics.counterTotalReadDataSize.inc(data.remaining());
+
ShuffleServerMetrics.counterTotalReadLocalIndexFileSize.inc(data.remaining());
+ response = new GetLocalShuffleIndexResponse(req.getRequestId(),
+ status, msg, Unpooled.wrappedBuffer(data),
shuffleIndexResult.getDataFileLen());
+ long readTime = System.currentTimeMillis() - start;
+ LOG.info("Successfully getShuffleIndex cost {} ms for {}"
+ + " bytes with {}", readTime, data.remaining(),
requestInfo);
+ } catch (FileNotFoundException indexFileNotFoundException) {
+ LOG.warn("Index file for {} is not found, maybe the data has been
flushed to cold storage.",
+ requestInfo, indexFileNotFoundException);
+ response = new GetLocalShuffleIndexResponse(req.getRequestId(),
status, msg, Unpooled.EMPTY_BUFFER, 0L);
+ } catch (Exception e) {
+ status = StatusCode.INTERNAL_ERROR;
+ msg = "Error happened when get shuffle index for " + requestInfo + ",
" + e.getMessage();
+ LOG.error(msg, e);
+ response = new GetLocalShuffleIndexResponse(req.getRequestId(),
status, msg, Unpooled.EMPTY_BUFFER, 0L);
+ } finally {
+
shuffleServer.getShuffleBufferManager().releaseReadMemory(assumedFileSize);
+ }
+ } else {
+ status = StatusCode.INTERNAL_ERROR;
+ msg = "Can't require memory to get shuffle index";
+ LOG.error(msg + " for " + requestInfo);
+ response = new GetLocalShuffleIndexResponse(req.getRequestId(), status,
msg, Unpooled.EMPTY_BUFFER, 0L);
+ }
+ client.sendRpcSync(response, RPC_TIMEOUT);
+ }
+
+ public void handleGetLocalShuffleData(
+ TransportClient client, GetLocalShuffleDataRequest req) {
+ String appId = req.getAppId();
+ int shuffleId = req.getShuffleId();
+ int partitionId = req.getPartitionId();
+ int partitionNumPerRange = req.getPartitionNumPerRange();
+ int partitionNum = req.getPartitionNum();
+ long offset = req.getOffset();
+ int length = req.getLength();
+ long timestamp = req.getTimestamp();
+ if (timestamp > 0) {
+ long transportTime = System.currentTimeMillis() - timestamp;
+ if (transportTime > 0) {
+ shuffleServer.getGrpcMetrics().recordTransportTime(
+ ShuffleServerGrpcMetrics.GET_SHUFFLE_DATA_METHOD, transportTime);
+ }
+ }
+ String storageType =
shuffleServer.getShuffleServerConf().get(RssBaseConf.RSS_STORAGE_TYPE);
+ StatusCode status = StatusCode.SUCCESS;
+ String msg = "OK";
+ GetLocalShuffleDataResponse response;
+ ShuffleDataResult sdr;
+ String requestInfo = "appId[" + appId + "], shuffleId[" + shuffleId + "],
partitionId["
+ + partitionId + "]" + "offset[" + offset + "]" +
"length[" + length + "]";
+
+ int[] range = ShuffleStorageUtils.getPartitionRange(partitionId,
partitionNumPerRange, partitionNum);
+ Storage storage = shuffleServer
+ .getStorageManager()
+ .selectStorage(
+ new ShuffleDataReadEvent(appId, shuffleId,
partitionId, range[0])
+ );
+ if (storage != null) {
+ storage.updateReadMetrics(new StorageReadMetrics(appId, shuffleId));
+ }
+
+ if
(shuffleServer.getShuffleBufferManager().requireReadMemoryWithRetry(length)) {
+ try {
+ long start = System.currentTimeMillis();
+ sdr = shuffleServer.getShuffleTaskManager().getShuffleData(appId,
shuffleId, partitionId,
+ partitionNumPerRange, partitionNum, storageType, offset, length);
+ long readTime = System.currentTimeMillis() - start;
+ ShuffleServerMetrics.counterTotalReadTime.inc(readTime);
+
ShuffleServerMetrics.counterTotalReadDataSize.inc(sdr.getData().length);
+
ShuffleServerMetrics.counterTotalReadLocalDataFileSize.inc(sdr.getData().length);
+ shuffleServer.getGrpcMetrics().recordProcessTime(
+ ShuffleServerGrpcMetrics.GET_SHUFFLE_DATA_METHOD, readTime);
+ LOG.info("Successfully getShuffleData cost {} ms for shuffle"
+ + " data with {}", readTime, requestInfo);
+ response = new GetLocalShuffleDataResponse(req.getRequestId(),
+ status, msg, sdr.getDataBuf());
+ } catch (Exception e) {
+ status = StatusCode.INTERNAL_ERROR;
+ msg = "Error happened when get shuffle data for " + requestInfo + ", "
+ e.getMessage();
+ LOG.error(msg, e);
+ response = new GetLocalShuffleDataResponse(req.getRequestId(), status,
msg, Unpooled.EMPTY_BUFFER);
+ } finally {
+ shuffleServer.getShuffleBufferManager().releaseReadMemory(length);
+ }
+ } else {
+ status = StatusCode.INTERNAL_ERROR;
+ msg = "Can't require memory to get shuffle data";
+ LOG.error(msg + " for " + requestInfo);
+ response = new GetLocalShuffleDataResponse(req.getRequestId(), status,
msg, Unpooled.EMPTY_BUFFER);
+ }
+ client.sendRpcSync(response, RPC_TIMEOUT);
+ }
+
+ private List<ShufflePartitionedData>
toPartitionedData(SendShuffleDataRequest req) {
+ List<ShufflePartitionedData> ret = Lists.newArrayList();
+
+ for (Map.Entry<Integer, List<ShuffleBlockInfo>> entry:
req.getPartitionToBlocks().entrySet()) {
+ ret.add(new ShufflePartitionedData(
+ entry.getKey(),
+ toPartitionedBlock(entry.getValue())));
+ }
+ return ret;
+ }
+
+ private ShufflePartitionedBlock[] toPartitionedBlock(List<ShuffleBlockInfo>
blocks) {
+ if (blocks == null || blocks.size() == 0) {
+ return new ShufflePartitionedBlock[]{};
+ }
+ ShufflePartitionedBlock[] ret = new ShufflePartitionedBlock[blocks.size()];
+ int i = 0;
+ for (ShuffleBlockInfo block : blocks) {
+ ret[i] = new ShufflePartitionedBlock(
+ block.getLength(),
+ block.getUncompressLength(),
+ block.getCrc(),
+ block.getBlockId(),
+ block.getTaskAttemptId(),
+ block.getData());
+ i++;
+ }
+ return ret;
+ }
+}
+
diff --git
a/server/src/main/java/org/apache/uniffle/server/netty/StreamServer.java
b/server/src/main/java/org/apache/uniffle/server/netty/StreamServer.java
index 57f91338..ae412b38 100644
--- a/server/src/main/java/org/apache/uniffle/server/netty/StreamServer.java
+++ b/server/src/main/java/org/apache/uniffle/server/netty/StreamServer.java
@@ -19,12 +19,10 @@ package org.apache.uniffle.server.netty;
import java.io.IOException;
import java.util.concurrent.TimeUnit;
-import java.util.function.Supplier;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.ChannelFuture;
-import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
@@ -36,13 +34,15 @@ import io.netty.channel.socket.nio.NioServerSocketChannel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.apache.uniffle.common.netty.TransportFrameDecoder;
+import org.apache.uniffle.common.netty.client.TransportConf;
+import org.apache.uniffle.common.netty.client.TransportContext;
import org.apache.uniffle.common.rpc.ServerInterface;
import org.apache.uniffle.common.util.Constants;
import org.apache.uniffle.common.util.ExitUtils;
import org.apache.uniffle.common.util.RssUtils;
import org.apache.uniffle.server.ShuffleServer;
import org.apache.uniffle.server.ShuffleServerConf;
-import org.apache.uniffle.server.netty.decoder.StreamServerInitDecoder;
public class StreamServer implements ServerInterface {
@@ -75,8 +75,7 @@ public class StreamServer implements ServerInterface {
int backlogSize,
int timeoutMillis,
int sendBuf,
- int receiveBuf,
- Supplier<ChannelHandler[]> handlerSupplier) {
+ int receiveBuf) {
ServerBootstrap serverBootstrap = new ServerBootstrap().group(bossGroup,
workerGroup);
if (bossGroup instanceof EpollEventLoopGroup) {
serverBootstrap.channel(EpollServerSocketChannel.class);
@@ -84,10 +83,13 @@ public class StreamServer implements ServerInterface {
serverBootstrap.channel(NioServerSocketChannel.class);
}
+ ShuffleServerNettyHandler serverNettyHandler = new
ShuffleServerNettyHandler(shuffleServer);
+ TransportContext transportContext =
+ new TransportContext(new TransportConf(shuffleServerConf),
serverNettyHandler, true);
serverBootstrap.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(final SocketChannel ch) {
- ch.pipeline().addLast(handlerSupplier.get());
+ transportContext.initializePipeline(ch, new TransportFrameDecoder());
}
})
.option(ChannelOption.SO_BACKLOG, backlogSize)
@@ -121,15 +123,12 @@ public class StreamServer implements ServerInterface {
@Override
public void startOnPort(int port) throws Exception {
- Supplier<ChannelHandler[]> streamHandlers = () -> new ChannelHandler[]{
- new StreamServerInitDecoder()
- };
+
ServerBootstrap serverBootstrap = bootstrapChannel(shuffleBossGroup,
shuffleWorkerGroup,
shuffleServerConf.getInteger(ShuffleServerConf.NETTY_SERVER_CONNECT_BACKLOG),
shuffleServerConf.getInteger(ShuffleServerConf.NETTY_SERVER_CONNECT_TIMEOUT),
shuffleServerConf.getInteger(ShuffleServerConf.NETTY_SERVER_SEND_BUF),
-
shuffleServerConf.getInteger(ShuffleServerConf.NETTY_SERVER_RECEIVE_BUF),
- streamHandlers);
+
shuffleServerConf.getInteger(ShuffleServerConf.NETTY_SERVER_RECEIVE_BUF));
// Bind the ports and save the results so that the channels can be closed
later.
// If the second bind fails, the first one gets cleaned up in the shutdown.
@@ -143,6 +142,7 @@ public class StreamServer implements ServerInterface {
}
}
+ @Override
public void stop() {
if (channelFuture != null) {
channelFuture.channel().close().awaitUninterruptibly(10L,
TimeUnit.SECONDS);
diff --git
a/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerTest.java
b/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerTest.java
index 53eb4ef8..659a9724 100644
---
a/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerTest.java
+++
b/server/src/test/java/org/apache/uniffle/server/ShuffleFlushManagerTest.java
@@ -539,7 +539,7 @@ public class ShuffleFlushManagerTest extends HdfsTestBase {
// case3: local disk is full or corrupted, fallback to HDFS
List<ShufflePartitionedBlock> blocks = Lists.newArrayList(
- new ShufflePartitionedBlock(100000, 1000, 1, 1, 1L, null)
+ new ShufflePartitionedBlock(100000, 1000, 1, 1, 1L, (byte[]) null)
);
ShuffleDataFlushEvent bigEvent = new ShuffleDataFlushEvent(1, "1", 1, 1,
1, 100, blocks, null, null);
bigEvent.setUnderStorage(((MultiStorageManager)storageManager).getWarmStorageManager().selectStorage(event));
@@ -571,7 +571,8 @@ public class ShuffleFlushManagerTest extends HdfsTestBase {
Thread.sleep(1 * 1000);
} while (manager.getEventNumInFlush() != 0);
- List<ShufflePartitionedBlock> blocks = Lists.newArrayList(new
ShufflePartitionedBlock(100, 1000, 1, 1, 1L, null));
+ List<ShufflePartitionedBlock> blocks =
+ Lists.newArrayList(new ShufflePartitionedBlock(100, 1000, 1, 1, 1L,
(byte[]) null));
ShuffleDataFlushEvent bigEvent = new ShuffleDataFlushEvent(1, "1", 1, 1,
1, 100, blocks, null, null);
bigEvent.setUnderStorage(storageManager.selectStorage(event));
storageManager.updateWriteMetrics(bigEvent, 0);
diff --git
a/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferManagerTest.java
b/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferManagerTest.java
index ff39cac6..dabd1e08 100644
---
a/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferManagerTest.java
+++
b/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferManagerTest.java
@@ -35,6 +35,7 @@ import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.ShuffleDataResult;
import org.apache.uniffle.common.ShufflePartitionedData;
import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.common.util.ByteBufUtils;
import org.apache.uniffle.common.util.Constants;
import org.apache.uniffle.server.ShuffleFlushManager;
import org.apache.uniffle.server.ShuffleServer;
@@ -173,11 +174,11 @@ public class ShuffleBufferManagerTest extends
BufferTestBase {
// validate get shuffle data
ShuffleDataResult sdr = shuffleBufferManager.getShuffleData(
appId, 2, 0, Constants.INVALID_BLOCK_ID, 60);
- assertArrayEquals(spd2.getBlockList()[0].getData(), sdr.getData());
+
assertArrayEquals(ByteBufUtils.readBytes(spd2.getBlockList()[0].getData()),
sdr.getData());
long lastBlockId = spd2.getBlockList()[0].getBlockId();
sdr = shuffleBufferManager.getShuffleData(
appId, 2, 0, lastBlockId, 100);
- assertArrayEquals(spd3.getBlockList()[0].getData(), sdr.getData());
+
assertArrayEquals(ByteBufUtils.readBytes(spd3.getBlockList()[0].getData()),
sdr.getData());
// flush happen
ShufflePartitionedData spd5 = createData(0, 10);
shuffleBufferManager.cacheShuffleData(appId, 4, false, spd5);
@@ -193,11 +194,11 @@ public class ShuffleBufferManagerTest extends
BufferTestBase {
// data in flush buffer now, it also can be got before flush finish
sdr = shuffleBufferManager.getShuffleData(
appId, 2, 0, Constants.INVALID_BLOCK_ID, 60);
- assertArrayEquals(spd2.getBlockList()[0].getData(), sdr.getData());
+
assertArrayEquals(ByteBufUtils.readBytes(spd2.getBlockList()[0].getData()),
sdr.getData());
lastBlockId = spd2.getBlockList()[0].getBlockId();
sdr = shuffleBufferManager.getShuffleData(
appId, 2, 0, lastBlockId, 100);
- assertArrayEquals(spd3.getBlockList()[0].getData(), sdr.getData());
+
assertArrayEquals(ByteBufUtils.readBytes(spd3.getBlockList()[0].getData()),
sdr.getData());
// cache data again, it should cause flush
spd1 = createData(0, 10);
shuffleBufferManager.cacheShuffleData(appId, 1, false, spd1);
diff --git
a/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferTest.java
b/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferTest.java
index 39ab4e61..b4c0ee4e 100644
---
a/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferTest.java
+++
b/server/src/test/java/org/apache/uniffle/server/buffer/ShuffleBufferTest.java
@@ -28,6 +28,7 @@ import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.ShuffleDataResult;
import org.apache.uniffle.common.ShufflePartitionedBlock;
import org.apache.uniffle.common.ShufflePartitionedData;
+import org.apache.uniffle.common.util.ByteBufUtils;
import org.apache.uniffle.common.util.Constants;
import org.apache.uniffle.server.ShuffleDataFlushEvent;
@@ -603,7 +604,7 @@ public class ShuffleBufferTest extends BufferTestBase {
int offset = 0;
for (ShufflePartitionedData spd : spds) {
ShufflePartitionedBlock block = spd.getBlockList()[0];
- System.arraycopy(block.getData(), 0, expectedData, offset,
block.getLength());
+ ByteBufUtils.readBytes(block.getData(), expectedData, offset,
block.getLength());
offset += block.getLength();
}
return expectedData;
diff --git
a/server/src/test/java/org/apache/uniffle/server/storage/MultiStorageManagerTest.java
b/server/src/test/java/org/apache/uniffle/server/storage/MultiStorageManagerTest.java
index 3cea6d32..1c66ad43 100644
---
a/server/src/test/java/org/apache/uniffle/server/storage/MultiStorageManagerTest.java
+++
b/server/src/test/java/org/apache/uniffle/server/storage/MultiStorageManagerTest.java
@@ -47,7 +47,8 @@ public class MultiStorageManagerTest {
String remoteStorage = "test";
String appId = "selectStorageManagerTest_appId";
manager.registerRemoteStorage(appId, new RemoteStorageInfo(remoteStorage));
- List<ShufflePartitionedBlock> blocks = Lists.newArrayList(new
ShufflePartitionedBlock(100, 1000, 1, 1, 1L, null));
+ List<ShufflePartitionedBlock> blocks =
+ Lists.newArrayList(new ShufflePartitionedBlock(100, 1000, 1, 1, 1L,
(byte[]) null));
ShuffleDataFlushEvent event = new ShuffleDataFlushEvent(
1, appId, 1, 1, 1, 1000, blocks, null, null);
assertTrue((manager.selectStorage(event) instanceof LocalStorage));
@@ -81,7 +82,7 @@ public class MultiStorageManagerTest {
* is enabled.
*/
List<ShufflePartitionedBlock> blocks = Lists.newArrayList(
- new ShufflePartitionedBlock(10001, 1000, 1, 1, 1L, null)
+ new ShufflePartitionedBlock(10001, 1000, 1, 1, 1L, (byte[]) null)
);
ShuffleDataFlushEvent event = new ShuffleDataFlushEvent(
1, appId, 1, 1, 1, 100000, blocks, null, null);
@@ -112,7 +113,7 @@ public class MultiStorageManagerTest {
* case1: big event should be written into cold storage directly
*/
List<ShufflePartitionedBlock> blocks = Lists.newArrayList(
- new ShufflePartitionedBlock(10001, 1000, 1, 1, 1L, null)
+ new ShufflePartitionedBlock(10001, 1000, 1, 1, 1L, (byte[]) null)
);
ShuffleDataFlushEvent hugeEvent = new ShuffleDataFlushEvent(
1, appId, 1, 1, 1, 10001, blocks, null, null);
@@ -121,7 +122,7 @@ public class MultiStorageManagerTest {
/**
* case2: fallback when disk can not write
*/
- blocks = Lists.newArrayList(new ShufflePartitionedBlock(100, 1000, 1, 1,
1L, null));
+ blocks = Lists.newArrayList(new ShufflePartitionedBlock(100, 1000, 1, 1,
1L, (byte[]) null));
ShuffleDataFlushEvent event = new ShuffleDataFlushEvent(
1, appId, 1, 1, 1, 1000, blocks, null, null);
Storage storage = manager.selectStorage(event);
diff --git
a/server/src/test/java/org/apache/uniffle/server/storage/StorageManagerFallbackStrategyTest.java
b/server/src/test/java/org/apache/uniffle/server/storage/StorageManagerFallbackStrategyTest.java
index c5b4512b..2446c9af 100644
---
a/server/src/test/java/org/apache/uniffle/server/storage/StorageManagerFallbackStrategyTest.java
+++
b/server/src/test/java/org/apache/uniffle/server/storage/StorageManagerFallbackStrategyTest.java
@@ -54,7 +54,7 @@ public class StorageManagerFallbackStrategyTest {
String appId = "testDefaultFallbackStrategy_appId";
coldStorageManager.registerRemoteStorage(appId, new
RemoteStorageInfo(remoteStorage));
List<ShufflePartitionedBlock> blocks = Lists.newArrayList(
- new ShufflePartitionedBlock(100, 1000, 1, 1, 1L, null));
+ new ShufflePartitionedBlock(100, 1000, 1, 1, 1L, (byte[]) null));
ShuffleDataFlushEvent event = new ShuffleDataFlushEvent(
1, appId, 1, 1, 1, 1000, blocks, null, null);
event.increaseRetryTimes();
@@ -90,7 +90,7 @@ public class StorageManagerFallbackStrategyTest {
String appId = "testHdfsFallbackStrategy_appId";
coldStorageManager.registerRemoteStorage(appId, new
RemoteStorageInfo(remoteStorage));
List<ShufflePartitionedBlock> blocks = Lists.newArrayList(
- new ShufflePartitionedBlock(100, 1000, 1, 1, 1L, null));
+ new ShufflePartitionedBlock(100, 1000, 1, 1, 1L, (byte[]) null));
ShuffleDataFlushEvent event = new ShuffleDataFlushEvent(
1, appId, 1, 1, 1, 1000, blocks, null, null);
event.increaseRetryTimes();
@@ -112,7 +112,7 @@ public class StorageManagerFallbackStrategyTest {
String appId = "testLocalFallbackStrategy_appId";
coldStorageManager.registerRemoteStorage(appId, new
RemoteStorageInfo(remoteStorage));
List<ShufflePartitionedBlock> blocks = Lists.newArrayList(
- new ShufflePartitionedBlock(100, 1000, 1, 1, 1L, null));
+ new ShufflePartitionedBlock(100, 1000, 1, 1, 1L, (byte[]) null));
ShuffleDataFlushEvent event = new ShuffleDataFlushEvent(
1, appId, 1, 1, 1, 1000, blocks, null, null);
event.increaseRetryTimes();
diff --git
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HdfsShuffleWriteHandler.java
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HdfsShuffleWriteHandler.java
index da701702..776115ef 100644
---
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HdfsShuffleWriteHandler.java
+++
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HdfsShuffleWriteHandler.java
@@ -32,6 +32,7 @@ import org.slf4j.LoggerFactory;
import org.apache.uniffle.common.ShufflePartitionedBlock;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.filesystem.HadoopFilesystemProvider;
+import org.apache.uniffle.common.util.ByteBufUtils;
import org.apache.uniffle.storage.common.FileBasedShuffleSegment;
import org.apache.uniffle.storage.handler.api.ShuffleWriteHandler;
import org.apache.uniffle.storage.util.ShuffleStorageUtils;
@@ -118,7 +119,7 @@ public class HdfsShuffleWriteHandler implements
ShuffleWriteHandler {
long blockId = block.getBlockId();
long crc = block.getCrc();
long startOffset = dataWriter.nextOffset();
- dataWriter.writeData(block.getData());
+ dataWriter.writeData(ByteBufUtils.readBytes(block.getData()));
FileBasedShuffleSegment segment = new FileBasedShuffleSegment(
blockId, startOffset, block.getLength(),
block.getUncompressLength(), crc, block.getTaskAttemptId());
diff --git
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileWriteHandler.java
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileWriteHandler.java
index 00e43927..25f504fb 100644
---
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileWriteHandler.java
+++
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileWriteHandler.java
@@ -28,6 +28,7 @@ import org.slf4j.LoggerFactory;
import org.apache.uniffle.common.ShufflePartitionedBlock;
import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.util.ByteBufUtils;
import org.apache.uniffle.storage.common.FileBasedShuffleSegment;
import org.apache.uniffle.storage.handler.api.ShuffleWriteHandler;
import org.apache.uniffle.storage.util.ShuffleStorageUtils;
@@ -106,7 +107,8 @@ public class LocalFileWriteHandler implements
ShuffleWriteHandler {
long blockId = block.getBlockId();
long crc = block.getCrc();
long startOffset = dataWriter.nextOffset();
- dataWriter.writeData(block.getData());
+ dataWriter.writeData(ByteBufUtils.readBytes(block.getData()));
+ block.getData().release();
FileBasedShuffleSegment segment = new FileBasedShuffleSegment(
blockId, startOffset, block.getLength(),
block.getUncompressLength(), crc, block.getTaskAttemptId());
diff --git
a/storage/src/test/java/org/apache/uniffle/storage/HdfsShuffleHandlerTestBase.java
b/storage/src/test/java/org/apache/uniffle/storage/HdfsShuffleHandlerTestBase.java
index dd8f8859..52d34efc 100644
---
a/storage/src/test/java/org/apache/uniffle/storage/HdfsShuffleHandlerTestBase.java
+++
b/storage/src/test/java/org/apache/uniffle/storage/HdfsShuffleHandlerTestBase.java
@@ -30,6 +30,7 @@ import org.apache.hadoop.fs.Path;
import org.apache.uniffle.common.BufferSegment;
import org.apache.uniffle.common.ShuffleDataResult;
import org.apache.uniffle.common.ShufflePartitionedBlock;
+import org.apache.uniffle.common.util.ByteBufUtils;
import org.apache.uniffle.common.util.ChecksumUtils;
import org.apache.uniffle.common.util.Constants;
import org.apache.uniffle.storage.common.FileBasedShuffleSegment;
@@ -88,7 +89,7 @@ public class HdfsShuffleHandlerTestBase {
offset += spb.getLength();
segments.add(segment);
if (doWrite) {
- writer.writeData(spb.getData());
+ writer.writeData(ByteBufUtils.readBytes(spb.getData()));
}
}
expectedIndexSegments.put(partitionId, segments);
diff --git
a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/LocalFileHandlerTestBase.java
b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/LocalFileHandlerTestBase.java
index 89eb24d5..199e6ad2 100644
---
a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/LocalFileHandlerTestBase.java
+++
b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/LocalFileHandlerTestBase.java
@@ -33,6 +33,7 @@ import org.apache.uniffle.common.ShuffleDataSegment;
import org.apache.uniffle.common.ShuffleIndexResult;
import org.apache.uniffle.common.ShufflePartitionedBlock;
import org.apache.uniffle.common.segment.FixedSizeSegmentSplitter;
+import org.apache.uniffle.common.util.ByteBufUtils;
import org.apache.uniffle.common.util.ChecksumUtils;
import org.apache.uniffle.storage.common.FileBasedShuffleSegment;
import org.apache.uniffle.storage.handler.api.ServerReadHandler;
@@ -58,9 +59,11 @@ public class LocalFileHandlerTestBase {
public static void writeTestData(List<ShufflePartitionedBlock> blocks,
ShuffleWriteHandler handler,
Map<Long, byte[]> expectedData, Set<Long> expectedBlockIds) throws
Exception {
+ blocks.forEach(block -> block.getData().retain());
handler.write(blocks);
blocks.forEach(block -> expectedBlockIds.add(block.getBlockId()));
- blocks.forEach(block -> expectedData.put(block.getBlockId(),
block.getData()));
+ blocks.forEach(block -> expectedData.put(block.getBlockId(),
ByteBufUtils.readBytes(block.getData())));
+ blocks.forEach(block -> block.getData().release());
}
public static void validateResult(ServerReadHandler readHandler, Set<Long>
expectedBlockIds,