http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala deleted file mode 100644 index 4894ecd..0000000 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network - -import java.nio.ByteBuffer - -import scala.concurrent.Await -import scala.concurrent.duration._ -import scala.io.Source - -import org.apache.spark._ - -private[spark] object ConnectionManagerTest extends Logging{ - def main(args: Array[String]) { - // <mesos cluster> - the master URL <slaves file> - a list slaves to run connectionTest on - // [num of tasks] - the number of parallel tasks to be initiated default is number of slave - // hosts [size of msg in MB (integer)] - the size of messages to be sent in each task, - // default is 10 [count] - how many times to run, default is 3 [await time in seconds] : - // await time (in seconds), default is 600 - if (args.length < 2) { - println("Usage: ConnectionManagerTest <mesos cluster> <slaves file> [num of tasks] " + - "[size of msg in MB (integer)] [count] [await time in seconds)] ") - System.exit(1) - } - - if (args(0).startsWith("local")) { - println("This runs only on a mesos cluster") - } - - val sc = new SparkContext(args(0), "ConnectionManagerTest") - val slavesFile = Source.fromFile(args(1)) - val slaves = slavesFile.mkString.split("\n") - slavesFile.close() - - /* println("Slaves") */ - /* slaves.foreach(println) */ - val tasknum = if (args.length > 2) args(2).toInt else slaves.length - val size = ( if (args.length > 3) (args(3).toInt) else 10 ) * 1024 * 1024 - val count = if (args.length > 4) args(4).toInt else 3 - val awaitTime = (if (args.length > 5) args(5).toInt else 600 ).second - println("Running " + count + " rounds of test: " + "parallel tasks = " + tasknum + ", " + - "msg size = " + size/1024/1024 + " MB, awaitTime = " + awaitTime) - val slaveConnManagerIds = sc.parallelize(0 until tasknum, tasknum).map( - i => SparkEnv.get.connectionManager.id).collect() - println("\nSlave ConnectionManagerIds") - slaveConnManagerIds.foreach(println) - println - - (0 until count).foreach(i => { - val resultStrs = sc.parallelize(0 until tasknum, tasknum).map(i => { - val connManager = SparkEnv.get.connectionManager - val thisConnManagerId = connManager.id - connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - logInfo("Received [" + msg + "] from [" + id + "]") - None - }) - - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val startTime = System.currentTimeMillis - val futures = slaveConnManagerIds.filter(_ != thisConnManagerId).map{ slaveConnManagerId => - { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]") - connManager.sendMessageReliably(slaveConnManagerId, bufferMessage) - } - } - val results = futures.map(f => Await.result(f, awaitTime)) - val finishTime = System.currentTimeMillis - Thread.sleep(5000) - - val mb = size * results.size / 1024.0 / 1024.0 - val ms = finishTime - startTime - val resultStr = thisConnManagerId + " Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * - 1000.0) + " MB/s" - logInfo(resultStr) - resultStr - }).collect() - - println("---------------------") - println("Run " + i) - resultStrs.foreach(println) - println("---------------------") - }) - } -} -
http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala new file mode 100644 index 0000000..dcecb6b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network + +import java.io.{FileInputStream, RandomAccessFile, File, InputStream} +import java.nio.ByteBuffer +import java.nio.channels.FileChannel.MapMode + +import com.google.common.io.ByteStreams +import io.netty.buffer.{ByteBufInputStream, ByteBuf} + +import org.apache.spark.util.ByteBufferInputStream + + +/** + * This interface provides an immutable view for data in the form of bytes. The implementation + * should specify how the data is provided: + * + * - FileSegmentManagedBuffer: data backed by part of a file + * - NioByteBufferManagedBuffer: data backed by a NIO ByteBuffer + * - NettyByteBufManagedBuffer: data backed by a Netty ByteBuf + */ +sealed abstract class ManagedBuffer { + // Note that all the methods are defined with parenthesis because their implementations can + // have side effects (io operations). + + /** Number of bytes of the data. */ + def size: Long + + /** + * Exposes this buffer's data as an NIO ByteBuffer. Changing the position and limit of the + * returned ByteBuffer should not affect the content of this buffer. + */ + def nioByteBuffer(): ByteBuffer + + /** + * Exposes this buffer's data as an InputStream. The underlying implementation does not + * necessarily check for the length of bytes read, so the caller is responsible for making sure + * it does not go over the limit. + */ + def inputStream(): InputStream +} + + +/** + * A [[ManagedBuffer]] backed by a segment in a file + */ +final class FileSegmentManagedBuffer(val file: File, val offset: Long, val length: Long) + extends ManagedBuffer { + + override def size: Long = length + + override def nioByteBuffer(): ByteBuffer = { + val channel = new RandomAccessFile(file, "r").getChannel + channel.map(MapMode.READ_ONLY, offset, length) + } + + override def inputStream(): InputStream = { + val is = new FileInputStream(file) + is.skip(offset) + ByteStreams.limit(is, length) + } +} + + +/** + * A [[ManagedBuffer]] backed by [[java.nio.ByteBuffer]]. + */ +final class NioByteBufferManagedBuffer(buf: ByteBuffer) extends ManagedBuffer { + + override def size: Long = buf.remaining() + + override def nioByteBuffer() = buf.duplicate() + + override def inputStream() = new ByteBufferInputStream(buf) +} + + +/** + * A [[ManagedBuffer]] backed by a Netty [[ByteBuf]]. + */ +final class NettyByteBufManagedBuffer(buf: ByteBuf) extends ManagedBuffer { + + override def size: Long = buf.readableBytes() + + override def nioByteBuffer() = buf.nioBuffer() + + override def inputStream() = new ByteBufInputStream(buf) + + // TODO(rxin): Promote this to top level ManagedBuffer interface and add documentation for it. + def release(): Unit = buf.release() +} http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/Message.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/network/Message.scala b/core/src/main/scala/org/apache/spark/network/Message.scala deleted file mode 100644 index 04ea50f..0000000 --- a/core/src/main/scala/org/apache/spark/network/Message.scala +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network - -import java.net.InetSocketAddress -import java.nio.ByteBuffer - -import scala.collection.mutable.ArrayBuffer - -private[spark] abstract class Message(val typ: Long, val id: Int) { - var senderAddress: InetSocketAddress = null - var started = false - var startTime = -1L - var finishTime = -1L - var isSecurityNeg = false - var hasError = false - - def size: Int - - def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] - - def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] - - def timeTaken(): String = (finishTime - startTime).toString + " ms" - - override def toString = this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")" -} - - -private[spark] object Message { - val BUFFER_MESSAGE = 1111111111L - - var lastId = 1 - - def getNewId() = synchronized { - lastId += 1 - if (lastId == 0) { - lastId += 1 - } - lastId - } - - def createBufferMessage(dataBuffers: Seq[ByteBuffer], ackId: Int): BufferMessage = { - if (dataBuffers == null) { - return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer], ackId) - } - if (dataBuffers.exists(_ == null)) { - throw new Exception("Attempting to create buffer message with null buffer") - } - new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer] ++= dataBuffers, ackId) - } - - def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage = - createBufferMessage(dataBuffers, 0) - - def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = { - if (dataBuffer == null) { - createBufferMessage(Array(ByteBuffer.allocate(0)), ackId) - } else { - createBufferMessage(Array(dataBuffer), ackId) - } - } - - def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage = - createBufferMessage(dataBuffer, 0) - - def createBufferMessage(ackId: Int): BufferMessage = { - createBufferMessage(new Array[ByteBuffer](0), ackId) - } - - def create(header: MessageChunkHeader): Message = { - val newMessage: Message = header.typ match { - case BUFFER_MESSAGE => new BufferMessage(header.id, - ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other) - } - newMessage.hasError = header.hasError - newMessage.senderAddress = header.address - newMessage - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/MessageChunk.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/network/MessageChunk.scala b/core/src/main/scala/org/apache/spark/network/MessageChunk.scala deleted file mode 100644 index d0f986a..0000000 --- a/core/src/main/scala/org/apache/spark/network/MessageChunk.scala +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network - -import java.nio.ByteBuffer - -import scala.collection.mutable.ArrayBuffer - -private[network] -class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { - - val size = if (buffer == null) 0 else buffer.remaining - - lazy val buffers = { - val ab = new ArrayBuffer[ByteBuffer]() - ab += header.buffer - if (buffer != null) { - ab += buffer - } - ab - } - - override def toString = { - "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")" - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala deleted file mode 100644 index f3ecca5..0000000 --- a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network - -import java.net.InetAddress -import java.net.InetSocketAddress -import java.nio.ByteBuffer - -private[spark] class MessageChunkHeader( - val typ: Long, - val id: Int, - val totalSize: Int, - val chunkSize: Int, - val other: Int, - val hasError: Boolean, - val securityNeg: Int, - val address: InetSocketAddress) { - lazy val buffer = { - // No need to change this, at 'use' time, we do a reverse lookup of the hostname. - // Refer to network.Connection - val ip = address.getAddress.getAddress() - val port = address.getPort() - ByteBuffer. - allocate(MessageChunkHeader.HEADER_SIZE). - putLong(typ). - putInt(id). - putInt(totalSize). - putInt(chunkSize). - putInt(other). - put(if (hasError) 1.asInstanceOf[Byte] else 0.asInstanceOf[Byte]). - putInt(securityNeg). - putInt(ip.size). - put(ip). - putInt(port). - position(MessageChunkHeader.HEADER_SIZE). - flip.asInstanceOf[ByteBuffer] - } - - override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ + - " and sizes " + totalSize + " / " + chunkSize + " bytes, securityNeg: " + securityNeg - -} - - -private[spark] object MessageChunkHeader { - val HEADER_SIZE = 45 - - def create(buffer: ByteBuffer): MessageChunkHeader = { - if (buffer.remaining != HEADER_SIZE) { - throw new IllegalArgumentException("Cannot convert buffer data to Message") - } - val typ = buffer.getLong() - val id = buffer.getInt() - val totalSize = buffer.getInt() - val chunkSize = buffer.getInt() - val other = buffer.getInt() - val hasError = buffer.get() != 0 - val securityNeg = buffer.getInt() - val ipSize = buffer.getInt() - val ipBytes = new Array[Byte](ipSize) - buffer.get(ipBytes) - val ip = InetAddress.getByAddress(ipBytes) - val port = buffer.getInt() - new MessageChunkHeader(typ, id, totalSize, chunkSize, other, hasError, securityNeg, - new InetSocketAddress(ip, port)) - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala deleted file mode 100644 index 53a6038..0000000 --- a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network - -import java.nio.ByteBuffer -import org.apache.spark.{SecurityManager, SparkConf} - -private[spark] object ReceiverTest { - def main(args: Array[String]) { - val conf = new SparkConf - val manager = new ConnectionManager(9999, conf, new SecurityManager(conf)) - println("Started connection manager with id = " + manager.id) - - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - /* println("Received [" + msg + "] from [" + id + "] at " + System.currentTimeMillis) */ - val buffer = ByteBuffer.wrap("response".getBytes("utf-8")) - Some(Message.createBufferMessage(buffer, msg.id)) - }) - Thread.currentThread.join() - } -} - http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala b/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala deleted file mode 100644 index 9af9e2e..0000000 --- a/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala +++ /dev/null @@ -1,162 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network - -import java.nio.ByteBuffer - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.StringBuilder - -import org.apache.spark._ -import org.apache.spark.network._ - -/** - * SecurityMessage is class that contains the connectionId and sasl token - * used in SASL negotiation. SecurityMessage has routines for converting - * it to and from a BufferMessage so that it can be sent by the ConnectionManager - * and easily consumed by users when received. - * The api was modeled after BlockMessage. - * - * The connectionId is the connectionId of the client side. Since - * message passing is asynchronous and its possible for the server side (receiving) - * to get multiple different types of messages on the same connection the connectionId - * is used to know which connnection the security message is intended for. - * - * For instance, lets say we are node_0. We need to send data to node_1. The node_0 side - * is acting as a client and connecting to node_1. SASL negotiation has to occur - * between node_0 and node_1 before node_1 trusts node_0 so node_0 sends a security message. - * node_1 receives the message from node_0 but before it can process it and send a response, - * some thread on node_1 decides it needs to send data to node_0 so it connects to node_0 - * and sends a security message of its own to authenticate as a client. Now node_0 gets - * the message and it needs to decide if this message is in response to it being a client - * (from the first send) or if its just node_1 trying to connect to it to send data. This - * is where the connectionId field is used. node_0 can lookup the connectionId to see if - * it is in response to it being a client or if its in response to someone sending other data. - * - * The format of a SecurityMessage as its sent is: - * - Length of the ConnectionId - * - ConnectionId - * - Length of the token - * - Token - */ -private[spark] class SecurityMessage() extends Logging { - - private var connectionId: String = null - private var token: Array[Byte] = null - - def set(byteArr: Array[Byte], newconnectionId: String) { - if (byteArr == null) { - token = new Array[Byte](0) - } else { - token = byteArr - } - connectionId = newconnectionId - } - - /** - * Read the given buffer and set the members of this class. - */ - def set(buffer: ByteBuffer) { - val idLength = buffer.getInt() - val idBuilder = new StringBuilder(idLength) - for (i <- 1 to idLength) { - idBuilder += buffer.getChar() - } - connectionId = idBuilder.toString() - - val tokenLength = buffer.getInt() - token = new Array[Byte](tokenLength) - if (tokenLength > 0) { - buffer.get(token, 0, tokenLength) - } - } - - def set(bufferMsg: BufferMessage) { - val buffer = bufferMsg.buffers.apply(0) - buffer.clear() - set(buffer) - } - - def getConnectionId: String = { - return connectionId - } - - def getToken: Array[Byte] = { - return token - } - - /** - * Create a BufferMessage that can be sent by the ConnectionManager containing - * the security information from this class. - * @return BufferMessage - */ - def toBufferMessage: BufferMessage = { - val buffers = new ArrayBuffer[ByteBuffer]() - - // 4 bytes for the length of the connectionId - // connectionId is of type char so multiple the length by 2 to get number of bytes - // 4 bytes for the length of token - // token is a byte buffer so just take the length - var buffer = ByteBuffer.allocate(4 + connectionId.length() * 2 + 4 + token.length) - buffer.putInt(connectionId.length()) - connectionId.foreach((x: Char) => buffer.putChar(x)) - buffer.putInt(token.length) - - if (token.length > 0) { - buffer.put(token) - } - buffer.flip() - buffers += buffer - - var message = Message.createBufferMessage(buffers) - logDebug("message total size is : " + message.size) - message.isSecurityNeg = true - return message - } - - override def toString: String = { - "SecurityMessage [connId= " + connectionId + ", Token = " + token + "]" - } -} - -private[spark] object SecurityMessage { - - /** - * Convert the given BufferMessage to a SecurityMessage by parsing the contents - * of the BufferMessage and populating the SecurityMessage fields. - * @param bufferMessage is a BufferMessage that was received - * @return new SecurityMessage - */ - def fromBufferMessage(bufferMessage: BufferMessage): SecurityMessage = { - val newSecurityMessage = new SecurityMessage() - newSecurityMessage.set(bufferMessage) - newSecurityMessage - } - - /** - * Create a SecurityMessage to send from a given saslResponse. - * @param response is the response to a challenge from the SaslClient or Saslserver - * @param connectionId the client connectionId we are negotiation authentication for - * @return a new SecurityMessage - */ - def fromResponse(response : Array[Byte], connectionId : String) : SecurityMessage = { - val newSecurityMessage = new SecurityMessage() - newSecurityMessage.set(response, connectionId) - newSecurityMessage - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/SenderTest.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/network/SenderTest.scala b/core/src/main/scala/org/apache/spark/network/SenderTest.scala deleted file mode 100644 index ea2ad10..0000000 --- a/core/src/main/scala/org/apache/spark/network/SenderTest.scala +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network - -import java.nio.ByteBuffer -import org.apache.spark.{SecurityManager, SparkConf} - -import scala.concurrent.Await -import scala.concurrent.duration.Duration -import scala.util.Try - -private[spark] object SenderTest { - def main(args: Array[String]) { - - if (args.length < 2) { - println("Usage: SenderTest <target host> <target port>") - System.exit(1) - } - - val targetHost = args(0) - val targetPort = args(1).toInt - val targetConnectionManagerId = new ConnectionManagerId(targetHost, targetPort) - val conf = new SparkConf - val manager = new ConnectionManager(0, conf, new SecurityManager(conf)) - println("Started connection manager with id = " + manager.id) - - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - println("Received [" + msg + "] from [" + id + "]") - None - }) - - val size = 100 * 1024 * 1024 - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val targetServer = args(0) - - val count = 100 - (0 until count).foreach(i => { - val dataMessage = Message.createBufferMessage(buffer.duplicate) - val startTime = System.currentTimeMillis - /* println("Started timer at " + startTime) */ - val promise = manager.sendMessageReliably(targetConnectionManagerId, dataMessage) - val responseStr: String = Try(Await.result(promise, Duration.Inf)) - .map { response => - val buffer = response.asInstanceOf[BufferMessage].buffers(0) - new String(buffer.array, "utf-8") - }.getOrElse("none") - - val finishTime = System.currentTimeMillis - val mb = size / 1024.0 / 1024.0 - val ms = finishTime - startTime - // val resultStr = "Sent " + mb + " MB " + targetServer + " in " + ms + " ms at " + (mb / ms - // * 1000.0) + " MB/s" - val resultStr = "Sent " + mb + " MB " + targetServer + " in " + ms + " ms (" + - (mb / ms * 1000.0).toInt + "MB/s) | Response = " + responseStr - println(resultStr) - }) - } -} - http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala new file mode 100644 index 0000000..b573f1a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.nio + +import java.nio.ByteBuffer + +import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} + +import scala.collection.mutable.{ArrayBuffer, StringBuilder} + +// private[spark] because we need to register them in Kryo +private[spark] case class GetBlock(id: BlockId) +private[spark] case class GotBlock(id: BlockId, data: ByteBuffer) +private[spark] case class PutBlock(id: BlockId, data: ByteBuffer, level: StorageLevel) + +private[nio] class BlockMessage() { + // Un-initialized: typ = 0 + // GetBlock: typ = 1 + // GotBlock: typ = 2 + // PutBlock: typ = 3 + private var typ: Int = BlockMessage.TYPE_NON_INITIALIZED + private var id: BlockId = null + private var data: ByteBuffer = null + private var level: StorageLevel = null + + def set(getBlock: GetBlock) { + typ = BlockMessage.TYPE_GET_BLOCK + id = getBlock.id + } + + def set(gotBlock: GotBlock) { + typ = BlockMessage.TYPE_GOT_BLOCK + id = gotBlock.id + data = gotBlock.data + } + + def set(putBlock: PutBlock) { + typ = BlockMessage.TYPE_PUT_BLOCK + id = putBlock.id + data = putBlock.data + level = putBlock.level + } + + def set(buffer: ByteBuffer) { + /* + println() + println("BlockMessage: ") + while(buffer.remaining > 0) { + print(buffer.get()) + } + buffer.rewind() + println() + println() + */ + typ = buffer.getInt() + val idLength = buffer.getInt() + val idBuilder = new StringBuilder(idLength) + for (i <- 1 to idLength) { + idBuilder += buffer.getChar() + } + id = BlockId(idBuilder.toString) + + if (typ == BlockMessage.TYPE_PUT_BLOCK) { + + val booleanInt = buffer.getInt() + val replication = buffer.getInt() + level = StorageLevel(booleanInt, replication) + + val dataLength = buffer.getInt() + data = ByteBuffer.allocate(dataLength) + if (dataLength != buffer.remaining) { + throw new Exception("Error parsing buffer") + } + data.put(buffer) + data.flip() + } else if (typ == BlockMessage.TYPE_GOT_BLOCK) { + + val dataLength = buffer.getInt() + data = ByteBuffer.allocate(dataLength) + if (dataLength != buffer.remaining) { + throw new Exception("Error parsing buffer") + } + data.put(buffer) + data.flip() + } + + } + + def set(bufferMsg: BufferMessage) { + val buffer = bufferMsg.buffers.apply(0) + buffer.clear() + set(buffer) + } + + def getType: Int = typ + def getId: BlockId = id + def getData: ByteBuffer = data + def getLevel: StorageLevel = level + + def toBufferMessage: BufferMessage = { + val buffers = new ArrayBuffer[ByteBuffer]() + var buffer = ByteBuffer.allocate(4 + 4 + id.name.length * 2) + buffer.putInt(typ).putInt(id.name.length) + id.name.foreach((x: Char) => buffer.putChar(x)) + buffer.flip() + buffers += buffer + + if (typ == BlockMessage.TYPE_PUT_BLOCK) { + buffer = ByteBuffer.allocate(8).putInt(level.toInt).putInt(level.replication) + buffer.flip() + buffers += buffer + + buffer = ByteBuffer.allocate(4).putInt(data.remaining) + buffer.flip() + buffers += buffer + + buffers += data + } else if (typ == BlockMessage.TYPE_GOT_BLOCK) { + buffer = ByteBuffer.allocate(4).putInt(data.remaining) + buffer.flip() + buffers += buffer + + buffers += data + } + + /* + println() + println("BlockMessage: ") + buffers.foreach(b => { + while(b.remaining > 0) { + print(b.get()) + } + b.rewind() + }) + println() + println() + */ + Message.createBufferMessage(buffers) + } + + override def toString: String = { + "BlockMessage [type = " + typ + ", id = " + id + ", level = " + level + + ", data = " + (if (data != null) data.remaining.toString else "null") + "]" + } +} + +private[nio] object BlockMessage { + val TYPE_NON_INITIALIZED: Int = 0 + val TYPE_GET_BLOCK: Int = 1 + val TYPE_GOT_BLOCK: Int = 2 + val TYPE_PUT_BLOCK: Int = 3 + + def fromBufferMessage(bufferMessage: BufferMessage): BlockMessage = { + val newBlockMessage = new BlockMessage() + newBlockMessage.set(bufferMessage) + newBlockMessage + } + + def fromByteBuffer(buffer: ByteBuffer): BlockMessage = { + val newBlockMessage = new BlockMessage() + newBlockMessage.set(buffer) + newBlockMessage + } + + def fromGetBlock(getBlock: GetBlock): BlockMessage = { + val newBlockMessage = new BlockMessage() + newBlockMessage.set(getBlock) + newBlockMessage + } + + def fromGotBlock(gotBlock: GotBlock): BlockMessage = { + val newBlockMessage = new BlockMessage() + newBlockMessage.set(gotBlock) + newBlockMessage + } + + def fromPutBlock(putBlock: PutBlock): BlockMessage = { + val newBlockMessage = new BlockMessage() + newBlockMessage.set(putBlock) + newBlockMessage + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala new file mode 100644 index 0000000..a1a2c00 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.nio + +import java.nio.ByteBuffer + +import org.apache.spark._ +import org.apache.spark.storage.{StorageLevel, TestBlockId} + +import scala.collection.mutable.ArrayBuffer + +private[nio] +class BlockMessageArray(var blockMessages: Seq[BlockMessage]) + extends Seq[BlockMessage] with Logging { + + def this(bm: BlockMessage) = this(Array(bm)) + + def this() = this(null.asInstanceOf[Seq[BlockMessage]]) + + def apply(i: Int) = blockMessages(i) + + def iterator = blockMessages.iterator + + def length = blockMessages.length + + def set(bufferMessage: BufferMessage) { + val startTime = System.currentTimeMillis + val newBlockMessages = new ArrayBuffer[BlockMessage]() + val buffer = bufferMessage.buffers(0) + buffer.clear() + /* + println() + println("BlockMessageArray: ") + while(buffer.remaining > 0) { + print(buffer.get()) + } + buffer.rewind() + println() + println() + */ + while (buffer.remaining() > 0) { + val size = buffer.getInt() + logDebug("Creating block message of size " + size + " bytes") + val newBuffer = buffer.slice() + newBuffer.clear() + newBuffer.limit(size) + logDebug("Trying to convert buffer " + newBuffer + " to block message") + val newBlockMessage = BlockMessage.fromByteBuffer(newBuffer) + logDebug("Created " + newBlockMessage) + newBlockMessages += newBlockMessage + buffer.position(buffer.position() + size) + } + val finishTime = System.currentTimeMillis + logDebug("Converted block message array from buffer message in " + + (finishTime - startTime) / 1000.0 + " s") + this.blockMessages = newBlockMessages + } + + def toBufferMessage: BufferMessage = { + val buffers = new ArrayBuffer[ByteBuffer]() + + blockMessages.foreach(blockMessage => { + val bufferMessage = blockMessage.toBufferMessage + logDebug("Adding " + blockMessage) + val sizeBuffer = ByteBuffer.allocate(4).putInt(bufferMessage.size) + sizeBuffer.flip + buffers += sizeBuffer + buffers ++= bufferMessage.buffers + logDebug("Added " + bufferMessage) + }) + + logDebug("Buffer list:") + buffers.foreach((x: ByteBuffer) => logDebug("" + x)) + /* + println() + println("BlockMessageArray: ") + buffers.foreach(b => { + while(b.remaining > 0) { + print(b.get()) + } + b.rewind() + }) + println() + println() + */ + Message.createBufferMessage(buffers) + } +} + +private[nio] object BlockMessageArray { + + def fromBufferMessage(bufferMessage: BufferMessage): BlockMessageArray = { + val newBlockMessageArray = new BlockMessageArray() + newBlockMessageArray.set(bufferMessage) + newBlockMessageArray + } + + def main(args: Array[String]) { + val blockMessages = + (0 until 10).map { i => + if (i % 2 == 0) { + val buffer = ByteBuffer.allocate(100) + buffer.clear + BlockMessage.fromPutBlock(PutBlock(TestBlockId(i.toString), buffer, + StorageLevel.MEMORY_ONLY_SER)) + } else { + BlockMessage.fromGetBlock(GetBlock(TestBlockId(i.toString))) + } + } + val blockMessageArray = new BlockMessageArray(blockMessages) + println("Block message array created") + + val bufferMessage = blockMessageArray.toBufferMessage + println("Converted to buffer message") + + val totalSize = bufferMessage.size + val newBuffer = ByteBuffer.allocate(totalSize) + newBuffer.clear() + bufferMessage.buffers.foreach(buffer => { + assert (0 == buffer.position()) + newBuffer.put(buffer) + buffer.rewind() + }) + newBuffer.flip + val newBufferMessage = Message.createBufferMessage(newBuffer) + println("Copied to new buffer message, size = " + newBufferMessage.size) + + val newBlockMessageArray = BlockMessageArray.fromBufferMessage(newBufferMessage) + println("Converted back to block message array") + newBlockMessageArray.foreach(blockMessage => { + blockMessage.getType match { + case BlockMessage.TYPE_PUT_BLOCK => { + val pB = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel) + println(pB) + } + case BlockMessage.TYPE_GET_BLOCK => { + val gB = new GetBlock(blockMessage.getId) + println(gB) + } + } + }) + } +} + + http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala new file mode 100644 index 0000000..3b245c5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.nio + +import java.nio.ByteBuffer + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.storage.BlockManager + + +private[nio] +class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int) + extends Message(Message.BUFFER_MESSAGE, id_) { + + val initialSize = currentSize() + var gotChunkForSendingOnce = false + + def size = initialSize + + def currentSize() = { + if (buffers == null || buffers.isEmpty) { + 0 + } else { + buffers.map(_.remaining).reduceLeft(_ + _) + } + } + + def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = { + if (maxChunkSize <= 0) { + throw new Exception("Max chunk size is " + maxChunkSize) + } + + val security = if (isSecurityNeg) 1 else 0 + if (size == 0 && !gotChunkForSendingOnce) { + val newChunk = new MessageChunk( + new MessageChunkHeader(typ, id, 0, 0, ackId, hasError, security, senderAddress), null) + gotChunkForSendingOnce = true + return Some(newChunk) + } + + while(!buffers.isEmpty) { + val buffer = buffers(0) + if (buffer.remaining == 0) { + BlockManager.dispose(buffer) + buffers -= buffer + } else { + val newBuffer = if (buffer.remaining <= maxChunkSize) { + buffer.duplicate() + } else { + buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer] + } + buffer.position(buffer.position + newBuffer.remaining) + val newChunk = new MessageChunk(new MessageChunkHeader( + typ, id, size, newBuffer.remaining, ackId, + hasError, security, senderAddress), newBuffer) + gotChunkForSendingOnce = true + return Some(newChunk) + } + } + None + } + + def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = { + // STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer + if (buffers.size > 1) { + throw new Exception("Attempting to get chunk from message with multiple data buffers") + } + val buffer = buffers(0) + val security = if (isSecurityNeg) 1 else 0 + if (buffer.remaining > 0) { + if (buffer.remaining < chunkSize) { + throw new Exception("Not enough space in data buffer for receiving chunk") + } + val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer] + buffer.position(buffer.position + newBuffer.remaining) + val newChunk = new MessageChunk(new MessageChunkHeader( + typ, id, size, newBuffer.remaining, ackId, hasError, security, senderAddress), newBuffer) + return Some(newChunk) + } + None + } + + def flip() { + buffers.foreach(_.flip) + } + + def hasAckId() = (ackId != 0) + + def isCompletelyReceived() = !buffers(0).hasRemaining + + override def toString = { + if (hasAckId) { + "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")" + } else { + "BufferMessage(id = " + id + ", size = " + size + ")" + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/nio/Connection.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala new file mode 100644 index 0000000..74074a8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -0,0 +1,587 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.nio + +import java.net._ +import java.nio._ +import java.nio.channels._ + +import org.apache.spark._ + +import scala.collection.mutable.{ArrayBuffer, HashMap, Queue} + +private[nio] +abstract class Connection(val channel: SocketChannel, val selector: Selector, + val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId) + extends Logging { + + var sparkSaslServer: SparkSaslServer = null + var sparkSaslClient: SparkSaslClient = null + + def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId) = { + this(channel_, selector_, + ConnectionManagerId.fromSocketAddress( + channel_.socket.getRemoteSocketAddress.asInstanceOf[InetSocketAddress]), id_) + } + + channel.configureBlocking(false) + channel.socket.setTcpNoDelay(true) + channel.socket.setReuseAddress(true) + channel.socket.setKeepAlive(true) + /* channel.socket.setReceiveBufferSize(32768) */ + + @volatile private var closed = false + var onCloseCallback: Connection => Unit = null + var onExceptionCallback: (Connection, Exception) => Unit = null + var onKeyInterestChangeCallback: (Connection, Int) => Unit = null + + val remoteAddress = getRemoteAddress() + + /** + * Used to synchronize client requests: client's work-related requests must + * wait until SASL authentication completes. + */ + private val authenticated = new Object() + + def getAuthenticated(): Object = authenticated + + def isSaslComplete(): Boolean + + def resetForceReregister(): Boolean + + // Read channels typically do not register for write and write does not for read + // Now, we do have write registering for read too (temporarily), but this is to detect + // channel close NOT to actually read/consume data on it ! + // How does this work if/when we move to SSL ? + + // What is the interest to register with selector for when we want this connection to be selected + def registerInterest() + + // What is the interest to register with selector for when we want this connection to + // be de-selected + // Traditionally, 0 - but in our case, for example, for close-detection on SendingConnection hack, + // it will be SelectionKey.OP_READ (until we fix it properly) + def unregisterInterest() + + // On receiving a read event, should we change the interest for this channel or not ? + // Will be true for ReceivingConnection, false for SendingConnection. + def changeInterestForRead(): Boolean + + private def disposeSasl() { + if (sparkSaslServer != null) { + sparkSaslServer.dispose() + } + + if (sparkSaslClient != null) { + sparkSaslClient.dispose() + } + } + + // On receiving a write event, should we change the interest for this channel or not ? + // Will be false for ReceivingConnection, true for SendingConnection. + // Actually, for now, should not get triggered for ReceivingConnection + def changeInterestForWrite(): Boolean + + def getRemoteConnectionManagerId(): ConnectionManagerId = { + socketRemoteConnectionManagerId + } + + def key() = channel.keyFor(selector) + + def getRemoteAddress() = channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress] + + // Returns whether we have to register for further reads or not. + def read(): Boolean = { + throw new UnsupportedOperationException( + "Cannot read on connection of type " + this.getClass.toString) + } + + // Returns whether we have to register for further writes or not. + def write(): Boolean = { + throw new UnsupportedOperationException( + "Cannot write on connection of type " + this.getClass.toString) + } + + def close() { + closed = true + val k = key() + if (k != null) { + k.cancel() + } + channel.close() + disposeSasl() + callOnCloseCallback() + } + + protected def isClosed: Boolean = closed + + def onClose(callback: Connection => Unit) { + onCloseCallback = callback + } + + def onException(callback: (Connection, Exception) => Unit) { + onExceptionCallback = callback + } + + def onKeyInterestChange(callback: (Connection, Int) => Unit) { + onKeyInterestChangeCallback = callback + } + + def callOnExceptionCallback(e: Exception) { + if (onExceptionCallback != null) { + onExceptionCallback(this, e) + } else { + logError("Error in connection to " + getRemoteConnectionManagerId() + + " and OnExceptionCallback not registered", e) + } + } + + def callOnCloseCallback() { + if (onCloseCallback != null) { + onCloseCallback(this) + } else { + logWarning("Connection to " + getRemoteConnectionManagerId() + + " closed and OnExceptionCallback not registered") + } + + } + + def changeConnectionKeyInterest(ops: Int) { + if (onKeyInterestChangeCallback != null) { + onKeyInterestChangeCallback(this, ops) + } else { + throw new Exception("OnKeyInterestChangeCallback not registered") + } + } + + def printRemainingBuffer(buffer: ByteBuffer) { + val bytes = new Array[Byte](buffer.remaining) + val curPosition = buffer.position + buffer.get(bytes) + bytes.foreach(x => print(x + " ")) + buffer.position(curPosition) + print(" (" + bytes.size + ")") + } + + def printBuffer(buffer: ByteBuffer, position: Int, length: Int) { + val bytes = new Array[Byte](length) + val curPosition = buffer.position + buffer.position(position) + buffer.get(bytes) + bytes.foreach(x => print(x + " ")) + print(" (" + position + ", " + length + ")") + buffer.position(curPosition) + } +} + + +private[nio] +class SendingConnection(val address: InetSocketAddress, selector_ : Selector, + remoteId_ : ConnectionManagerId, id_ : ConnectionId) + extends Connection(SocketChannel.open, selector_, remoteId_, id_) { + + def isSaslComplete(): Boolean = { + if (sparkSaslClient != null) sparkSaslClient.isComplete() else false + } + + private class Outbox { + val messages = new Queue[Message]() + val defaultChunkSize = 65536 + var nextMessageToBeUsed = 0 + + def addMessage(message: Message) { + messages.synchronized { + /* messages += message */ + messages.enqueue(message) + logDebug("Added [" + message + "] to outbox for sending to " + + "[" + getRemoteConnectionManagerId() + "]") + } + } + + def getChunk(): Option[MessageChunk] = { + messages.synchronized { + while (!messages.isEmpty) { + /* nextMessageToBeUsed = nextMessageToBeUsed % messages.size */ + /* val message = messages(nextMessageToBeUsed) */ + val message = messages.dequeue() + val chunk = message.getChunkForSending(defaultChunkSize) + if (chunk.isDefined) { + messages.enqueue(message) + nextMessageToBeUsed = nextMessageToBeUsed + 1 + if (!message.started) { + logDebug( + "Starting to send [" + message + "] to [" + getRemoteConnectionManagerId() + "]") + message.started = true + message.startTime = System.currentTimeMillis + } + logTrace( + "Sending chunk from [" + message + "] to [" + getRemoteConnectionManagerId() + "]") + return chunk + } else { + message.finishTime = System.currentTimeMillis + logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + + "] in " + message.timeTaken ) + } + } + } + None + } + } + + // outbox is used as a lock - ensure that it is always used as a leaf (since methods which + // lock it are invoked in context of other locks) + private val outbox = new Outbox() + /* + This is orthogonal to whether we have pending bytes to write or not - and satisfies a slightly + different purpose. This flag is to see if we need to force reregister for write even when we + do not have any pending bytes to write to socket. + This can happen due to a race between adding pending buffers, and checking for existing of + data as detailed in https://github.com/mesos/spark/pull/791 + */ + private var needForceReregister = false + + val currentBuffers = new ArrayBuffer[ByteBuffer]() + + /* channel.socket.setSendBufferSize(256 * 1024) */ + + override def getRemoteAddress() = address + + val DEFAULT_INTEREST = SelectionKey.OP_READ + + override def registerInterest() { + // Registering read too - does not really help in most cases, but for some + // it does - so let us keep it for now. + changeConnectionKeyInterest(SelectionKey.OP_WRITE | DEFAULT_INTEREST) + } + + override def unregisterInterest() { + changeConnectionKeyInterest(DEFAULT_INTEREST) + } + + def send(message: Message) { + outbox.synchronized { + outbox.addMessage(message) + needForceReregister = true + } + if (channel.isConnected) { + registerInterest() + } + } + + // return previous value after resetting it. + def resetForceReregister(): Boolean = { + outbox.synchronized { + val result = needForceReregister + needForceReregister = false + result + } + } + + // MUST be called within the selector loop + def connect() { + try{ + channel.register(selector, SelectionKey.OP_CONNECT) + channel.connect(address) + logInfo("Initiating connection to [" + address + "]") + } catch { + case e: Exception => { + logError("Error connecting to " + address, e) + callOnExceptionCallback(e) + } + } + } + + def finishConnect(force: Boolean): Boolean = { + try { + // Typically, this should finish immediately since it was triggered by a connect + // selection - though need not necessarily always complete successfully. + val connected = channel.finishConnect + if (!force && !connected) { + logInfo( + "finish connect failed [" + address + "], " + outbox.messages.size + " messages pending") + return false + } + + // Fallback to previous behavior - assume finishConnect completed + // This will happen only when finishConnect failed for some repeated number of times + // (10 or so) + // Is highly unlikely unless there was an unclean close of socket, etc + registerInterest() + logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending") + } catch { + case e: Exception => { + logWarning("Error finishing connection to " + address, e) + callOnExceptionCallback(e) + } + } + true + } + + override def write(): Boolean = { + try { + while (true) { + if (currentBuffers.size == 0) { + outbox.synchronized { + outbox.getChunk() match { + case Some(chunk) => { + val buffers = chunk.buffers + // If we have 'seen' pending messages, then reset flag - since we handle that as + // normal registering of event (below) + if (needForceReregister && buffers.exists(_.remaining() > 0)) resetForceReregister() + + currentBuffers ++= buffers + } + case None => { + // changeConnectionKeyInterest(0) + /* key.interestOps(0) */ + return false + } + } + } + } + + if (currentBuffers.size > 0) { + val buffer = currentBuffers(0) + val remainingBytes = buffer.remaining + val writtenBytes = channel.write(buffer) + if (buffer.remaining == 0) { + currentBuffers -= buffer + } + if (writtenBytes < remainingBytes) { + // re-register for write. + return true + } + } + } + } catch { + case e: Exception => { + logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e) + callOnExceptionCallback(e) + close() + return false + } + } + // should not happen - to keep scala compiler happy + true + } + + // This is a hack to determine if remote socket was closed or not. + // SendingConnection DOES NOT expect to receive any data - if it does, it is an error + // For a bunch of cases, read will return -1 in case remote socket is closed : hence we + // register for reads to determine that. + override def read(): Boolean = { + // We don't expect the other side to send anything; so, we just read to detect an error or EOF. + try { + val length = channel.read(ByteBuffer.allocate(1)) + if (length == -1) { // EOF + close() + } else if (length > 0) { + logWarning( + "Unexpected data read from SendingConnection to " + getRemoteConnectionManagerId()) + } + } catch { + case e: Exception => + logError("Exception while reading SendingConnection to " + getRemoteConnectionManagerId(), + e) + callOnExceptionCallback(e) + close() + } + + false + } + + override def changeInterestForRead(): Boolean = false + + override def changeInterestForWrite(): Boolean = ! isClosed +} + + +// Must be created within selector loop - else deadlock +private[spark] class ReceivingConnection( + channel_ : SocketChannel, + selector_ : Selector, + id_ : ConnectionId) + extends Connection(channel_, selector_, id_) { + + def isSaslComplete(): Boolean = { + if (sparkSaslServer != null) sparkSaslServer.isComplete() else false + } + + class Inbox() { + val messages = new HashMap[Int, BufferMessage]() + + def getChunk(header: MessageChunkHeader): Option[MessageChunk] = { + + def createNewMessage: BufferMessage = { + val newMessage = Message.create(header).asInstanceOf[BufferMessage] + newMessage.started = true + newMessage.startTime = System.currentTimeMillis + newMessage.isSecurityNeg = header.securityNeg == 1 + logDebug( + "Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]") + messages += ((newMessage.id, newMessage)) + newMessage + } + + val message = messages.getOrElseUpdate(header.id, createNewMessage) + logTrace( + "Receiving chunk of [" + message + "] from [" + getRemoteConnectionManagerId() + "]") + message.getChunkForReceiving(header.chunkSize) + } + + def getMessageForChunk(chunk: MessageChunk): Option[BufferMessage] = { + messages.get(chunk.header.id) + } + + def removeMessage(message: Message) { + messages -= message.id + } + } + + @volatile private var inferredRemoteManagerId: ConnectionManagerId = null + + override def getRemoteConnectionManagerId(): ConnectionManagerId = { + val currId = inferredRemoteManagerId + if (currId != null) currId else super.getRemoteConnectionManagerId() + } + + // The reciever's remote address is the local socket on remote side : which is NOT + // the connection manager id of the receiver. + // We infer that from the messages we receive on the receiver socket. + private def processConnectionManagerId(header: MessageChunkHeader) { + val currId = inferredRemoteManagerId + if (header.address == null || currId != null) return + + val managerId = ConnectionManagerId.fromSocketAddress(header.address) + + if (managerId != null) { + inferredRemoteManagerId = managerId + } + } + + + val inbox = new Inbox() + val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE) + var onReceiveCallback: (Connection, Message) => Unit = null + var currentChunk: MessageChunk = null + + channel.register(selector, SelectionKey.OP_READ) + + override def read(): Boolean = { + try { + while (true) { + if (currentChunk == null) { + val headerBytesRead = channel.read(headerBuffer) + if (headerBytesRead == -1) { + close() + return false + } + if (headerBuffer.remaining > 0) { + // re-register for read event ... + return true + } + headerBuffer.flip + if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) { + throw new Exception( + "Unexpected number of bytes (" + headerBuffer.remaining + ") in the header") + } + val header = MessageChunkHeader.create(headerBuffer) + headerBuffer.clear() + + processConnectionManagerId(header) + + header.typ match { + case Message.BUFFER_MESSAGE => { + if (header.totalSize == 0) { + if (onReceiveCallback != null) { + onReceiveCallback(this, Message.create(header)) + } + currentChunk = null + // re-register for read event ... + return true + } else { + currentChunk = inbox.getChunk(header).orNull + } + } + case _ => throw new Exception("Message of unknown type received") + } + } + + if (currentChunk == null) throw new Exception("No message chunk to receive data") + + val bytesRead = channel.read(currentChunk.buffer) + if (bytesRead == 0) { + // re-register for read event ... + return true + } else if (bytesRead == -1) { + close() + return false + } + + /* logDebug("Read " + bytesRead + " bytes for the buffer") */ + + if (currentChunk.buffer.remaining == 0) { + /* println("Filled buffer at " + System.currentTimeMillis) */ + val bufferMessage = inbox.getMessageForChunk(currentChunk).get + if (bufferMessage.isCompletelyReceived) { + bufferMessage.flip() + bufferMessage.finishTime = System.currentTimeMillis + logDebug("Finished receiving [" + bufferMessage + "] from " + + "[" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken) + if (onReceiveCallback != null) { + onReceiveCallback(this, bufferMessage) + } + inbox.removeMessage(bufferMessage) + } + currentChunk = null + } + } + } catch { + case e: Exception => { + logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e) + callOnExceptionCallback(e) + close() + return false + } + } + // should not happen - to keep scala compiler happy + true + } + + def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback} + + // override def changeInterestForRead(): Boolean = ! isClosed + override def changeInterestForRead(): Boolean = true + + override def changeInterestForWrite(): Boolean = { + throw new IllegalStateException("Unexpected invocation right now") + } + + override def registerInterest() { + // Registering read too - does not really help in most cases, but for some + // it does - so let us keep it for now. + changeConnectionKeyInterest(SelectionKey.OP_READ) + } + + override def unregisterInterest() { + changeConnectionKeyInterest(0) + } + + // For read conn, always false. + override def resetForceReregister(): Boolean = false +} http://git-wip-us.apache.org/repos/asf/spark/blob/08ce1888/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala new file mode 100644 index 0000000..764dc5e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.nio + +private[nio] case class ConnectionId(connectionManagerId: ConnectionManagerId, uniqId: Int) { + override def toString = connectionManagerId.host + "_" + connectionManagerId.port + "_" + uniqId +} + +private[nio] object ConnectionId { + + def createConnectionIdFromString(connectionIdString: String): ConnectionId = { + val res = connectionIdString.split("_").map(_.trim()) + if (res.size != 3) { + throw new Exception("Error converting ConnectionId string: " + connectionIdString + + " to a ConnectionId Object") + } + new ConnectionId(new ConnectionManagerId(res(0), res(1).toInt), res(2).toInt) + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org