This is an automated email from the ASF dual-hosted git repository.

roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 04d6854dc [#1355] fix(client): Netty client will leak when decoding 
responses (#1455)
04d6854dc is described below

commit 04d6854dc9ab16fe6675eb10d166b9520cd05924
Author: RickyMa <[email protected]>
AuthorDate: Fri Jan 19 22:12:48 2024 +0800

    [#1355] fix(client): Netty client will leak when decoding responses (#1455)
    
    ### What changes were proposed in this pull request?
    
    The current code logic is that the `ByteBuf` is only released when 
`msg.body() == null`. However, when `msg.body != null`, `msg.body().byteBuf()` 
returns a `NettyManagedBuffer.EMPTY_BUFFER`, and the `ByteBuf` is not released 
in this case, resulting in a memory leak issue every time decoding an RPC 
response from ShuffleServer.
    Over time, if the Spark Job runs long enough and there are enough requests, 
it will eventually cause a significant memory leak on the client side (Spark 
Executor).
    The modifications to the other code are mainly for readability and enhanced 
protection, and will not cause any side effects.
    
    ### Why are the changes needed?
    
    To fix the memory leak issue in the Netty client when decoding RPC 
responses.
    For [#1359](https://github.com/apache/incubator-uniffle/issues/1359)
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Rerun successfully and no more Netty leak logs.
---
 .../uniffle/client/impl/ShuffleReadClientImpl.java |   3 +
 .../common/netty/TransportFrameDecoder.java        |  17 +-
 .../common/netty/protocol/ResponseMessage.java     |   2 +-
 .../common/netty/EncoderAndDecoderTest.java        |   8 +-
 .../common/netty/TransportFrameDecoderTest.java    | 253 +++++++++++++++++++++
 .../handler/impl/DataSkippableReadHandler.java     |  13 +-
 6 files changed, 283 insertions(+), 13 deletions(-)

diff --git 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
index 2c2fc25ce..7e2f930a7 100644
--- 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
+++ 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
@@ -293,6 +293,9 @@ public class ShuffleReadClientImpl implements 
ShuffleReadClient {
 
   @Override
   public void close() {
+    if (sdr != null) {
+      sdr.release();
+    }
     if (readBuffer != null) {
       RssUtils.releaseByteBuffer(readBuffer);
     }
diff --git 
a/common/src/main/java/org/apache/uniffle/common/netty/TransportFrameDecoder.java
 
b/common/src/main/java/org/apache/uniffle/common/netty/TransportFrameDecoder.java
index 4a7b8ab4b..cfb0c40ae 100644
--- 
a/common/src/main/java/org/apache/uniffle/common/netty/TransportFrameDecoder.java
+++ 
b/common/src/main/java/org/apache/uniffle/common/netty/TransportFrameDecoder.java
@@ -67,15 +67,26 @@ public class TransportFrameDecoder extends 
ChannelInboundHandlerAdapter implemen
       if (frame == null) {
         break;
       }
-      Message msg = Message.decode(curType, frame);
-      if (msg.body() == null) {
-        frame.release();
+      Message msg = null;
+      try {
+        msg = Message.decode(curType, frame);
+      } finally {
+        if (shouldRelease(msg)) {
+          frame.release();
+        }
       }
       ctx.fireChannelRead(msg);
       clear();
     }
   }
 
+  static boolean shouldRelease(Message msg) {
+    if (msg == null || msg.body() == null || msg.body().byteBuf() == null) {
+      return true;
+    }
+    return msg.body().byteBuf().readableBytes() == 0;
+  }
+
   private void clear() {
     curType = Message.Type.UNKNOWN_TYPE;
     msgSize = -1;
diff --git 
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/ResponseMessage.java
 
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/ResponseMessage.java
index 36e3f3c2e..aa897fd65 100644
--- 
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/ResponseMessage.java
+++ 
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/ResponseMessage.java
@@ -29,6 +29,6 @@ public abstract class ResponseMessage extends Message {
   }
 
   public ResponseMessage createFailureResponse(String error) {
-    throw new UnsupportedOperationException();
+    throw new UnsupportedOperationException(error);
   }
 }
diff --git 
a/common/src/test/java/org/apache/uniffle/common/netty/EncoderAndDecoderTest.java
 
b/common/src/test/java/org/apache/uniffle/common/netty/EncoderAndDecoderTest.java
index 127f62f07..7611673a5 100644
--- 
a/common/src/test/java/org/apache/uniffle/common/netty/EncoderAndDecoderTest.java
+++ 
b/common/src/test/java/org/apache/uniffle/common/netty/EncoderAndDecoderTest.java
@@ -138,7 +138,7 @@ public class EncoderAndDecoderTest {
                 1,
                 1,
                 1,
-                10,
+                data.length,
                 123,
                 Unpooled.wrappedBuffer(data).retain(),
                 shuffleServerInfoList,
@@ -149,7 +149,7 @@ public class EncoderAndDecoderTest {
                 1,
                 1,
                 1,
-                10,
+                data.length,
                 123,
                 Unpooled.wrappedBuffer(data).retain(),
                 shuffleServerInfoList,
@@ -162,7 +162,7 @@ public class EncoderAndDecoderTest {
                 1,
                 2,
                 1,
-                10,
+                data.length,
                 123,
                 Unpooled.wrappedBuffer(data).retain(),
                 shuffleServerInfoList,
@@ -173,7 +173,7 @@ public class EncoderAndDecoderTest {
                 1,
                 1,
                 2,
-                10,
+                data.length,
                 123,
                 Unpooled.wrappedBuffer(data).retain(),
                 shuffleServerInfoList,
diff --git 
a/common/src/test/java/org/apache/uniffle/common/netty/TransportFrameDecoderTest.java
 
b/common/src/test/java/org/apache/uniffle/common/netty/TransportFrameDecoderTest.java
new file mode 100644
index 000000000..1f907ebe9
--- /dev/null
+++ 
b/common/src/test/java/org/apache/uniffle/common/netty/TransportFrameDecoderTest.java
@@ -0,0 +1,253 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import org.junit.jupiter.api.Test;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
+
+import org.apache.uniffle.common.BufferSegment;
+import org.apache.uniffle.common.ShuffleBlockInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.netty.buffer.NettyManagedBuffer;
+import org.apache.uniffle.common.netty.protocol.GetLocalShuffleDataRequest;
+import org.apache.uniffle.common.netty.protocol.GetLocalShuffleDataResponse;
+import org.apache.uniffle.common.netty.protocol.GetLocalShuffleIndexRequest;
+import org.apache.uniffle.common.netty.protocol.GetLocalShuffleIndexResponse;
+import org.apache.uniffle.common.netty.protocol.GetMemoryShuffleDataRequest;
+import org.apache.uniffle.common.netty.protocol.GetMemoryShuffleDataResponse;
+import org.apache.uniffle.common.netty.protocol.Message;
+import org.apache.uniffle.common.netty.protocol.RpcResponse;
+import org.apache.uniffle.common.netty.protocol.SendShuffleDataRequest;
+import org.apache.uniffle.common.rpc.StatusCode;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class TransportFrameDecoderTest {
+
+  /** test if the RPC response should be released after decoding */
+  @Test
+  public void testShouldRpcResponsesToBeReleased() {
+    RpcResponse rpcResponse1 = generateRpcResponse();
+    int length1 = rpcResponse1.encodedLength();
+    ByteBuf byteBuf1 = Unpooled.buffer(length1);
+    rpcResponse1.encode(byteBuf1);
+    assertEquals(byteBuf1.readableBytes(), length1);
+    Message message1 = Message.decode(rpcResponse1.type(), byteBuf1);
+    assertTrue(TransportFrameDecoder.shouldRelease(message1));
+    byteBuf1.release();
+
+    GetLocalShuffleDataResponse rpcResponse2 = 
generateGetLocalShuffleDataResponse();
+    int length2 = rpcResponse2.encodedLength();
+    byte[] body2 = generateBody();
+    ByteBuf byteBuf2 = Unpooled.buffer(length2 + body2.length);
+    rpcResponse2.encode(byteBuf2);
+    assertEquals(byteBuf2.readableBytes(), length2);
+    byteBuf2.writeBytes(body2);
+    Message message2 = Message.decode(rpcResponse2.type(), byteBuf2);
+    assertFalse(TransportFrameDecoder.shouldRelease(message2));
+    // after processing some business logic in the code, and finally release 
the body buffer
+    message2.body().release();
+
+    GetLocalShuffleIndexResponse rpcResponse3 = 
generateGetLocalShuffleIndexResponse();
+    int length3 = rpcResponse3.encodedLength();
+    byte[] body3 = generateBody();
+    ByteBuf byteBuf3 = Unpooled.buffer(length3 + body3.length);
+    rpcResponse3.encode(byteBuf3);
+    assertEquals(byteBuf3.readableBytes(), length3);
+    byteBuf3.writeBytes(body3);
+    Message message3 = Message.decode(rpcResponse3.type(), byteBuf3);
+    assertFalse(TransportFrameDecoder.shouldRelease(message3));
+    // after processing some business logic in the code, and finally release 
the body buffer
+    message3.body().release();
+
+    GetMemoryShuffleDataResponse rpcResponse4 = 
generateGetMemoryShuffleDataResponse();
+    int length4 = rpcResponse4.encodedLength();
+    byte[] body4 = generateBody();
+    ByteBuf byteBuf4 = Unpooled.buffer(length4 + body4.length);
+    rpcResponse4.encode(byteBuf4);
+    assertEquals(byteBuf4.readableBytes(), length4);
+    byteBuf4.writeBytes(body4);
+    Message message4 = Message.decode(rpcResponse4.type(), byteBuf4);
+    assertFalse(TransportFrameDecoder.shouldRelease(message4));
+    // after processing some business logic in the code, and finally release 
the body buffer
+    message4.body().release();
+  }
+
+  /** test if the RPC request should be released after decoding */
+  @Test
+  public void testShouldRpcRequestsToBeReleased() {
+    SendShuffleDataRequest rpcRequest1 = generateShuffleDataRequest();
+    int length1 = rpcRequest1.encodedLength();
+    ByteBuf byteBuf1 = Unpooled.buffer(length1);
+    rpcRequest1.encode(byteBuf1);
+    assertEquals(byteBuf1.readableBytes(), length1);
+    Message message1 = Message.decode(rpcRequest1.type(), byteBuf1);
+    assertTrue(TransportFrameDecoder.shouldRelease(message1));
+    byteBuf1.release();
+
+    GetLocalShuffleDataRequest rpcRequest2 = 
generateGetLocalShuffleDataRequest();
+    int length2 = rpcRequest2.encodedLength();
+    ByteBuf byteBuf2 = Unpooled.buffer(length2);
+    rpcRequest2.encode(byteBuf2);
+    assertEquals(byteBuf2.readableBytes(), length2);
+    Message message2 = Message.decode(rpcRequest2.type(), byteBuf2);
+    assertTrue(TransportFrameDecoder.shouldRelease(message2));
+    byteBuf2.release();
+
+    GetLocalShuffleIndexRequest rpcRequest3 = 
generateGetLocalShuffleIndexRequest();
+    int length3 = rpcRequest3.encodedLength();
+    ByteBuf byteBuf3 = Unpooled.buffer(length3);
+    rpcRequest3.encode(byteBuf3);
+    assertEquals(byteBuf3.readableBytes(), length3);
+    Message message3 = Message.decode(rpcRequest3.type(), byteBuf3);
+    assertTrue(TransportFrameDecoder.shouldRelease(message3));
+    byteBuf3.release();
+
+    GetMemoryShuffleDataRequest rpcRequest4 = 
generateGetMemoryShuffleDataRequest();
+    int length4 = rpcRequest4.encodedLength();
+    ByteBuf byteBuf4 = Unpooled.buffer(length4);
+    rpcRequest4.encode(byteBuf4);
+    assertEquals(byteBuf4.readableBytes(), length4);
+    Message message4 = Message.decode(rpcRequest4.type(), byteBuf4);
+    assertTrue(TransportFrameDecoder.shouldRelease(message4));
+    byteBuf4.release();
+  }
+
+  private byte[] generateBody() {
+    return new byte[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
+  }
+
+  private RpcResponse generateRpcResponse() {
+    RpcResponse rpcResponse1 = new RpcResponse(1, StatusCode.SUCCESS, 
"test_message");
+    return rpcResponse1;
+  }
+
+  private GetLocalShuffleDataResponse generateGetLocalShuffleDataResponse() {
+    byte[] data2 = new byte[] {1, 2, 3};
+    GetLocalShuffleDataResponse rpcResponse2 =
+        new GetLocalShuffleDataResponse(
+            1,
+            StatusCode.SUCCESS,
+            "",
+            new NettyManagedBuffer(Unpooled.wrappedBuffer(data2).retain()));
+    return rpcResponse2;
+  }
+
+  private GetLocalShuffleIndexResponse generateGetLocalShuffleIndexResponse() {
+    byte[] data3 = new byte[] {1, 2, 3};
+    GetLocalShuffleIndexResponse rpcResponse3 =
+        new GetLocalShuffleIndexResponse(
+            1, StatusCode.SUCCESS, "", Unpooled.wrappedBuffer(data3).retain(), 
23);
+    return rpcResponse3;
+  }
+
+  private GetMemoryShuffleDataResponse generateGetMemoryShuffleDataResponse() {
+    byte[] data4 = new byte[] {1, 2, 3, 4, 5};
+    List<BufferSegment> bufferSegments =
+        Lists.newArrayList(
+            new BufferSegment(1, 0, 5, 10, 123, 1), new BufferSegment(1, 0, 5, 
10, 345, 1));
+    GetMemoryShuffleDataResponse rpcResponse4 =
+        new GetMemoryShuffleDataResponse(
+            1, StatusCode.SUCCESS, "", bufferSegments, 
Unpooled.wrappedBuffer(data4).retain());
+    return rpcResponse4;
+  }
+
+  private SendShuffleDataRequest generateShuffleDataRequest() {
+    String appId = "test_app";
+    byte[] data = new byte[] {1, 2, 3};
+    List<ShuffleServerInfo> shuffleServerInfoList =
+        Arrays.asList(new ShuffleServerInfo("aaa", 1), new 
ShuffleServerInfo("bbb", 2));
+    List<ShuffleBlockInfo> shuffleBlockInfoList1 =
+        Arrays.asList(
+            new ShuffleBlockInfo(
+                1,
+                1,
+                1,
+                data.length,
+                123,
+                Unpooled.wrappedBuffer(data).retain(),
+                shuffleServerInfoList,
+                5,
+                0,
+                1),
+            new ShuffleBlockInfo(
+                1,
+                1,
+                1,
+                data.length,
+                123,
+                Unpooled.wrappedBuffer(data).retain(),
+                shuffleServerInfoList,
+                5,
+                0,
+                1));
+    List<ShuffleBlockInfo> shuffleBlockInfoList2 =
+        Arrays.asList(
+            new ShuffleBlockInfo(
+                1,
+                2,
+                1,
+                data.length,
+                123,
+                Unpooled.wrappedBuffer(data).retain(),
+                shuffleServerInfoList,
+                5,
+                0,
+                1),
+            new ShuffleBlockInfo(
+                1,
+                1,
+                2,
+                data.length,
+                123,
+                Unpooled.wrappedBuffer(data).retain(),
+                shuffleServerInfoList,
+                5,
+                0,
+                1));
+    Map<Integer, List<ShuffleBlockInfo>> partitionToBlocks = Maps.newHashMap();
+    partitionToBlocks.put(1, shuffleBlockInfoList1);
+    partitionToBlocks.put(2, shuffleBlockInfoList2);
+    return new SendShuffleDataRequest(1L, appId, 1, 1, partitionToBlocks, 
12345);
+  }
+
+  private GetLocalShuffleDataRequest generateGetLocalShuffleDataRequest() {
+    return new GetLocalShuffleDataRequest(
+        1, "test_app", 1, 1, 1, 100, 0, 200, System.currentTimeMillis());
+  }
+
+  private GetLocalShuffleIndexRequest generateGetLocalShuffleIndexRequest() {
+    return new GetLocalShuffleIndexRequest(1, "test_app", 1, 1, 1, 100);
+  }
+
+  private GetMemoryShuffleDataRequest generateGetMemoryShuffleDataRequest() {
+    Roaring64NavigableMap expectedTaskIdsBitmap = 
Roaring64NavigableMap.bitmapOf(1, 2, 3, 4, 5);
+    return new GetMemoryShuffleDataRequest(
+        1, "test_app", 1, 1, 1, 64, System.currentTimeMillis(), 
expectedTaskIdsBitmap);
+  }
+}
diff --git 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandler.java
 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandler.java
index f3ea3c1f5..220e02997 100644
--- 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandler.java
+++ 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandler.java
@@ -73,11 +73,14 @@ public abstract class DataSkippableReadHandler extends 
AbstractClientReadHandler
         return null;
       }
 
-      shuffleDataSegments =
-          SegmentSplitterFactory.getInstance()
-              .get(distributionType, expectTaskIds, readBufferSize)
-              .split(shuffleIndexResult);
-      shuffleIndexResult.release();
+      try {
+        shuffleDataSegments =
+            SegmentSplitterFactory.getInstance()
+                .get(distributionType, expectTaskIds, readBufferSize)
+                .split(shuffleIndexResult);
+      } finally {
+        shuffleIndexResult.release();
+      }
     }
 
     // We should skip unexpected and processed segments when handler is read

Reply via email to