http://git-wip-us.apache.org/repos/asf/incubator-toree/blob/68f7ddd6/kernel/src/test/scala/org/apache/toree/kernel/protocol/v5/magic/PostProcessorSpec.scala ---------------------------------------------------------------------- diff --git a/kernel/src/test/scala/org/apache/toree/kernel/protocol/v5/magic/PostProcessorSpec.scala b/kernel/src/test/scala/org/apache/toree/kernel/protocol/v5/magic/PostProcessorSpec.scala new file mode 100644 index 0000000..1557460 --- /dev/null +++ b/kernel/src/test/scala/org/apache/toree/kernel/protocol/v5/magic/PostProcessorSpec.scala @@ -0,0 +1,123 @@ +package com.ibm.spark.kernel.protocol.v5.magic + +import com.ibm.spark.interpreter.Interpreter +import com.ibm.spark.kernel.protocol.v5._ +import com.ibm.spark.magic.{CellMagicOutput, LineMagicOutput} +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.scalatest.mock.MockitoSugar +import org.scalatest.{FunSpec, Matchers} + +class PostProcessorSpec extends FunSpec with Matchers with MockitoSugar{ + describe("#matchCellMagic") { + it("should return the cell magic output when the Left contains a " + + "CellMagicOutput") { + val processor = new PostProcessor(mock[Interpreter]) + val codeOutput = "some output" + val cmo = CellMagicOutput() + val left = Left(cmo) + processor.matchCellMagic(codeOutput, left) should be(cmo) + } + + it("should package the original code when the Left does not contain a " + + "CellMagicOutput") { + val processor = new PostProcessor(mock[Interpreter]) + val codeOutput = "some output" + val left = Left("") + val data = Data(MIMEType.PlainText -> codeOutput) + processor.matchCellMagic(codeOutput, left) should be(data) + } + } + + describe("#matchLineMagic") { + it("should process the code output when the Right contains a " + + "LineMagicOutput") { + val processor = spy(new PostProcessor(mock[Interpreter])) + val codeOutput = "some output" + val lmo = LineMagicOutput + val right = Right(lmo) + processor.matchLineMagic(codeOutput, right) + verify(processor).processLineMagic(codeOutput) + } + + it("should package the original code when the Right does not contain a " + + "LineMagicOutput") { + val processor = new PostProcessor(mock[Interpreter]) + val codeOutput = "some output" + val right = Right("") + val data = Data(MIMEType.PlainText -> codeOutput) + processor.matchLineMagic(codeOutput, right) should be(data) + } + } + + describe("#processLineMagic") { + it("should remove the result of the magic invocation if it is the last " + + "line") { + val processor = new PostProcessor(mock[Interpreter]) + val x = "hello world" + val codeOutput = s"$x\nsome other output" + val data = Data(MIMEType.PlainText -> x) + processor.processLineMagic(codeOutput) should be(data) + } + } + + describe("#process") { + it("should call matchCellMagic when the last variable is a Left") { + val intp = mock[Interpreter] + val left = Left("") + // Need to mock lastExecutionVariableName as it is being chained with + // the read method + doReturn(Some("")).when(intp).lastExecutionVariableName + doReturn(Some(left)).when(intp).read(anyString()) + + val processor = spy(new PostProcessor(intp)) + val codeOutput = "hello" + processor.process(codeOutput) + verify(processor).matchCellMagic(codeOutput, left) + } + + it("should call matchLineMagic when the last variable is a Right") { + val intp = mock[Interpreter] + val right = Right("") + // Need to mock lastExecutionVariableName as it is being chained with + // the read method + doReturn(Some("")).when(intp).lastExecutionVariableName + doReturn(Some(right)).when(intp).read(anyString()) + + val processor = spy(new PostProcessor(intp)) + val codeOutput = "hello" + processor.process(codeOutput) + verify(processor).matchLineMagic(codeOutput, right) + } + + it("should package the original code output when the Left is not a " + + "Left[CellMagicOutput, Nothing]") { + val intp = mock[Interpreter] + val left = Left("") + // Need to mock lastExecutionVariableName as it is being chained with + // the read method + doReturn(Some("")).when(intp).lastExecutionVariableName + doReturn(Some(left)).when(intp).read(anyString()) + + val processor = spy(new PostProcessor(intp)) + val codeOutput = "hello" + val data = Data(MIMEType.PlainText -> codeOutput) + processor.process(codeOutput) should be(data) + } + + it("should package the original code output when the Right is not a " + + "Right[LineMagicOutput, Nothing]") { + val intp = mock[Interpreter] + val right = Right("") + // Need to mock lastExecutionVariableName as it is being chained with + // the read method + doReturn(Some("")).when(intp).lastExecutionVariableName + doReturn(Some(right)).when(intp).read(anyString()) + + val processor = spy(new PostProcessor(intp)) + val codeOutput = "hello" + val data = Data(MIMEType.PlainText -> codeOutput) + processor.process(codeOutput) should be(data) + } + } +}
http://git-wip-us.apache.org/repos/asf/incubator-toree/blob/68f7ddd6/kernel/src/test/scala/org/apache/toree/kernel/protocol/v5/relay/ExecuteRequestRelaySpec.scala ---------------------------------------------------------------------- diff --git a/kernel/src/test/scala/org/apache/toree/kernel/protocol/v5/relay/ExecuteRequestRelaySpec.scala b/kernel/src/test/scala/org/apache/toree/kernel/protocol/v5/relay/ExecuteRequestRelaySpec.scala new file mode 100644 index 0000000..7adf620 --- /dev/null +++ b/kernel/src/test/scala/org/apache/toree/kernel/protocol/v5/relay/ExecuteRequestRelaySpec.scala @@ -0,0 +1,183 @@ +/* + * Copyright 2014 IBM Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.ibm.spark.kernel.protocol.v5.relay + +import java.io.OutputStream + +import akka.actor._ +import akka.testkit.{ImplicitSender, TestKit, TestProbe} +import com.ibm.spark.interpreter.{ExecuteAborted, ExecuteError} +import com.ibm.spark.kernel.protocol.v5._ +import com.ibm.spark.kernel.protocol.v5.content._ +import com.ibm.spark.kernel.protocol.v5.kernel.ActorLoader +import com.ibm.spark.kernel.protocol.v5.magic.{MagicParser, PostProcessor} +import com.ibm.spark.magic.MagicLoader +import com.ibm.spark.magic.dependencies.DependencyMap +import com.typesafe.config.ConfigFactory +import org.mockito.Mockito._ +import org.scalatest.mock.MockitoSugar +import org.scalatest.{BeforeAndAfter, FunSpecLike, Matchers} + +object ExecuteRequestRelaySpec { + val config = """ + akka { + loglevel = "WARNING" + }""" +} + +class ExecuteRequestRelaySpec extends TestKit( + ActorSystem( + "ExecuteRequestRelayActorSystem", + ConfigFactory.parseString(ExecuteRequestRelaySpec.config) + ) +) with ImplicitSender with FunSpecLike with Matchers with MockitoSugar + with BeforeAndAfter +{ + var mockActorLoader: ActorLoader = _ + var interpreterActorProbe: TestProbe = _ + + before { + mockActorLoader = mock[ActorLoader] + interpreterActorProbe = new TestProbe(system) + val mockInterpreterActorSelection = + system.actorSelection(interpreterActorProbe.ref.path.toString) + doReturn(mockInterpreterActorSelection).when(mockActorLoader) + .load(SystemActorType.Interpreter) + } + + describe("ExecuteRequestRelay") { + describe("#receive(KernelMessage)") { + it("should handle an abort returned by the InterpreterActor") { + val executeRequest = + ExecuteRequest("%myMagic", false, true, UserExpressions(), true) + + val mockMagicLoader = mock[MagicLoader] + val mockPostProcessor = mock[PostProcessor] + val mockDependencyMap = mock[DependencyMap] + doReturn(mockDependencyMap).when(mockMagicLoader).dependencyMap + + val mockMagicParser = mock[MagicParser] + doReturn(Left(executeRequest.code)) + .when(mockMagicParser).parse(executeRequest.code) + + val executeRequestRelay = system.actorOf(Props( + classOf[ExecuteRequestRelay], mockActorLoader, + mockMagicLoader, mockMagicParser, mockPostProcessor + )) + + // Send the message to the ExecuteRequestRelay + executeRequestRelay ! + ((executeRequest, mock[KernelMessage], mock[OutputStream])) + + // Expected does not actually match real return of magic, which + // is a tuple of ExecuteReply and ExecuteResult + val expected = new ExecuteAborted() + interpreterActorProbe.expectMsgClass( + classOf[(ExecuteRequest, KernelMessage, OutputStream)] + ) + + // Reply with an error + interpreterActorProbe.reply(Right(expected)) + + expectMsg( + (ExecuteReplyAbort(1), ExecuteResult(1, Data(), Metadata())) + ) + } + + it("should handle an error returned by the InterpreterActor") { + val executeRequest = + ExecuteRequest("%myMagic", false, true, UserExpressions(), true) + + val mockMagicLoader = mock[MagicLoader] + val mockPostProcessor = mock[PostProcessor] + val mockDependencyMap = mock[DependencyMap] + doReturn(mockDependencyMap).when(mockMagicLoader).dependencyMap + + val mockMagicParser = mock[MagicParser] + doReturn(Left(executeRequest.code)) + .when(mockMagicParser).parse(executeRequest.code) + + val executeRequestRelay = system.actorOf(Props( + classOf[ExecuteRequestRelay], mockActorLoader, + mockMagicLoader, mockMagicParser, mockPostProcessor + )) + + // Send the message to the ExecuteRequestRelay + executeRequestRelay ! + ((executeRequest, mock[KernelMessage], mock[OutputStream])) + + // Expected does not actually match real return of magic, which + // is a tuple of ExecuteReply and ExecuteResult + val expected = ExecuteError("NAME", "MESSAGE", List()) + interpreterActorProbe.expectMsgClass( + classOf[(ExecuteRequest, KernelMessage, OutputStream)] + ) + + // Reply with an error + interpreterActorProbe.reply(Right(expected)) + + expectMsg(( + ExecuteReplyError(1, Some(expected.name), Some(expected.value), + Some(expected.stackTrace.map(_.toString).toList)), + ExecuteResult(1, Data("text/plain" -> expected.toString), Metadata()) + )) + } + + it("should return an (ExecuteReply, ExecuteResult) on interpreter " + + "success") { + val expected = "SOME OTHER VALUE" + val executeRequest = + ExecuteRequest("notAMagic", false, true, UserExpressions(), true) + + val mockMagicLoader = mock[MagicLoader] + val mockPostProcessor = mock[PostProcessor] + doReturn(Data(MIMEType.PlainText -> expected)) + .when(mockPostProcessor).process(expected) + + val mockDependencyMap = mock[DependencyMap] + doReturn(mockDependencyMap).when(mockMagicLoader).dependencyMap + + val mockMagicParser = mock[MagicParser] + doReturn(Left(executeRequest.code)) + .when(mockMagicParser).parse(executeRequest.code) + + val executeRequestRelay = system.actorOf(Props( + classOf[ExecuteRequestRelay], mockActorLoader, + mockMagicLoader, mockMagicParser, mockPostProcessor + )) + + // Send the message to the ExecuteRequestRelay + executeRequestRelay ! + ((executeRequest, mock[KernelMessage], mock[OutputStream])) + + // Expected does not actually match real return of interpreter, which + // is a tuple of ExecuteReply and ExecuteResult + interpreterActorProbe.expectMsgClass( + classOf[(ExecuteRequest, KernelMessage, OutputStream)] + ) + + // Reply with a successful interpret + interpreterActorProbe.reply(Left(expected)) + + expectMsg(( + ExecuteReplyOk(1, Some(Payloads()), Some(UserExpressions())), + ExecuteResult(1, Data(MIMEType.PlainText -> expected), Metadata()) + )) + } + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-toree/blob/68f7ddd6/kernel/src/test/scala/org/apache/toree/kernel/protocol/v5/relay/KernelMessageRelaySpec.scala ---------------------------------------------------------------------- diff --git a/kernel/src/test/scala/org/apache/toree/kernel/protocol/v5/relay/KernelMessageRelaySpec.scala b/kernel/src/test/scala/org/apache/toree/kernel/protocol/v5/relay/KernelMessageRelaySpec.scala new file mode 100644 index 0000000..63fa50c --- /dev/null +++ b/kernel/src/test/scala/org/apache/toree/kernel/protocol/v5/relay/KernelMessageRelaySpec.scala @@ -0,0 +1,243 @@ +/* + * Copyright 2014 IBM Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.ibm.spark.kernel.protocol.v5.relay + +import akka.actor._ +import akka.testkit.{ImplicitSender, TestKit, TestProbe} +import com.ibm.spark.communication.ZMQMessage +import com.ibm.spark.communication.security.SecurityActorType +import com.ibm.spark.kernel.protocol.v5._ +import com.ibm.spark.kernel.protocol.v5.kernel.{ActorLoader, Utilities} +import Utilities._ +import org.mockito.Mockito._ +import org.scalatest.concurrent.{PatienceConfiguration, ScalaFutures} +import org.scalatest.mock.MockitoSugar +import org.scalatest.time.{Millis, Seconds, Span} +import org.scalatest.{BeforeAndAfter, FunSpecLike, Matchers} +import org.mockito.Matchers.{eq => mockEq} +import org.mockito.AdditionalMatchers.{not => mockNot} +import scala.concurrent.duration._ +import com.ibm.spark.kernel.protocol.v5.KernelMessage +import scala.concurrent._ +import akka.pattern.pipe +import scala.util.Random +import ExecutionContext.Implicits.global + +class KernelMessageRelaySpec extends TestKit(ActorSystem("RelayActorSystem")) + with ImplicitSender with FunSpecLike with Matchers with MockitoSugar + with BeforeAndAfter with ScalaFutures { + private val IncomingMessageType = MessageType.Incoming.CompleteRequest.toString + private val OutgoingMessageType = MessageType.Outgoing.CompleteReply.toString + + private val header: Header = Header("<UUID>", "<USER>", "<SESSION>", + "<TYPE>", "<VERSION>") + private val parentHeader: Header = Header("<PARENT-UUID>", "<PARENT-USER>", + "<PARENT-SESSION>", "<PARENT-TYPE>", "<PARENT-VERSION>") + private val incomingKernelMessage: KernelMessage = KernelMessage(Seq("<ID>"), + "<SIGNATURE>", header.copy(msg_type = IncomingMessageType), + parentHeader, Metadata(), "<CONTENT>") + private val outgoingKernelMessage: KernelMessage = KernelMessage(Seq("<ID>"), + "<SIGNATURE>", header.copy(msg_type = OutgoingMessageType), + incomingKernelMessage.header, Metadata(), "<CONTENT>") + private val incomingZmqStrings = "1" :: "2" :: "3" :: "4" :: Nil + + private var actorLoader: ActorLoader = _ + private var signatureProbe: TestProbe = _ + private var signatureSelection: ActorSelection = _ + private var captureProbe: TestProbe = _ + private var captureSelection: ActorSelection = _ + private var handlerProbe: TestProbe = _ + private var handlerSelection: ActorSelection = _ + private var relayWithoutSignatureManager: ActorRef = _ + private var relayWithSignatureManager: ActorRef = _ + + before { + // Create a mock ActorLoader for the Relay we are going to test + actorLoader = mock[ActorLoader] + + // Create a probe for the signature manager and mock the ActorLoader to + // return the associated ActorSelection + signatureProbe = TestProbe() + signatureSelection = system.actorSelection(signatureProbe.ref.path.toString) + when(actorLoader.load(SecurityActorType.SignatureManager)) + .thenReturn(signatureSelection) + + // Create a probe to capture output from the relay for testing + captureProbe = TestProbe() + captureSelection = system.actorSelection(captureProbe.ref.path.toString) + when(actorLoader.load(mockNot(mockEq(SecurityActorType.SignatureManager)))) + .thenReturn(captureSelection) + + relayWithoutSignatureManager = system.actorOf(Props( + classOf[KernelMessageRelay], actorLoader, false, + mock[Map[String, String]], mock[Map[String, String]] + )) + + relayWithSignatureManager = system.actorOf(Props( + classOf[KernelMessageRelay], actorLoader, true, + mock[Map[String, String]], mock[Map[String, String]] + )) + } + + describe("Relay") { + describe("#receive") { + describe("when not using the signature manager") { + it("should not send anything to SignatureManager for incoming") { + relayWithoutSignatureManager ! true // Mark as ready for incoming + relayWithoutSignatureManager ! incomingKernelMessage + signatureProbe.expectNoMsg(25.millis) + } + + it("should not send anything to SignatureManager for outgoing") { + relayWithoutSignatureManager ! outgoingKernelMessage + signatureProbe.expectNoMsg(25.millis) + } + + it("should relay KernelMessage for incoming") { + relayWithoutSignatureManager ! true // Mark as ready for incoming + relayWithoutSignatureManager ! + ((incomingZmqStrings, incomingKernelMessage)) + captureProbe.expectMsg(incomingKernelMessage) + } + + it("should relay KernelMessage for outgoing") { + relayWithoutSignatureManager ! outgoingKernelMessage + captureProbe.expectMsg(outgoingKernelMessage) + } + } + + describe("when using the signature manager") { + it("should verify the signature if the message is incoming") { + relayWithSignatureManager ! true // Mark as ready for incoming + relayWithSignatureManager ! incomingKernelMessage + signatureProbe.expectMsg(incomingKernelMessage) + } + + it("should construct the signature if the message is outgoing") { + relayWithSignatureManager ! outgoingKernelMessage + signatureProbe.expectMsg(outgoingKernelMessage) + } + } + + describe("when not ready") { + it("should not relay the message if it is incoming") { + val incomingMessage: ZMQMessage = incomingKernelMessage + + relayWithoutSignatureManager ! incomingMessage + captureProbe.expectNoMsg(25.millis) + } + + it("should relay the message if it is outgoing") { + relayWithoutSignatureManager ! outgoingKernelMessage + captureProbe.expectMsg(outgoingKernelMessage) + } + } + + describe("when ready") { + it("should relay the message if it is incoming") { + relayWithoutSignatureManager ! true // Mark as ready for incoming + relayWithoutSignatureManager ! + ((incomingZmqStrings, incomingKernelMessage)) + captureProbe.expectMsg(incomingKernelMessage) + } + + it("should relay the message if it is outgoing") { + relayWithoutSignatureManager ! true // Mark as ready for incoming + relayWithoutSignatureManager ! outgoingKernelMessage + captureProbe.expectMsg(outgoingKernelMessage) + } + } + + describe("multiple messages in order"){ + it("should relay messages in the order they were received") { + // Setup the base actor system and the relay + val actorLoader = mock[ActorLoader] + val kernelMessageRelay = system.actorOf(Props( + classOf[KernelMessageRelay], actorLoader, true, + mock[Map[String, String]], mock[Map[String, String]] + )) + // Where all of the messages are relayed to, otherwise NPE + val captureProbe = TestProbe() + val captureSelection = system.actorSelection(captureProbe.ref.path.toString) + when(actorLoader.load(MessageType.Incoming.CompleteRequest)) + .thenReturn(captureSelection) + + + val n = 5 + val chaoticPromise: Promise[String] = Promise() + var actual : List[String] = List() + val expected = (0 until n).map(_.toString).toList + + // setup a ChaoticActor to accumulate message values + // A promise succeeds after n messages have been accumulated + val chaoticActor: ActorRef = system.actorOf(Props( + classOf[ChaoticActor[Boolean]], + (paramVal: Any) => { + val tuple = paramVal.asInstanceOf[(String, Seq[_])] + actual = actual :+ tuple._1 + if (actual.length == n) chaoticPromise.success("Done") + true + } + )) + + when(actorLoader.load(SecurityActorType.SignatureManager)) + .thenReturn(system.actorSelection(chaoticActor.path)) + + kernelMessageRelay ! true + + // Sends messages with contents = to values of increasing counter + sendKernelMessages(n, kernelMessageRelay) + // Message values should be accumulated in the proper order + whenReady(chaoticPromise.future, + PatienceConfiguration.Timeout(Span(3, Seconds)), + PatienceConfiguration.Interval(Span(100, Millis))) { + case _: String => + actual should be(expected) + } + + } + } + } + } + def sendKernelMessages(n: Int, kernelMessageRelay: ActorRef): Unit ={ + // Sends n messages to the relay + (0 until n).foreach (i => { + val km = KernelMessage(Seq("<ID>"), s"${i}", + header.copy(msg_type = IncomingMessageType), parentHeader, + Metadata(), s"${i}") + kernelMessageRelay ! Tuple2(Seq("SomeString"), km) + }) + + } +} + + +case class ChaoticActor[U](receiveFunc : Any => U) extends Actor { + override def receive: Receive = { + case fVal: Any => + // The test actor system runs the actors on a single thread, so we must + // simulate asynchronous behaviour by staring a new thread + val promise = Promise[U]() + promise.future pipeTo sender + new Thread(new Runnable { + override def run(): Unit = { + Thread.sleep(Random.nextInt(30) * 10) + promise.success(receiveFunc(fVal)) + } + }).start() + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-toree/blob/68f7ddd6/kernel/src/test/scala/org/apache/toree/kernel/protocol/v5/stream/KernelInputStreamSpec.scala ---------------------------------------------------------------------- diff --git a/kernel/src/test/scala/org/apache/toree/kernel/protocol/v5/stream/KernelInputStreamSpec.scala b/kernel/src/test/scala/org/apache/toree/kernel/protocol/v5/stream/KernelInputStreamSpec.scala new file mode 100644 index 0000000..5d06cb6 --- /dev/null +++ b/kernel/src/test/scala/org/apache/toree/kernel/protocol/v5/stream/KernelInputStreamSpec.scala @@ -0,0 +1,142 @@ +/* + * Copyright 2015 IBM Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.ibm.spark.kernel.protocol.v5.stream + +import akka.actor.{ActorRef, Actor, ActorSystem} +import akka.testkit.{TestActorRef, TestKit, TestProbe} +import com.ibm.spark.kernel.protocol.v5._ +import com.ibm.spark.kernel.protocol.v5.content.InputRequest +import com.ibm.spark.kernel.protocol.v5.kernel.ActorLoader +import org.mockito.Mockito._ +import org.scalatest._ +import org.scalatest.mock.MockitoSugar +import play.api.libs.json.Json + +import scala.concurrent.duration._ + +class KernelInputStreamSpec + extends TestKit(ActorSystem("KernelInputStreamActorSystem")) + with FunSpecLike with Matchers with GivenWhenThen with BeforeAndAfter + with MockitoSugar +{ + + private var mockActorLoader: ActorLoader = _ + private var mockKMBuilder: KMBuilder = _ + private var kernelInputOutputHandlerProbe: TestProbe = _ + private var kernelInputStream: KernelInputStream = _ + private var fakeInputOutputHandlerActor: ActorRef = _ + + private val TestReplyString = "some reply" + + before { + mockActorLoader = mock[ActorLoader] + mockKMBuilder = KMBuilder() // No need to really mock this + + kernelInputStream = new KernelInputStream(mockActorLoader, mockKMBuilder) + + kernelInputOutputHandlerProbe = TestProbe() + fakeInputOutputHandlerActor = TestActorRef(new Actor { + override def receive: Receive = { + // Handle case for getting an input_request + case kernelMessage: KernelMessage => + val messageType = kernelMessage.header.msg_type + kernelInputOutputHandlerProbe.ref ! kernelMessage + if (messageType == MessageType.Outgoing.InputRequest.toString) + sender ! TestReplyString + } + }) + + // Add the actor that routes to our test probe and responds with a fake + // set of data + doReturn(system.actorSelection(fakeInputOutputHandlerActor.path.toString)) + .when(mockActorLoader).load(MessageType.Incoming.InputReply) + } + + describe("KernelInputStream") { + describe("#available") { + it("should be zero when no input has been read") { + kernelInputStream.available() should be (0) + } + + it("should match the bytes remaining internally") { + kernelInputStream.read() + + kernelInputStream.available() should be (TestReplyString.length - 1) + } + } + + describe("#read") { + it("should send a request for more data if the buffer is empty") { + // Fresh input stream has nothing in its buffer + kernelInputStream.read() + + // Verify that a message was sent out requesting data + kernelInputOutputHandlerProbe.expectMsgPF() { + case KernelMessage(_, _, header, _, _, _) + if header.msg_type == MessageType.Outgoing.InputRequest.toString => + true + } + } + + it("should use the provided prompt in its requests") { + val expected = KernelInputStream.DefaultPrompt + + // Fresh input stream has nothing in its buffer + kernelInputStream.read() + + // Verify that a message was sent out requesting data with the + // specific prompt + kernelInputOutputHandlerProbe.expectMsgPF() { + case KernelMessage(_, _, header, _, _, contentString) + if header.msg_type == MessageType.Outgoing.InputRequest.toString => + Json.parse(contentString).as[InputRequest].prompt should be (expected) + } + } + + it("should use the provided password flag in its requests") { + val expected = KernelInputStream.DefaultPassword + + // Fresh input stream has nothing in its buffer + kernelInputStream.read() + + // Verify that a message was sent out requesting data with the + // specific prompt + kernelInputOutputHandlerProbe.expectMsgPF() { + case KernelMessage(_, _, header, _, _, contentString) + if header.msg_type == MessageType.Outgoing.InputRequest.toString => + Json.parse(contentString).as[InputRequest].password should be (expected) + } + } + + it("should return the next byte from the current buffer") { + kernelInputStream.read() should be (TestReplyString.head) + } + + it("should not send a request for more data if data is in the buffer") { + // Run read for length of message (avoiding sending out a second + // request) + val readLength = TestReplyString.length + + for (i <- 1 to readLength) + kernelInputStream.read() should be (TestReplyString.charAt(i - 1)) + + kernelInputOutputHandlerProbe.expectMsgClass(classOf[KernelMessage]) + kernelInputOutputHandlerProbe.expectNoMsg(300.milliseconds) + } + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-toree/blob/68f7ddd6/kernel/src/test/scala/org/apache/toree/kernel/protocol/v5/stream/KernelOuputStreamSpec.scala ---------------------------------------------------------------------- diff --git a/kernel/src/test/scala/org/apache/toree/kernel/protocol/v5/stream/KernelOuputStreamSpec.scala b/kernel/src/test/scala/org/apache/toree/kernel/protocol/v5/stream/KernelOuputStreamSpec.scala new file mode 100644 index 0000000..cb4f158 --- /dev/null +++ b/kernel/src/test/scala/org/apache/toree/kernel/protocol/v5/stream/KernelOuputStreamSpec.scala @@ -0,0 +1,283 @@ +/* + * Copyright 2014 IBM Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.ibm.spark.kernel.protocol.v5.stream + +import java.util.UUID + +import akka.actor.{ActorSelection, ActorSystem} +import akka.testkit.{TestKit, TestProbe} +import com.ibm.spark.kernel.protocol.v5._ +import com.ibm.spark.kernel.protocol.v5.kernel.ActorLoader +import com.ibm.spark.utils.ScheduledTaskManager +import org.mockito.Mockito._ +import org.scalatest.mock.MockitoSugar +import org.scalatest._ +import play.api.libs.json._ +import com.ibm.spark.kernel.protocol.v5.content.StreamContent + +import scala.concurrent.duration._ + +class KernelOuputStreamSpec + extends TestKit(ActorSystem("KernelOutputStreamActorSystem")) + with FunSpecLike with Matchers with GivenWhenThen with BeforeAndAfter + with MockitoSugar +{ + + private var mockActorLoader: ActorLoader = _ + private var mockScheduledTaskManager: MockScheduledTaskManager = _ + private var kernelOutputRelayProbe: TestProbe = _ + + // + // SHARED ELEMENTS BETWEEN TESTS + // + + private val ExecutionCount = 3 + + private val MaxMessageTimeout = 1.second + private val MaxNoMessageTimeout = 200.milliseconds + + private val GeneratedTaskId = UUID.randomUUID().toString + + private val skeletonBuilder = KMBuilder() + .withIds(Nil).withSignature("").withContentString("") + .withParentHeader(Header("", "", "", "", "5.0")) + + /** + * This stubs out the methods of the scheduled task manager and provides a + * form of verification, which is not (easily) doable with Mockito due to the + * call-by-name argument in addTask. + */ + private class MockScheduledTaskManager extends ScheduledTaskManager { + private var addTaskCalled = false + private var removeTaskCalled = false + private var stopCalled = false + + def verifyAddTaskCalled(): Unit = addTaskCalled should be (true) + def verifyRemoveTaskCalled(): Unit = removeTaskCalled should be (true) + def verifyStopCalled(): Unit = stopCalled should be (true) + def verifyAddTaskNotCalled(): Unit = addTaskCalled should be (false) + def verifyRemoveTaskNotCalled(): Unit = removeTaskCalled should be (false) + def verifyStopNotCalled(): Unit = stopCalled should be (false) + def resetVerify(): Unit = { + addTaskCalled = false + removeTaskCalled = false + stopCalled = false + } + + override def addTask[T](executionDelay: Long, timeInterval: Long, task: => T): String = + { addTaskCalled = true; GeneratedTaskId } + + override def removeTask(taskId: String): Boolean = + { removeTaskCalled = true; true } + + override def stop(): Unit = stopCalled = true + + def teardown(): Unit = super.stop() + } + + before { + // Create a mock ActorLoader for the KernelOutputStream we are testing + mockActorLoader = mock[ActorLoader] + + mockScheduledTaskManager = new MockScheduledTaskManager + + // Create a probe for the relay and mock the ActorLoader to return the + // associated ActorSelection + kernelOutputRelayProbe = TestProbe() + val kernelOutputRelaySelection: ActorSelection = + system.actorSelection(kernelOutputRelayProbe.ref.path.toString) + doReturn(kernelOutputRelaySelection) + .when(mockActorLoader).load(SystemActorType.KernelMessageRelay) + } + + after { + mockScheduledTaskManager.teardown() + } + + describe("KernelOutputStream") { + describe("#write(Int)") { + it("should add a new byte to the internal list") { + Given("a kernel output stream with a skeleton kernel builder") + val kernelOutputStream = new KernelOutputStream( + mockActorLoader, skeletonBuilder, mockScheduledTaskManager + ) + + When("a byte is written to the stream") + val expected = 'a' + kernelOutputStream.write(expected) + + Then("it should be appended to the internal list") + kernelOutputStream.flush() + val message = kernelOutputRelayProbe + .receiveOne(MaxMessageTimeout).asInstanceOf[KernelMessage] + val executeResult = Json.parse(message.contentString).as[StreamContent] + executeResult.text should be (expected.toString) + } + + it("should enable periodic flushing") { + Given("a kernel output stream with a skeleton kernel builder") + val kernelOutputStream = new KernelOutputStream( + mockActorLoader, skeletonBuilder, mockScheduledTaskManager + ) + + When("a byte is written to the stream") + val expected = 'a' + kernelOutputStream.write(expected) + + Then("it should add a task to periodically flush") + mockScheduledTaskManager.verifyAddTaskCalled() + } + + it("should not enable periodic flushing if already enabled") { + Given("a kernel output stream with a skeleton kernel builder") + val kernelOutputStream = new KernelOutputStream( + mockActorLoader, skeletonBuilder, mockScheduledTaskManager + ) + + And("periodic flushing is already enabled") + kernelOutputStream.write('a') + mockScheduledTaskManager.verifyAddTaskCalled() + mockScheduledTaskManager.resetVerify() + + When("a byte is written to the stream") + kernelOutputStream.write('b') + + Then("it should not add a task to periodically flush") + mockScheduledTaskManager.verifyAddTaskNotCalled() + } + } + describe("#flush") { + it("should disable periodic flushing") { + Given("a kernel output stream with a skeleton kernel builder") + val kernelOutputStream = new KernelOutputStream( + mockActorLoader, skeletonBuilder, mockScheduledTaskManager + ) + + When("a byte is written to the stream") + val expected = 'a' + kernelOutputStream.write(expected) + + And("flush is invoked") + kernelOutputStream.flush() + + Then("it should remove the task to periodically flush") + mockScheduledTaskManager.verifyRemoveTaskCalled() + } + + it("should not disable periodic flushing if not enabled") { + Given("a kernel output stream with a skeleton kernel builder") + val kernelOutputStream = new KernelOutputStream( + mockActorLoader, skeletonBuilder, mockScheduledTaskManager + ) + + When("flush is invoked") + kernelOutputStream.flush() + + Then("it should not remove the task to periodically flush") + mockScheduledTaskManager.verifyRemoveTaskNotCalled() + } + + it("should not send empty (whitespace) messages if flag is false") { + Given("a kernel output stream with send empty output set to false") + val kernelOutputStream = new KernelOutputStream( + mockActorLoader, skeletonBuilder, mockScheduledTaskManager, + sendEmptyOutput = false + ) + + When("whitespace is created and flushed") + val expected = "\r \r \n \t" + kernelOutputStream.write(expected.getBytes) + kernelOutputStream.flush() + + Then("no message should be sent") + kernelOutputRelayProbe.expectNoMsg(MaxNoMessageTimeout) + } + + it("should send empty (whitespace) messages if flag is true") { + Given("a kernel output stream with send empty output set to false") + val kernelOutputStream = new KernelOutputStream( + mockActorLoader, skeletonBuilder, mockScheduledTaskManager, + sendEmptyOutput = true + ) + + When("whitespace is created and flushed") + val expected = "\r \r \n \t" + kernelOutputStream.write(expected.getBytes) + kernelOutputStream.flush() + + Then("the whitespace message should have been sent") + val message = kernelOutputRelayProbe + .receiveOne(MaxMessageTimeout).asInstanceOf[KernelMessage] + val actual = Json.parse(message.contentString).as[StreamContent].text + + actual should be (expected) + } + + it("should set the ids of the kernel message") { + Given("a kernel output stream with a skeleton kernel builder") + val kernelOutputStream = new KernelOutputStream( + mockActorLoader, skeletonBuilder, mockScheduledTaskManager + ) + + When("a string is written as the result and flushed") + val expected = "some string" + kernelOutputStream.write(expected.getBytes) + kernelOutputStream.flush() + + Then("the ids should be set to execute_result") + val message = kernelOutputRelayProbe + .receiveOne(MaxMessageTimeout).asInstanceOf[KernelMessage] + message.ids should be (Seq(MessageType.Outgoing.Stream.toString)) + } + + it("should set the message type in the header of the kernel message to an execute_result") { + Given("a kernel output stream with a skeleton kernel builder") + val kernelOutputStream = new KernelOutputStream( + mockActorLoader, skeletonBuilder, mockScheduledTaskManager + ) + + When("a string is written as the result and flushed") + val expected = "some string" + kernelOutputStream.write(expected.getBytes) + kernelOutputStream.flush() + + Then("the msg_type in the header should be execute_result") + val message = kernelOutputRelayProbe + .receiveOne(MaxMessageTimeout).asInstanceOf[KernelMessage] + message.header.msg_type should be (MessageType.Outgoing.Stream.toString) + } + + it("should set the content string of the kernel message") { + Given("a kernel output stream with a skeleton kernel builder") + val kernelOutputStream = new KernelOutputStream( + mockActorLoader, skeletonBuilder, mockScheduledTaskManager + ) + + When("a string is written as the result and flushed") + val expected = "some string" + kernelOutputStream.write(expected.getBytes) + kernelOutputStream.flush() + + Then("the content string should have text/plain set to the string") + val message = kernelOutputRelayProbe + .receiveOne(MaxMessageTimeout).asInstanceOf[KernelMessage] + val executeResult = Json.parse(message.contentString).as[StreamContent] + executeResult.text should be (expected) + } + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-toree/blob/68f7ddd6/kernel/src/test/scala/org/apache/toree/magic/builtin/AddDepsSpec.scala ---------------------------------------------------------------------- diff --git a/kernel/src/test/scala/org/apache/toree/magic/builtin/AddDepsSpec.scala b/kernel/src/test/scala/org/apache/toree/magic/builtin/AddDepsSpec.scala new file mode 100644 index 0000000..956088c --- /dev/null +++ b/kernel/src/test/scala/org/apache/toree/magic/builtin/AddDepsSpec.scala @@ -0,0 +1,184 @@ +/* + * Copyright 2014 IBM Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.ibm.spark.magic.builtin + +import java.io.{ByteArrayOutputStream, OutputStream} +import java.net.URL + +import com.ibm.spark.dependencies.DependencyDownloader +import com.ibm.spark.interpreter.Interpreter +import com.ibm.spark.utils.ArgumentParsingSupport +import org.apache.spark.SparkContext +import org.scalatest.mock.MockitoSugar +import org.scalatest.{GivenWhenThen, Matchers, FunSpec} +import org.mockito.Mockito._ +import org.mockito.Matchers._ + +import com.ibm.spark.magic._ +import com.ibm.spark.magic.dependencies._ + +class AddDepsSpec extends FunSpec with Matchers with MockitoSugar + with GivenWhenThen +{ + describe("AddDeps"){ + describe("#execute") { + it("should print out the help message if the input is invalid") { + val byteArrayOutputStream = new ByteArrayOutputStream() + val mockIntp = mock[Interpreter] + val mockSC = mock[SparkContext] + val mockDownloader = mock[DependencyDownloader] + var printHelpWasRun = false + + val addDepsMagic = new AddDeps + with IncludeSparkContext + with IncludeInterpreter + with IncludeOutputStream + with IncludeDependencyDownloader + with ArgumentParsingSupport + { + override val sparkContext: SparkContext = mockSC + override val interpreter: Interpreter = mockIntp + override val dependencyDownloader: DependencyDownloader = + mockDownloader + override val outputStream: OutputStream = byteArrayOutputStream + + override def printHelp( + outputStream: OutputStream, usage: String + ): Unit = printHelpWasRun = true + } + + val expected = LineMagicOutput + val actual = addDepsMagic.execute("notvalid") + + printHelpWasRun should be (true) + verify(mockIntp, times(0)).addJars(any()) + verify(mockIntp, times(0)).bind(any(), any(), any(), any()) + verify(mockSC, times(0)).addJar(any()) + verify(mockDownloader, times(0)).retrieve( + anyString(), anyString(), anyString(), anyBoolean(), anyBoolean()) + actual should be (expected) + } + + it("should set the retrievals transitive to true if provided") { + val mockDependencyDownloader = mock[DependencyDownloader] + doReturn(Nil).when(mockDependencyDownloader).retrieve( + anyString(), anyString(), anyString(), anyBoolean(), anyBoolean()) + + val addDepsMagic = new AddDeps + with IncludeSparkContext + with IncludeInterpreter + with IncludeOutputStream + with IncludeDependencyDownloader + with ArgumentParsingSupport + { + override val sparkContext: SparkContext = mock[SparkContext] + override val interpreter: Interpreter = mock[Interpreter] + override val dependencyDownloader: DependencyDownloader = + mockDependencyDownloader + override val outputStream: OutputStream = mock[OutputStream] + } + + val expected = "com.ibm.spark" :: "kernel" :: "1.0" :: "--transitive" :: Nil + addDepsMagic.execute(expected.mkString(" ")) + + verify(mockDependencyDownloader).retrieve( + expected(0), expected(1), expected(2), true) + } + + it("should set the retrieval's transitive to false if not provided") { + val mockDependencyDownloader = mock[DependencyDownloader] + doReturn(Nil).when(mockDependencyDownloader).retrieve( + anyString(), anyString(), anyString(), anyBoolean(), anyBoolean()) + + val addDepsMagic = new AddDeps + with IncludeSparkContext + with IncludeInterpreter + with IncludeOutputStream + with IncludeDependencyDownloader + with ArgumentParsingSupport + { + override val sparkContext: SparkContext = mock[SparkContext] + override val interpreter: Interpreter = mock[Interpreter] + override val dependencyDownloader: DependencyDownloader = + mockDependencyDownloader + override val outputStream: OutputStream = mock[OutputStream] + } + + val expected = "com.ibm.spark" :: "kernel" :: "1.0" :: Nil + addDepsMagic.execute(expected.mkString(" ")) + + verify(mockDependencyDownloader).retrieve( + expected(0), expected(1), expected(2), false) + } + + it("should add retrieved artifacts to the interpreter") { + val mockDependencyDownloader = mock[DependencyDownloader] + doReturn(Nil).when(mockDependencyDownloader).retrieve( + anyString(), anyString(), anyString(), anyBoolean(), anyBoolean()) + val mockInterpreter = mock[Interpreter] + + val addDepsMagic = new AddDeps + with IncludeSparkContext + with IncludeInterpreter + with IncludeOutputStream + with IncludeDependencyDownloader + with ArgumentParsingSupport + { + override val sparkContext: SparkContext = mock[SparkContext] + override val interpreter: Interpreter = mockInterpreter + override val dependencyDownloader: DependencyDownloader = + mockDependencyDownloader + override val outputStream: OutputStream = mock[OutputStream] + } + + val expected = "com.ibm.spark" :: "kernel" :: "1.0" :: Nil + addDepsMagic.execute(expected.mkString(" ")) + + verify(mockInterpreter).addJars(any[URL]) + } + + it("should add retrieved artifacts to the spark context") { + val mockDependencyDownloader = mock[DependencyDownloader] + val fakeUrl = new URL("file:/foo") + doReturn(fakeUrl :: fakeUrl :: fakeUrl :: Nil) + .when(mockDependencyDownloader).retrieve( + anyString(), anyString(), anyString(), anyBoolean(), anyBoolean() + ) + val mockSparkContext = mock[SparkContext] + + val addDepsMagic = new AddDeps + with IncludeSparkContext + with IncludeInterpreter + with IncludeOutputStream + with IncludeDependencyDownloader + with ArgumentParsingSupport + { + override val sparkContext: SparkContext = mockSparkContext + override val interpreter: Interpreter = mock[Interpreter] + override val dependencyDownloader: DependencyDownloader = + mockDependencyDownloader + override val outputStream: OutputStream = mock[OutputStream] + } + + val expected = "com.ibm.spark" :: "kernel" :: "1.0" :: Nil + addDepsMagic.execute(expected.mkString(" ")) + + verify(mockSparkContext, times(3)).addJar(anyString()) + } + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-toree/blob/68f7ddd6/kernel/src/test/scala/org/apache/toree/magic/builtin/AddJarSpec.scala ---------------------------------------------------------------------- diff --git a/kernel/src/test/scala/org/apache/toree/magic/builtin/AddJarSpec.scala b/kernel/src/test/scala/org/apache/toree/magic/builtin/AddJarSpec.scala new file mode 100644 index 0000000..5612c8a --- /dev/null +++ b/kernel/src/test/scala/org/apache/toree/magic/builtin/AddJarSpec.scala @@ -0,0 +1,220 @@ +/* + * Copyright 2014 IBM Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.ibm.spark.magic.builtin + +import java.io.OutputStream +import java.net.URL +import java.nio.file.{FileSystems, Files} + +import com.ibm.spark.interpreter.Interpreter +import com.ibm.spark.magic.dependencies.{IncludeConfig, IncludeOutputStream, IncludeInterpreter, IncludeSparkContext} +import com.typesafe.config.ConfigFactory +import org.apache.spark.SparkContext +import org.scalatest.{Matchers, FunSpec} +import org.scalatest.mock.MockitoSugar + +import org.mockito.Mockito._ +import org.mockito.Matchers._ +import com.ibm.spark.magic.MagicLoader + +class AddJarSpec extends FunSpec with Matchers with MockitoSugar { + describe("AddJar"){ + describe("#execute") { + it("should call addJar on the provided SparkContext and addJars on the " + + "provided interpreter") { + val mockSparkContext = mock[SparkContext] + val mockInterpreter = mock[Interpreter] + val mockOutputStream = mock[OutputStream] + val mockMagicLoader = mock[MagicLoader] + val testConfig = ConfigFactory.load() + + val addJarMagic = new AddJar + with IncludeSparkContext + with IncludeInterpreter + with IncludeOutputStream + with IncludeConfig + { + override val sparkContext: SparkContext = mockSparkContext + override val interpreter: Interpreter = mockInterpreter + override val outputStream: OutputStream = mockOutputStream + override lazy val magicLoader: MagicLoader = mockMagicLoader + override val config = testConfig + override def downloadFile(fileUrl: URL, destinationUrl: URL): URL = + new URL("file://someFile") // Cannot mock URL + } + + addJarMagic.execute("""http://www.example.com/someJar.jar""") + + verify(mockSparkContext).addJar(anyString()) + verify(mockInterpreter).addJars(any[URL]) + verify(mockMagicLoader, times(0)).addJar(any()) + } + + it("should raise exception if jar file does not end in .jar or .zip") { + val mockOutputStream = mock[OutputStream] + + val addJarMagic = new AddJar + with IncludeOutputStream + { + override val outputStream: OutputStream = mockOutputStream + } + + intercept[IllegalArgumentException] { + addJarMagic.execute("""http://www.example.com/""") + } + intercept[IllegalArgumentException] { + addJarMagic.execute("""http://www.example.com/not_a_jar""") + } + } + + it("should extract jar file name from jar URL") { + val mockOutputStream = mock[OutputStream] + + val addJarMagic = new AddJar + with IncludeOutputStream + { + override val outputStream: OutputStream = mockOutputStream + } + + var url = """http://www.example.com/someJar.jar""" + var jarName = addJarMagic.getFileFromLocation(url) + assert(jarName == "someJar.jar") + + url = """http://www.example.com/remotecontent?filepath=/path/to/someJar.jar""" + jarName = addJarMagic.getFileFromLocation(url) + assert(jarName == "someJar.jar") + + url = """http://www.example.com/""" + jarName = addJarMagic.getFileFromLocation(url) + assert(jarName == "") + } + + it("should use a cached jar if the force option is not provided") { + val mockSparkContext = mock[SparkContext] + val mockInterpreter = mock[Interpreter] + val mockOutputStream = mock[OutputStream] + var downloadFileCalled = false // Used to verify that downloadFile + // was or was not called in this test + val testConfig = ConfigFactory.load() + + val addJarMagic = new AddJar + with IncludeSparkContext + with IncludeInterpreter + with IncludeOutputStream + with IncludeConfig + { + override val sparkContext: SparkContext = mockSparkContext + override val interpreter: Interpreter = mockInterpreter + override val outputStream: OutputStream = mockOutputStream + override val config = testConfig + override def downloadFile(fileUrl: URL, destinationUrl: URL): URL = { + downloadFileCalled = true + new URL("file://someFile") // Cannot mock URL + } + } + + // Create a temporary file representing our jar to fake the cache + val tmpFilePath = Files.createTempFile( + FileSystems.getDefault.getPath(AddJar.getJarDir(testConfig)), + "someJar", + ".jar" + ) + + addJarMagic.execute( + """http://www.example.com/""" + tmpFilePath.getFileName) + + tmpFilePath.toFile.delete() + + downloadFileCalled should be (false) + verify(mockSparkContext).addJar(anyString()) + verify(mockInterpreter).addJars(any[URL]) + } + + it("should not use a cached jar if the force option is provided") { + val mockSparkContext = mock[SparkContext] + val mockInterpreter = mock[Interpreter] + val mockOutputStream = mock[OutputStream] + var downloadFileCalled = false // Used to verify that downloadFile + // was or was not called in this test + val testConfig = ConfigFactory.load() + + val addJarMagic = new AddJar + with IncludeSparkContext + with IncludeInterpreter + with IncludeOutputStream + with IncludeConfig + { + override val sparkContext: SparkContext = mockSparkContext + override val interpreter: Interpreter = mockInterpreter + override val outputStream: OutputStream = mockOutputStream + override val config = testConfig + override def downloadFile(fileUrl: URL, destinationUrl: URL): URL = { + downloadFileCalled = true + new URL("file://someFile") // Cannot mock URL + } + } + + // Create a temporary file representing our jar to fake the cache + val tmpFilePath = Files.createTempFile( + FileSystems.getDefault.getPath(AddJar.getJarDir(testConfig)), + "someJar", + ".jar" + ) + + addJarMagic.execute( + """-f http://www.example.com/""" + tmpFilePath.getFileName) + + tmpFilePath.toFile.delete() + + downloadFileCalled should be (true) + verify(mockSparkContext).addJar(anyString()) + verify(mockInterpreter).addJars(any[URL]) + } + + it("should add magic jar to magicloader and not to interpreter and spark"+ + "context") { + val mockSparkContext = mock[SparkContext] + val mockInterpreter = mock[Interpreter] + val mockOutputStream = mock[OutputStream] + val mockMagicLoader = mock[MagicLoader] + val testConfig = ConfigFactory.load() + + val addJarMagic = new AddJar + with IncludeSparkContext + with IncludeInterpreter + with IncludeOutputStream + with IncludeConfig + { + override val sparkContext: SparkContext = mockSparkContext + override val interpreter: Interpreter = mockInterpreter + override val outputStream: OutputStream = mockOutputStream + override lazy val magicLoader: MagicLoader = mockMagicLoader + override val config = testConfig + override def downloadFile(fileUrl: URL, destinationUrl: URL): URL = + new URL("file://someFile") // Cannot mock URL + } + + addJarMagic.execute( + """--magic http://www.example.com/someJar.jar""") + + verify(mockMagicLoader).addJar(any()) + verify(mockSparkContext, times(0)).addJar(anyString()) + verify(mockInterpreter, times(0)).addJars(any[URL]) + } + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-toree/blob/68f7ddd6/kernel/src/test/scala/org/apache/toree/magic/builtin/BuiltinLoaderSpec.scala ---------------------------------------------------------------------- diff --git a/kernel/src/test/scala/org/apache/toree/magic/builtin/BuiltinLoaderSpec.scala b/kernel/src/test/scala/org/apache/toree/magic/builtin/BuiltinLoaderSpec.scala new file mode 100644 index 0000000..d2abde7 --- /dev/null +++ b/kernel/src/test/scala/org/apache/toree/magic/builtin/BuiltinLoaderSpec.scala @@ -0,0 +1,40 @@ +/* + * Copyright 2014 IBM Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.ibm.spark.magic.builtin + +import org.scalatest.mock.MockitoSugar +import org.scalatest.{Matchers, FunSpec} + +class BuiltinLoaderSpec extends FunSpec with Matchers with MockitoSugar { + describe("BuiltinLoader") { + describe("#getClasses") { + it("should return classes in a package") { + val pkg = this.getClass.getPackage.getName + val classes = new BuiltinLoader().getClasses(pkg) + classes.size shouldNot be(0) + } + } + + describe("#loadClasses") { + it("should return class objects for classes in a package") { + val pkg = this.getClass.getPackage.getName + val classes = new BuiltinLoader().loadClasses(pkg).toList + classes.contains(this.getClass) should be (true) + } + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-toree/blob/68f7ddd6/kernel/src/test/scala/org/apache/toree/magic/builtin/HtmlSpec.scala ---------------------------------------------------------------------- diff --git a/kernel/src/test/scala/org/apache/toree/magic/builtin/HtmlSpec.scala b/kernel/src/test/scala/org/apache/toree/magic/builtin/HtmlSpec.scala new file mode 100644 index 0000000..541cfcd --- /dev/null +++ b/kernel/src/test/scala/org/apache/toree/magic/builtin/HtmlSpec.scala @@ -0,0 +1,37 @@ +/* + * Copyright 2014 IBM Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.ibm.spark.magic.builtin + +import com.ibm.spark.kernel.protocol.v5.MIMEType +import com.ibm.spark.magic.CellMagicOutput +import org.scalatest.mock.MockitoSugar +import org.scalatest.{FunSpec, Matchers} + +class HtmlSpec extends FunSpec with Matchers with MockitoSugar { + describe("Html"){ + describe("#execute") { + it("should return the entire cell's contents with the MIME type of " + + "text/html") { + val htmlMagic = new Html + + val code = "some code on a line\nanother line" + val expected = CellMagicOutput(MIMEType.TextHtml -> code) + htmlMagic.execute(code) should be (expected) + } + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-toree/blob/68f7ddd6/kernel/src/test/scala/org/apache/toree/magic/builtin/JavaScriptSpec.scala ---------------------------------------------------------------------- diff --git a/kernel/src/test/scala/org/apache/toree/magic/builtin/JavaScriptSpec.scala b/kernel/src/test/scala/org/apache/toree/magic/builtin/JavaScriptSpec.scala new file mode 100644 index 0000000..33dfbd5 --- /dev/null +++ b/kernel/src/test/scala/org/apache/toree/magic/builtin/JavaScriptSpec.scala @@ -0,0 +1,36 @@ +/* + * Copyright 2014 IBM Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.ibm.spark.magic.builtin + +import org.scalatest.mock.MockitoSugar +import org.scalatest.{FunSpec, Matchers} +import com.ibm.spark.magic.CellMagicOutput +import com.ibm.spark.kernel.protocol.v5.MIMEType + +class JavaScriptSpec extends FunSpec with Matchers with MockitoSugar { + describe("JavaScript"){ + describe("#execute") { + it("should return the entire cell's contents with the MIME type of text/javascript") { + val javaScriptMagic = new JavaScript + + val code = "some code on a line\nmore code on another line" + val expected = CellMagicOutput(MIMEType.ApplicationJavaScript -> code) + javaScriptMagic.execute(code) should be (expected) + } + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-toree/blob/68f7ddd6/kernel/src/test/scala/org/apache/toree/magic/builtin/LSMagicSpec.scala ---------------------------------------------------------------------- diff --git a/kernel/src/test/scala/org/apache/toree/magic/builtin/LSMagicSpec.scala b/kernel/src/test/scala/org/apache/toree/magic/builtin/LSMagicSpec.scala new file mode 100644 index 0000000..60b4a6a --- /dev/null +++ b/kernel/src/test/scala/org/apache/toree/magic/builtin/LSMagicSpec.scala @@ -0,0 +1,65 @@ +package com.ibm.spark.magic.builtin + +import java.io.OutputStream +import java.net.URL + +import com.ibm.spark.interpreter.Interpreter +import com.ibm.spark.magic.dependencies.{IncludeOutputStream, IncludeInterpreter, IncludeSparkContext} +import com.ibm.spark.magic.{CellMagic, LineMagic} +import org.apache.spark.SparkContext +import org.scalatest.{Matchers, FunSpec} +import org.scalatest.mock.MockitoSugar + +import org.mockito.Mockito._ +import org.mockito.Matchers._ + +class TestLSMagic(sc: SparkContext, intp: Interpreter, os: OutputStream) + extends LSMagic + with IncludeSparkContext + with IncludeInterpreter + with IncludeOutputStream + { + override val sparkContext: SparkContext = sc + override val interpreter: Interpreter = intp + override val outputStream: OutputStream = os + } + +class LSMagicSpec extends FunSpec with Matchers with MockitoSugar { + describe("LSMagic") { + + describe("#execute") { + it("should call println with a magics message") { + val lsm = spy(new TestLSMagic( + mock[SparkContext], mock[Interpreter], mock[OutputStream]) + ) + val classList = new BuiltinLoader().loadClasses() + lsm.execute("") + verify(lsm).magicNames("%", classOf[LineMagic], classList) + verify(lsm).magicNames("%%", classOf[CellMagic], classList) + } + } + + describe("#magicNames") { + it("should filter classnames by interface") { + val prefix = "%" + val interface = classOf[LineMagic] + val classes : List[Class[_]] = List(classOf[LSMagic], classOf[Integer]) + val lsm = new TestLSMagic( + mock[SparkContext], mock[Interpreter], mock[OutputStream]) + lsm.magicNames(prefix, interface, classes).length should be(1) + } + it("should prepend prefix to each name"){ + val prefix = "%" + val className = classOf[LSMagic].getSimpleName + val interface = classOf[LineMagic] + val expected = s"${prefix}${className}" + val classes : List[Class[_]] = List(classOf[LSMagic], classOf[Integer]) + val lsm = new TestLSMagic( + mock[SparkContext], mock[Interpreter], mock[OutputStream]) + lsm.magicNames(prefix, interface, classes) should be(List(expected)) + } + } + + } + +} http://git-wip-us.apache.org/repos/asf/incubator-toree/blob/68f7ddd6/kernel/src/test/scala/org/apache/toree/magic/builtin/RDDSpec.scala ---------------------------------------------------------------------- diff --git a/kernel/src/test/scala/org/apache/toree/magic/builtin/RDDSpec.scala b/kernel/src/test/scala/org/apache/toree/magic/builtin/RDDSpec.scala new file mode 100644 index 0000000..56b4cb7 --- /dev/null +++ b/kernel/src/test/scala/org/apache/toree/magic/builtin/RDDSpec.scala @@ -0,0 +1,117 @@ +/* + * Copyright 2014 IBM Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.ibm.spark.magic.builtin + +import com.ibm.spark.interpreter.Results.Result +import com.ibm.spark.interpreter.{Results, ExecuteAborted, ExecuteError, Interpreter} +import com.ibm.spark.kernel.protocol.v5.MIMEType +import com.ibm.spark.magic.dependencies.{IncludeKernelInterpreter, IncludeInterpreter} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.types.StructType +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.scalatest.mock.MockitoSugar +import org.scalatest.{BeforeAndAfter, FunSpec, Matchers} +import play.api.libs.json.Json + +class RDDSpec extends FunSpec with Matchers with MockitoSugar with BeforeAndAfter { + + val resOutput = "res1: org.apache.spark.sql.SchemaRDD =" + + val mockInterpreter = mock[Interpreter] + val mockDataFrame = mock[DataFrame] + val mockRdd = mock[org.apache.spark.rdd.RDD[Any]] + val mockStruct = mock[StructType] + val columns = Seq("foo", "bar").toArray + val rows = Array( Array("a", "b"), Array("c", "d") ) + + doReturn(mockStruct).when(mockDataFrame).schema + doReturn(columns).when(mockStruct).fieldNames + doReturn(mockRdd).when(mockDataFrame).map(any())(any()) + doReturn(rows).when(mockRdd).take(anyInt()) + + val rddMagic = new RDD with IncludeKernelInterpreter { + override val kernelInterpreter: Interpreter = mockInterpreter + } + + before { + doReturn(Some("someRDD")).when(mockInterpreter).lastExecutionVariableName + doReturn(Some(mockDataFrame)).when(mockInterpreter).read(anyString()) + doReturn((Results.Success, Left(resOutput))) + .when(mockInterpreter).interpret(anyString(), anyBoolean()) + } + + describe("RDD") { + describe("#execute") { + it("should return valid JSON when the executed code evaluates to a " + + "SchemaRDD") { + val magicOutput = rddMagic.execute("schemaRDD") + magicOutput.contains(MIMEType.ApplicationJson) should be (true) + Json.parse(magicOutput(MIMEType.ApplicationJson)) + } + + it("should return normally when the executed code does not evaluate to " + + "a SchemaRDD") { + doReturn((mock[Result], Left("foo"))).when(mockInterpreter) + .interpret(anyString(), anyBoolean()) + val magicOutput = rddMagic.execute("") + magicOutput.contains(MIMEType.PlainText) should be (true) + } + + it("should return error message when the interpreter does not return " + + "SchemaRDD as expected") { + doReturn(Some("foo")).when(mockInterpreter).read(anyString()) + val magicOutput = rddMagic.execute("") + magicOutput.contains(MIMEType.PlainText) should be (true) + } + + it("should throw a Throwable if the interpreter returns an ExecuteError"){ + val expected = "some error message" + val mockExecuteError = mock[ExecuteError] + doReturn(expected).when(mockExecuteError).value + + doReturn((mock[Result], Right(mockExecuteError))).when(mockInterpreter) + .interpret(anyString(), anyBoolean()) + val actual = { + val exception = intercept[Throwable] { + rddMagic.execute("") + } + exception.getLocalizedMessage + } + + actual should be (expected) + } + + it("should throw a Throwable if the interpreter returns an " + + "ExecuteAborted") { + val expected = "RDD magic aborted!" + val mockExecuteAborted = mock[ExecuteAborted] + + doReturn((mock[Result], Right(mockExecuteAborted))) + .when(mockInterpreter).interpret(anyString(), anyBoolean()) + val actual = { + val exception = intercept[Throwable] { + rddMagic.execute("") + } + exception.getLocalizedMessage + } + + actual should be (expected) + } + } + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-toree/blob/68f7ddd6/kernel/src/test/scala/org/apache/toree/utils/json/RddToJsonSpec.scala ---------------------------------------------------------------------- diff --git a/kernel/src/test/scala/org/apache/toree/utils/json/RddToJsonSpec.scala b/kernel/src/test/scala/org/apache/toree/utils/json/RddToJsonSpec.scala new file mode 100644 index 0000000..121d35e --- /dev/null +++ b/kernel/src/test/scala/org/apache/toree/utils/json/RddToJsonSpec.scala @@ -0,0 +1,55 @@ +/* + * Copyright 2014 IBM Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.ibm.spark.utils.json + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.types.StructType +import org.scalatest.mock.MockitoSugar +import org.scalatest.{Matchers, FunSpec} +import org.mockito.Mockito._ +import org.mockito.Matchers._ +import play.api.libs.json.{JsArray, JsString, Json} + +class RddToJsonSpec extends FunSpec with MockitoSugar with Matchers { + + val mockDataFrame = mock[DataFrame] + val mockRdd = mock[RDD[Any]] + val mockStruct = mock[StructType] + val columns = Seq("foo", "bar").toArray + val rows = Array( Array("a", "b"), Array("c", "d") ) + + doReturn(mockStruct).when(mockDataFrame).schema + doReturn(columns).when(mockStruct).fieldNames + doReturn(mockRdd).when(mockDataFrame).map(any())(any()) + doReturn(rows).when(mockRdd).take(anyInt()) + + describe("RddToJson") { + describe("#convert(SchemaRDD)") { + it("should convert to valid JSON object") { + + val json = RddToJson.convert(mockDataFrame) + val jsValue = Json.parse(json) + + jsValue \ "columns" should be (JsArray(Seq(JsString("foo"), JsString("bar")))) + jsValue \ "rows" should be (JsArray(Seq( + JsArray(Seq(JsString("a"), JsString("b"))), + JsArray(Seq(JsString("c"), JsString("d")))))) + } + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-toree/blob/68f7ddd6/macros/src/main/scala/com/ibm/spark/annotations/Experimental.scala ---------------------------------------------------------------------- diff --git a/macros/src/main/scala/com/ibm/spark/annotations/Experimental.scala b/macros/src/main/scala/com/ibm/spark/annotations/Experimental.scala deleted file mode 100644 index 241cdc6..0000000 --- a/macros/src/main/scala/com/ibm/spark/annotations/Experimental.scala +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Copyright 2014 IBM Corp. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.ibm.spark.annotations - -import scala.language.experimental.macros -import scala.annotation.{StaticAnnotation, Annotation} - -/** - * Marks as experimental, indicating that the API is subject to change between - * minor versions. - */ -class Experimental extends Annotation with StaticAnnotation http://git-wip-us.apache.org/repos/asf/incubator-toree/blob/68f7ddd6/macros/src/main/scala/org/apache/toree/annotations/Experimental.scala ---------------------------------------------------------------------- diff --git a/macros/src/main/scala/org/apache/toree/annotations/Experimental.scala b/macros/src/main/scala/org/apache/toree/annotations/Experimental.scala new file mode 100644 index 0000000..241cdc6 --- /dev/null +++ b/macros/src/main/scala/org/apache/toree/annotations/Experimental.scala @@ -0,0 +1,25 @@ +/* + * Copyright 2014 IBM Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.ibm.spark.annotations + +import scala.language.experimental.macros +import scala.annotation.{StaticAnnotation, Annotation} + +/** + * Marks as experimental, indicating that the API is subject to change between + * minor versions. + */ +class Experimental extends Annotation with StaticAnnotation http://git-wip-us.apache.org/repos/asf/incubator-toree/blob/68f7ddd6/protocol/src/main/scala/com/ibm/spark/comm/CommCallbacks.scala ---------------------------------------------------------------------- diff --git a/protocol/src/main/scala/com/ibm/spark/comm/CommCallbacks.scala b/protocol/src/main/scala/com/ibm/spark/comm/CommCallbacks.scala deleted file mode 100644 index d462874..0000000 --- a/protocol/src/main/scala/com/ibm/spark/comm/CommCallbacks.scala +++ /dev/null @@ -1,168 +0,0 @@ -/* - * Copyright 2014 IBM Corp. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.ibm.spark.comm - -import com.ibm.spark.annotations.Experimental -import com.ibm.spark.kernel.protocol.v5._ - -import scala.util.Try - -@Experimental -object CommCallbacks { - type OpenCallback = (CommWriter, UUID, String, MsgData) => Unit - type MsgCallback = (CommWriter, UUID, MsgData) => Unit - type CloseCallback = (CommWriter, UUID, MsgData) => Unit -} - -import com.ibm.spark.comm.CommCallbacks._ - -/** - * Represents available callbacks to be triggered when various Comm events - * are triggered. - * - * @param openCallbacks The sequence of open callbacks - * @param msgCallbacks The sequence of msg callbacks - * @param closeCallbacks The sequence of close callbacks - */ -@Experimental -class CommCallbacks( - val openCallbacks: Seq[CommCallbacks.OpenCallback] = Nil, - val msgCallbacks: Seq[CommCallbacks.MsgCallback] = Nil, - val closeCallbacks: Seq[CommCallbacks.CloseCallback] = Nil -) { - - /** - * Adds a new open callback to be triggered. - * - * @param openCallback The open callback to add - * - * @return The updated CommCallbacks instance - */ - def addOpenCallback(openCallback: OpenCallback): CommCallbacks = - new CommCallbacks( - openCallbacks :+ openCallback, - msgCallbacks, - closeCallbacks - ) - - /** - * Adds a new msg callback to be triggered. - * - * @param msgCallback The msg callback to add - * - * @return The updated CommCallbacks instance - */ - def addMsgCallback(msgCallback: MsgCallback): CommCallbacks = - new CommCallbacks( - openCallbacks, - msgCallbacks :+ msgCallback, - closeCallbacks - ) - - /** - * Adds a new close callback to be triggered. - * - * @param closeCallback The close callback to add - * - * @return The updated CommCallbacks instance - */ - def addCloseCallback(closeCallback: CloseCallback): CommCallbacks = - new CommCallbacks( - openCallbacks, - msgCallbacks, - closeCallbacks :+ closeCallback - ) - - /** - * Removes the specified open callback from the collection of callbacks. - * - * @param openCallback The open callback to remove - * - * @return The updated CommCallbacks instance - */ - def removeOpenCallback(openCallback: OpenCallback): CommCallbacks = - new CommCallbacks( - openCallbacks.filterNot(_ == openCallback), - msgCallbacks, - closeCallbacks - ) - - /** - * Removes the specified msg callback from the collection of callbacks. - * - * @param msgCallback The msg callback to remove - * - * @return The updated CommCallbacks instance - */ - def removeMsgCallback(msgCallback: MsgCallback): CommCallbacks = - new CommCallbacks( - openCallbacks, - msgCallbacks.filterNot(_ == msgCallback), - closeCallbacks - ) - - /** - * Removes the specified close callback from the collection of callbacks. - * - * @param closeCallback The close callback to remove - * - * @return The updated CommCallbacks instance - */ - def removeCloseCallback(closeCallback: CloseCallback): CommCallbacks = - new CommCallbacks( - openCallbacks, - msgCallbacks, - closeCallbacks.filterNot(_ == closeCallback) - ) - - /** - * Executes all registered open callbacks and returns a sequence of results. - * - * @param commWriter The Comm Writer that can be used for responses - * @param commId The Comm Id to pass to all open callbacks - * @param targetName The Comm Target Name to pass to all open callbacks - * @param data The data to pass to all open callbacks - * - * @return The sequence of results from trying to execute callbacks - */ - def executeOpenCallbacks( - commWriter: CommWriter, commId: UUID, targetName: String, data: MsgData - ) = openCallbacks.map(f => Try(f(commWriter, commId, targetName, data))) - - /** - * Executes all registered msg callbacks and returns a sequence of results. - * - * @param commWriter The Comm Writer that can be used for responses - * @param commId The Comm Id to pass to all msg callbacks - * @param data The data to pass to all msg callbacks - * - * @return The sequence of results from trying to execute callbacks - */ - def executeMsgCallbacks(commWriter: CommWriter, commId: UUID, data: MsgData) = - msgCallbacks.map(f => Try(f(commWriter, commId, data))) - - /** - * Executes all registered close callbacks and returns a sequence of results. - * - * @param commWriter The Comm Writer that can be used for responses - * @param commId The Comm Id to pass to all close callbacks - * @param data The data to pass to all close callbacks - * - * @return The sequence of results from trying to execute callbacks - */ - def executeCloseCallbacks(commWriter: CommWriter, commId: UUID, data: MsgData) = - closeCallbacks.map(f => Try(f(commWriter, commId, data))) -} http://git-wip-us.apache.org/repos/asf/incubator-toree/blob/68f7ddd6/protocol/src/main/scala/com/ibm/spark/comm/CommManager.scala ---------------------------------------------------------------------- diff --git a/protocol/src/main/scala/com/ibm/spark/comm/CommManager.scala b/protocol/src/main/scala/com/ibm/spark/comm/CommManager.scala deleted file mode 100644 index c15f1a5..0000000 --- a/protocol/src/main/scala/com/ibm/spark/comm/CommManager.scala +++ /dev/null @@ -1,165 +0,0 @@ -/* - * Copyright 2014 IBM Corp. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.ibm.spark.comm - -import java.util.UUID - -import com.ibm.spark.annotations.Experimental -import com.ibm.spark.comm.CommCallbacks.{CloseCallback, OpenCallback} -import com.ibm.spark.kernel.protocol.v5 -import com.ibm.spark.kernel.protocol.v5._ -import com.ibm.spark.kernel.protocol.v5.content.CommContent - -/** - * Represents a manager for Comm connections that facilitates and maintains - * connections started and received through this service. - * - * @param commRegistrar The registrar to use for callback registration - */ -@Experimental -abstract class CommManager(private val commRegistrar: CommRegistrar) { - /** - * The base function to call that performs a link given the target name and - * the Comm id for the specific instance. - */ - private val linkFunc: OpenCallback = - (_, commId, targetName, _) => commRegistrar.link(targetName, commId) - - /** - * The base function to call that performs an unlink given the Comm id for - * the specific instance. - */ - private val unlinkFunc: CloseCallback = - (_, commId, _) => commRegistrar.unlink(commId) - - // TODO: This is potentially bad design considering appending methods to - // CommWriter will not require this class to be updated! - /** - * Represents a wrapper for a CommWriter instance that links and unlinks - * when invoked. - * - * @param commWriter The CommWriter instance to wrap - */ - private class WrapperCommWriter(private val commWriter: CommWriter) - extends CommWriter(commWriter.commId) - { - override protected[comm] def sendCommKernelMessage[ - T <: KernelMessageContent with CommContent - ](commContent: T): Unit = commWriter.sendCommKernelMessage(commContent) - - // Overridden to unlink before sending close message - override def writeClose(data: MsgData): Unit = { - unlinkFunc(this, this.commId, data) - commWriter.writeClose(data) - } - - // Overridden to unlink before sending close message - override def close(): Unit = { - unlinkFunc(this, this.commId, null) - commWriter.close() - } - - // Overriden to link before sending open message - override def writeOpen(targetName: String, data: MsgData): Unit = { - linkFunc(this, this.commId, targetName, data) - commWriter.writeOpen(targetName, data) - } - - override def writeMsg(data: MsgData): Unit = commWriter.writeMsg(data) - override def write(cbuf: Array[Char], off: Int, len: Int): Unit = - commWriter.write(cbuf, off, len) - override def flush(): Unit = commWriter.flush() - } - - /** - * Loads the specified target and provides a registrar pointing to the target. - * - * @param targetName The name of the target to load - * - * @return The CommRegistrar pointing to the target - */ - def withTarget(targetName: String): CommRegistrar = { - commRegistrar.withTarget(targetName) - } - - /** - * Registers a new Comm for use on the kernel. Establishes default callbacks - * to link and unlink specific Comm instances for the new target. - * - * @param targetName The name of the target to register - * - * @return The new CommRegistrar set to the provided target - */ - def register(targetName: String): CommRegistrar = { - commRegistrar.register(targetName) - .addOpenHandler(linkFunc) - .addCloseHandler(unlinkFunc) - } - - /** - * Unregisters the specified target used for Comm messages. - * - * @param targetName The name of the target to unregister - * - * @return Some collection of callbacks associated with the target if it was - * registered, otherwise None - */ - def unregister(targetName: String): Option[CommCallbacks] = { - commRegistrar.unregister(targetName) - } - - /** - * Indicates whether or not the specified target is currently registered with - * this Comm manager. - * - * @param targetName The name of the target - * - * @return True if the target is registered, otherwise false - */ - def isRegistered(targetName: String): Boolean = - commRegistrar.isRegistered(targetName) - - /** - * Opens a new Comm connection. Establishes a new link between the specified - * target and the generated Comm id. - * - * @param targetName The name of the target to connect - * @param data The optional data to send - * - * @return The new CommWriter representing the connection - */ - def open(targetName: String, data: v5.MsgData = v5.MsgData.Empty): CommWriter = { - val commId = UUID.randomUUID().toString - - // Create our CommWriter and wrap it to establish links and unlink on close - val commWriter = new WrapperCommWriter(newCommWriter(commId)) - - // Establish the actual connection - commWriter.writeOpen(targetName, data) - - commWriter - } - - /** - * Creates a new CommWriter instance given the Comm id. - * - * @param commId The Comm id to use with the Comm writer - * - * @return The new CommWriter instance - */ - protected def newCommWriter(commId: v5.UUID): CommWriter -}