This is an automated email from the ASF dual-hosted git repository.
roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new e99eade1 [#133] feat(netty): Add Netty Utils (#727)
e99eade1 is described below
commit e99eade19bd8ea0fb9883f65358c320760a66c1f
Author: Xianming Lei <[email protected]>
AuthorDate: Sun Mar 19 01:49:23 2023 +0800
[#133] feat(netty): Add Netty Utils (#727)
### What changes were proposed in this pull request?
Add netty utils.
### Why are the changes needed?
Add netty utils for netty replace grpc.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
UT.
Co-authored-by: leixianming <[email protected]>
---
.../org/apache/spark/network/util/NettyUtils.java | 137 ++++++++++++++++++++
common/pom.xml | 5 +
.../{util/ThreadUtils.java => netty/IOMode.java} | 17 +--
.../protocol/Encodable.java} | 15 +--
.../uniffle/common/netty/protocol/Message.java | 72 +++++++++++
.../uniffle/common/netty/protocol/RpcResponse.java | 85 +++++++++++++
.../apache/uniffle/common/util/ByteBufUtils.java | 57 +++++++++
.../util/{ThreadUtils.java => JavaUtils.java} | 23 ++--
.../org/apache/uniffle/common/util/NettyUtils.java | 114 +++++++++++++++++
.../apache/uniffle/common/util/ThreadUtils.java | 6 +
.../uniffle/common/util/ByteBufUtilsTest.java | 46 +++++++
.../apache/uniffle/common/util/JavaUtilsTest.java} | 24 ++--
.../apache/uniffle/common/util/NettyUtilsTest.java | 138 +++++++++++++++++++++
pom.xml | 1 -
14 files changed, 700 insertions(+), 40 deletions(-)
diff --git
a/client-spark/spark2/src/main/java/org/apache/spark/network/util/NettyUtils.java
b/client-spark/spark2/src/main/java/org/apache/spark/network/util/NettyUtils.java
new file mode 100644
index 00000000..b231e626
--- /dev/null
+++
b/client-spark/spark2/src/main/java/org/apache/spark/network/util/NettyUtils.java
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import java.lang.reflect.Field;
+import java.util.concurrent.ThreadFactory;
+
+import io.netty.buffer.PooledByteBufAllocator;
+import io.netty.channel.Channel;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.ServerChannel;
+import io.netty.channel.epoll.EpollEventLoopGroup;
+import io.netty.channel.epoll.EpollServerSocketChannel;
+import io.netty.channel.epoll.EpollSocketChannel;
+import io.netty.channel.nio.NioEventLoopGroup;
+import io.netty.channel.socket.nio.NioServerSocketChannel;
+import io.netty.channel.socket.nio.NioSocketChannel;
+import io.netty.util.concurrent.DefaultThreadFactory;
+import io.netty.util.internal.PlatformDependent;
+import io.netty.util.internal.SystemPropertyUtil;
+
+/** copy from spark, In order to override the createPooledByteBufAllocator
method,
+ * the property DEFAULT_TINY_CACHE_SIZE does not exist in netty>4.1.47. */
+public class NettyUtils {
+
+ private static final int DEFAULT_TINY_CACHE_SIZE =
SystemPropertyUtil.getInt("io.netty.allocator.tinyCacheSize", 512);
+
+ /** Creates a new ThreadFactory which prefixes each thread with the given
name. */
+ public static ThreadFactory createThreadFactory(String threadPoolPrefix) {
+ return new DefaultThreadFactory(threadPoolPrefix, true);
+ }
+
+ /** Creates a Netty EventLoopGroup based on the IOMode. */
+ public static EventLoopGroup createEventLoop(IOMode mode, int numThreads,
String threadPrefix) {
+ ThreadFactory threadFactory = createThreadFactory(threadPrefix);
+
+ switch (mode) {
+ case NIO:
+ return new NioEventLoopGroup(numThreads, threadFactory);
+ case EPOLL:
+ return new EpollEventLoopGroup(numThreads, threadFactory);
+ default:
+ throw new IllegalArgumentException("Unknown io mode: " + mode);
+ }
+ }
+
+ /** Returns the correct (client) SocketChannel class based on IOMode. */
+ public static Class<? extends Channel> getClientChannelClass(IOMode mode) {
+ switch (mode) {
+ case NIO:
+ return NioSocketChannel.class;
+ case EPOLL:
+ return EpollSocketChannel.class;
+ default:
+ throw new IllegalArgumentException("Unknown io mode: " + mode);
+ }
+ }
+
+ /** Returns the correct ServerSocketChannel class based on IOMode. */
+ public static Class<? extends ServerChannel> getServerChannelClass(IOMode
mode) {
+ switch (mode) {
+ case NIO:
+ return NioServerSocketChannel.class;
+ case EPOLL:
+ return EpollServerSocketChannel.class;
+ default:
+ throw new IllegalArgumentException("Unknown io mode: " + mode);
+ }
+ }
+
+ /**
+ * Creates a LengthFieldBasedFrameDecoder where the first 8 bytes are the
length of the frame.
+ * This is used before all decoders.
+ */
+ public static TransportFrameDecoder createFrameDecoder() {
+ return new TransportFrameDecoder();
+ }
+
+ /** Returns the remote address on the channel or "<unknown remote>" if
none exists. */
+ public static String getRemoteAddress(Channel channel) {
+ if (channel != null && channel.remoteAddress() != null) {
+ return channel.remoteAddress().toString();
+ }
+ return "<unknown remote>";
+ }
+
+ /**
+ * Create a pooled ByteBuf allocator but disables the thread-local cache.
Thread-local caches
+ * are disabled for TransportClients because the ByteBufs are allocated by
the event loop thread,
+ * but released by the executor thread rather than the event loop thread.
Those thread-local
+ * caches actually delay the recycling of buffers, leading to larger memory
usage.
+ */
+ public static PooledByteBufAllocator createPooledByteBufAllocator(
+ boolean allowDirectBufs,
+ boolean allowCache,
+ int numCores) {
+ if (numCores == 0) {
+ numCores = Runtime.getRuntime().availableProcessors();
+ }
+ return new PooledByteBufAllocator(
+ allowDirectBufs && PlatformDependent.directBufferPreferred(),
+ Math.min(getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"), numCores),
+ Math.min(getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"),
allowDirectBufs ? numCores : 0),
+ getPrivateStaticField("DEFAULT_PAGE_SIZE"),
+ getPrivateStaticField("DEFAULT_MAX_ORDER"),
+ allowCache ? DEFAULT_TINY_CACHE_SIZE : 0,
+ allowCache ? getPrivateStaticField("DEFAULT_SMALL_CACHE_SIZE") : 0,
+ allowCache ? getPrivateStaticField("DEFAULT_NORMAL_CACHE_SIZE") : 0
+ );
+ }
+
+ /** Used to get defaults from Netty's private static fields. */
+ private static int getPrivateStaticField(String name) {
+ try {
+ Field f =
PooledByteBufAllocator.DEFAULT.getClass().getDeclaredField(name);
+ f.setAccessible(true);
+ return f.getInt(null);
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+}
diff --git a/common/pom.xml b/common/pom.xml
index 6078c377..314828b7 100644
--- a/common/pom.xml
+++ b/common/pom.xml
@@ -80,6 +80,11 @@
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
</dependency>
+ <dependency>
+ <groupId>io.netty</groupId>
+ <artifactId>netty-all</artifactId>
+ <version>${netty.version}</version>
+ </dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-common</artifactId>
diff --git
a/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
b/common/src/main/java/org/apache/uniffle/common/netty/IOMode.java
similarity index 64%
copy from common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
copy to common/src/main/java/org/apache/uniffle/common/netty/IOMode.java
index f8000b6e..07a41a07 100644
--- a/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
+++ b/common/src/main/java/org/apache/uniffle/common/netty/IOMode.java
@@ -15,18 +15,9 @@
* limitations under the License.
*/
-package org.apache.uniffle.common.util;
+package org.apache.uniffle.common.netty;
-import java.util.concurrent.ThreadFactory;
-
-import com.google.common.util.concurrent.ThreadFactoryBuilder;
-
-/**
- * Provide a general method to create a thread factory to make the code more
standardized
- */
-public class ThreadUtils {
-
- public static ThreadFactory getThreadFactory(String factoryName) {
- return new
ThreadFactoryBuilder().setDaemon(true).setNameFormat(factoryName).build();
- }
+public enum IOMode {
+ NIO,
+ EPOLL
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Encodable.java
similarity index 65%
copy from common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
copy to
common/src/main/java/org/apache/uniffle/common/netty/protocol/Encodable.java
index f8000b6e..0ec305fa 100644
--- a/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
+++
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Encodable.java
@@ -15,18 +15,13 @@
* limitations under the License.
*/
-package org.apache.uniffle.common.util;
+package org.apache.uniffle.common.netty.protocol;
-import java.util.concurrent.ThreadFactory;
+import io.netty.buffer.ByteBuf;
-import com.google.common.util.concurrent.ThreadFactoryBuilder;
+public interface Encodable {
-/**
- * Provide a general method to create a thread factory to make the code more
standardized
- */
-public class ThreadUtils {
+ int encodedLength();
- public static ThreadFactory getThreadFactory(String factoryName) {
- return new
ThreadFactoryBuilder().setDaemon(true).setNameFormat(factoryName).build();
- }
+ void encode(ByteBuf buf);
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java
new file mode 100644
index 00000000..6eb2813b
--- /dev/null
+++ b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.protocol;
+
+import io.netty.buffer.ByteBuf;
+
+public abstract class Message implements Encodable {
+
+ public abstract Type type();
+
+ public enum Type implements Encodable {
+ UNKNOWN_TYPE(-1),
+ RPC_RESPONSE(0);
+
+ private final byte id;
+
+ Type(int id) {
+ assert id < 128 : "Cannot have more than 128 message types";
+ this.id = (byte) id;
+ }
+
+ public byte id() {
+ return id;
+ }
+
+ @Override
+ public int encodedLength() {
+ return 1;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeByte(id);
+ }
+
+ public static Type decode(ByteBuf buf) {
+ byte id = buf.readByte();
+ switch (id) {
+ case 0:
+ return RPC_RESPONSE;
+ case -1:
+ throw new IllegalArgumentException("User type messages cannot be
decoded.");
+ default:
+ throw new IllegalArgumentException("Unknown message type: " + id);
+ }
+ }
+ }
+
+ public static Message decode(Type msgType, ByteBuf in) {
+ switch (msgType) {
+ case RPC_RESPONSE:
+ return RpcResponse.decode(in);
+ default:
+ throw new IllegalArgumentException("Unexpected message type: " +
msgType);
+ }
+ }
+}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/RpcResponse.java
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/RpcResponse.java
new file mode 100644
index 00000000..9fef38cb
--- /dev/null
+++
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/RpcResponse.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.protocol;
+
+import io.netty.buffer.ByteBuf;
+
+import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.common.util.ByteBufUtils;
+
+public class RpcResponse extends Message {
+ private long requestId;
+ private StatusCode statusCode;
+ private String retMessage;
+
+ public RpcResponse(long requestId, StatusCode statusCode) {
+ this(requestId, statusCode, null);
+ }
+
+ public RpcResponse(long requestId, StatusCode statusCode, String retMessage)
{
+ this.requestId = requestId;
+ this.statusCode = statusCode;
+ this.retMessage = retMessage;
+ }
+
+ public StatusCode getStatusCode() {
+ return statusCode;
+ }
+
+ public String getRetMessage() {
+ return retMessage;
+ }
+
+ @Override
+ public String toString() {
+ return "RpcResponse{"
+ + "requestId=" + requestId
+ + ", statusCode=" + statusCode
+ + ", retMessage='" + retMessage
+ + '\'' + '}';
+ }
+
+ @Override
+ public int encodedLength() {
+ return Long.BYTES + Integer.BYTES + ByteBufUtils.encodedLength(retMessage);
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeLong(requestId);
+ buf.writeInt(statusCode.ordinal());
+ ByteBufUtils.writeLengthAndString(buf, retMessage);
+ }
+
+
+ public static RpcResponse decode(ByteBuf buf) {
+ long requestId = buf.readLong();
+ StatusCode statusCode = StatusCode.fromCode(buf.readInt());
+ String retMessage = ByteBufUtils.readLengthAndString(buf);
+ return new RpcResponse(requestId, statusCode, retMessage);
+ }
+
+ public long getRequestId() {
+ return requestId;
+ }
+
+ @Override
+ public Type type() {
+ return Type.RPC_RESPONSE;
+ }
+}
diff --git
a/common/src/main/java/org/apache/uniffle/common/util/ByteBufUtils.java
b/common/src/main/java/org/apache/uniffle/common/util/ByteBufUtils.java
new file mode 100644
index 00000000..6b1f0dd0
--- /dev/null
+++ b/common/src/main/java/org/apache/uniffle/common/util/ByteBufUtils.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.util;
+
+import java.nio.charset.StandardCharsets;
+
+import io.netty.buffer.ByteBuf;
+
+public class ByteBufUtils {
+
+ public static int encodedLength(String s) {
+ return 4 + s.getBytes(StandardCharsets.UTF_8).length;
+ }
+
+ public static final void writeLengthAndString(ByteBuf buf, String str) {
+ if (str == null) {
+ buf.writeInt(-1);
+ return;
+ }
+
+ byte[] bytes = str.getBytes(StandardCharsets.UTF_8);
+ buf.writeInt(bytes.length);
+ buf.writeBytes(bytes);
+ }
+
+ public static final String readLengthAndString(ByteBuf buf) {
+ int length = buf.readInt();
+ if (length == -1) {
+ return null;
+ }
+
+ byte[] bytes = new byte[length];
+ buf.readBytes(bytes);
+ return new String(bytes, StandardCharsets.UTF_8);
+ }
+
+ public static final byte[] readBytes(ByteBuf buf) {
+ byte[] bytes = new byte[buf.readableBytes()];
+ buf.readBytes(bytes);
+ return bytes;
+ }
+}
diff --git
a/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
b/common/src/main/java/org/apache/uniffle/common/util/JavaUtils.java
similarity index 62%
copy from common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
copy to common/src/main/java/org/apache/uniffle/common/util/JavaUtils.java
index f8000b6e..318a0ca9 100644
--- a/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
+++ b/common/src/main/java/org/apache/uniffle/common/util/JavaUtils.java
@@ -17,16 +17,23 @@
package org.apache.uniffle.common.util;
-import java.util.concurrent.ThreadFactory;
+import java.io.Closeable;
+import java.io.IOException;
-import com.google.common.util.concurrent.ThreadFactoryBuilder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
-/**
- * Provide a general method to create a thread factory to make the code more
standardized
- */
-public class ThreadUtils {
+public class JavaUtils {
+ private static final Logger logger =
LoggerFactory.getLogger(JavaUtils.class);
- public static ThreadFactory getThreadFactory(String factoryName) {
- return new
ThreadFactoryBuilder().setDaemon(true).setNameFormat(factoryName).build();
+ /** Closes the given object, ignoring IOExceptions. */
+ public static void closeQuietly(Closeable closeable) {
+ try {
+ if (closeable != null) {
+ closeable.close();
+ }
+ } catch (IOException e) {
+ logger.error("IOException should not have been thrown.", e);
+ }
}
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/util/NettyUtils.java
b/common/src/main/java/org/apache/uniffle/common/util/NettyUtils.java
new file mode 100644
index 00000000..49707004
--- /dev/null
+++ b/common/src/main/java/org/apache/uniffle/common/util/NettyUtils.java
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.util;
+
+import java.util.concurrent.ThreadFactory;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.PooledByteBufAllocator;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.epoll.EpollEventLoopGroup;
+import io.netty.channel.epoll.EpollSocketChannel;
+import io.netty.channel.nio.NioEventLoopGroup;
+import io.netty.channel.socket.nio.NioSocketChannel;
+import io.netty.util.internal.PlatformDependent;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.netty.IOMode;
+import org.apache.uniffle.common.netty.protocol.Message;
+
+public class NettyUtils {
+ private static final Logger logger =
LoggerFactory.getLogger(NettyUtils.class);
+
+ /** Creates a Netty EventLoopGroup based on the IOMode. */
+ public static EventLoopGroup createEventLoop(IOMode mode, int numThreads,
String threadPrefix) {
+ ThreadFactory threadFactory =
ThreadUtils.getNettyThreadFactory(threadPrefix);
+
+ switch (mode) {
+ case NIO:
+ return new NioEventLoopGroup(numThreads, threadFactory);
+ case EPOLL:
+ return new EpollEventLoopGroup(numThreads, threadFactory);
+ default:
+ throw new IllegalArgumentException("Unknown io mode: " + mode);
+ }
+ }
+
+ /** Returns the correct (client) SocketChannel class based on IOMode. */
+ public static Class<? extends Channel> getClientChannelClass(IOMode mode) {
+ switch (mode) {
+ case NIO:
+ return NioSocketChannel.class;
+ case EPOLL:
+ return EpollSocketChannel.class;
+ default:
+ throw new IllegalArgumentException("Unknown io mode: " + mode);
+ }
+ }
+
+ public static PooledByteBufAllocator createPooledByteBufAllocator(
+ boolean allowDirectBufs, boolean allowCache, int numCores) {
+ if (numCores == 0) {
+ numCores = Runtime.getRuntime().availableProcessors();
+ }
+ return new PooledByteBufAllocator(
+ allowDirectBufs && PlatformDependent.directBufferPreferred(),
+ Math.min(PooledByteBufAllocator.defaultNumHeapArena(), numCores),
+ Math.min(PooledByteBufAllocator.defaultNumDirectArena(),
allowDirectBufs ? numCores : 0),
+ PooledByteBufAllocator.defaultPageSize(),
+ PooledByteBufAllocator.defaultMaxOrder(),
+ allowCache ? PooledByteBufAllocator.defaultSmallCacheSize() : 0,
+ allowCache ? PooledByteBufAllocator.defaultNormalCacheSize() : 0,
+ allowCache && PooledByteBufAllocator.defaultUseCacheForAllThreads());
+ }
+
+ /** Returns the remote address on the channel or "<unknown remote>" if
none exists. */
+ public static String getRemoteAddress(Channel channel) {
+ if (channel != null && channel.remoteAddress() != null) {
+ return channel.remoteAddress().toString();
+ }
+ return "<unknown remote>";
+ }
+
+ public static ChannelFuture writeResponseMsg(ChannelHandlerContext ctx,
Message msg, boolean doWriteType) {
+ ByteBuf responseMsgBuf = ctx.alloc().buffer(msg.encodedLength());
+ try {
+ if (doWriteType) {
+ responseMsgBuf.writeByte(msg.type().id());
+ }
+ msg.encode(responseMsgBuf);
+ return ctx.writeAndFlush(responseMsgBuf);
+ } catch (Throwable ex) {
+ logger.warn("Caught exception, releasing ByteBuf", ex);
+ responseMsgBuf.release();
+ throw ex;
+ }
+ }
+
+ public static String getServerConnectionInfo(ChannelHandlerContext ctx) {
+ return getServerConnectionInfo(ctx.channel());
+ }
+
+ public static String getServerConnectionInfo(Channel channel) {
+ return String.format("[%s -> %s]", channel.localAddress(),
channel.remoteAddress());
+ }
+}
diff --git
a/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
b/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
index f8000b6e..27915ca0 100644
--- a/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
+++ b/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
@@ -20,6 +20,7 @@ package org.apache.uniffle.common.util;
import java.util.concurrent.ThreadFactory;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
+import io.netty.util.concurrent.DefaultThreadFactory;
/**
* Provide a general method to create a thread factory to make the code more
standardized
@@ -29,4 +30,9 @@ public class ThreadUtils {
public static ThreadFactory getThreadFactory(String factoryName) {
return new
ThreadFactoryBuilder().setDaemon(true).setNameFormat(factoryName).build();
}
+
+ /** Creates a new ThreadFactory which prefixes each thread with the given
name. */
+ public static ThreadFactory getNettyThreadFactory(String threadPoolPrefix) {
+ return new DefaultThreadFactory(threadPoolPrefix, true);
+ }
}
diff --git
a/common/src/test/java/org/apache/uniffle/common/util/ByteBufUtilsTest.java
b/common/src/test/java/org/apache/uniffle/common/util/ByteBufUtilsTest.java
new file mode 100644
index 00000000..3f60eaef
--- /dev/null
+++ b/common/src/test/java/org/apache/uniffle/common/util/ByteBufUtilsTest.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.util;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertArrayEquals;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNull;
+
+public class ByteBufUtilsTest {
+
+ @Test
+ public void test() {
+ ByteBuf byteBuf = Unpooled.buffer(100);
+ String expectedString = "test_str";
+ ByteBufUtils.writeLengthAndString(byteBuf, expectedString);
+ assertEquals(expectedString, ByteBufUtils.readLengthAndString(byteBuf));
+
+ byteBuf.clear();
+ byte[] expectedBytes = expectedString.getBytes();
+ byteBuf.writeBytes(expectedBytes);
+ assertArrayEquals(expectedBytes, ByteBufUtils.readBytes(byteBuf));
+
+ byteBuf.clear();
+ ByteBufUtils.writeLengthAndString(byteBuf, null);
+ assertNull(ByteBufUtils.readLengthAndString(byteBuf));
+ }
+}
diff --git
a/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
b/common/src/test/java/org/apache/uniffle/common/util/JavaUtilsTest.java
similarity index 68%
copy from common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
copy to common/src/test/java/org/apache/uniffle/common/util/JavaUtilsTest.java
index f8000b6e..3f041f01 100644
--- a/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
+++ b/common/src/test/java/org/apache/uniffle/common/util/JavaUtilsTest.java
@@ -17,16 +17,24 @@
package org.apache.uniffle.common.util;
-import java.util.concurrent.ThreadFactory;
+import java.io.Closeable;
+import java.io.IOException;
-import com.google.common.util.concurrent.ThreadFactoryBuilder;
+import org.junit.jupiter.api.Test;
-/**
- * Provide a general method to create a thread factory to make the code more
standardized
- */
-public class ThreadUtils {
+public class JavaUtilsTest {
+
+ static class MockClient implements Closeable {
+
+ @Override
+ public void close() throws IOException {
+ throw new IOException("test exception!");
+ }
+ }
- public static ThreadFactory getThreadFactory(String factoryName) {
- return new
ThreadFactoryBuilder().setDaemon(true).setNameFormat(factoryName).build();
+ @Test
+ public void test() {
+ MockClient client = new MockClient();
+ JavaUtils.closeQuietly(client);
}
}
diff --git
a/common/src/test/java/org/apache/uniffle/common/util/NettyUtilsTest.java
b/common/src/test/java/org/apache/uniffle/common/util/NettyUtilsTest.java
new file mode 100644
index 00000000..8422ead3
--- /dev/null
+++ b/common/src/test/java/org/apache/uniffle/common/util/NettyUtilsTest.java
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.util;
+
+import java.util.List;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Supplier;
+
+import io.netty.bootstrap.Bootstrap;
+import io.netty.bootstrap.ServerBootstrap;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.PooledByteBufAllocator;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelHandler;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.ChannelOption;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.socket.SocketChannel;
+import io.netty.channel.socket.nio.NioServerSocketChannel;
+import io.netty.handler.codec.ByteToMessageDecoder;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import org.apache.uniffle.common.netty.IOMode;
+import org.apache.uniffle.common.netty.protocol.Message;
+import org.apache.uniffle.common.netty.protocol.RpcResponse;
+import org.apache.uniffle.common.rpc.StatusCode;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public class NettyUtilsTest {
+ private EventLoopGroup bossGroup;
+ private EventLoopGroup workerGroup;
+ private ChannelFuture channelFuture;
+ private static final String EXPECTED_MESSAGE = "test_message";
+ private static final int PORT = 12345;
+
+ static class MockDecoder extends ByteToMessageDecoder {
+ @Override
+ protected void decode(ChannelHandlerContext ctx, ByteBuf byteBuf,
List<Object> list) throws Exception {
+ RpcResponse rpcResponse = new RpcResponse(1L, StatusCode.SUCCESS,
EXPECTED_MESSAGE);
+ NettyUtils.writeResponseMsg(ctx, rpcResponse, true);
+ }
+ }
+
+ @Test
+ public void test() throws InterruptedException {
+ EventLoopGroup workerGroup = NettyUtils.createEventLoop(IOMode.NIO, 2,
"netty-client");
+ PooledByteBufAllocator pooledByteBufAllocator =
+ NettyUtils.createPooledByteBufAllocator(
+ true, false /* allowCache */, 2);
+ Bootstrap bootstrap = new Bootstrap();
+ bootstrap
+ .group(workerGroup)
+ .channel(NettyUtils.getClientChannelClass(IOMode.NIO))
+ .option(ChannelOption.TCP_NODELAY, true)
+ .option(ChannelOption.SO_KEEPALIVE, true)
+ .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 30000)
+ .option(ChannelOption.ALLOCATOR, pooledByteBufAllocator);
+ final AtomicReference<Channel> channelRef = new AtomicReference<>();
+ bootstrap.handler(
+ new ChannelInitializer<SocketChannel>() {
+ @Override
+ public void initChannel(SocketChannel ch) {
+ ch.pipeline().addLast(new ByteToMessageDecoder() {
+ @Override
+ protected void decode(ChannelHandlerContext
channelHandlerContext, ByteBuf byteBuf, List<Object> list) {
+ Message.Type messageType;
+ messageType = Message.Type.decode(byteBuf);
+ assertEquals(Message.Type.RPC_RESPONSE, messageType);
+ RpcResponse rpcResponse =
(RpcResponse)Message.decode(messageType, byteBuf);
+ assertEquals(1L, rpcResponse.getRequestId());
+ assertEquals(StatusCode.SUCCESS, rpcResponse.getStatusCode());
+ assertEquals(EXPECTED_MESSAGE, rpcResponse.getRetMessage());
+ }
+ });
+ channelRef.set(ch);
+ }
+ });
+ bootstrap.connect("localhost", PORT);
+ ByteBuf byteBuf = Unpooled.buffer(1);
+ byteBuf.writeByte(1);
+ // wait for initChannel
+ Thread.sleep(200);
+ channelRef.get().writeAndFlush(byteBuf);
+ channelRef.get().closeFuture().await(3L, TimeUnit.SECONDS);
+ }
+
+ @BeforeEach
+ public void startNettyServer() {
+ Supplier<ChannelHandler[]> handlerSupplier = () -> new ChannelHandler[]{
+ new MockDecoder()
+ };
+ bossGroup = NettyUtils.createEventLoop(IOMode.NIO, 1, "netty-boss-group");
+ workerGroup = NettyUtils.createEventLoop(IOMode.NIO, 5,
"netty-worker-group");
+ ServerBootstrap serverBootstrap = new ServerBootstrap().group(bossGroup,
workerGroup)
+
.channel(NioServerSocketChannel.class);
+ serverBootstrap.childHandler(new ChannelInitializer<SocketChannel>() {
+ @Override
+ public void initChannel(final SocketChannel ch) {
+ ch.pipeline().addLast(handlerSupplier.get());
+ }
+ })
+ .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
+ .childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
+ .childOption(ChannelOption.TCP_NODELAY, true)
+ .childOption(ChannelOption.SO_KEEPALIVE, true);
+ channelFuture = serverBootstrap.bind(PORT);
+ channelFuture.syncUninterruptibly();
+ }
+
+ @AfterEach
+ public void stopNettyServer() {
+ channelFuture.channel().close().awaitUninterruptibly(10L,
TimeUnit.SECONDS);
+ bossGroup.shutdownGracefully();
+ workerGroup.shutdownGracefully();
+ }
+}
diff --git a/pom.xml b/pom.xml
index 1eb0d569..affa0ef2 100644
--- a/pom.xml
+++ b/pom.xml
@@ -1120,7 +1120,6 @@
<spark.version>2.3.4</spark.version>
<client.type>2</client.type>
<jackson.version>2.9.0</jackson.version>
- <netty.version>4.1.47.Final</netty.version>
</properties>
<modules>
<module>client-spark/common</module>