This is an automated email from the ASF dual-hosted git repository.
roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new bc9aaaaf [#584] feat(netty): Add transport client pool for netty (#771)
bc9aaaaf is described below
commit bc9aaaafad4a2a614064559e6f46c14191817664
Author: xumanbu <[email protected]>
AuthorDate: Mon Apr 3 11:28:56 2023 +0800
[#584] feat(netty): Add transport client pool for netty (#771)
### What changes were proposed in this pull request?
1. add for netty rpc client TransportClient
2. TransportClientFactory for connection pool
3. TransportContext contains the context to create a
TransportClientFactory, setup Netty Channel pipelines with a
TransportResponseHandler
4. TransportConf for netty transport config create by RssConf
### Why are the changes needed?
Fix: #584
### Does this PR introduce _any_ user-facing change?
add client configurations and add the ability to reuse netty clients.
Todo: update the user documentation after the netty feature is completed
@xumanbu
### How was this patch tested?
local test
Co-authored-by: jam.xu <[email protected]>
---
.../uniffle/common/config/RssClientConf.java | 52 +++++
.../uniffle/common/netty/MessageEncoder.java | 5 +
.../common/netty/client/RpcResponseCallback.java | 36 +++
.../common/netty/client/TransportClient.java | 144 ++++++++++++
.../netty/client/TransportClientFactory.java | 250 +++++++++++++++++++++
.../uniffle/common/netty/client/TransportConf.java | 64 ++++++
.../common/netty/client/TransportContext.java | 60 +++++
.../netty/handle/TransportResponseHandler.java | 72 ++++++
.../common/netty/EncoderAndDecoderTest.java | 4 +-
.../netty/client/TransportClientFactoryTest.java | 107 +++++++++
.../netty/client/TransportClientTestBase.java | 119 ++++++++++
11 files changed, 911 insertions(+), 2 deletions(-)
diff --git
a/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
b/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
index 119ab4e6..10f56bd2 100644
--- a/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
+++ b/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
@@ -19,6 +19,7 @@ package org.apache.uniffle.common.config;
import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.compression.Codec;
+import org.apache.uniffle.common.netty.IOMode;
import static org.apache.uniffle.common.compression.Codec.Type.LZ4;
@@ -43,4 +44,55 @@ public class RssClientConf {
.defaultValue(ShuffleDataDistributionType.NORMAL)
.withDescription("The type of partition shuffle data distribution,
including normal and local_order. "
+ "The default value is normal. This config is only valid in
Spark3.x");
+
+ public static final ConfigOption<Integer> NETTY_IO_CONNECT_TIMEOUT_MS =
ConfigOptions
+ .key("rss.client.netty.io.connect.timeout.ms")
+ .intType()
+ .defaultValue(10 * 1000)
+ .withDescription("netty connect to server time out mills");
+
+ public static final ConfigOption<IOMode> NETTY_IO_MODE = ConfigOptions
+ .key("rss.client.netty.io.mode")
+ .enumType(IOMode.class)
+ .defaultValue(IOMode.NIO)
+ .withDescription("Netty EventLoopGroup backend, available options: NIO,
EPOLL.");
+
+ public static final ConfigOption<Integer> NETTY_IO_CONNECTION_TIMEOUT_MS =
ConfigOptions
+ .key("rss.client.netty.client.connection.timeout.ms")
+ .intType()
+ .defaultValue(10 * 60 * 1000)
+ .withDescription("connection active timeout");
+
+ public static final ConfigOption<Integer> NETTY_CLIENT_THREADS =
ConfigOptions
+ .key("rss.client.netty.client.threads")
+ .intType()
+ .defaultValue(0)
+ .withDescription("Number of threads used in the client thread pool.");
+
+ public static final ConfigOption<Boolean> NETWORK_CLIENT_PREFER_DIRECT_BUFS
= ConfigOptions
+ .key("rss.client.netty.client.prefer.direct.bufs")
+ .booleanType()
+ .defaultValue(true)
+ .withDescription("If true, we will prefer allocating off-heap byte
buffers within Netty.");
+
+ public static final ConfigOption<Integer>
NETTY_CLIENT_NUM_CONNECTIONS_PER_PEER = ConfigOptions
+ .key("rss.client.netty.client.connections.per.peer")
+ .intType()
+ .defaultValue(2)
+ .withDescription("Number of concurrent connections between two nodes.");
+
+ public static final ConfigOption<Integer> NETTY_CLIENT_RECEIVE_BUFFER =
ConfigOptions
+ .key("rss.client.netty.client.receive.buffer")
+ .intType()
+ .defaultValue(0)
+ .withDescription("Receive buffer size (SO_RCVBUF). Note: the optimal
size for receive buffer and send buffer "
+ + "should be latency * network_bandwidth. Assuming latency = 1ms,
network_bandwidth = 10Gbps "
+ + "buffer size should be ~ 1.25MB.");
+
+ public static final ConfigOption<Integer> NETTY_CLIENT_SEND_BUFFER =
ConfigOptions
+ .key("rss.client.netty.client.send.buffer")
+ .intType()
+ .defaultValue(0)
+ .withDescription("Send buffer size (SO_SNDBUF).");
+
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/MessageEncoder.java
b/common/src/main/java/org/apache/uniffle/common/netty/MessageEncoder.java
index 4167e53a..e3537ecd 100644
--- a/common/src/main/java/org/apache/uniffle/common/netty/MessageEncoder.java
+++ b/common/src/main/java/org/apache/uniffle/common/netty/MessageEncoder.java
@@ -39,6 +39,11 @@ public class MessageEncoder extends
ChannelOutboundHandlerAdapter {
private static final Logger LOG =
LoggerFactory.getLogger(MessageEncoder.class);
+ public static final MessageEncoder INSTANCE = new MessageEncoder();
+
+ private MessageEncoder() {
+ }
+
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise
promise) {
// todo: support zero copy
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/client/RpcResponseCallback.java
b/common/src/main/java/org/apache/uniffle/common/netty/client/RpcResponseCallback.java
new file mode 100644
index 00000000..6de925c0
--- /dev/null
+++
b/common/src/main/java/org/apache/uniffle/common/netty/client/RpcResponseCallback.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.client;
+
+import org.apache.uniffle.common.netty.protocol.RpcResponse;
+
+public interface RpcResponseCallback {
+ /**
+ * Successful serialized result from server.
+ *
+ * <p>After `onSuccess` returns, `response` will be recycled and its content
will become invalid.
+ * Please copy the content of `response` if you want to use it after
`onSuccess` returns.
+ */
+ void onSuccess(RpcResponse rpcResponse);
+
+ /**
+ * Exception either propagated from server or raised on client side.
+ */
+ void onFailure(Throwable e);
+}
+
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClient.java
b/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClient.java
new file mode 100644
index 00000000..34ebb20a
--- /dev/null
+++
b/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClient.java
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.client;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.net.SocketAddress;
+import java.util.Objects;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicLong;
+
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.util.concurrent.Future;
+import io.netty.util.concurrent.GenericFutureListener;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.netty.handle.TransportResponseHandler;
+import org.apache.uniffle.common.netty.protocol.Message;
+import org.apache.uniffle.common.util.NettyUtils;
+
+
+public class TransportClient implements Closeable {
+ private static final Logger logger =
LoggerFactory.getLogger(TransportClient.class);
+
+ private Channel channel;
+ private TransportResponseHandler handler;
+ private volatile boolean timedOut;
+
+ private static final AtomicLong counter = new AtomicLong();
+
+ public TransportClient(Channel channel, TransportResponseHandler handler) {
+ this.channel = Objects.requireNonNull(channel);
+ this.handler = Objects.requireNonNull(handler);
+ this.timedOut = false;
+ }
+
+ public Channel getChannel() {
+ return channel;
+ }
+
+ public boolean isActive() {
+ return !timedOut && (channel.isOpen() || channel.isActive());
+ }
+
+ public SocketAddress getSocketAddress() {
+ return channel.remoteAddress();
+ }
+
+ public ChannelFuture sendShuffleData(Message message, RpcResponseCallback
callback) {
+ if (logger.isTraceEnabled()) {
+ logger.trace("Pushing data to {}", NettyUtils.getRemoteAddress(channel));
+ }
+ long requestId = requestId();
+ handler.addResponseCallback(requestId, callback);
+ RpcChannelListener listener = new RpcChannelListener(requestId, callback);
+ return channel.writeAndFlush(message).addListener(listener);
+ }
+
+ public static long requestId() {
+ return counter.getAndIncrement();
+ }
+
+ public class StdChannelListener implements GenericFutureListener<Future<?
super Void>> {
+ final long startTime;
+ final Object requestId;
+
+ public StdChannelListener(Object requestId) {
+ this.startTime = System.currentTimeMillis();
+ this.requestId = requestId;
+ }
+
+ @Override
+ public void operationComplete(Future<? super Void> future) throws
Exception {
+ if (future.isSuccess()) {
+ if (logger.isTraceEnabled()) {
+ long timeTaken = System.currentTimeMillis() - startTime;
+ logger.trace(
+ "Sending request {} to {} took {} ms",
+ requestId,
+ NettyUtils.getRemoteAddress(channel),
+ timeTaken);
+ }
+ } else {
+ String errorMsg =
+ String.format(
+ "Failed to send request %s to %s: %s, channel will be closed",
+ requestId, NettyUtils.getRemoteAddress(channel),
future.cause());
+ logger.warn(errorMsg);
+ channel.close();
+ try {
+ handleFailure(errorMsg, future.cause());
+ } catch (Exception e) {
+ logger.error("Uncaught exception in RPC response callback handler!",
e);
+ }
+ }
+ }
+
+ protected void handleFailure(String errorMsg, Throwable cause) {
+ logger.error("Error encountered " + errorMsg, cause);
+ }
+ }
+
+ private class RpcChannelListener extends StdChannelListener {
+ final long rpcRequestId;
+ final RpcResponseCallback callback;
+
+ RpcChannelListener(long rpcRequestId, RpcResponseCallback callback) {
+ super("RPC " + rpcRequestId);
+ this.rpcRequestId = rpcRequestId;
+ this.callback = callback;
+ }
+
+ @Override
+ protected void handleFailure(String errorMsg, Throwable cause) {
+ handler.removeRpcRequest(rpcRequestId);
+ callback.onFailure(new IOException(errorMsg, cause));
+ }
+ }
+
+
+ @Override
+ public void close() throws IOException {
+ // close is a local operation and should finish with milliseconds; timeout
just to be safe
+ channel.close().awaitUninterruptibly(10, TimeUnit.SECONDS);
+ }
+
+}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClientFactory.java
b/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClientFactory.java
new file mode 100644
index 00000000..c8056151
--- /dev/null
+++
b/common/src/main/java/org/apache/uniffle/common/netty/client/TransportClientFactory.java
@@ -0,0 +1,250 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.client;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.net.SocketAddress;
+import java.util.Objects;
+import java.util.Random;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicReference;
+
+import io.netty.bootstrap.Bootstrap;
+import io.netty.buffer.PooledByteBufAllocator;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.ChannelOption;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.socket.SocketChannel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.netty.IOMode;
+import org.apache.uniffle.common.netty.TransportFrameDecoder;
+import org.apache.uniffle.common.netty.handle.TransportResponseHandler;
+import org.apache.uniffle.common.util.JavaUtils;
+import org.apache.uniffle.common.util.NettyUtils;
+
+public class TransportClientFactory implements Closeable {
+
+ /**
+ * A simple data structure to track the pool of clients between two peer
nodes.
+ */
+ private static class ClientPool {
+ TransportClient[] clients;
+ Object[] locks;
+
+ ClientPool(int size) {
+ clients = new TransportClient[size];
+ locks = new Object[size];
+ for (int i = 0; i < size; i++) {
+ locks[i] = new Object();
+ }
+ }
+ }
+
+ private static final Logger logger =
LoggerFactory.getLogger(TransportClientFactory.class);
+
+ private final TransportContext context;
+ private final TransportConf conf;
+ private final ConcurrentHashMap<SocketAddress, ClientPool> connectionPool;
+
+ /**
+ * Random number generator for picking connections between peers.
+ */
+ private final Random rand;
+
+ private final int numConnectionsPerPeer;
+
+ private final Class<? extends Channel> socketChannelClass;
+ private EventLoopGroup workerGroup;
+ private PooledByteBufAllocator pooledAllocator;
+
+ public TransportClientFactory(TransportContext context) {
+ this.context = Objects.requireNonNull(context);
+ this.conf = context.getConf();
+ this.connectionPool = JavaUtils.newConcurrentMap();
+ this.numConnectionsPerPeer = conf.numConnectionsPerPeer();
+ this.rand = new Random();
+
+ IOMode ioMode = conf.ioMode();
+ this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode);
+ this.workerGroup =
+ NettyUtils.createEventLoop(ioMode, conf.clientThreads(),
"netty-rpc-client");
+ this.pooledAllocator =
+ NettyUtils.createPooledByteBufAllocator(
+ conf.preferDirectBufs(), false, conf.clientThreads());
+ }
+
+ public TransportClient createClient(String remoteHost, int remotePort, int
partitionId)
+ throws IOException, InterruptedException {
+ return createClient(remoteHost, remotePort, partitionId, new
TransportFrameDecoder());
+ }
+
+ public TransportClient createClient(
+ String remoteHost, int remotePort, int partitionId,
ChannelInboundHandlerAdapter decoder)
+ throws IOException, InterruptedException {
+ // Get connection from the connection pool first.
+ // If it is not found or not active, create a new one.
+ // Use unresolved address here to avoid DNS resolution each time we
creates a client.
+ final InetSocketAddress unresolvedAddress =
+ InetSocketAddress.createUnresolved(remoteHost, remotePort);
+
+ // Create the ClientPool if we don't have it yet.
+ ClientPool clientPool = connectionPool.computeIfAbsent(unresolvedAddress, x
+ -> new ClientPool(numConnectionsPerPeer));
+
+ int clientIndex =
+ partitionId < 0 ? rand.nextInt(numConnectionsPerPeer) : partitionId %
numConnectionsPerPeer;
+ TransportClient cachedClient = clientPool.clients[clientIndex];
+
+ if (cachedClient != null && cachedClient.isActive()) {
+ // Make sure that the channel will not timeout by updating the last use
time of the
+ // handler. Then check that the client is still alive, in case it timed
out before
+ // this code was able to update things.
+ TransportResponseHandler handler =
+
cachedClient.getChannel().pipeline().get(TransportResponseHandler.class);
+
+ if (cachedClient.isActive()) {
+ logger.trace(
+ "Returning cached connection to {}: {}",
cachedClient.getSocketAddress(), cachedClient);
+ return cachedClient;
+ }
+ }
+
+ // If we reach here, we don't have an existing connection open. Let's
create a new one.
+ // Multiple threads might race here to create new connections. Keep only
one of them active.
+ final long preResolveHost = System.nanoTime();
+ final InetSocketAddress resolvedAddress = new
InetSocketAddress(remoteHost, remotePort);
+ final long hostResolveTimeMs = (System.nanoTime() - preResolveHost) /
1000000;
+ if (hostResolveTimeMs > 2000) {
+ logger.warn("DNS resolution for {} took {} ms", resolvedAddress,
hostResolveTimeMs);
+ } else {
+ logger.trace("DNS resolution for {} took {} ms", resolvedAddress,
hostResolveTimeMs);
+ }
+
+ synchronized (clientPool.locks[clientIndex]) {
+ cachedClient = clientPool.clients[clientIndex];
+
+ if (cachedClient != null) {
+ if (cachedClient.isActive()) {
+ logger.trace("Returning cached connection to {}: {}",
resolvedAddress, cachedClient);
+ return cachedClient;
+ } else {
+ logger.info("Found inactive connection to {}, creating a new one.",
resolvedAddress);
+ }
+ }
+ clientPool.clients[clientIndex] = internalCreateClient(resolvedAddress,
decoder);
+ return clientPool.clients[clientIndex];
+ }
+ }
+
+ public TransportClient createClient(String remoteHost, int remotePort)
+ throws IOException, InterruptedException {
+ return createClient(remoteHost, remotePort, -1);
+ }
+
+ /**
+ * Create a completely new {@link TransportClient} to the given remote host
/ port. This
+ * connection is not pooled.
+ *
+ * <p>As with {@link #createClient(String, int)}, this method is blocking.
+ */
+ private TransportClient internalCreateClient(
+ InetSocketAddress address, ChannelInboundHandlerAdapter decoder)
+ throws IOException, InterruptedException {
+ Bootstrap bootstrap = new Bootstrap();
+ bootstrap
+ .group(workerGroup)
+ .channel(socketChannelClass)
+ // Disable Nagle's Algorithm since we don't want packets to wait
+ .option(ChannelOption.TCP_NODELAY, true)
+ .option(ChannelOption.SO_KEEPALIVE, true)
+ .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectTimeoutMs())
+ .option(ChannelOption.ALLOCATOR, pooledAllocator);
+
+ if (conf.receiveBuf() > 0) {
+ bootstrap.option(ChannelOption.SO_RCVBUF, conf.receiveBuf());
+ }
+
+ if (conf.sendBuf() > 0) {
+ bootstrap.option(ChannelOption.SO_SNDBUF, conf.sendBuf());
+ }
+
+ final AtomicReference<TransportClient> clientRef = new AtomicReference<>();
+ final AtomicReference<Channel> channelRef = new AtomicReference<>();
+
+ bootstrap.handler(
+ new ChannelInitializer<SocketChannel>() {
+ @Override
+ public void initChannel(SocketChannel ch) {
+ TransportResponseHandler transportResponseHandler =
context.initializePipeline(ch, decoder);
+ TransportClient client = new TransportClient(ch,
transportResponseHandler);
+ clientRef.set(client);
+ channelRef.set(ch);
+ }
+ });
+
+ // Connect to the remote server
+ ChannelFuture cf = bootstrap.connect(address);
+ if (!cf.await(conf.connectTimeoutMs())) {
+ throw new IOException(
+ String.format("Connecting to %s timed out (%s ms)", address,
conf.connectTimeoutMs()));
+ } else if (cf.cause() != null) {
+ throw new IOException(String.format("Failed to connect to %s", address),
cf.cause());
+ }
+
+ TransportClient client = clientRef.get();
+ assert client != null : "Channel future completed successfully with null
client";
+
+ logger.debug("Connection to {} successful", address);
+
+ return client;
+ }
+
+ /**
+ * Close all connections in the connection pool, and shutdown the worker
thread pool.
+ */
+ @Override
+ public void close() {
+ // Go through all clients and close them if they are active.
+ for (ClientPool clientPool : connectionPool.values()) {
+ for (int i = 0; i < clientPool.clients.length; i++) {
+ TransportClient client = clientPool.clients[i];
+ if (client != null) {
+ clientPool.clients[i] = null;
+ JavaUtils.closeQuietly(client);
+ }
+ }
+ }
+ connectionPool.clear();
+
+
+ if (workerGroup != null && !workerGroup.isShuttingDown()) {
+ workerGroup.shutdownGracefully();
+ }
+ }
+
+ public TransportContext getContext() {
+ return context;
+ }
+}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/client/TransportConf.java
b/common/src/main/java/org/apache/uniffle/common/netty/client/TransportConf.java
new file mode 100644
index 00000000..a664cc19
--- /dev/null
+++
b/common/src/main/java/org/apache/uniffle/common/netty/client/TransportConf.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.client;
+
+import org.apache.uniffle.common.config.RssClientConf;
+import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.netty.IOMode;
+
+public class TransportConf {
+
+ private final RssConf rssConf;
+
+ public TransportConf(RssConf rssConf) {
+ this.rssConf = rssConf;
+ }
+
+ public IOMode ioMode() {
+ return rssConf.get(RssClientConf.NETTY_IO_MODE);
+ }
+
+ public int connectTimeoutMs() {
+ return rssConf.get(RssClientConf.NETTY_IO_CONNECT_TIMEOUT_MS);
+ }
+
+ public int connectionTimeoutMs() {
+ return rssConf.get(RssClientConf.NETTY_IO_CONNECTION_TIMEOUT_MS);
+ }
+
+ public int clientThreads() {
+ return rssConf.get(RssClientConf.NETTY_CLIENT_THREADS);
+ }
+
+ public int numConnectionsPerPeer() {
+ return rssConf.get(RssClientConf.NETTY_CLIENT_NUM_CONNECTIONS_PER_PEER);
+ }
+
+ public boolean preferDirectBufs() {
+ return rssConf.get(RssClientConf.NETWORK_CLIENT_PREFER_DIRECT_BUFS);
+ }
+
+ public int receiveBuf() {
+ return rssConf.get(RssClientConf.NETTY_CLIENT_RECEIVE_BUFFER);
+ }
+
+ public int sendBuf() {
+ return rssConf.get(RssClientConf.NETTY_CLIENT_SEND_BUFFER);
+ }
+
+}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/client/TransportContext.java
b/common/src/main/java/org/apache/uniffle/common/netty/client/TransportContext.java
new file mode 100644
index 00000000..134b633a
--- /dev/null
+++
b/common/src/main/java/org/apache/uniffle/common/netty/client/TransportContext.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.client;
+
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.channel.socket.SocketChannel;
+import io.netty.handler.timeout.IdleStateHandler;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.netty.MessageEncoder;
+import org.apache.uniffle.common.netty.handle.TransportResponseHandler;
+
+public class TransportContext {
+ private static final Logger logger =
LoggerFactory.getLogger(TransportContext.class);
+
+ private TransportConf transportConf;
+
+ private static final MessageEncoder ENCODER = MessageEncoder.INSTANCE;
+
+ public TransportContext(TransportConf transportConf) {
+ this.transportConf = transportConf;
+ }
+
+ public TransportClientFactory createClientFactory() {
+ return new TransportClientFactory(this);
+ }
+
+ public TransportResponseHandler initializePipeline(
+ SocketChannel channel, ChannelInboundHandlerAdapter decoder) {
+ TransportResponseHandler responseHandler = new
TransportResponseHandler(channel);
+ channel
+ .pipeline()
+ .addLast("encoder", ENCODER) // out
+ .addLast("decoder", decoder) // in
+ .addLast(
+ "idleStateHandler", new IdleStateHandler(0, 0,
transportConf.connectionTimeoutMs() / 1000))
+ .addLast("responseHandler", responseHandler);
+ return responseHandler;
+ }
+
+ public TransportConf getConf() {
+ return transportConf;
+ }
+}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/handle/TransportResponseHandler.java
b/common/src/main/java/org/apache/uniffle/common/netty/handle/TransportResponseHandler.java
new file mode 100644
index 00000000..86dd9953
--- /dev/null
+++
b/common/src/main/java/org/apache/uniffle/common/netty/handle/TransportResponseHandler.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.handle;
+
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.netty.client.RpcResponseCallback;
+import org.apache.uniffle.common.netty.protocol.RpcResponse;
+import org.apache.uniffle.common.util.NettyUtils;
+
+
+public class TransportResponseHandler extends ChannelInboundHandlerAdapter {
+ private static final Logger logger =
LoggerFactory.getLogger(TransportResponseHandler.class);
+
+ private Map<Long, RpcResponseCallback> outstandingRpcRequests;
+ private Channel channel;
+
+ public TransportResponseHandler(Channel channel) {
+ this.channel = channel;
+ this.outstandingRpcRequests = new ConcurrentHashMap<>();
+ }
+
+ @Override
+ public void channelRead(ChannelHandlerContext ctx, Object msg) throws
Exception {
+ if (msg instanceof RpcResponse) {
+ RpcResponse responseMessage = (RpcResponse) msg;
+ RpcResponseCallback listener =
outstandingRpcRequests.get(responseMessage.getRequestId());
+ if (listener == null) {
+ logger.warn("Ignoring response from {} since it is not outstanding",
+ NettyUtils.getRemoteAddress(channel));
+ } else {
+ listener.onSuccess(responseMessage);
+ }
+ } else {
+ throw new RssException("receive unexpected message!");
+ }
+ super.channelRead(ctx, msg);
+ }
+
+ public void addResponseCallback(long requestId, RpcResponseCallback
callback) {
+ outstandingRpcRequests.put(requestId, callback);
+ }
+
+ public void removeRpcRequest(long requestId) {
+ outstandingRpcRequests.remove(requestId);
+ }
+
+
+}
diff --git
a/common/src/test/java/org/apache/uniffle/common/netty/EncoderAndDecoderTest.java
b/common/src/test/java/org/apache/uniffle/common/netty/EncoderAndDecoderTest.java
index 7adce841..63441b55 100644
---
a/common/src/test/java/org/apache/uniffle/common/netty/EncoderAndDecoderTest.java
+++
b/common/src/test/java/org/apache/uniffle/common/netty/EncoderAndDecoderTest.java
@@ -107,7 +107,7 @@ public class EncoderAndDecoderTest {
new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel ch) {
- ch.pipeline().addLast("ClientEncoder", new MessageEncoder())
+ ch.pipeline().addLast("ClientEncoder", MessageEncoder.INSTANCE)
.addLast("ClientDecoder", new TransportFrameDecoder())
.addLast("ClientResponseHandler", new MockResponseHandler());
channelRef.set(ch);
@@ -152,7 +152,7 @@ public class EncoderAndDecoderTest {
serverBootstrap.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(final SocketChannel ch) {
- ch.pipeline().addLast("ServerEncoder", new MessageEncoder())
+ ch.pipeline().addLast("ServerEncoder", MessageEncoder.INSTANCE)
.addLast("ServerDecoder", new TransportFrameDecoder())
.addLast("ServerResponseHandler", new MockResponseHandler());
}
diff --git
a/common/src/test/java/org/apache/uniffle/common/netty/client/TransportClientFactoryTest.java
b/common/src/test/java/org/apache/uniffle/common/netty/client/TransportClientFactoryTest.java
new file mode 100644
index 00000000..cdf8958c
--- /dev/null
+++
b/common/src/test/java/org/apache/uniffle/common/netty/client/TransportClientFactoryTest.java
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.client;
+
+import java.io.IOException;
+
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+
+import org.apache.uniffle.common.config.RssBaseConf;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class TransportClientFactoryTest extends TransportClientTestBase {
+
+ private static int SERVER_PORT_RANGE_START = 10000;
+ private static int SERVER_PORT_RANGE_END = 10005;
+
+ @BeforeAll
+ public static void setupServer() {
+ for (int i = SERVER_PORT_RANGE_START; i < SERVER_PORT_RANGE_END + 1; i++) {
+ mockServers.add(new MockServer(i));
+ }
+ startMockServer();
+ }
+
+ @Test
+ public void testCreateClient() throws IOException, InterruptedException {
+ RssBaseConf rssBaseConf = new RssBaseConf();
+ rssBaseConf.setInteger("rss.client.netty.client.connections.per.peer", 1);
+ TransportConf transportConf = new TransportConf(rssBaseConf);
+ TransportContext transportContext = new TransportContext(transportConf);
+ TransportClient transportClient1 = transportContext.createClientFactory()
+ .createClient("localhost", SERVER_PORT_RANGE_START, 1);
+ assertTrue(transportClient1.isActive());
+ transportClient1.close();
+
+ TransportClient transportClient2 = transportContext.createClientFactory()
+ .createClient("localhost", SERVER_PORT_RANGE_START, 1);
+ assertNotEquals(transportClient1, transportClient2);
+ assertTrue(transportClient2.isActive());
+ }
+
+ @Test
+ public void testClientReuse() throws IOException, InterruptedException {
+ RssBaseConf rssBaseConf = new RssBaseConf();
+ TransportConf transportConf = new TransportConf(rssBaseConf);
+ TransportContext transportContext = new TransportContext(transportConf);
+ TransportClientFactory transportClientFactory =
transportContext.createClientFactory();
+ TransportClient client1 = transportClientFactory.createClient("localhost",
SERVER_PORT_RANGE_START, 1);
+ TransportClient client2 = transportClientFactory.createClient("localhost",
SERVER_PORT_RANGE_START, 1);
+ assertEquals(client1, client2);
+ }
+
+ @Test
+ public void testClientDiffPartition() throws IOException,
InterruptedException {
+ RssBaseConf rssBaseConf = new RssBaseConf();
+ rssBaseConf.setInteger("rss.client.netty.client.connections.per.peer", 1);
+ TransportConf transportConf = new TransportConf(rssBaseConf);
+ TransportContext transportContext = new TransportContext(transportConf);
+ TransportClientFactory transportClientFactory =
transportContext.createClientFactory();
+ TransportClient client1 = transportClientFactory.createClient("localhost",
SERVER_PORT_RANGE_START, 1);
+ TransportClient client2 = transportClientFactory.createClient("localhost",
SERVER_PORT_RANGE_START, 2);
+ assertEquals(client1, client2);
+ transportClientFactory.close();
+
+ rssBaseConf.setInteger("rss.client.netty.client.connections.per.peer", 10);
+ transportConf = new TransportConf(rssBaseConf);
+ transportContext = new TransportContext(transportConf);
+ transportClientFactory = transportContext.createClientFactory();
+ client1 = transportClientFactory.createClient("localhost",
SERVER_PORT_RANGE_START, 1);
+ client2 = transportClientFactory.createClient("localhost",
SERVER_PORT_RANGE_START, 2);
+ assertNotEquals(client1, client2);
+ transportClientFactory.close();
+ }
+
+ @Test
+ public void testClientDiffServer() throws IOException, InterruptedException {
+ RssBaseConf rssBaseConf = new RssBaseConf();
+ rssBaseConf.setInteger("rss.client.netty.client.connections.per.peer", 1);
+ TransportConf transportConf = new TransportConf(rssBaseConf);
+ TransportContext transportContext = new TransportContext(transportConf);
+ TransportClientFactory transportClientFactory =
transportContext.createClientFactory();
+ TransportClient client1 = transportClientFactory.createClient("localhost",
SERVER_PORT_RANGE_START, 1);
+ TransportClient client2 = transportClientFactory.createClient("localhost",
SERVER_PORT_RANGE_START + 1, 1);
+ assertNotEquals(client1, client2);
+ transportClientFactory.close();
+ }
+
+}
diff --git
a/common/src/test/java/org/apache/uniffle/common/netty/client/TransportClientTestBase.java
b/common/src/test/java/org/apache/uniffle/common/netty/client/TransportClientTestBase.java
new file mode 100644
index 00000000..dbe7fb41
--- /dev/null
+++
b/common/src/test/java/org/apache/uniffle/common/netty/client/TransportClientTestBase.java
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.client;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.collect.Lists;
+import io.netty.bootstrap.ServerBootstrap;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.ChannelPipeline;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.nio.NioEventLoopGroup;
+import io.netty.channel.socket.SocketChannel;
+import io.netty.channel.socket.nio.NioServerSocketChannel;
+import org.junit.jupiter.api.AfterAll;
+
+public abstract class TransportClientTestBase {
+
+ protected static List<MockServer> mockServers = Lists.newArrayList();
+
+ protected static void startMockServer() {
+ for (MockServer shuffleServer : mockServers) {
+ try {
+ shuffleServer.start();
+ } catch (IOException e) {
+ throw new RuntimeException(String.format("start mock server on port %s
failed", shuffleServer.port), e);
+ }
+ }
+ }
+
+
+ @AfterAll
+ public static void shutdownServers() throws Exception {
+ for (MockServer shuffleServer : mockServers) {
+ shuffleServer.stop();
+ }
+ mockServers.clear();
+ }
+
+ public static class MockServer {
+ ServerBootstrap bootstrap;
+ ChannelFuture channelFuture;
+ private EventLoopGroup bossGroup;
+ private EventLoopGroup workerGroup;
+ int port;
+
+ public MockServer(int port) {
+ this.port = port;
+ this.bossGroup = new NioEventLoopGroup(1);
+ this.workerGroup = new NioEventLoopGroup(2);
+ }
+
+ public void start() throws IOException {
+
+ try {
+ bootstrap = new ServerBootstrap();
+ bootstrap.group(bossGroup, workerGroup)
+ .channel(NioServerSocketChannel.class)
+ .childHandler(new ChannelInitializer<SocketChannel>() {
+ @Override
+ public void initChannel(SocketChannel ch) throws Exception {
+ ChannelPipeline p = ch.pipeline();
+ p.addLast(new MockEchoServerHandler());
+ }
+ });
+ channelFuture = bootstrap.bind(port).sync();
+ } catch (InterruptedException e) {
+ stop();
+ }
+ }
+
+ public void stop() {
+ if (channelFuture != null) {
+ channelFuture.channel().close().awaitUninterruptibly(10L,
TimeUnit.SECONDS);
+ channelFuture = null;
+ }
+ if (bossGroup != null) {
+ bossGroup.shutdownGracefully();
+ workerGroup.shutdownGracefully();
+ bossGroup = null;
+ workerGroup = null;
+ }
+ }
+ }
+
+ static class MockEchoServerHandler extends ChannelInboundHandlerAdapter {
+
+ @Override
+ public void channelRead(ChannelHandlerContext ctx, Object msg) {
+ ctx.writeAndFlush(msg);
+ }
+
+ @Override
+ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
+ cause.printStackTrace();
+ ctx.close();
+ }
+ }
+}