http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java ---------------------------------------------------------------------- diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java new file mode 100644 index 0000000..f55b884 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.buffer; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBufInputStream; +import io.netty.buffer.Unpooled; + +/** + * A {@link ManagedBuffer} backed by {@link ByteBuffer}. + */ +public final class NioManagedBuffer extends ManagedBuffer { + private final ByteBuffer buf; + + public NioManagedBuffer(ByteBuffer buf) { + this.buf = buf; + } + + @Override + public long size() { + return buf.remaining(); + } + + @Override + public ByteBuffer nioByteBuffer() throws IOException { + return buf.duplicate(); + } + + @Override + public InputStream createInputStream() throws IOException { + return new ByteBufInputStream(Unpooled.wrappedBuffer(buf)); + } + + @Override + public ManagedBuffer retain() { + return this; + } + + @Override + public ManagedBuffer release() { + return this; + } + + @Override + public Object convertToNetty() throws IOException { + return Unpooled.wrappedBuffer(buf); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("buf", buf) + .toString(); + } +} +
http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java ---------------------------------------------------------------------- diff --git a/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java b/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java new file mode 100644 index 0000000..1fbdcd6 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +/** + * General exception caused by a remote exception while fetching a chunk. + */ +public class ChunkFetchFailureException extends RuntimeException { + public ChunkFetchFailureException(String errorMsg, Throwable cause) { + super(errorMsg, cause); + } + + public ChunkFetchFailureException(String errorMsg) { + super(errorMsg); + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java ---------------------------------------------------------------------- diff --git a/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java b/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java new file mode 100644 index 0000000..519e6cb --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * Callback for the result of a single chunk result. For a single stream, the callbacks are + * guaranteed to be called by the same thread in the same order as the requests for chunks were + * made. + * + * Note that if a general stream failure occurs, all outstanding chunk requests may be failed. + */ +public interface ChunkReceivedCallback { + /** + * Called upon receipt of a particular chunk. + * + * The given buffer will initially have a refcount of 1, but will be release()'d as soon as this + * call returns. You must therefore either retain() the buffer or copy its contents before + * returning. + */ + void onSuccess(int chunkIndex, ManagedBuffer buffer); + + /** + * Called upon failure to fetch a particular chunk. Note that this may actually be called due + * to failure to fetch a prior chunk in this stream. + * + * After receiving a failure, the stream may or may not be valid. The client should not assume + * that the server's side of the stream has been closed. + */ + void onFailure(int chunkIndex, Throwable e); +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java ---------------------------------------------------------------------- diff --git a/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java new file mode 100644 index 0000000..6ec960d --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +/** + * Callback for the result of a single RPC. This will be invoked once with either success or + * failure. + */ +public interface RpcResponseCallback { + /** Successful serialized result from server. */ + void onSuccess(byte[] response); + + /** Exception either propagated from server or raised on client side. */ + void onFailure(Throwable e); +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java ---------------------------------------------------------------------- diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java new file mode 100644 index 0000000..b1732fc --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +import java.io.Closeable; +import java.util.UUID; +import java.util.concurrent.TimeUnit; + +import com.google.common.base.Preconditions; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.protocol.ChunkFetchRequest; +import org.apache.spark.network.protocol.RpcRequest; +import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.util.NettyUtils; + +/** + * Client for fetching consecutive chunks of a pre-negotiated stream. This API is intended to allow + * efficient transfer of a large amount of data, broken up into chunks with size ranging from + * hundreds of KB to a few MB. + * + * Note that while this client deals with the fetching of chunks from a stream (i.e., data plane), + * the actual setup of the streams is done outside the scope of the transport layer. The convenience + * method "sendRPC" is provided to enable control plane communication between the client and server + * to perform this setup. + * + * For example, a typical workflow might be: + * client.sendRPC(new OpenFile("/foo")) --> returns StreamId = 100 + * client.fetchChunk(streamId = 100, chunkIndex = 0, callback) + * client.fetchChunk(streamId = 100, chunkIndex = 1, callback) + * ... + * client.sendRPC(new CloseStream(100)) + * + * Construct an instance of TransportClient using {@link TransportClientFactory}. A single + * TransportClient may be used for multiple streams, but any given stream must be restricted to a + * single client, in order to avoid out-of-order responses. + * + * NB: This class is used to make requests to the server, while {@link TransportResponseHandler} is + * responsible for handling responses from the server. + * + * Concurrency: thread safe and can be called from multiple threads. + */ +public class TransportClient implements Closeable { + private final Logger logger = LoggerFactory.getLogger(TransportClient.class); + + private final Channel channel; + private final TransportResponseHandler handler; + + public TransportClient(Channel channel, TransportResponseHandler handler) { + this.channel = Preconditions.checkNotNull(channel); + this.handler = Preconditions.checkNotNull(handler); + } + + public boolean isActive() { + return channel.isOpen() || channel.isActive(); + } + + /** + * Requests a single chunk from the remote side, from the pre-negotiated streamId. + * + * Chunk indices go from 0 onwards. It is valid to request the same chunk multiple times, though + * some streams may not support this. + * + * Multiple fetchChunk requests may be outstanding simultaneously, and the chunks are guaranteed + * to be returned in the same order that they were requested, assuming only a single + * TransportClient is used to fetch the chunks. + * + * @param streamId Identifier that refers to a stream in the remote StreamManager. This should + * be agreed upon by client and server beforehand. + * @param chunkIndex 0-based index of the chunk to fetch + * @param callback Callback invoked upon successful receipt of chunk, or upon any failure. + */ + public void fetchChunk( + long streamId, + final int chunkIndex, + final ChunkReceivedCallback callback) { + final String serverAddr = NettyUtils.getRemoteAddress(channel); + final long startTime = System.currentTimeMillis(); + logger.debug("Sending fetch chunk request {} to {}", chunkIndex, serverAddr); + + final StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex); + handler.addFetchRequest(streamChunkId, callback); + + channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener( + new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isSuccess()) { + long timeTaken = System.currentTimeMillis() - startTime; + logger.trace("Sending request {} to {} took {} ms", streamChunkId, serverAddr, + timeTaken); + } else { + String errorMsg = String.format("Failed to send request %s to %s: %s", streamChunkId, + serverAddr, future.cause()); + logger.error(errorMsg, future.cause()); + handler.removeFetchRequest(streamChunkId); + callback.onFailure(chunkIndex, new RuntimeException(errorMsg, future.cause())); + channel.close(); + } + } + }); + } + + /** + * Sends an opaque message to the RpcHandler on the server-side. The callback will be invoked + * with the server's response or upon any failure. + */ + public void sendRpc(byte[] message, final RpcResponseCallback callback) { + final String serverAddr = NettyUtils.getRemoteAddress(channel); + final long startTime = System.currentTimeMillis(); + logger.trace("Sending RPC to {}", serverAddr); + + final long requestId = UUID.randomUUID().getLeastSignificantBits(); + handler.addRpcRequest(requestId, callback); + + channel.writeAndFlush(new RpcRequest(requestId, message)).addListener( + new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isSuccess()) { + long timeTaken = System.currentTimeMillis() - startTime; + logger.trace("Sending request {} to {} took {} ms", requestId, serverAddr, timeTaken); + } else { + String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId, + serverAddr, future.cause()); + logger.error(errorMsg, future.cause()); + handler.removeRpcRequest(requestId); + callback.onFailure(new RuntimeException(errorMsg, future.cause())); + channel.close(); + } + } + }); + } + + @Override + public void close() { + // close is a local operation and should finish with milliseconds; timeout just to be safe + channel.close().awaitUninterruptibly(10, TimeUnit.SECONDS); + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java ---------------------------------------------------------------------- diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java new file mode 100644 index 0000000..10eb9ef --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +import java.io.Closeable; +import java.lang.reflect.Field; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicReference; + +import io.netty.bootstrap.Bootstrap; +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.util.internal.PlatformDependent; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.server.TransportChannelHandler; +import org.apache.spark.network.util.IOMode; +import org.apache.spark.network.util.NettyUtils; +import org.apache.spark.network.util.TransportConf; + +/** + * Factory for creating {@link TransportClient}s by using createClient. + * + * The factory maintains a connection pool to other hosts and should return the same + * {@link TransportClient} for the same remote host. It also shares a single worker thread pool for + * all {@link TransportClient}s. + */ +public class TransportClientFactory implements Closeable { + private final Logger logger = LoggerFactory.getLogger(TransportClientFactory.class); + + private final TransportContext context; + private final TransportConf conf; + private final ConcurrentHashMap<SocketAddress, TransportClient> connectionPool; + + private final Class<? extends Channel> socketChannelClass; + private final EventLoopGroup workerGroup; + + public TransportClientFactory(TransportContext context) { + this.context = context; + this.conf = context.getConf(); + this.connectionPool = new ConcurrentHashMap<SocketAddress, TransportClient>(); + + IOMode ioMode = IOMode.valueOf(conf.ioMode()); + this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode); + // TODO: Make thread pool name configurable. + this.workerGroup = NettyUtils.createEventLoop(ioMode, conf.clientThreads(), "shuffle-client"); + } + + /** + * Create a new BlockFetchingClient connecting to the given remote host / port. + * + * This blocks until a connection is successfully established. + * + * Concurrency: This method is safe to call from multiple threads. + */ + public TransportClient createClient(String remoteHost, int remotePort) throws TimeoutException { + // Get connection from the connection pool first. + // If it is not found or not active, create a new one. + final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); + TransportClient cachedClient = connectionPool.get(address); + if (cachedClient != null && cachedClient.isActive()) { + return cachedClient; + } else if (cachedClient != null) { + connectionPool.remove(address, cachedClient); // Remove inactive clients. + } + + logger.debug("Creating new connection to " + address); + + Bootstrap bootstrap = new Bootstrap(); + bootstrap.group(workerGroup) + .channel(socketChannelClass) + // Disable Nagle's Algorithm since we don't want packets to wait + .option(ChannelOption.TCP_NODELAY, true) + .option(ChannelOption.SO_KEEPALIVE, true) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs()); + + // Use pooled buffers to reduce temporary buffer allocation + bootstrap.option(ChannelOption.ALLOCATOR, createPooledByteBufAllocator()); + + final AtomicReference<TransportClient> client = new AtomicReference<TransportClient>(); + + bootstrap.handler(new ChannelInitializer<SocketChannel>() { + @Override + public void initChannel(SocketChannel ch) { + TransportChannelHandler clientHandler = context.initializePipeline(ch); + client.set(clientHandler.getClient()); + } + }); + + // Connect to the remote server + ChannelFuture cf = bootstrap.connect(address); + if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) { + throw new TimeoutException( + String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs())); + } else if (cf.cause() != null) { + throw new RuntimeException(String.format("Failed to connect to %s", address), cf.cause()); + } + + // Successful connection + assert client.get() != null : "Channel future completed successfully with null client"; + TransportClient oldClient = connectionPool.putIfAbsent(address, client.get()); + if (oldClient == null) { + return client.get(); + } else { + logger.debug("Two clients were created concurrently, second one will be disposed."); + client.get().close(); + return oldClient; + } + } + + /** Close all connections in the connection pool, and shutdown the worker thread pool. */ + @Override + public void close() { + for (TransportClient client : connectionPool.values()) { + try { + client.close(); + } catch (RuntimeException e) { + logger.warn("Ignoring exception during close", e); + } + } + connectionPool.clear(); + + if (workerGroup != null) { + workerGroup.shutdownGracefully(); + } + } + + /** + * Create a pooled ByteBuf allocator but disables the thread-local cache. Thread-local caches + * are disabled because the ByteBufs are allocated by the event loop thread, but released by the + * executor thread rather than the event loop thread. Those thread-local caches actually delay + * the recycling of buffers, leading to larger memory usage. + */ + private PooledByteBufAllocator createPooledByteBufAllocator() { + return new PooledByteBufAllocator( + PlatformDependent.directBufferPreferred(), + getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"), + getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"), + getPrivateStaticField("DEFAULT_PAGE_SIZE"), + getPrivateStaticField("DEFAULT_MAX_ORDER"), + 0, // tinyCacheSize + 0, // smallCacheSize + 0 // normalCacheSize + ); + } + + /** Used to get defaults from Netty's private static fields. */ + private int getPrivateStaticField(String name) { + try { + Field f = PooledByteBufAllocator.DEFAULT.getClass().getDeclaredField(name); + f.setAccessible(true); + return f.getInt(null); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java ---------------------------------------------------------------------- diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java new file mode 100644 index 0000000..d896559 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import com.google.common.annotations.VisibleForTesting; +import io.netty.channel.Channel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.protocol.ChunkFetchFailure; +import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.ResponseMessage; +import org.apache.spark.network.protocol.RpcFailure; +import org.apache.spark.network.protocol.RpcResponse; +import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.server.MessageHandler; +import org.apache.spark.network.util.NettyUtils; + +/** + * Handler that processes server responses, in response to requests issued from a + * [[TransportClient]]. It works by tracking the list of outstanding requests (and their callbacks). + * + * Concurrency: thread safe and can be called from multiple threads. + */ +public class TransportResponseHandler extends MessageHandler<ResponseMessage> { + private final Logger logger = LoggerFactory.getLogger(TransportResponseHandler.class); + + private final Channel channel; + + private final Map<StreamChunkId, ChunkReceivedCallback> outstandingFetches; + + private final Map<Long, RpcResponseCallback> outstandingRpcs; + + public TransportResponseHandler(Channel channel) { + this.channel = channel; + this.outstandingFetches = new ConcurrentHashMap<StreamChunkId, ChunkReceivedCallback>(); + this.outstandingRpcs = new ConcurrentHashMap<Long, RpcResponseCallback>(); + } + + public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) { + outstandingFetches.put(streamChunkId, callback); + } + + public void removeFetchRequest(StreamChunkId streamChunkId) { + outstandingFetches.remove(streamChunkId); + } + + public void addRpcRequest(long requestId, RpcResponseCallback callback) { + outstandingRpcs.put(requestId, callback); + } + + public void removeRpcRequest(long requestId) { + outstandingRpcs.remove(requestId); + } + + /** + * Fire the failure callback for all outstanding requests. This is called when we have an + * uncaught exception or pre-mature connection termination. + */ + private void failOutstandingRequests(Throwable cause) { + for (Map.Entry<StreamChunkId, ChunkReceivedCallback> entry : outstandingFetches.entrySet()) { + entry.getValue().onFailure(entry.getKey().chunkIndex, cause); + } + for (Map.Entry<Long, RpcResponseCallback> entry : outstandingRpcs.entrySet()) { + entry.getValue().onFailure(cause); + } + + // It's OK if new fetches appear, as they will fail immediately. + outstandingFetches.clear(); + outstandingRpcs.clear(); + } + + @Override + public void channelUnregistered() { + if (numOutstandingRequests() > 0) { + String remoteAddress = NettyUtils.getRemoteAddress(channel); + logger.error("Still have {} requests outstanding when connection from {} is closed", + numOutstandingRequests(), remoteAddress); + failOutstandingRequests(new RuntimeException("Connection from " + remoteAddress + " closed")); + } + } + + @Override + public void exceptionCaught(Throwable cause) { + if (numOutstandingRequests() > 0) { + String remoteAddress = NettyUtils.getRemoteAddress(channel); + logger.error("Still have {} requests outstanding when connection from {} is closed", + numOutstandingRequests(), remoteAddress); + failOutstandingRequests(cause); + } + } + + @Override + public void handle(ResponseMessage message) { + String remoteAddress = NettyUtils.getRemoteAddress(channel); + if (message instanceof ChunkFetchSuccess) { + ChunkFetchSuccess resp = (ChunkFetchSuccess) message; + ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId); + if (listener == null) { + logger.warn("Ignoring response for block {} from {} since it is not outstanding", + resp.streamChunkId, remoteAddress); + resp.buffer.release(); + } else { + outstandingFetches.remove(resp.streamChunkId); + listener.onSuccess(resp.streamChunkId.chunkIndex, resp.buffer); + resp.buffer.release(); + } + } else if (message instanceof ChunkFetchFailure) { + ChunkFetchFailure resp = (ChunkFetchFailure) message; + ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId); + if (listener == null) { + logger.warn("Ignoring response for block {} from {} ({}) since it is not outstanding", + resp.streamChunkId, remoteAddress, resp.errorString); + } else { + outstandingFetches.remove(resp.streamChunkId); + listener.onFailure(resp.streamChunkId.chunkIndex, new ChunkFetchFailureException( + "Failure while fetching " + resp.streamChunkId + ": " + resp.errorString)); + } + } else if (message instanceof RpcResponse) { + RpcResponse resp = (RpcResponse) message; + RpcResponseCallback listener = outstandingRpcs.get(resp.requestId); + if (listener == null) { + logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding", + resp.requestId, remoteAddress, resp.response.length); + } else { + outstandingRpcs.remove(resp.requestId); + listener.onSuccess(resp.response); + } + } else if (message instanceof RpcFailure) { + RpcFailure resp = (RpcFailure) message; + RpcResponseCallback listener = outstandingRpcs.get(resp.requestId); + if (listener == null) { + logger.warn("Ignoring response for RPC {} from {} ({}) since it is not outstanding", + resp.requestId, remoteAddress, resp.errorString); + } else { + outstandingRpcs.remove(resp.requestId); + listener.onFailure(new RuntimeException(resp.errorString)); + } + } else { + throw new IllegalStateException("Unknown response type: " + message.type()); + } + } + + /** Returns total number of outstanding requests (fetch requests + rpcs) */ + @VisibleForTesting + public int numOutstandingRequests() { + return outstandingFetches.size() + outstandingRpcs.size(); + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java ---------------------------------------------------------------------- diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java new file mode 100644 index 0000000..152af98 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Charsets; +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** + * Response to {@link ChunkFetchRequest} when there is an error fetching the chunk. + */ +public final class ChunkFetchFailure implements ResponseMessage { + public final StreamChunkId streamChunkId; + public final String errorString; + + public ChunkFetchFailure(StreamChunkId streamChunkId, String errorString) { + this.streamChunkId = streamChunkId; + this.errorString = errorString; + } + + @Override + public Type type() { return Type.ChunkFetchFailure; } + + @Override + public int encodedLength() { + return streamChunkId.encodedLength() + 4 + errorString.getBytes(Charsets.UTF_8).length; + } + + @Override + public void encode(ByteBuf buf) { + streamChunkId.encode(buf); + byte[] errorBytes = errorString.getBytes(Charsets.UTF_8); + buf.writeInt(errorBytes.length); + buf.writeBytes(errorBytes); + } + + public static ChunkFetchFailure decode(ByteBuf buf) { + StreamChunkId streamChunkId = StreamChunkId.decode(buf); + int numErrorStringBytes = buf.readInt(); + byte[] errorBytes = new byte[numErrorStringBytes]; + buf.readBytes(errorBytes); + return new ChunkFetchFailure(streamChunkId, new String(errorBytes, Charsets.UTF_8)); + } + + @Override + public boolean equals(Object other) { + if (other instanceof ChunkFetchFailure) { + ChunkFetchFailure o = (ChunkFetchFailure) other; + return streamChunkId.equals(o.streamChunkId) && errorString.equals(o.errorString); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamChunkId", streamChunkId) + .add("errorString", errorString) + .toString(); + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java ---------------------------------------------------------------------- diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java new file mode 100644 index 0000000..980947c --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** + * Request to fetch a sequence of a single chunk of a stream. This will correspond to a single + * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure). + */ +public final class ChunkFetchRequest implements RequestMessage { + public final StreamChunkId streamChunkId; + + public ChunkFetchRequest(StreamChunkId streamChunkId) { + this.streamChunkId = streamChunkId; + } + + @Override + public Type type() { return Type.ChunkFetchRequest; } + + @Override + public int encodedLength() { + return streamChunkId.encodedLength(); + } + + @Override + public void encode(ByteBuf buf) { + streamChunkId.encode(buf); + } + + public static ChunkFetchRequest decode(ByteBuf buf) { + return new ChunkFetchRequest(StreamChunkId.decode(buf)); + } + + @Override + public boolean equals(Object other) { + if (other instanceof ChunkFetchRequest) { + ChunkFetchRequest o = (ChunkFetchRequest) other; + return streamChunkId.equals(o.streamChunkId); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamChunkId", streamChunkId) + .toString(); + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java ---------------------------------------------------------------------- diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java new file mode 100644 index 0000000..ff49364 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; + +/** + * Response to {@link ChunkFetchRequest} when a chunk exists and has been successfully fetched. + * + * Note that the server-side encoding of this messages does NOT include the buffer itself, as this + * may be written by Netty in a more efficient manner (i.e., zero-copy write). + * Similarly, the client-side decoding will reuse the Netty ByteBuf as the buffer. + */ +public final class ChunkFetchSuccess implements ResponseMessage { + public final StreamChunkId streamChunkId; + public final ManagedBuffer buffer; + + public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) { + this.streamChunkId = streamChunkId; + this.buffer = buffer; + } + + @Override + public Type type() { return Type.ChunkFetchSuccess; } + + @Override + public int encodedLength() { + return streamChunkId.encodedLength(); + } + + /** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */ + @Override + public void encode(ByteBuf buf) { + streamChunkId.encode(buf); + } + + /** Decoding uses the given ByteBuf as our data, and will retain() it. */ + public static ChunkFetchSuccess decode(ByteBuf buf) { + StreamChunkId streamChunkId = StreamChunkId.decode(buf); + buf.retain(); + NettyManagedBuffer managedBuf = new NettyManagedBuffer(buf.duplicate()); + return new ChunkFetchSuccess(streamChunkId, managedBuf); + } + + @Override + public boolean equals(Object other) { + if (other instanceof ChunkFetchSuccess) { + ChunkFetchSuccess o = (ChunkFetchSuccess) other; + return streamChunkId.equals(o.streamChunkId) && buffer.equals(o.buffer); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamChunkId", streamChunkId) + .add("buffer", buffer) + .toString(); + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java ---------------------------------------------------------------------- diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java b/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java new file mode 100644 index 0000000..b4e2994 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import io.netty.buffer.ByteBuf; + +/** + * Interface for an object which can be encoded into a ByteBuf. Multiple Encodable objects are + * stored in a single, pre-allocated ByteBuf, so Encodables must also provide their length. + * + * Encodable objects should provide a static "decode(ByteBuf)" method which is invoked by + * {@link MessageDecoder}. During decoding, if the object uses the ByteBuf as its data (rather than + * just copying data from it), then you must retain() the ByteBuf. + * + * Additionally, when adding a new Encodable Message, add it to {@link Message.Type}. + */ +public interface Encodable { + /** Number of bytes of the encoded form of this object. */ + int encodedLength(); + + /** + * Serializes this object by writing into the given ByteBuf. + * This method must write exactly encodedLength() bytes. + */ + void encode(ByteBuf buf); +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/main/java/org/apache/spark/network/protocol/Message.java ---------------------------------------------------------------------- diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java new file mode 100644 index 0000000..d568370 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import io.netty.buffer.ByteBuf; + +/** An on-the-wire transmittable message. */ +public interface Message extends Encodable { + /** Used to identify this request type. */ + Type type(); + + /** Preceding every serialized Message is its type, which allows us to deserialize it. */ + public static enum Type implements Encodable { + ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2), + RpcRequest(3), RpcResponse(4), RpcFailure(5); + + private final byte id; + + private Type(int id) { + assert id < 128 : "Cannot have more than 128 message types"; + this.id = (byte) id; + } + + public byte id() { return id; } + + @Override public int encodedLength() { return 1; } + + @Override public void encode(ByteBuf buf) { buf.writeByte(id); } + + public static Type decode(ByteBuf buf) { + byte id = buf.readByte(); + switch (id) { + case 0: return ChunkFetchRequest; + case 1: return ChunkFetchSuccess; + case 2: return ChunkFetchFailure; + case 3: return RpcRequest; + case 4: return RpcResponse; + case 5: return RpcFailure; + default: throw new IllegalArgumentException("Unknown message type: " + id); + } + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java ---------------------------------------------------------------------- diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java new file mode 100644 index 0000000..81f8d7f --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.util.List; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToMessageDecoder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Decoder used by the client side to encode server-to-client responses. + * This encoder is stateless so it is safe to be shared by multiple threads. + */ [email protected] +public final class MessageDecoder extends MessageToMessageDecoder<ByteBuf> { + + private final Logger logger = LoggerFactory.getLogger(MessageDecoder.class); + @Override + public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) { + Message.Type msgType = Message.Type.decode(in); + Message decoded = decode(msgType, in); + assert decoded.type() == msgType; + logger.trace("Received message " + msgType + ": " + decoded); + out.add(decoded); + } + + private Message decode(Message.Type msgType, ByteBuf in) { + switch (msgType) { + case ChunkFetchRequest: + return ChunkFetchRequest.decode(in); + + case ChunkFetchSuccess: + return ChunkFetchSuccess.decode(in); + + case ChunkFetchFailure: + return ChunkFetchFailure.decode(in); + + case RpcRequest: + return RpcRequest.decode(in); + + case RpcResponse: + return RpcResponse.decode(in); + + case RpcFailure: + return RpcFailure.decode(in); + + default: + throw new IllegalArgumentException("Unexpected message type: " + msgType); + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java ---------------------------------------------------------------------- diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java new file mode 100644 index 0000000..4cb8bec --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.util.List; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToMessageEncoder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Encoder used by the server side to encode server-to-client responses. + * This encoder is stateless so it is safe to be shared by multiple threads. + */ [email protected] +public final class MessageEncoder extends MessageToMessageEncoder<Message> { + + private final Logger logger = LoggerFactory.getLogger(MessageEncoder.class); + + /*** + * Encodes a Message by invoking its encode() method. For non-data messages, we will add one + * ByteBuf to 'out' containing the total frame length, the message type, and the message itself. + * In the case of a ChunkFetchSuccess, we will also add the ManagedBuffer corresponding to the + * data to 'out', in order to enable zero-copy transfer. + */ + @Override + public void encode(ChannelHandlerContext ctx, Message in, List<Object> out) { + Object body = null; + long bodyLength = 0; + + // Only ChunkFetchSuccesses have data besides the header. + // The body is used in order to enable zero-copy transfer for the payload. + if (in instanceof ChunkFetchSuccess) { + ChunkFetchSuccess resp = (ChunkFetchSuccess) in; + try { + bodyLength = resp.buffer.size(); + body = resp.buffer.convertToNetty(); + } catch (Exception e) { + // Re-encode this message as BlockFetchFailure. + logger.error(String.format("Error opening block %s for client %s", + resp.streamChunkId, ctx.channel().remoteAddress()), e); + encode(ctx, new ChunkFetchFailure(resp.streamChunkId, e.getMessage()), out); + return; + } + } + + Message.Type msgType = in.type(); + // All messages have the frame length, message type, and message itself. + int headerLength = 8 + msgType.encodedLength() + in.encodedLength(); + long frameLength = headerLength + bodyLength; + ByteBuf header = ctx.alloc().buffer(headerLength); + header.writeLong(frameLength); + msgType.encode(header); + in.encode(header); + assert header.writableBytes() == 0; + + out.add(header); + if (body != null && bodyLength > 0) { + out.add(body); + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java ---------------------------------------------------------------------- diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java new file mode 100644 index 0000000..31b15bb --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import org.apache.spark.network.protocol.Message; + +/** Messages from the client to the server. */ +public interface RequestMessage extends Message { + // token interface +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java ---------------------------------------------------------------------- diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java new file mode 100644 index 0000000..6edffd1 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import org.apache.spark.network.protocol.Message; + +/** Messages from the server to the client. */ +public interface ResponseMessage extends Message { + // token interface +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java ---------------------------------------------------------------------- diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java new file mode 100644 index 0000000..e239d4f --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Charsets; +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** Response to {@link RpcRequest} for a failed RPC. */ +public final class RpcFailure implements ResponseMessage { + public final long requestId; + public final String errorString; + + public RpcFailure(long requestId, String errorString) { + this.requestId = requestId; + this.errorString = errorString; + } + + @Override + public Type type() { return Type.RpcFailure; } + + @Override + public int encodedLength() { + return 8 + 4 + errorString.getBytes(Charsets.UTF_8).length; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(requestId); + byte[] errorBytes = errorString.getBytes(Charsets.UTF_8); + buf.writeInt(errorBytes.length); + buf.writeBytes(errorBytes); + } + + public static RpcFailure decode(ByteBuf buf) { + long requestId = buf.readLong(); + int numErrorStringBytes = buf.readInt(); + byte[] errorBytes = new byte[numErrorStringBytes]; + buf.readBytes(errorBytes); + return new RpcFailure(requestId, new String(errorBytes, Charsets.UTF_8)); + } + + @Override + public boolean equals(Object other) { + if (other instanceof RpcFailure) { + RpcFailure o = (RpcFailure) other; + return requestId == o.requestId && errorString.equals(o.errorString); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("requestId", requestId) + .add("errorString", errorString) + .toString(); + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java ---------------------------------------------------------------------- diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java new file mode 100644 index 0000000..099e934 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** + * A generic RPC which is handled by a remote {@link org.apache.spark.network.server.RpcHandler}. + * This will correspond to a single + * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure). + */ +public final class RpcRequest implements RequestMessage { + /** Used to link an RPC request with its response. */ + public final long requestId; + + /** Serialized message to send to remote RpcHandler. */ + public final byte[] message; + + public RpcRequest(long requestId, byte[] message) { + this.requestId = requestId; + this.message = message; + } + + @Override + public Type type() { return Type.RpcRequest; } + + @Override + public int encodedLength() { + return 8 + 4 + message.length; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(requestId); + buf.writeInt(message.length); + buf.writeBytes(message); + } + + public static RpcRequest decode(ByteBuf buf) { + long requestId = buf.readLong(); + int messageLen = buf.readInt(); + byte[] message = new byte[messageLen]; + buf.readBytes(message); + return new RpcRequest(requestId, message); + } + + @Override + public boolean equals(Object other) { + if (other instanceof RpcRequest) { + RpcRequest o = (RpcRequest) other; + return requestId == o.requestId && Arrays.equals(message, o.message); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("requestId", requestId) + .add("message", message) + .toString(); + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java ---------------------------------------------------------------------- diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java new file mode 100644 index 0000000..ed47947 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** Response to {@link RpcRequest} for a successful RPC. */ +public final class RpcResponse implements ResponseMessage { + public final long requestId; + public final byte[] response; + + public RpcResponse(long requestId, byte[] response) { + this.requestId = requestId; + this.response = response; + } + + @Override + public Type type() { return Type.RpcResponse; } + + @Override + public int encodedLength() { return 8 + 4 + response.length; } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(requestId); + buf.writeInt(response.length); + buf.writeBytes(response); + } + + public static RpcResponse decode(ByteBuf buf) { + long requestId = buf.readLong(); + int responseLen = buf.readInt(); + byte[] response = new byte[responseLen]; + buf.readBytes(response); + return new RpcResponse(requestId, response); + } + + @Override + public boolean equals(Object other) { + if (other instanceof RpcResponse) { + RpcResponse o = (RpcResponse) other; + return requestId == o.requestId && Arrays.equals(response, o.response); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("requestId", requestId) + .add("response", response) + .toString(); + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java ---------------------------------------------------------------------- diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java new file mode 100644 index 0000000..d46a263 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** +* Encapsulates a request for a particular chunk of a stream. +*/ +public final class StreamChunkId implements Encodable { + public final long streamId; + public final int chunkIndex; + + public StreamChunkId(long streamId, int chunkIndex) { + this.streamId = streamId; + this.chunkIndex = chunkIndex; + } + + @Override + public int encodedLength() { + return 8 + 4; + } + + public void encode(ByteBuf buffer) { + buffer.writeLong(streamId); + buffer.writeInt(chunkIndex); + } + + public static StreamChunkId decode(ByteBuf buffer) { + assert buffer.readableBytes() >= 8 + 4; + long streamId = buffer.readLong(); + int chunkIndex = buffer.readInt(); + return new StreamChunkId(streamId, chunkIndex); + } + + @Override + public int hashCode() { + return Objects.hashCode(streamId, chunkIndex); + } + + @Override + public boolean equals(Object other) { + if (other instanceof StreamChunkId) { + StreamChunkId o = (StreamChunkId) other; + return streamId == o.streamId && chunkIndex == o.chunkIndex; + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamId", streamId) + .add("chunkIndex", chunkIndex) + .toString(); + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java ---------------------------------------------------------------------- diff --git a/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java new file mode 100644 index 0000000..9688705 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import java.util.Iterator; +import java.util.Map; +import java.util.Random; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * StreamManager which allows registration of an Iterator<ManagedBuffer>, which are individually + * fetched as chunks by the client. + */ +public class DefaultStreamManager extends StreamManager { + private final Logger logger = LoggerFactory.getLogger(DefaultStreamManager.class); + + private final AtomicLong nextStreamId; + private final Map<Long, StreamState> streams; + + /** State of a single stream. */ + private static class StreamState { + final Iterator<ManagedBuffer> buffers; + + // Used to keep track of the index of the buffer that the user has retrieved, just to ensure + // that the caller only requests each chunk one at a time, in order. + int curChunk = 0; + + StreamState(Iterator<ManagedBuffer> buffers) { + this.buffers = buffers; + } + } + + public DefaultStreamManager() { + // For debugging purposes, start with a random stream id to help identifying different streams. + // This does not need to be globally unique, only unique to this class. + nextStreamId = new AtomicLong((long) new Random().nextInt(Integer.MAX_VALUE) * 1000); + streams = new ConcurrentHashMap<Long, StreamState>(); + } + + @Override + public ManagedBuffer getChunk(long streamId, int chunkIndex) { + StreamState state = streams.get(streamId); + if (chunkIndex != state.curChunk) { + throw new IllegalStateException(String.format( + "Received out-of-order chunk index %s (expected %s)", chunkIndex, state.curChunk)); + } else if (!state.buffers.hasNext()) { + throw new IllegalStateException(String.format( + "Requested chunk index beyond end %s", chunkIndex)); + } + state.curChunk += 1; + ManagedBuffer nextChunk = state.buffers.next(); + + if (!state.buffers.hasNext()) { + logger.trace("Removing stream id {}", streamId); + streams.remove(streamId); + } + + return nextChunk; + } + + @Override + public void connectionTerminated(long streamId) { + // Release all remaining buffers. + StreamState state = streams.remove(streamId); + if (state != null && state.buffers != null) { + while (state.buffers.hasNext()) { + state.buffers.next().release(); + } + } + } + + /** + * Registers a stream of ManagedBuffers which are served as individual chunks one at a time to + * callers. Each ManagedBuffer will be release()'d after it is transferred on the wire. If a + * client connection is closed before the iterator is fully drained, then the remaining buffers + * will all be release()'d. + */ + public long registerStream(Iterator<ManagedBuffer> buffers) { + long myStreamId = nextStreamId.getAndIncrement(); + streams.put(myStreamId, new StreamState(buffers)); + return myStreamId; + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java ---------------------------------------------------------------------- diff --git a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java new file mode 100644 index 0000000..b80c151 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import org.apache.spark.network.protocol.Message; + +/** + * Handles either request or response messages coming off of Netty. A MessageHandler instance + * is associated with a single Netty Channel (though it may have multiple clients on the same + * Channel.) + */ +public abstract class MessageHandler<T extends Message> { + /** Handles the receipt of a single message. */ + public abstract void handle(T message); + + /** Invoked when an exception was caught on the Channel. */ + public abstract void exceptionCaught(Throwable cause); + + /** Invoked when the channel this MessageHandler is on has been unregistered. */ + public abstract void channelUnregistered(); +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java ---------------------------------------------------------------------- diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java new file mode 100644 index 0000000..f54a696 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; + +/** + * Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s. + */ +public interface RpcHandler { + /** + * Receive a single RPC message. Any exception thrown while in this method will be sent back to + * the client in string form as a standard RPC failure. + * + * @param client A channel client which enables the handler to make requests back to the sender + * of this RPC. + * @param message The serialized bytes of the RPC. + * @param callback Callback which should be invoked exactly once upon success or failure of the + * RPC. + */ + void receive(TransportClient client, byte[] message, RpcResponseCallback callback); +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java ---------------------------------------------------------------------- diff --git a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java new file mode 100644 index 0000000..5a9a14a --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * The StreamManager is used to fetch individual chunks from a stream. This is used in + * {@link TransportRequestHandler} in order to respond to fetchChunk() requests. Creation of the + * stream is outside the scope of the transport layer, but a given stream is guaranteed to be read + * by only one client connection, meaning that getChunk() for a particular stream will be called + * serially and that once the connection associated with the stream is closed, that stream will + * never be used again. + */ +public abstract class StreamManager { + /** + * Called in response to a fetchChunk() request. The returned buffer will be passed as-is to the + * client. A single stream will be associated with a single TCP connection, so this method + * will not be called in parallel for a particular stream. + * + * Chunks may be requested in any order, and requests may be repeated, but it is not required + * that implementations support this behavior. + * + * The returned ManagedBuffer will be release()'d after being written to the network. + * + * @param streamId id of a stream that has been previously registered with the StreamManager. + * @param chunkIndex 0-indexed chunk of the stream that's requested + */ + public abstract ManagedBuffer getChunk(long streamId, int chunkIndex); + + /** + * Indicates that the TCP connection that was tied to the given stream has been terminated. After + * this occurs, we are guaranteed not to read from the stream again, so any state can be cleaned + * up. + */ + public void connectionTerminated(long streamId) { } +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java ---------------------------------------------------------------------- diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java new file mode 100644 index 0000000..e491367 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportResponseHandler; +import org.apache.spark.network.protocol.Message; +import org.apache.spark.network.protocol.RequestMessage; +import org.apache.spark.network.protocol.ResponseMessage; +import org.apache.spark.network.util.NettyUtils; + +/** + * The single Transport-level Channel handler which is used for delegating requests to the + * {@link TransportRequestHandler} and responses to the {@link TransportResponseHandler}. + * + * All channels created in the transport layer are bidirectional. When the Client initiates a Netty + * Channel with a RequestMessage (which gets handled by the Server's RequestHandler), the Server + * will produce a ResponseMessage (handled by the Client's ResponseHandler). However, the Server + * also gets a handle on the same Channel, so it may then begin to send RequestMessages to the + * Client. + * This means that the Client also needs a RequestHandler and the Server needs a ResponseHandler, + * for the Client's responses to the Server's requests. + */ +public class TransportChannelHandler extends SimpleChannelInboundHandler<Message> { + private final Logger logger = LoggerFactory.getLogger(TransportChannelHandler.class); + + private final TransportClient client; + private final TransportResponseHandler responseHandler; + private final TransportRequestHandler requestHandler; + + public TransportChannelHandler( + TransportClient client, + TransportResponseHandler responseHandler, + TransportRequestHandler requestHandler) { + this.client = client; + this.responseHandler = responseHandler; + this.requestHandler = requestHandler; + } + + public TransportClient getClient() { + return client; + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + logger.warn("Exception in connection from " + NettyUtils.getRemoteAddress(ctx.channel()), + cause); + requestHandler.exceptionCaught(cause); + responseHandler.exceptionCaught(cause); + ctx.close(); + } + + @Override + public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { + try { + requestHandler.channelUnregistered(); + } catch (RuntimeException e) { + logger.error("Exception from request handler while unregistering channel", e); + } + try { + responseHandler.channelUnregistered(); + } catch (RuntimeException e) { + logger.error("Exception from response handler while unregistering channel", e); + } + super.channelUnregistered(ctx); + } + + @Override + public void channelRead0(ChannelHandlerContext ctx, Message request) { + if (request instanceof RequestMessage) { + requestHandler.handle((RequestMessage) request); + } else { + responseHandler.handle((ResponseMessage) request); + } + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
