Repository: spark Updated Branches: refs/heads/master 51ce99735 -> dff015533
http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java ---------------------------------------------------------------------- diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java new file mode 100644 index 0000000..352f865 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import java.util.Set; + +import com.google.common.base.Throwables; +import com.google.common.collect.Sets; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.RequestMessage; +import org.apache.spark.network.protocol.ChunkFetchRequest; +import org.apache.spark.network.protocol.RpcRequest; +import org.apache.spark.network.protocol.ChunkFetchFailure; +import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.RpcFailure; +import org.apache.spark.network.protocol.RpcResponse; +import org.apache.spark.network.util.NettyUtils; + +/** + * A handler that processes requests from clients and writes chunk data back. Each handler is + * attached to a single Netty channel, and keeps track of which streams have been fetched via this + * channel, in order to clean them up if the channel is terminated (see #channelUnregistered). + * + * The messages should have been processed by the pipeline setup by {@link TransportServer}. + */ +public class TransportRequestHandler extends MessageHandler<RequestMessage> { + private final Logger logger = LoggerFactory.getLogger(TransportRequestHandler.class); + + /** The Netty channel that this handler is associated with. */ + private final Channel channel; + + /** Client on the same channel allowing us to talk back to the requester. */ + private final TransportClient reverseClient; + + /** Returns each chunk part of a stream. */ + private final StreamManager streamManager; + + /** Handles all RPC messages. */ + private final RpcHandler rpcHandler; + + /** List of all stream ids that have been read on this handler, used for cleanup. */ + private final Set<Long> streamIds; + + public TransportRequestHandler( + Channel channel, + TransportClient reverseClient, + StreamManager streamManager, + RpcHandler rpcHandler) { + this.channel = channel; + this.reverseClient = reverseClient; + this.streamManager = streamManager; + this.rpcHandler = rpcHandler; + this.streamIds = Sets.newHashSet(); + } + + @Override + public void exceptionCaught(Throwable cause) { + } + + @Override + public void channelUnregistered() { + // Inform the StreamManager that these streams will no longer be read from. + for (long streamId : streamIds) { + streamManager.connectionTerminated(streamId); + } + } + + @Override + public void handle(RequestMessage request) { + if (request instanceof ChunkFetchRequest) { + processFetchRequest((ChunkFetchRequest) request); + } else if (request instanceof RpcRequest) { + processRpcRequest((RpcRequest) request); + } else { + throw new IllegalArgumentException("Unknown request type: " + request); + } + } + + private void processFetchRequest(final ChunkFetchRequest req) { + final String client = NettyUtils.getRemoteAddress(channel); + streamIds.add(req.streamChunkId.streamId); + + logger.trace("Received req from {} to fetch block {}", client, req.streamChunkId); + + ManagedBuffer buf; + try { + buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex); + } catch (Exception e) { + logger.error(String.format( + "Error opening block %s for request from %s", req.streamChunkId, client), e); + respond(new ChunkFetchFailure(req.streamChunkId, Throwables.getStackTraceAsString(e))); + return; + } + + respond(new ChunkFetchSuccess(req.streamChunkId, buf)); + } + + private void processRpcRequest(final RpcRequest req) { + try { + rpcHandler.receive(reverseClient, req.message, new RpcResponseCallback() { + @Override + public void onSuccess(byte[] response) { + respond(new RpcResponse(req.requestId, response)); + } + + @Override + public void onFailure(Throwable e) { + respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); + } + }); + } catch (Exception e) { + logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e); + respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); + } + } + + /** + * Responds to a single message with some Encodable object. If a failure occurs while sending, + * it will be logged and the channel closed. + */ + private void respond(final Encodable result) { + final String remoteAddress = channel.remoteAddress().toString(); + channel.writeAndFlush(result).addListener( + new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isSuccess()) { + logger.trace(String.format("Sent result %s to client %s", result, remoteAddress)); + } else { + logger.error(String.format("Error sending result %s to %s; closing connection", + result, remoteAddress), future.cause()); + channel.close(); + } + } + } + ); + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java ---------------------------------------------------------------------- diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java new file mode 100644 index 0000000..2430707 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import java.io.Closeable; +import java.net.InetSocketAddress; +import java.util.concurrent.TimeUnit; + +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.util.IOMode; +import org.apache.spark.network.util.NettyUtils; +import org.apache.spark.network.util.TransportConf; + +/** + * Server for the efficient, low-level streaming service. + */ +public class TransportServer implements Closeable { + private final Logger logger = LoggerFactory.getLogger(TransportServer.class); + + private final TransportContext context; + private final TransportConf conf; + + private ServerBootstrap bootstrap; + private ChannelFuture channelFuture; + private int port = -1; + + public TransportServer(TransportContext context) { + this.context = context; + this.conf = context.getConf(); + + init(); + } + + public int getPort() { + if (port == -1) { + throw new IllegalStateException("Server not initialized"); + } + return port; + } + + private void init() { + + IOMode ioMode = IOMode.valueOf(conf.ioMode()); + EventLoopGroup bossGroup = + NettyUtils.createEventLoop(ioMode, conf.serverThreads(), "shuffle-server"); + EventLoopGroup workerGroup = bossGroup; + + bootstrap = new ServerBootstrap() + .group(bossGroup, workerGroup) + .channel(NettyUtils.getServerChannelClass(ioMode)) + .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) + .childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT); + + if (conf.backLog() > 0) { + bootstrap.option(ChannelOption.SO_BACKLOG, conf.backLog()); + } + + if (conf.receiveBuf() > 0) { + bootstrap.childOption(ChannelOption.SO_RCVBUF, conf.receiveBuf()); + } + + if (conf.sendBuf() > 0) { + bootstrap.childOption(ChannelOption.SO_SNDBUF, conf.sendBuf()); + } + + bootstrap.childHandler(new ChannelInitializer<SocketChannel>() { + @Override + protected void initChannel(SocketChannel ch) throws Exception { + context.initializePipeline(ch); + } + }); + + channelFuture = bootstrap.bind(new InetSocketAddress(conf.serverPort())); + channelFuture.syncUninterruptibly(); + + port = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort(); + logger.debug("Shuffle server started on port :" + port); + } + + @Override + public void close() { + if (channelFuture != null) { + // close is a local operation and should finish with milliseconds; timeout just to be safe + channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS); + channelFuture = null; + } + if (bootstrap != null && bootstrap.group() != null) { + bootstrap.group().shutdownGracefully(); + } + if (bootstrap != null && bootstrap.childGroup() != null) { + bootstrap.childGroup().shutdownGracefully(); + } + bootstrap = null; + } + +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java ---------------------------------------------------------------------- diff --git a/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java b/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java new file mode 100644 index 0000000..d944d9d --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.util.NoSuchElementException; + +/** + * Provides a mechanism for constructing a {@link TransportConf} using some sort of configuration. + */ +public abstract class ConfigProvider { + /** Obtains the value of the given config, throws NoSuchElementException if it doesn't exist. */ + public abstract String get(String name); + + public String get(String name, String defaultValue) { + try { + return get(name); + } catch (NoSuchElementException e) { + return defaultValue; + } + } + + public int getInt(String name, int defaultValue) { + return Integer.parseInt(get(name, Integer.toString(defaultValue))); + } + + public long getLong(String name, long defaultValue) { + return Long.parseLong(get(name, Long.toString(defaultValue))); + } + + public double getDouble(String name, double defaultValue) { + return Double.parseDouble(get(name, Double.toString(defaultValue))); + } + + public boolean getBoolean(String name, boolean defaultValue) { + return Boolean.parseBoolean(get(name, Boolean.toString(defaultValue))); + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/main/java/org/apache/spark/network/util/IOMode.java ---------------------------------------------------------------------- diff --git a/network/common/src/main/java/org/apache/spark/network/util/IOMode.java b/network/common/src/main/java/org/apache/spark/network/util/IOMode.java new file mode 100644 index 0000000..6b208d9 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/IOMode.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +/** + * Selector for which form of low-level IO we should use. + * NIO is always available, while EPOLL is only available on Linux. + * AUTO is used to select EPOLL if it's available, or NIO otherwise. + */ +public enum IOMode { + NIO, EPOLL +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java ---------------------------------------------------------------------- diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java new file mode 100644 index 0000000..32ba3f5 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.io.Closeable; +import java.io.IOException; + +import com.google.common.io.Closeables; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class JavaUtils { + private static final Logger logger = LoggerFactory.getLogger(JavaUtils.class); + + /** Closes the given object, ignoring IOExceptions. */ + public static void closeQuietly(Closeable closeable) { + try { + closeable.close(); + } catch (IOException e) { + logger.error("IOException should not have been thrown.", e); + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java ---------------------------------------------------------------------- diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java new file mode 100644 index 0000000..b187234 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.util.concurrent.ThreadFactory; + +import com.google.common.util.concurrent.ThreadFactoryBuilder; +import io.netty.channel.Channel; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.ServerChannel; +import io.netty.channel.epoll.Epoll; +import io.netty.channel.epoll.EpollEventLoopGroup; +import io.netty.channel.epoll.EpollServerSocketChannel; +import io.netty.channel.epoll.EpollSocketChannel; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.LengthFieldBasedFrameDecoder; + +/** + * Utilities for creating various Netty constructs based on whether we're using EPOLL or NIO. + */ +public class NettyUtils { + /** Creates a Netty EventLoopGroup based on the IOMode. */ + public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) { + + ThreadFactory threadFactory = new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat(threadPrefix + "-%d") + .build(); + + switch (mode) { + case NIO: + return new NioEventLoopGroup(numThreads, threadFactory); + case EPOLL: + return new EpollEventLoopGroup(numThreads, threadFactory); + default: + throw new IllegalArgumentException("Unknown io mode: " + mode); + } + } + + /** Returns the correct (client) SocketChannel class based on IOMode. */ + public static Class<? extends Channel> getClientChannelClass(IOMode mode) { + switch (mode) { + case NIO: + return NioSocketChannel.class; + case EPOLL: + return EpollSocketChannel.class; + default: + throw new IllegalArgumentException("Unknown io mode: " + mode); + } + } + + /** Returns the correct ServerSocketChannel class based on IOMode. */ + public static Class<? extends ServerChannel> getServerChannelClass(IOMode mode) { + switch (mode) { + case NIO: + return NioServerSocketChannel.class; + case EPOLL: + return EpollServerSocketChannel.class; + default: + throw new IllegalArgumentException("Unknown io mode: " + mode); + } + } + + /** + * Creates a LengthFieldBasedFrameDecoder where the first 8 bytes are the length of the frame. + * This is used before all decoders. + */ + public static ByteToMessageDecoder createFrameDecoder() { + // maxFrameLength = 2G + // lengthFieldOffset = 0 + // lengthFieldLength = 8 + // lengthAdjustment = -8, i.e. exclude the 8 byte length itself + // initialBytesToStrip = 8, i.e. strip out the length field itself + return new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 8, -8, 8); + } + + /** Returns the remote address on the channel or "<remote address>" if none exists. */ + public static String getRemoteAddress(Channel channel) { + if (channel != null && channel.remoteAddress() != null) { + return channel.remoteAddress().toString(); + } + return "<unknown remote>"; + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java ---------------------------------------------------------------------- diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java new file mode 100644 index 0000000..80f65d9 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +/** + * A central location that tracks all the settings we expose to users. + */ +public class TransportConf { + private final ConfigProvider conf; + + public TransportConf(ConfigProvider conf) { + this.conf = conf; + } + + /** Port the server listens on. Default to a random port. */ + public int serverPort() { return conf.getInt("spark.shuffle.io.port", 0); } + + /** IO mode: nio or epoll */ + public String ioMode() { return conf.get("spark.shuffle.io.mode", "NIO").toUpperCase(); } + + /** Connect timeout in secs. Default 120 secs. */ + public int connectionTimeoutMs() { + return conf.getInt("spark.shuffle.io.connectionTimeout", 120) * 1000; + } + + /** Requested maximum length of the queue of incoming connections. Default -1 for no backlog. */ + public int backLog() { return conf.getInt("spark.shuffle.io.backLog", -1); } + + /** Number of threads used in the server thread pool. Default to 0, which is 2x#cores. */ + public int serverThreads() { return conf.getInt("spark.shuffle.io.serverThreads", 0); } + + /** Number of threads used in the client thread pool. Default to 0, which is 2x#cores. */ + public int clientThreads() { return conf.getInt("spark.shuffle.io.clientThreads", 0); } + + /** + * Receive buffer size (SO_RCVBUF). + * Note: the optimal size for receive buffer and send buffer should be + * latency * network_bandwidth. + * Assuming latency = 1ms, network_bandwidth = 10Gbps + * buffer size should be ~ 1.25MB + */ + public int receiveBuf() { return conf.getInt("spark.shuffle.io.receiveBuffer", -1); } + + /** Send buffer size (SO_SNDBUF). */ + public int sendBuf() { return conf.getInt("spark.shuffle.io.sendBuffer", -1); } +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java ---------------------------------------------------------------------- diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java new file mode 100644 index 0000000..738dca9 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.io.File; +import java.io.RandomAccessFile; +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; + +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.buffer.FileSegmentManagedBuffer; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.util.TransportConf; + +public class ChunkFetchIntegrationSuite { + static final long STREAM_ID = 1; + static final int BUFFER_CHUNK_INDEX = 0; + static final int FILE_CHUNK_INDEX = 1; + + static TransportServer server; + static TransportClientFactory clientFactory; + static StreamManager streamManager; + static File testFile; + + static ManagedBuffer bufferChunk; + static ManagedBuffer fileChunk; + + @BeforeClass + public static void setUp() throws Exception { + int bufSize = 100000; + final ByteBuffer buf = ByteBuffer.allocate(bufSize); + for (int i = 0; i < bufSize; i ++) { + buf.put((byte) i); + } + buf.flip(); + bufferChunk = new NioManagedBuffer(buf); + + testFile = File.createTempFile("shuffle-test-file", "txt"); + testFile.deleteOnExit(); + RandomAccessFile fp = new RandomAccessFile(testFile, "rw"); + byte[] fileContent = new byte[1024]; + new Random().nextBytes(fileContent); + fp.write(fileContent); + fp.close(); + fileChunk = new FileSegmentManagedBuffer(testFile, 10, testFile.length() - 25); + + TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + streamManager = new StreamManager() { + @Override + public ManagedBuffer getChunk(long streamId, int chunkIndex) { + assertEquals(STREAM_ID, streamId); + if (chunkIndex == BUFFER_CHUNK_INDEX) { + return new NioManagedBuffer(buf); + } else if (chunkIndex == FILE_CHUNK_INDEX) { + return new FileSegmentManagedBuffer(testFile, 10, testFile.length() - 25); + } else { + throw new IllegalArgumentException("Invalid chunk index: " + chunkIndex); + } + } + }; + TransportContext context = new TransportContext(conf, streamManager, new NoOpRpcHandler()); + server = context.createServer(); + clientFactory = context.createClientFactory(); + } + + @AfterClass + public static void tearDown() { + server.close(); + clientFactory.close(); + testFile.delete(); + } + + class FetchResult { + public Set<Integer> successChunks; + public Set<Integer> failedChunks; + public List<ManagedBuffer> buffers; + + public void releaseBuffers() { + for (ManagedBuffer buffer : buffers) { + buffer.release(); + } + } + } + + private FetchResult fetchChunks(List<Integer> chunkIndices) throws Exception { + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + final Semaphore sem = new Semaphore(0); + + final FetchResult res = new FetchResult(); + res.successChunks = Collections.synchronizedSet(new HashSet<Integer>()); + res.failedChunks = Collections.synchronizedSet(new HashSet<Integer>()); + res.buffers = Collections.synchronizedList(new LinkedList<ManagedBuffer>()); + + ChunkReceivedCallback callback = new ChunkReceivedCallback() { + @Override + public void onSuccess(int chunkIndex, ManagedBuffer buffer) { + buffer.retain(); + res.successChunks.add(chunkIndex); + res.buffers.add(buffer); + sem.release(); + } + + @Override + public void onFailure(int chunkIndex, Throwable e) { + res.failedChunks.add(chunkIndex); + sem.release(); + } + }; + + for (int chunkIndex : chunkIndices) { + client.fetchChunk(STREAM_ID, chunkIndex, callback); + } + if (!sem.tryAcquire(chunkIndices.size(), 5, TimeUnit.SECONDS)) { + fail("Timeout getting response from the server"); + } + client.close(); + return res; + } + + @Test + public void fetchBufferChunk() throws Exception { + FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX)); + assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX)); + assertTrue(res.failedChunks.isEmpty()); + assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk)); + res.releaseBuffers(); + } + + @Test + public void fetchFileChunk() throws Exception { + FetchResult res = fetchChunks(Lists.newArrayList(FILE_CHUNK_INDEX)); + assertEquals(res.successChunks, Sets.newHashSet(FILE_CHUNK_INDEX)); + assertTrue(res.failedChunks.isEmpty()); + assertBufferListsEqual(res.buffers, Lists.newArrayList(fileChunk)); + res.releaseBuffers(); + } + + @Test + public void fetchNonExistentChunk() throws Exception { + FetchResult res = fetchChunks(Lists.newArrayList(12345)); + assertTrue(res.successChunks.isEmpty()); + assertEquals(res.failedChunks, Sets.newHashSet(12345)); + assertTrue(res.buffers.isEmpty()); + } + + @Test + public void fetchBothChunks() throws Exception { + FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX)); + assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX)); + assertTrue(res.failedChunks.isEmpty()); + assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk, fileChunk)); + res.releaseBuffers(); + } + + @Test + public void fetchChunkAndNonExistent() throws Exception { + FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX, 12345)); + assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX)); + assertEquals(res.failedChunks, Sets.newHashSet(12345)); + assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk)); + res.releaseBuffers(); + } + + private void assertBufferListsEqual(List<ManagedBuffer> list0, List<ManagedBuffer> list1) + throws Exception { + assertEquals(list0.size(), list1.size()); + for (int i = 0; i < list0.size(); i ++) { + assertBuffersEqual(list0.get(i), list1.get(i)); + } + } + + private void assertBuffersEqual(ManagedBuffer buffer0, ManagedBuffer buffer1) throws Exception { + ByteBuffer nio0 = buffer0.nioByteBuffer(); + ByteBuffer nio1 = buffer1.nioByteBuffer(); + + int len = nio0.remaining(); + assertEquals(nio0.remaining(), nio1.remaining()); + for (int i = 0; i < len; i ++) { + assertEquals(nio0.get(), nio1.get()); + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java ---------------------------------------------------------------------- diff --git a/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java b/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java new file mode 100644 index 0000000..7aa37ef --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java @@ -0,0 +1,28 @@ +package org.apache.spark.network;/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.server.RpcHandler; + +/** Test RpcHandler which always returns a zero-sized success. */ +public class NoOpRpcHandler implements RpcHandler { + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + callback.onSuccess(new byte[0]); + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java ---------------------------------------------------------------------- diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java new file mode 100644 index 0000000..43dc0cf --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +import org.apache.spark.network.protocol.Message; +import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.protocol.ChunkFetchRequest; +import org.apache.spark.network.protocol.ChunkFetchFailure; +import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.RpcRequest; +import org.apache.spark.network.protocol.RpcFailure; +import org.apache.spark.network.protocol.RpcResponse; +import org.apache.spark.network.protocol.MessageDecoder; +import org.apache.spark.network.protocol.MessageEncoder; +import org.apache.spark.network.util.NettyUtils; + +public class ProtocolSuite { + private void testServerToClient(Message msg) { + EmbeddedChannel serverChannel = new EmbeddedChannel(new MessageEncoder()); + serverChannel.writeOutbound(msg); + + EmbeddedChannel clientChannel = new EmbeddedChannel( + NettyUtils.createFrameDecoder(), new MessageDecoder()); + + while (!serverChannel.outboundMessages().isEmpty()) { + clientChannel.writeInbound(serverChannel.readOutbound()); + } + + assertEquals(1, clientChannel.inboundMessages().size()); + assertEquals(msg, clientChannel.readInbound()); + } + + private void testClientToServer(Message msg) { + EmbeddedChannel clientChannel = new EmbeddedChannel(new MessageEncoder()); + clientChannel.writeOutbound(msg); + + EmbeddedChannel serverChannel = new EmbeddedChannel( + NettyUtils.createFrameDecoder(), new MessageDecoder()); + + while (!clientChannel.outboundMessages().isEmpty()) { + serverChannel.writeInbound(clientChannel.readOutbound()); + } + + assertEquals(1, serverChannel.inboundMessages().size()); + assertEquals(msg, serverChannel.readInbound()); + } + + @Test + public void requests() { + testClientToServer(new ChunkFetchRequest(new StreamChunkId(1, 2))); + testClientToServer(new RpcRequest(12345, new byte[0])); + testClientToServer(new RpcRequest(12345, new byte[100])); + } + + @Test + public void responses() { + testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(10))); + testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(0))); + testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "this is an error")); + testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "")); + testServerToClient(new RpcResponse(12345, new byte[0])); + testServerToClient(new RpcResponse(12345, new byte[1000])); + testServerToClient(new RpcFailure(0, "this is an error")); + testServerToClient(new RpcFailure(0, "")); + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java ---------------------------------------------------------------------- diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java new file mode 100644 index 0000000..9f216dd --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.util.Collections; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Set; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; + +import com.google.common.base.Charsets; +import com.google.common.collect.Sets; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.server.DefaultStreamManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.TransportConf; + +public class RpcIntegrationSuite { + static TransportServer server; + static TransportClientFactory clientFactory; + static RpcHandler rpcHandler; + + @BeforeClass + public static void setUp() throws Exception { + TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + rpcHandler = new RpcHandler() { + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + String msg = new String(message, Charsets.UTF_8); + String[] parts = msg.split("/"); + if (parts[0].equals("hello")) { + callback.onSuccess(("Hello, " + parts[1] + "!").getBytes(Charsets.UTF_8)); + } else if (parts[0].equals("return error")) { + callback.onFailure(new RuntimeException("Returned: " + parts[1])); + } else if (parts[0].equals("throw error")) { + throw new RuntimeException("Thrown: " + parts[1]); + } + } + }; + TransportContext context = new TransportContext(conf, new DefaultStreamManager(), rpcHandler); + server = context.createServer(); + clientFactory = context.createClientFactory(); + } + + @AfterClass + public static void tearDown() { + server.close(); + clientFactory.close(); + } + + class RpcResult { + public Set<String> successMessages; + public Set<String> errorMessages; + } + + private RpcResult sendRPC(String ... commands) throws Exception { + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + final Semaphore sem = new Semaphore(0); + + final RpcResult res = new RpcResult(); + res.successMessages = Collections.synchronizedSet(new HashSet<String>()); + res.errorMessages = Collections.synchronizedSet(new HashSet<String>()); + + RpcResponseCallback callback = new RpcResponseCallback() { + @Override + public void onSuccess(byte[] message) { + res.successMessages.add(new String(message, Charsets.UTF_8)); + sem.release(); + } + + @Override + public void onFailure(Throwable e) { + res.errorMessages.add(e.getMessage()); + sem.release(); + } + }; + + for (String command : commands) { + client.sendRpc(command.getBytes(Charsets.UTF_8), callback); + } + + if (!sem.tryAcquire(commands.length, 5, TimeUnit.SECONDS)) { + fail("Timeout getting response from the server"); + } + client.close(); + return res; + } + + @Test + public void singleRPC() throws Exception { + RpcResult res = sendRPC("hello/Aaron"); + assertEquals(res.successMessages, Sets.newHashSet("Hello, Aaron!")); + assertTrue(res.errorMessages.isEmpty()); + } + + @Test + public void doubleRPC() throws Exception { + RpcResult res = sendRPC("hello/Aaron", "hello/Reynold"); + assertEquals(res.successMessages, Sets.newHashSet("Hello, Aaron!", "Hello, Reynold!")); + assertTrue(res.errorMessages.isEmpty()); + } + + @Test + public void returnErrorRPC() throws Exception { + RpcResult res = sendRPC("return error/OK"); + assertTrue(res.successMessages.isEmpty()); + assertErrorsContain(res.errorMessages, Sets.newHashSet("Returned: OK")); + } + + @Test + public void throwErrorRPC() throws Exception { + RpcResult res = sendRPC("throw error/uh-oh"); + assertTrue(res.successMessages.isEmpty()); + assertErrorsContain(res.errorMessages, Sets.newHashSet("Thrown: uh-oh")); + } + + @Test + public void doubleTrouble() throws Exception { + RpcResult res = sendRPC("return error/OK", "throw error/uh-oh"); + assertTrue(res.successMessages.isEmpty()); + assertErrorsContain(res.errorMessages, Sets.newHashSet("Returned: OK", "Thrown: uh-oh")); + } + + @Test + public void sendSuccessAndFailure() throws Exception { + RpcResult res = sendRPC("hello/Bob", "throw error/the", "hello/Builder", "return error/!"); + assertEquals(res.successMessages, Sets.newHashSet("Hello, Bob!", "Hello, Builder!")); + assertErrorsContain(res.errorMessages, Sets.newHashSet("Thrown: the", "Returned: !")); + } + + private void assertErrorsContain(Set<String> errors, Set<String> contains) { + assertEquals(contains.size(), errors.size()); + + Set<String> remainingErrors = Sets.newHashSet(errors); + for (String contain : contains) { + Iterator<String> it = remainingErrors.iterator(); + boolean foundMatch = false; + while (it.hasNext()) { + if (it.next().contains(contain)) { + it.remove(); + foundMatch = true; + break; + } + } + assertTrue("Could not find error containing " + contain + "; errors: " + errors, foundMatch); + } + + assertTrue(remainingErrors.isEmpty()); + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java ---------------------------------------------------------------------- diff --git a/network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java b/network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java new file mode 100644 index 0000000..f4e0a24 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.util.NoSuchElementException; + +import org.apache.spark.network.util.ConfigProvider; + +/** Uses System properties to obtain config values. */ +public class SystemPropertyConfigProvider extends ConfigProvider { + @Override + public String get(String name) { + String value = System.getProperty(name); + if (value == null) { + throw new NoSuchElementException(name); + } + return value; + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java ---------------------------------------------------------------------- diff --git a/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java b/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java new file mode 100644 index 0000000..38113a9 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; + +import com.google.common.base.Preconditions; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; + +/** + * A ManagedBuffer implementation that contains 0, 1, 2, 3, ..., (len-1). + * + * Used for testing. + */ +public class TestManagedBuffer extends ManagedBuffer { + + private final int len; + private NettyManagedBuffer underlying; + + public TestManagedBuffer(int len) { + Preconditions.checkArgument(len <= Byte.MAX_VALUE); + this.len = len; + byte[] byteArray = new byte[len]; + for (int i = 0; i < len; i ++) { + byteArray[i] = (byte) i; + } + this.underlying = new NettyManagedBuffer(Unpooled.wrappedBuffer(byteArray)); + } + + + @Override + public long size() { + return underlying.size(); + } + + @Override + public ByteBuffer nioByteBuffer() throws IOException { + return underlying.nioByteBuffer(); + } + + @Override + public InputStream createInputStream() throws IOException { + return underlying.createInputStream(); + } + + @Override + public ManagedBuffer retain() { + underlying.retain(); + return this; + } + + @Override + public ManagedBuffer release() { + underlying.release(); + return this; + } + + @Override + public Object convertToNetty() throws IOException { + return underlying.convertToNetty(); + } + + @Override + public boolean equals(Object other) { + if (other instanceof ManagedBuffer) { + try { + ByteBuffer nioBuf = ((ManagedBuffer) other).nioByteBuffer(); + if (nioBuf.remaining() != len) { + return false; + } else { + for (int i = 0; i < len; i ++) { + if (nioBuf.get() != i) { + return false; + } + } + return true; + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + return false; + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/test/java/org/apache/spark/network/TestUtils.java ---------------------------------------------------------------------- diff --git a/network/common/src/test/java/org/apache/spark/network/TestUtils.java b/network/common/src/test/java/org/apache/spark/network/TestUtils.java new file mode 100644 index 0000000..56a2b80 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/TestUtils.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.net.InetAddress; + +public class TestUtils { + public static String getLocalHost() { + try { + return InetAddress.getLocalHost().getHostAddress(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java ---------------------------------------------------------------------- diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java new file mode 100644 index 0000000..3ef9646 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.util.concurrent.TimeoutException; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.server.DefaultStreamManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.TransportConf; + +public class TransportClientFactorySuite { + private TransportConf conf; + private TransportContext context; + private TransportServer server1; + private TransportServer server2; + + @Before + public void setUp() { + conf = new TransportConf(new SystemPropertyConfigProvider()); + StreamManager streamManager = new DefaultStreamManager(); + RpcHandler rpcHandler = new NoOpRpcHandler(); + context = new TransportContext(conf, streamManager, rpcHandler); + server1 = context.createServer(); + server2 = context.createServer(); + } + + @After + public void tearDown() { + JavaUtils.closeQuietly(server1); + JavaUtils.closeQuietly(server2); + } + + @Test + public void createAndReuseBlockClients() throws TimeoutException { + TransportClientFactory factory = context.createClientFactory(); + TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + TransportClient c3 = factory.createClient(TestUtils.getLocalHost(), server2.getPort()); + assertTrue(c1.isActive()); + assertTrue(c3.isActive()); + assertTrue(c1 == c2); + assertTrue(c1 != c3); + factory.close(); + } + + @Test + public void neverReturnInactiveClients() throws Exception { + TransportClientFactory factory = context.createClientFactory(); + TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + c1.close(); + + long start = System.currentTimeMillis(); + while (c1.isActive() && (System.currentTimeMillis() - start) < 3000) { + Thread.sleep(10); + } + assertFalse(c1.isActive()); + + TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + assertFalse(c1 == c2); + assertTrue(c2.isActive()); + factory.close(); + } + + @Test + public void closeBlockClientsWithFactory() throws TimeoutException { + TransportClientFactory factory = context.createClientFactory(); + TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort()); + assertTrue(c1.isActive()); + assertTrue(c2.isActive()); + factory.close(); + assertFalse(c1.isActive()); + assertFalse(c2.isActive()); + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java ---------------------------------------------------------------------- diff --git a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java new file mode 100644 index 0000000..17a03eb --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import io.netty.channel.local.LocalChannel; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.*; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportResponseHandler; +import org.apache.spark.network.protocol.ChunkFetchFailure; +import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.RpcFailure; +import org.apache.spark.network.protocol.RpcResponse; +import org.apache.spark.network.protocol.StreamChunkId; + +public class TransportResponseHandlerSuite { + @Test + public void handleSuccessfulFetch() { + StreamChunkId streamChunkId = new StreamChunkId(1, 0); + + TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); + ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); + handler.addFetchRequest(streamChunkId, callback); + assertEquals(1, handler.numOutstandingRequests()); + + handler.handle(new ChunkFetchSuccess(streamChunkId, new TestManagedBuffer(123))); + verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any()); + assertEquals(0, handler.numOutstandingRequests()); + } + + @Test + public void handleFailedFetch() { + StreamChunkId streamChunkId = new StreamChunkId(1, 0); + TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); + ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); + handler.addFetchRequest(streamChunkId, callback); + assertEquals(1, handler.numOutstandingRequests()); + + handler.handle(new ChunkFetchFailure(streamChunkId, "some error msg")); + verify(callback, times(1)).onFailure(eq(0), (Throwable) any()); + assertEquals(0, handler.numOutstandingRequests()); + } + + @Test + public void clearAllOutstandingRequests() { + TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); + ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); + handler.addFetchRequest(new StreamChunkId(1, 0), callback); + handler.addFetchRequest(new StreamChunkId(1, 1), callback); + handler.addFetchRequest(new StreamChunkId(1, 2), callback); + assertEquals(3, handler.numOutstandingRequests()); + + handler.handle(new ChunkFetchSuccess(new StreamChunkId(1, 0), new TestManagedBuffer(12))); + handler.exceptionCaught(new Exception("duh duh duhhhh")); + + // should fail both b2 and b3 + verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any()); + verify(callback, times(1)).onFailure(eq(1), (Throwable) any()); + verify(callback, times(1)).onFailure(eq(2), (Throwable) any()); + assertEquals(0, handler.numOutstandingRequests()); + } + + @Test + public void handleSuccessfulRPC() { + TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); + RpcResponseCallback callback = mock(RpcResponseCallback.class); + handler.addRpcRequest(12345, callback); + assertEquals(1, handler.numOutstandingRequests()); + + handler.handle(new RpcResponse(54321, new byte[7])); // should be ignored + assertEquals(1, handler.numOutstandingRequests()); + + byte[] arr = new byte[10]; + handler.handle(new RpcResponse(12345, arr)); + verify(callback, times(1)).onSuccess(eq(arr)); + assertEquals(0, handler.numOutstandingRequests()); + } + + @Test + public void handleFailedRPC() { + TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); + RpcResponseCallback callback = mock(RpcResponseCallback.class); + handler.addRpcRequest(12345, callback); + assertEquals(1, handler.numOutstandingRequests()); + + handler.handle(new RpcFailure(54321, "uh-oh!")); // should be ignored + assertEquals(1, handler.numOutstandingRequests()); + + handler.handle(new RpcFailure(12345, "oh no")); + verify(callback, times(1)).onFailure((Throwable) any()); + assertEquals(0, handler.numOutstandingRequests()); + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/pom.xml ---------------------------------------------------------------------- diff --git a/pom.xml b/pom.xml index abcb971..e4c9247 100644 --- a/pom.xml +++ b/pom.xml @@ -91,6 +91,7 @@ <module>graphx</module> <module>mllib</module> <module>tools</module> + <module>network/common</module> <module>streaming</module> <module>sql/catalyst</module> <module>sql/core</module> http://git-wip-us.apache.org/repos/asf/spark/blob/dff01553/project/MimaExcludes.scala ---------------------------------------------------------------------- diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 95152b5..adbdc5d 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -51,6 +51,11 @@ object MimaExcludes { // MapStatus should be private[spark] ProblemFilters.exclude[IncompatibleTemplateDefProblem]( "org.apache.spark.scheduler.MapStatus"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.network.netty.PathResolver"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.network.netty.client.BlockClientListener"), + // TaskContext was promoted to Abstract class ProblemFilters.exclude[AbstractClassProblem]( "org.apache.spark.TaskContext"), --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
