http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/repl/src/main/scala/com/cloudera/livy/repl/package.scala ---------------------------------------------------------------------- diff --git a/repl/src/main/scala/com/cloudera/livy/repl/package.scala b/repl/src/main/scala/com/cloudera/livy/repl/package.scala deleted file mode 100644 index bf58ad4..0000000 --- a/repl/src/main/scala/com/cloudera/livy/repl/package.scala +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.cloudera.livy - -import org.json4s.JField - -package object repl { - type MimeTypeMap = List[JField] - - val APPLICATION_JSON = "application/json" - val APPLICATION_LIVY_TABLE_JSON = "application/vnd.livy.table.v1+json" - val IMAGE_PNG = "image/png" - val TEXT_PLAIN = "text/plain" -}
http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/repl/src/main/scala/org/apache/livy/repl/AbstractSparkInterpreter.scala ---------------------------------------------------------------------- diff --git a/repl/src/main/scala/org/apache/livy/repl/AbstractSparkInterpreter.scala b/repl/src/main/scala/org/apache/livy/repl/AbstractSparkInterpreter.scala new file mode 100644 index 0000000..cf2cf35 --- /dev/null +++ b/repl/src/main/scala/org/apache/livy/repl/AbstractSparkInterpreter.scala @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.repl + +import java.io.ByteArrayOutputStream + +import scala.tools.nsc.interpreter.Results + +import org.apache.spark.rdd.RDD +import org.json4s.DefaultFormats +import org.json4s.Extraction +import org.json4s.JsonAST._ +import org.json4s.JsonDSL._ + +import org.apache.livy.Logging + +object AbstractSparkInterpreter { + private[repl] val KEEP_NEWLINE_REGEX = """(?<=\n)""".r + private val MAGIC_REGEX = "^%(\\w+)\\W*(.*)".r +} + +abstract class AbstractSparkInterpreter extends Interpreter with Logging { + import AbstractSparkInterpreter._ + + private implicit def formats = DefaultFormats + + protected val outputStream = new ByteArrayOutputStream() + + final def kind: String = "spark" + + protected def isStarted(): Boolean + + protected def interpret(code: String): Results.Result + + protected def valueOfTerm(name: String): Option[Any] + + override protected[repl] def execute(code: String): Interpreter.ExecuteResponse = + restoreContextClassLoader { + require(isStarted()) + + executeLines(code.trim.split("\n").toList, Interpreter.ExecuteSuccess(JObject( + (TEXT_PLAIN, JString("")) + ))) + } + + private def executeMagic(magic: String, rest: String): Interpreter.ExecuteResponse = { + magic match { + case "json" => executeJsonMagic(rest) + case "table" => executeTableMagic(rest) + case _ => + Interpreter.ExecuteError("UnknownMagic", f"Unknown magic command $magic") + } + } + + private def executeJsonMagic(name: String): Interpreter.ExecuteResponse = { + try { + val value = valueOfTerm(name) match { + case Some(obj: RDD[_]) => obj.asInstanceOf[RDD[_]].take(10) + case Some(obj) => obj + case None => return Interpreter.ExecuteError("NameError", f"Value $name does not exist") + } + + Interpreter.ExecuteSuccess(JObject( + (APPLICATION_JSON, Extraction.decompose(value)) + )) + } catch { + case _: Throwable => + Interpreter.ExecuteError("ValueError", "Failed to convert value into a JSON value") + } + } + + private class TypesDoNotMatch extends Exception + + private def convertTableType(value: JValue): String = { + value match { + case (JNothing | JNull) => "NULL_TYPE" + case JBool(_) => "BOOLEAN_TYPE" + case JString(_) => "STRING_TYPE" + case JInt(_) => "BIGINT_TYPE" + case JDouble(_) => "DOUBLE_TYPE" + case JDecimal(_) => "DECIMAL_TYPE" + case JArray(arr) => + if (allSameType(arr.iterator)) { + "ARRAY_TYPE" + } else { + throw new TypesDoNotMatch + } + case JObject(obj) => + if (allSameType(obj.iterator.map(_._2))) { + "MAP_TYPE" + } else { + throw new TypesDoNotMatch + } + } + } + + private def allSameType(values: Iterator[JValue]): Boolean = { + if (values.hasNext) { + val type_name = convertTableType(values.next()) + values.forall { case value => type_name.equals(convertTableType(value)) } + } else { + true + } + } + + private def executeTableMagic(name: String): Interpreter.ExecuteResponse = { + val value = valueOfTerm(name) match { + case Some(obj: RDD[_]) => obj.asInstanceOf[RDD[_]].take(10) + case Some(obj) => obj + case None => return Interpreter.ExecuteError("NameError", f"Value $name does not exist") + } + + extractTableFromJValue(Extraction.decompose(value)) + } + + private def extractTableFromJValue(value: JValue): Interpreter.ExecuteResponse = { + // Convert the value into JSON and map it to a table. + val rows: List[JValue] = value match { + case JArray(arr) => arr + case _ => List(value) + } + + try { + val headers = scala.collection.mutable.Map[String, Map[String, String]]() + + val data = rows.map { case row => + val cols: List[JField] = row match { + case JArray(arr: List[JValue]) => + arr.zipWithIndex.map { case (v, index) => JField(index.toString, v) } + case JObject(obj) => obj.sortBy(_._1) + case value: JValue => List(JField("0", value)) + } + + cols.map { case (k, v) => + val typeName = convertTableType(v) + + headers.get(k) match { + case Some(header) => + if (header.get("type").get != typeName) { + throw new TypesDoNotMatch + } + case None => + headers.put(k, Map( + "type" -> typeName, + "name" -> k + )) + } + + v + } + } + + Interpreter.ExecuteSuccess( + APPLICATION_LIVY_TABLE_JSON -> ( + ("headers" -> headers.toSeq.sortBy(_._1).map(_._2)) ~ ("data" -> data) + )) + } catch { + case _: TypesDoNotMatch => + Interpreter.ExecuteError("TypeError", "table rows have different types") + } + } + + private def executeLines( + lines: List[String], + resultFromLastLine: Interpreter.ExecuteResponse): Interpreter.ExecuteResponse = { + lines match { + case Nil => resultFromLastLine + case head :: tail => + val result = executeLine(head) + + result match { + case Interpreter.ExecuteIncomplete() => + tail match { + case Nil => + // ExecuteIncomplete could be caused by an actual incomplete statements (e.g. "sc.") + // or statements with just comments. + // To distinguish them, reissue the same statement wrapped in { }. + // If it is an actual incomplete statement, the interpreter will return an error. + // If it is some comment, the interpreter will return success. + executeLine(s"{\n$head\n}") match { + case Interpreter.ExecuteIncomplete() | Interpreter.ExecuteError(_, _, _) => + // Return the original error so users won't get confusing error message. + result + case _ => resultFromLastLine + } + case next :: nextTail => + executeLines(head + "\n" + next :: nextTail, resultFromLastLine) + } + case Interpreter.ExecuteError(_, _, _) => + result + + case _ => + executeLines(tail, result) + } + } + } + + private def executeLine(code: String): Interpreter.ExecuteResponse = { + code match { + case MAGIC_REGEX(magic, rest) => + executeMagic(magic, rest) + case _ => + scala.Console.withOut(outputStream) { + interpret(code) match { + case Results.Success => + Interpreter.ExecuteSuccess( + TEXT_PLAIN -> readStdout() + ) + case Results.Incomplete => Interpreter.ExecuteIncomplete() + case Results.Error => + val (ename, traceback) = parseError(readStdout()) + Interpreter.ExecuteError("Error", ename, traceback) + } + } + } + } + + protected[repl] def parseError(stdout: String): (String, Seq[String]) = { + // An example of Scala compile error message: + // <console>:27: error: type mismatch; + // found : Int + // required: Boolean + + // An example of Scala runtime exception error message: + // java.lang.RuntimeException: message + // at .error(<console>:11) + // ... 32 elided + + // Return the first line as ename. Lines following as traceback. + + val lines = KEEP_NEWLINE_REGEX.split(stdout) + val ename = lines.headOption.map(_.trim).getOrElse("unknown error") + val traceback = lines.tail + + (ename, traceback) + } + + protected def restoreContextClassLoader[T](fn: => T): T = { + val currentClassLoader = Thread.currentThread().getContextClassLoader() + try { + fn + } finally { + Thread.currentThread().setContextClassLoader(currentClassLoader) + } + } + + private def readStdout() = { + val output = outputStream.toString("UTF-8").trim + outputStream.reset() + + output + } +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/repl/src/main/scala/org/apache/livy/repl/BypassPySparkJob.scala ---------------------------------------------------------------------- diff --git a/repl/src/main/scala/org/apache/livy/repl/BypassPySparkJob.scala b/repl/src/main/scala/org/apache/livy/repl/BypassPySparkJob.scala new file mode 100644 index 0000000..6a9bdd1 --- /dev/null +++ b/repl/src/main/scala/org/apache/livy/repl/BypassPySparkJob.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.livy.repl + +import java.nio.charset.StandardCharsets + +import org.apache.livy.{Job, JobContext} + +class BypassPySparkJob( + serializedJob: Array[Byte], + replDriver: ReplDriver) extends Job[Array[Byte]] { + + override def call(jc: JobContext): Array[Byte] = { + val interpreter = replDriver.interpreter + require(interpreter != null && interpreter.isInstanceOf[PythonInterpreter]) + val pi = interpreter.asInstanceOf[PythonInterpreter] + + val resultByteArray = pi.pysparkJobProcessor.processBypassJob(serializedJob) + val resultString = new String(resultByteArray, StandardCharsets.UTF_8) + if (resultString.startsWith("Client job error:")) { + throw new PythonJobException(resultString) + } + resultByteArray + } +} + http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/repl/src/main/scala/org/apache/livy/repl/Interpreter.scala ---------------------------------------------------------------------- diff --git a/repl/src/main/scala/org/apache/livy/repl/Interpreter.scala b/repl/src/main/scala/org/apache/livy/repl/Interpreter.scala new file mode 100644 index 0000000..058b20b --- /dev/null +++ b/repl/src/main/scala/org/apache/livy/repl/Interpreter.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.repl + +import org.apache.spark.SparkContext +import org.json4s.JObject + +object Interpreter { + abstract class ExecuteResponse + + case class ExecuteSuccess(content: JObject) extends ExecuteResponse + case class ExecuteError(ename: String, + evalue: String, + traceback: Seq[String] = Seq()) extends ExecuteResponse + case class ExecuteIncomplete() extends ExecuteResponse + case class ExecuteAborted(message: String) extends ExecuteResponse +} + +trait Interpreter { + import Interpreter._ + + def kind: String + + /** + * Start the Interpreter. + * + * @return A SparkContext + */ + def start(): SparkContext + + /** + * Execute the code and return the result, it may + * take some time to execute. + */ + protected[repl] def execute(code: String): ExecuteResponse + + /** Shut down the interpreter. */ + def close(): Unit +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/repl/src/main/scala/org/apache/livy/repl/ProcessInterpreter.scala ---------------------------------------------------------------------- diff --git a/repl/src/main/scala/org/apache/livy/repl/ProcessInterpreter.scala b/repl/src/main/scala/org/apache/livy/repl/ProcessInterpreter.scala new file mode 100644 index 0000000..7995ba0 --- /dev/null +++ b/repl/src/main/scala/org/apache/livy/repl/ProcessInterpreter.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.repl + +import java.io.{BufferedReader, InputStreamReader, IOException, PrintWriter} +import java.util.concurrent.locks.ReentrantLock + +import scala.concurrent.Promise +import scala.io.Source + +import org.apache.spark.SparkContext +import org.json4s.JValue + +import org.apache.livy.{Logging, Utils} +import org.apache.livy.client.common.ClientConf + +private sealed trait Request +private case class ExecuteRequest(code: String, promise: Promise[JValue]) extends Request +private case class ShutdownRequest(promise: Promise[Unit]) extends Request + +/** + * Abstract class that describes an interpreter that is running in a separate process. + * + * This type is not thread safe, so must be protected by a mutex. + * + * @param process + */ +abstract class ProcessInterpreter(process: Process) + extends Interpreter with Logging { + protected[this] val stdin = new PrintWriter(process.getOutputStream) + protected[this] val stdout = new BufferedReader(new InputStreamReader(process.getInputStream), 1) + + override def start(): SparkContext = { + waitUntilReady() + + if (ClientConf.TEST_MODE) { + null.asInstanceOf[SparkContext] + } else { + SparkContext.getOrCreate() + } + } + + override protected[repl] def execute(code: String): Interpreter.ExecuteResponse = { + try { + sendExecuteRequest(code) + } catch { + case e: Throwable => + Interpreter.ExecuteError(e.getClass.getName, e.getMessage) + } + } + + override def close(): Unit = { + if (Utils.isProcessAlive(process)) { + logger.info("Shutting down process") + sendShutdownRequest() + + try { + process.getInputStream.close() + process.getOutputStream.close() + } catch { + case _: IOException => + } + + try { + process.destroy() + } finally { + logger.info("process has been shut down") + } + } + } + + protected def sendExecuteRequest(request: String): Interpreter.ExecuteResponse + + protected def sendShutdownRequest(): Unit = {} + + protected def waitUntilReady(): Unit + + private[this] val stderrLock = new ReentrantLock() + private[this] var stderrLines = Seq[String]() + + protected def takeErrorLines(): String = { + stderrLock.lock() + try { + val lines = stderrLines + stderrLines = Seq() + lines.mkString("\n") + } finally { + stderrLock.unlock() + } + } + + private[this] val stderrThread = new Thread("process stderr thread") { + override def run() = { + val lines = Source.fromInputStream(process.getErrorStream).getLines() + + for (line <- lines) { + stderrLock.lock() + try { + stderrLines :+= line + } finally { + stderrLock.unlock() + } + } + } + } + + stderrThread.setDaemon(true) + stderrThread.start() + + private[this] val processWatcherThread = new Thread("process watcher thread") { + override def run() = { + val exitCode = process.waitFor() + if (exitCode != 0) { + error(f"Process has died with $exitCode") + error(stderrLines.mkString("\n")) + } + } + } + + processWatcherThread.setDaemon(true) + processWatcherThread.start() +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/repl/src/main/scala/org/apache/livy/repl/PySparkJobProcessor.scala ---------------------------------------------------------------------- diff --git a/repl/src/main/scala/org/apache/livy/repl/PySparkJobProcessor.scala b/repl/src/main/scala/org/apache/livy/repl/PySparkJobProcessor.scala new file mode 100644 index 0000000..95b91cf --- /dev/null +++ b/repl/src/main/scala/org/apache/livy/repl/PySparkJobProcessor.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.livy.repl + +trait PySparkJobProcessor { + def processBypassJob(job: Array[Byte]): Array[Byte] + + def addFile(path: String) + + def addPyFile(path: String) + + def getLocalTmpDirPath: String +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/repl/src/main/scala/org/apache/livy/repl/PythonInterpreter.scala ---------------------------------------------------------------------- diff --git a/repl/src/main/scala/org/apache/livy/repl/PythonInterpreter.scala b/repl/src/main/scala/org/apache/livy/repl/PythonInterpreter.scala new file mode 100644 index 0000000..bfd5d76 --- /dev/null +++ b/repl/src/main/scala/org/apache/livy/repl/PythonInterpreter.scala @@ -0,0 +1,293 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.repl + +import java.io._ +import java.lang.ProcessBuilder.Redirect +import java.lang.reflect.Proxy +import java.nio.file.{Files, Paths} + +import scala.annotation.tailrec +import scala.collection.mutable.ArrayBuffer +import scala.collection.JavaConverters._ +import scala.util.control.NonFatal + +import org.apache.spark.{SparkConf, SparkContext} +import org.json4s.{DefaultFormats, JValue} +import org.json4s.JsonAST.JObject +import org.json4s.jackson.JsonMethods._ +import org.json4s.jackson.Serialization.write +import py4j._ +import py4j.reflection.PythonProxyHandler + +import org.apache.livy.Logging +import org.apache.livy.client.common.ClientConf +import org.apache.livy.rsc.BaseProtocol +import org.apache.livy.rsc.driver.BypassJobWrapper +import org.apache.livy.sessions._ + +// scalastyle:off println +object PythonInterpreter extends Logging { + + def apply(conf: SparkConf, kind: Kind): Interpreter = { + val pythonExec = kind match { + case PySpark() => sys.env.getOrElse("PYSPARK_PYTHON", "python") + case PySpark3() => sys.env.getOrElse("PYSPARK3_PYTHON", "python3") + case _ => throw new IllegalArgumentException(s"Unknown kind: $kind") + } + + val gatewayServer = new GatewayServer(null, 0) + gatewayServer.start() + + val builder = new ProcessBuilder(Seq(pythonExec, createFakeShell().toString).asJava) + + val env = builder.environment() + + val pythonPath = sys.env.getOrElse("PYTHONPATH", "") + .split(File.pathSeparator) + .++(if (!ClientConf.TEST_MODE) findPySparkArchives() else Nil) + .++(if (!ClientConf.TEST_MODE) findPyFiles() else Nil) + + env.put("PYSPARK_PYTHON", pythonExec) + env.put("PYTHONPATH", pythonPath.mkString(File.pathSeparator)) + env.put("PYTHONUNBUFFERED", "YES") + env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort) + env.put("SPARK_HOME", sys.env.getOrElse("SPARK_HOME", ".")) + env.put("LIVY_SPARK_MAJOR_VERSION", conf.get("spark.livy.spark_major_version", "1")) + builder.redirectError(Redirect.PIPE) + val process = builder.start() + new PythonInterpreter(process, gatewayServer, kind.toString) + } + + private def findPySparkArchives(): Seq[String] = { + sys.env.get("PYSPARK_ARCHIVES_PATH") + .map(_.split(",").toSeq) + .getOrElse { + sys.env.get("SPARK_HOME").map { sparkHome => + val pyLibPath = Seq(sparkHome, "python", "lib").mkString(File.separator) + val pyArchivesFile = new File(pyLibPath, "pyspark.zip") + require(pyArchivesFile.exists(), + "pyspark.zip not found; cannot run pyspark application in YARN mode.") + + val py4jFile = Files.newDirectoryStream(Paths.get(pyLibPath), "py4j-*-src.zip") + .iterator() + .next() + .toFile + + require(py4jFile.exists(), + "py4j-*-src.zip not found; cannot run pyspark application in YARN mode.") + Seq(pyArchivesFile.getAbsolutePath, py4jFile.getAbsolutePath) + }.getOrElse(Seq()) + } + } + + private def findPyFiles(): Seq[String] = { + val pyFiles = sys.props.getOrElse("spark.submit.pyFiles", "").split(",") + + if (sys.env.getOrElse("SPARK_YARN_MODE", "") == "true") { + // In spark mode, these files have been localized into the current directory. + pyFiles.map { file => + val name = new File(file).getName + new File(name).getAbsolutePath + } + } else { + pyFiles + } + } + + private def createFakeShell(): File = { + val source: InputStream = getClass.getClassLoader.getResourceAsStream("fake_shell.py") + + val file = Files.createTempFile("", "").toFile + file.deleteOnExit() + + val sink = new FileOutputStream(file) + val buf = new Array[Byte](1024) + var n = source.read(buf) + + while (n > 0) { + sink.write(buf, 0, n) + n = source.read(buf) + } + + source.close() + sink.close() + + file + } + + private def initiatePy4jCallbackGateway(server: GatewayServer): PySparkJobProcessor = { + val f = server.getClass.getDeclaredField("gateway") + f.setAccessible(true) + val gateway = f.get(server).asInstanceOf[Gateway] + val command: String = "f" + Protocol.ENTRY_POINT_OBJECT_ID + ";" + + "org.apache.livy.repl.PySparkJobProcessor" + getPythonProxy(command, gateway).asInstanceOf[PySparkJobProcessor] + } + + // This method is a hack to get around the classLoader issues faced in py4j 0.8.2.1 for + // dynamically adding jars to the driver. The change is to use the context classLoader instead + // of the system classLoader when initiating a new Proxy instance + // ISSUE - https://issues.apache.org/jira/browse/SPARK-6047 + // FIX - https://github.com/bartdag/py4j/pull/196 + private def getPythonProxy(commandPart: String, gateway: Gateway): Any = { + val proxyString = commandPart.substring(1, commandPart.length) + val parts = proxyString.split(";") + val length: Int = parts.length + val interfaces = ArrayBuffer.fill[Class[_]](length - 1){ null } + if (length < 2) { + throw new Py4JException("Invalid Python Proxy.") + } + else { + var proxy: Int = 1 + while (proxy < length) { + try { + interfaces(proxy - 1) = Class.forName(parts(proxy)) + if (!interfaces(proxy - 1).isInterface) { + throw new Py4JException("This class " + parts(proxy) + + " is not an interface and cannot be used as a Python Proxy.") + } + } catch { + case exception: ClassNotFoundException => { + throw new Py4JException("Invalid interface name: " + parts(proxy)) + } + } + proxy += 1 + } + + val pythonProxyHandler = try { + classOf[PythonProxyHandler].getConstructor(classOf[String], classOf[Gateway]) + .newInstance(parts(0), gateway) + } catch { + case NonFatal(e) => + classOf[PythonProxyHandler].getConstructor(classOf[String], + Class.forName("py4j.CallbackClient"), classOf[Gateway]) + .newInstance(parts(0), gateway.getCallbackClient, gateway) + } + + Proxy.newProxyInstance(Thread.currentThread.getContextClassLoader, + interfaces.toArray, pythonProxyHandler.asInstanceOf[PythonProxyHandler]) + } + } +} + +private class PythonInterpreter( + process: Process, + gatewayServer: GatewayServer, + pyKind: String) + extends ProcessInterpreter(process) + with Logging +{ + implicit val formats = DefaultFormats + + override def kind: String = pyKind + + private[repl] val pysparkJobProcessor = + PythonInterpreter.initiatePy4jCallbackGateway(gatewayServer) + + override def close(): Unit = { + try { + super.close() + } finally { + gatewayServer.shutdown() + } + } + + @tailrec + final override protected def waitUntilReady(): Unit = { + val READY_REGEX = "READY\\(port=([0-9]+)\\)".r + stdout.readLine() match { + case null => + case READY_REGEX(port) => updatePythonGatewayPort(port.toInt) + case _ => waitUntilReady() + } + } + + override protected def sendExecuteRequest(code: String): Interpreter.ExecuteResponse = { + sendRequest(Map("msg_type" -> "execute_request", "content" -> Map("code" -> code))) match { + case Some(response) => + assert((response \ "msg_type").extract[String] == "execute_reply") + + val content = response \ "content" + + (content \ "status").extract[String] match { + case "ok" => + Interpreter.ExecuteSuccess((content \ "data").extract[JObject]) + case "error" => + val ename = (content \ "ename").extract[String] + val evalue = (content \ "evalue").extract[String] + val traceback = (content \ "traceback").extract[Seq[String]] + + Interpreter.ExecuteError(ename, evalue, traceback) + case status => + Interpreter.ExecuteError("Internal Error", f"Unknown status $status") + } + case None => + Interpreter.ExecuteAborted(takeErrorLines()) + } + } + + override protected def sendShutdownRequest(): Unit = { + sendRequest(Map( + "msg_type" -> "shutdown_request", + "content" -> () + )).foreach { case rep => + warn(f"process failed to shut down while returning $rep") + } + } + + private def sendRequest(request: Map[String, Any]): Option[JValue] = { + stdin.println(write(request)) + stdin.flush() + + Option(stdout.readLine()).map { case line => + parse(line) + } + } + + def addFile(path: String): Unit = { + pysparkJobProcessor.addFile(path) + } + + def addPyFile(driver: ReplDriver, conf: SparkConf, path: String): Unit = { + val localCopyDir = new File(pysparkJobProcessor.getLocalTmpDirPath) + val localCopyFile = driver.copyFileToLocal(localCopyDir, path, SparkContext.getOrCreate(conf)) + pysparkJobProcessor.addPyFile(localCopyFile.getPath) + if (path.endsWith(".jar")) { + driver.addLocalFileToClassLoader(localCopyFile) + } + } + + private def updatePythonGatewayPort(port: Int): Unit = { + // The python gateway port can be 0 only when LivyConf.TEST_MODE is true + // Py4j 0.10 has different API signature for "getCallbackClient", use reflection to handle it. + if (port != 0) { + val callbackClient = gatewayServer.getClass + .getMethod("getCallbackClient") + .invoke(gatewayServer) + + val field = Class.forName("py4j.CallbackClient").getDeclaredField("port") + field.setAccessible(true) + field.setInt(callbackClient, port.toInt) + } + } +} + +case class PythonJobException(message: String) extends Exception(message) {} + +// scalastyle:on println http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/repl/src/main/scala/org/apache/livy/repl/ReplDriver.scala ---------------------------------------------------------------------- diff --git a/repl/src/main/scala/org/apache/livy/repl/ReplDriver.scala b/repl/src/main/scala/org/apache/livy/repl/ReplDriver.scala new file mode 100644 index 0000000..75966be --- /dev/null +++ b/repl/src/main/scala/org/apache/livy/repl/ReplDriver.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.repl + +import scala.concurrent.Await +import scala.concurrent.duration.Duration + +import io.netty.channel.ChannelHandlerContext +import org.apache.spark.SparkConf +import org.apache.spark.api.java.JavaSparkContext + +import org.apache.livy.Logging +import org.apache.livy.rsc.{BaseProtocol, ReplJobResults, RSCConf} +import org.apache.livy.rsc.BaseProtocol.ReplState +import org.apache.livy.rsc.driver._ +import org.apache.livy.rsc.rpc.Rpc +import org.apache.livy.sessions._ + +class ReplDriver(conf: SparkConf, livyConf: RSCConf) + extends RSCDriver(conf, livyConf) + with Logging { + + private[repl] var session: Session = _ + + private val kind = Kind(livyConf.get(RSCConf.Entry.SESSION_KIND)) + + private[repl] var interpreter: Interpreter = _ + + override protected def initializeContext(): JavaSparkContext = { + interpreter = kind match { + case PySpark() => PythonInterpreter(conf, PySpark()) + case PySpark3() => + PythonInterpreter(conf, PySpark3()) + case Spark() => new SparkInterpreter(conf) + case SparkR() => SparkRInterpreter(conf) + } + session = new Session(livyConf, interpreter, { s => broadcast(new ReplState(s.toString)) }) + + Option(Await.result(session.start(), Duration.Inf)) + .map(new JavaSparkContext(_)) + .orNull + } + + override protected def shutdownContext(): Unit = { + if (session != null) { + try { + session.close() + } finally { + super.shutdownContext() + } + } + } + + def handle(ctx: ChannelHandlerContext, msg: BaseProtocol.ReplJobRequest): Int = { + session.execute(msg.code) + } + + def handle(ctx: ChannelHandlerContext, msg: BaseProtocol.CancelReplJobRequest): Unit = { + session.cancel(msg.id) + } + + /** + * Return statement results. Results are sorted by statement id. + */ + def handle(ctx: ChannelHandlerContext, msg: BaseProtocol.GetReplJobResults): ReplJobResults = { + val statements = if (msg.allResults) { + session.statements.values.toArray + } else { + assert(msg.from != null) + assert(msg.size != null) + if (msg.size == 1) { + session.statements.get(msg.from).toArray + } else { + val until = msg.from + msg.size + session.statements.filterKeys(id => id >= msg.from && id < until).values.toArray + } + } + + // Update progress of statements when queried + statements.foreach { s => + s.updateProgress(session.progressOfStatement(s.id)) + } + + new ReplJobResults(statements.sortBy(_.id)) + } + + override protected def createWrapper(msg: BaseProtocol.BypassJobRequest): BypassJobWrapper = { + kind match { + case PySpark() | PySpark3() => new BypassJobWrapper(this, msg.id, + new BypassPySparkJob(msg.serializedJob, this)) + case _ => super.createWrapper(msg) + } + } + + override protected def addFile(path: String): Unit = { + require(interpreter != null) + interpreter match { + case pi: PythonInterpreter => pi.addFile(path) + case _ => super.addFile(path) + } + } + + override protected def addJarOrPyFile(path: String): Unit = { + require(interpreter != null) + interpreter match { + case pi: PythonInterpreter => pi.addPyFile(this, conf, path) + case _ => super.addJarOrPyFile(path) + } + } + + override protected def onClientAuthenticated(client: Rpc): Unit = { + if (session != null) { + client.call(new ReplState(session.state.toString)) + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/repl/src/main/scala/org/apache/livy/repl/Session.scala ---------------------------------------------------------------------- diff --git a/repl/src/main/scala/org/apache/livy/repl/Session.scala b/repl/src/main/scala/org/apache/livy/repl/Session.scala new file mode 100644 index 0000000..40176ea --- /dev/null +++ b/repl/src/main/scala/org/apache/livy/repl/Session.scala @@ -0,0 +1,289 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.repl + +import java.util.{LinkedHashMap => JLinkedHashMap} +import java.util.Map.Entry +import java.util.concurrent.Executors +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.JavaConverters._ +import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration._ + +import org.apache.spark.SparkContext +import org.json4s.jackson.JsonMethods.{compact, render} +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ + +import org.apache.livy.Logging +import org.apache.livy.rsc.RSCConf +import org.apache.livy.rsc.driver.{Statement, StatementState} +import org.apache.livy.sessions._ + +object Session { + val STATUS = "status" + val OK = "ok" + val ERROR = "error" + val EXECUTION_COUNT = "execution_count" + val DATA = "data" + val ENAME = "ename" + val EVALUE = "evalue" + val TRACEBACK = "traceback" +} + +class Session( + livyConf: RSCConf, + interpreter: Interpreter, + stateChangedCallback: SessionState => Unit = { _ => }) + extends Logging { + import Session._ + + private val interpreterExecutor = ExecutionContext.fromExecutorService( + Executors.newSingleThreadExecutor()) + + private val cancelExecutor = ExecutionContext.fromExecutorService( + Executors.newSingleThreadExecutor()) + + private implicit val formats = DefaultFormats + + @volatile private[repl] var _sc: Option[SparkContext] = None + + private var _state: SessionState = SessionState.NotStarted() + + // Number of statements kept in driver's memory + private val numRetainedStatements = livyConf.getInt(RSCConf.Entry.RETAINED_STATEMENT_NUMBER) + + private val _statements = new JLinkedHashMap[Int, Statement] { + protected override def removeEldestEntry(eldest: Entry[Int, Statement]): Boolean = { + size() > numRetainedStatements + } + }.asScala + + private val newStatementId = new AtomicInteger(0) + + stateChangedCallback(_state) + + def start(): Future[SparkContext] = { + val future = Future { + changeState(SessionState.Starting()) + val sc = interpreter.start() + _sc = Option(sc) + changeState(SessionState.Idle()) + sc + }(interpreterExecutor) + + future.onFailure { case _ => changeState(SessionState.Error()) }(interpreterExecutor) + future + } + + def kind: String = interpreter.kind + + def state: SessionState = _state + + def statements: collection.Map[Int, Statement] = _statements.synchronized { + _statements.toMap + } + + def execute(code: String): Int = { + val statementId = newStatementId.getAndIncrement() + val statement = new Statement(statementId, code, StatementState.Waiting, null) + _statements.synchronized { _statements(statementId) = statement } + + Future { + setJobGroup(statementId) + statement.compareAndTransit(StatementState.Waiting, StatementState.Running) + + if (statement.state.get() == StatementState.Running) { + statement.output = executeCode(statementId, code) + } + + statement.compareAndTransit(StatementState.Running, StatementState.Available) + statement.compareAndTransit(StatementState.Cancelling, StatementState.Cancelled) + statement.updateProgress(1.0) + }(interpreterExecutor) + + statementId + } + + def cancel(statementId: Int): Unit = { + val statementOpt = _statements.synchronized { _statements.get(statementId) } + if (statementOpt.isEmpty) { + return + } + + val statement = statementOpt.get + if (statement.state.get().isOneOf( + StatementState.Available, StatementState.Cancelled, StatementState.Cancelling)) { + return + } else { + // statement 1 is running and statement 2 is waiting. User cancels + // statement 2 then cancels statement 1. The 2nd cancel call will loop and block the 1st + // cancel call since cancelExecutor is single threaded. To avoid this, set the statement + // state to cancelled when cancelling a waiting statement. + statement.compareAndTransit(StatementState.Waiting, StatementState.Cancelled) + statement.compareAndTransit(StatementState.Running, StatementState.Cancelling) + } + + info(s"Cancelling statement $statementId...") + + Future { + val deadline = livyConf.getTimeAsMs(RSCConf.Entry.JOB_CANCEL_TIMEOUT).millis.fromNow + + while (statement.state.get() == StatementState.Cancelling) { + if (deadline.isOverdue()) { + info(s"Failed to cancel statement $statementId.") + statement.compareAndTransit(StatementState.Cancelling, StatementState.Cancelled) + } else { + _sc.foreach(_.cancelJobGroup(statementId.toString)) + if (statement.state.get() == StatementState.Cancelling) { + Thread.sleep(livyConf.getTimeAsMs(RSCConf.Entry.JOB_CANCEL_TRIGGER_INTERVAL)) + } + } + } + + if (statement.state.get() == StatementState.Cancelled) { + info(s"Statement $statementId cancelled.") + } + }(cancelExecutor) + } + + def close(): Unit = { + interpreterExecutor.shutdown() + cancelExecutor.shutdown() + interpreter.close() + } + + /** + * Get the current progress of given statement id. + */ + def progressOfStatement(stmtId: Int): Double = { + val jobGroup = statementIdToJobGroup(stmtId) + + _sc.map { sc => + val jobIds = sc.statusTracker.getJobIdsForGroup(jobGroup) + val jobs = jobIds.flatMap { id => sc.statusTracker.getJobInfo(id) } + val stages = jobs.flatMap { job => + job.stageIds().flatMap(sc.statusTracker.getStageInfo) + } + + val taskCount = stages.map(_.numTasks).sum + val completedTaskCount = stages.map(_.numCompletedTasks).sum + if (taskCount == 0) { + 0.0 + } else { + completedTaskCount.toDouble / taskCount + } + }.getOrElse(0.0) + } + + private def changeState(newState: SessionState): Unit = { + synchronized { + _state = newState + } + stateChangedCallback(newState) + } + + private def executeCode(executionCount: Int, code: String): String = { + changeState(SessionState.Busy()) + + def transitToIdle() = { + val executingLastStatement = executionCount == newStatementId.intValue() - 1 + if (_statements.isEmpty || executingLastStatement) { + changeState(SessionState.Idle()) + } + } + + val resultInJson = try { + interpreter.execute(code) match { + case Interpreter.ExecuteSuccess(data) => + transitToIdle() + + (STATUS -> OK) ~ + (EXECUTION_COUNT -> executionCount) ~ + (DATA -> data) + + case Interpreter.ExecuteIncomplete() => + transitToIdle() + + (STATUS -> ERROR) ~ + (EXECUTION_COUNT -> executionCount) ~ + (ENAME -> "Error") ~ + (EVALUE -> "incomplete statement") ~ + (TRACEBACK -> Seq.empty[String]) + + case Interpreter.ExecuteError(ename, evalue, traceback) => + transitToIdle() + + (STATUS -> ERROR) ~ + (EXECUTION_COUNT -> executionCount) ~ + (ENAME -> ename) ~ + (EVALUE -> evalue) ~ + (TRACEBACK -> traceback) + + case Interpreter.ExecuteAborted(message) => + changeState(SessionState.Error()) + + (STATUS -> ERROR) ~ + (EXECUTION_COUNT -> executionCount) ~ + (ENAME -> "Error") ~ + (EVALUE -> f"Interpreter died:\n$message") ~ + (TRACEBACK -> Seq.empty[String]) + } + } catch { + case e: Throwable => + error("Exception when executing code", e) + + transitToIdle() + + (STATUS -> ERROR) ~ + (EXECUTION_COUNT -> executionCount) ~ + (ENAME -> f"Internal Error: ${e.getClass.getName}") ~ + (EVALUE -> e.getMessage) ~ + (TRACEBACK -> Seq.empty[String]) + } + + compact(render(resultInJson)) + } + + private def setJobGroup(statementId: Int): String = { + val jobGroup = statementIdToJobGroup(statementId) + val cmd = Kind(interpreter.kind) match { + case Spark() => + // A dummy value to avoid automatic value binding in scala REPL. + s"""val _livyJobGroup$jobGroup = sc.setJobGroup("$jobGroup",""" + + s""""Job group for statement $jobGroup")""" + case PySpark() | PySpark3() => + s"""sc.setJobGroup("$jobGroup", "Job group for statement $jobGroup")""" + case SparkR() => + interpreter.asInstanceOf[SparkRInterpreter].sparkMajorVersion match { + case "1" => + s"""setJobGroup(sc, "$jobGroup", "Job group for statement $jobGroup", """ + + "FALSE)" + case "2" => + s"""setJobGroup("$jobGroup", "Job group for statement $jobGroup", FALSE)""" + } + } + // Set the job group + executeCode(statementId, cmd) + } + + private def statementIdToJobGroup(statementId: Int): String = { + statementId.toString + } +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/repl/src/main/scala/org/apache/livy/repl/SparkContextInitializer.scala ---------------------------------------------------------------------- diff --git a/repl/src/main/scala/org/apache/livy/repl/SparkContextInitializer.scala b/repl/src/main/scala/org/apache/livy/repl/SparkContextInitializer.scala new file mode 100644 index 0000000..a4a57f9 --- /dev/null +++ b/repl/src/main/scala/org/apache/livy/repl/SparkContextInitializer.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.repl + +import org.apache.spark.{SparkConf, SparkContext} + +import org.apache.livy.Logging + +/** + * A mixin trait for Spark entry point creation. This trait exists two different code path + * separately for Spark1 and Spark2, depends on whether SparkSession exists or not. + */ +trait SparkContextInitializer extends Logging { + self: SparkInterpreter => + + def createSparkContext(conf: SparkConf): Unit = { + if (isSparkSessionPresent()) { + spark2CreateContext(conf) + } else { + spark1CreateContext(conf) + } + } + + private def spark1CreateContext(conf: SparkConf): Unit = { + sparkContext = SparkContext.getOrCreate(conf) + var sqlContext: Object = null + + if (conf.getBoolean("spark.repl.enableHiveContext", false)) { + try { + val loader = Option(Thread.currentThread().getContextClassLoader) + .getOrElse(getClass.getClassLoader) + if (loader.getResource("hive-site.xml") == null) { + warn("livy.repl.enable-hive-context is true but no hive-site.xml found on classpath.") + } + + sqlContext = Class.forName("org.apache.spark.sql.hive.HiveContext") + .getConstructor(classOf[SparkContext]).newInstance(sparkContext).asInstanceOf[Object] + info("Created sql context (with Hive support).") + } catch { + case _: NoClassDefFoundError => + sqlContext = Class.forName("org.apache.spark.sql.SQLContext") + .getConstructor(classOf[SparkContext]).newInstance(sparkContext).asInstanceOf[Object] + info("Created sql context.") + } + } else { + sqlContext = Class.forName("org.apache.spark.sql.SQLContext") + .getConstructor(classOf[SparkContext]).newInstance(sparkContext).asInstanceOf[Object] + info("Created sql context.") + } + + bind("sc", "org.apache.spark.SparkContext", sparkContext, List("""@transient""")) + bind("sqlContext", sqlContext.getClass.getCanonicalName, sqlContext, List("""@transient""")) + + execute("import org.apache.spark.SparkContext._") + execute("import sqlContext.implicits._") + execute("import sqlContext.sql") + execute("import org.apache.spark.sql.functions._") + } + + private def spark2CreateContext(conf: SparkConf): Unit = { + val sparkClz = Class.forName("org.apache.spark.sql.SparkSession$") + val sparkObj = sparkClz.getField("MODULE$").get(null) + + val builderMethod = sparkClz.getMethod("builder") + val builder = builderMethod.invoke(sparkObj) + builder.getClass.getMethod("config", classOf[SparkConf]).invoke(builder, conf) + + var spark: Object = null + if (conf.get("spark.sql.catalogImplementation", "in-memory").toLowerCase == "hive") { + if (sparkClz.getMethod("hiveClassesArePresent").invoke(sparkObj).asInstanceOf[Boolean]) { + val loader = Option(Thread.currentThread().getContextClassLoader) + .getOrElse(getClass.getClassLoader) + if (loader.getResource("hive-site.xml") == null) { + warn("livy.repl.enable-hive-context is true but no hive-site.xml found on classpath.") + } + + builder.getClass.getMethod("enableHiveSupport").invoke(builder) + spark = builder.getClass.getMethod("getOrCreate").invoke(builder) + info("Created Spark session (with Hive support).") + } else { + spark = builder.getClass.getMethod("getOrCreate").invoke(builder) + info("Created Spark session.") + } + } else { + spark = builder.getClass.getMethod("getOrCreate").invoke(builder) + info("Created Spark session.") + } + + sparkContext = spark.getClass.getMethod("sparkContext").invoke(spark) + .asInstanceOf[SparkContext] + + bind("spark", spark.getClass.getCanonicalName, spark, List("""@transient""")) + bind("sc", "org.apache.spark.SparkContext", sparkContext, List("""@transient""")) + + execute("import org.apache.spark.SparkContext._") + execute("import spark.implicits._") + execute("import spark.sql") + execute("import org.apache.spark.sql.functions._") + } + + private def isSparkSessionPresent(): Boolean = { + try { + Class.forName("org.apache.spark.sql.SparkSession") + true + } catch { + case _: ClassNotFoundException | _: NoClassDefFoundError => false + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/repl/src/main/scala/org/apache/livy/repl/SparkRInterpreter.scala ---------------------------------------------------------------------- diff --git a/repl/src/main/scala/org/apache/livy/repl/SparkRInterpreter.scala b/repl/src/main/scala/org/apache/livy/repl/SparkRInterpreter.scala new file mode 100644 index 0000000..b745861 --- /dev/null +++ b/repl/src/main/scala/org/apache/livy/repl/SparkRInterpreter.scala @@ -0,0 +1,324 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy.repl + +import java.io.{File, FileOutputStream} +import java.lang.ProcessBuilder.Redirect +import java.nio.file.Files +import java.util.concurrent.{CountDownLatch, Semaphore, TimeUnit} + +import scala.annotation.tailrec +import scala.collection.JavaConverters._ +import scala.reflect.runtime.universe + +import org.apache.commons.codec.binary.Base64 +import org.apache.commons.lang.StringEscapeUtils +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} +import org.json4s._ +import org.json4s.JsonDSL._ + +import org.apache.livy.client.common.ClientConf +import org.apache.livy.rsc.RSCConf + +private case class RequestResponse(content: String, error: Boolean) + +// scalastyle:off println +object SparkRInterpreter { + private val LIVY_END_MARKER = "----LIVY_END_OF_COMMAND----" + private val LIVY_ERROR_MARKER = "----LIVY_END_OF_ERROR----" + private val PRINT_MARKER = f"""print("$LIVY_END_MARKER")""" + private val EXPECTED_OUTPUT = f"""[1] "$LIVY_END_MARKER"""" + + private val PLOT_REGEX = ( + "(" + + "(?:bagplot)|" + + "(?:barplot)|" + + "(?:boxplot)|" + + "(?:dotchart)|" + + "(?:hist)|" + + "(?:lines)|" + + "(?:pie)|" + + "(?:pie3D)|" + + "(?:plot)|" + + "(?:qqline)|" + + "(?:qqnorm)|" + + "(?:scatterplot)|" + + "(?:scatterplot3d)|" + + "(?:scatterplot\\.matrix)|" + + "(?:splom)|" + + "(?:stripchart)|" + + "(?:vioplot)" + + ")" + ).r.unanchored + + def apply(conf: SparkConf): SparkRInterpreter = { + val backendTimeout = sys.env.getOrElse("SPARKR_BACKEND_TIMEOUT", "120").toInt + val mirror = universe.runtimeMirror(getClass.getClassLoader) + val sparkRBackendClass = mirror.classLoader.loadClass("org.apache.spark.api.r.RBackend") + val backendInstance = sparkRBackendClass.getDeclaredConstructor().newInstance() + + var sparkRBackendPort = 0 + val initialized = new Semaphore(0) + // Launch a SparkR backend server for the R process to connect to + val backendThread = new Thread("SparkR backend") { + override def run(): Unit = { + sparkRBackendPort = sparkRBackendClass.getMethod("init").invoke(backendInstance) + .asInstanceOf[Int] + + initialized.release() + sparkRBackendClass.getMethod("run").invoke(backendInstance) + } + } + + backendThread.setDaemon(true) + backendThread.start() + try { + // Wait for RBackend initialization to finish + initialized.tryAcquire(backendTimeout, TimeUnit.SECONDS) + val rExec = conf.getOption("spark.r.shell.command") + .orElse(sys.env.get("SPARKR_DRIVER_R")) + .getOrElse("R") + + var packageDir = "" + if (sys.env.getOrElse("SPARK_YARN_MODE", "") == "true") { + packageDir = "./sparkr" + } else { + // local mode + val rLibPath = new File(sys.env.getOrElse("SPARKR_PACKAGE_DIR", + Seq(sys.env.getOrElse("SPARK_HOME", "."), "R", "lib").mkString(File.separator))) + if (!ClientConf.TEST_MODE) { + require(rLibPath.exists(), "Cannot find sparkr package directory.") + packageDir = rLibPath.getAbsolutePath() + } + } + + val builder = new ProcessBuilder(Seq(rExec, "--slave @").asJava) + val env = builder.environment() + env.put("SPARK_HOME", sys.env.getOrElse("SPARK_HOME", ".")) + env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString) + env.put("SPARKR_PACKAGE_DIR", packageDir) + env.put("R_PROFILE_USER", + Seq(packageDir, "SparkR", "profile", "general.R").mkString(File.separator)) + + builder.redirectErrorStream(true) + val process = builder.start() + new SparkRInterpreter(process, backendInstance, backendThread, + conf.get("spark.livy.spark_major_version", "1"), + conf.getBoolean("spark.repl.enableHiveContext", false)) + } catch { + case e: Exception => + if (backendThread != null) { + backendThread.interrupt() + } + throw e + } + } +} + +class SparkRInterpreter(process: Process, + backendInstance: Any, + backendThread: Thread, + val sparkMajorVersion: String, + hiveEnabled: Boolean) + extends ProcessInterpreter(process) { + import SparkRInterpreter._ + + implicit val formats = DefaultFormats + + private[this] var executionCount = 0 + override def kind: String = "sparkr" + private[this] val isStarted = new CountDownLatch(1) + + final override protected def waitUntilReady(): Unit = { + // Set the option to catch and ignore errors instead of halting. + sendRequest("options(error = dump.frames)") + if (!ClientConf.TEST_MODE) { + sendRequest("library(SparkR)") + if (sparkMajorVersion >= "2") { + if (hiveEnabled) { + sendRequest("spark <- SparkR::sparkR.session()") + } else { + sendRequest("spark <- SparkR::sparkR.session(enableHiveSupport=FALSE)") + } + sendRequest( + """sc <- SparkR:::callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "getJavaSparkContext", spark)""") + } else { + sendRequest("sc <- sparkR.init()") + if (hiveEnabled) { + sendRequest("sqlContext <- sparkRHive.init(sc)") + } else { + sendRequest("sqlContext <- sparkRSQL.init(sc)") + } + } + } + + isStarted.countDown() + executionCount = 0 + } + + override protected def sendExecuteRequest(command: String): Interpreter.ExecuteResponse = { + isStarted.await() + var code = command + + // Create a image file if this command is trying to plot. + val tempFile = PLOT_REGEX.findFirstIn(code).map { case _ => + val tempFile = Files.createTempFile("", ".png") + val tempFileString = tempFile.toAbsolutePath + + code = f"""png("$tempFileString")\n$code\ndev.off()""" + + tempFile + } + + try { + val response = sendRequest(code) + + if (response.error) { + Interpreter.ExecuteError("Error", response.content) + } else { + var content: JObject = TEXT_PLAIN -> response.content + + // If we rendered anything, pass along the last image. + tempFile.foreach { case file => + val bytes = Files.readAllBytes(file) + if (bytes.nonEmpty) { + val image = Base64.encodeBase64String(bytes) + content = content ~ (IMAGE_PNG -> image) + } + } + + Interpreter.ExecuteSuccess(content) + } + + } catch { + case e: Error => + Interpreter.ExecuteError("Error", e.output) + case e: Exited => + Interpreter.ExecuteAborted(e.getMessage) + } finally { + tempFile.foreach(Files.delete) + } + + } + + private def sendRequest(code: String): RequestResponse = { + stdin.println(s"""tryCatch(eval(parse(text="${StringEscapeUtils.escapeJava(code)}")) + |,error = function(e) sprintf("%s%s", e, "${LIVY_ERROR_MARKER}")) + """.stripMargin) + stdin.flush() + + stdin.println(PRINT_MARKER) + stdin.flush() + + readTo(EXPECTED_OUTPUT, LIVY_ERROR_MARKER) + } + + override protected def sendShutdownRequest() = { + stdin.println("q()") + stdin.flush() + + while (stdout.readLine() != null) {} + } + + override def close(): Unit = { + try { + val closeMethod = backendInstance.getClass().getMethod("close") + closeMethod.setAccessible(true) + closeMethod.invoke(backendInstance) + + backendThread.interrupt() + backendThread.join() + } finally { + super.close() + } + } + + @tailrec + private def readTo( + marker: String, + errorMarker: String, + output: StringBuilder = StringBuilder.newBuilder): RequestResponse = { + var char = readChar(output) + + // Remove any ANSI color codes which match the pattern "\u001b\\[[0-9;]*[mG]". + // It would be easier to do this with a regex, but unfortunately I don't see an easy way to do + // without copying the StringBuilder into a string for each character. + if (char == '\u001b') { + if (readChar(output) == '[') { + char = readDigits(output) + + if (char == 'm' || char == 'G') { + output.delete(output.lastIndexOf('\u001b'), output.length) + } + } + } + + if (output.endsWith(marker)) { + var result = stripMarker(output.toString(), marker) + + if (result.endsWith(errorMarker + "\"")) { + result = stripMarker(result, "\\n" + errorMarker) + RequestResponse(result, error = true) + } else { + RequestResponse(result, error = false) + } + } else { + readTo(marker, errorMarker, output) + } + } + + private def stripMarker(result: String, marker: String): String = { + result.replace(marker, "") + .stripPrefix("\n") + .stripSuffix("\n") + } + + private def readChar(output: StringBuilder): Char = { + val byte = stdout.read() + if (byte == -1) { + throw new Exited(output.toString()) + } else { + val char = byte.toChar + output.append(char) + char + } + } + + @tailrec + private def readDigits(output: StringBuilder): Char = { + val byte = stdout.read() + if (byte == -1) { + throw new Exited(output.toString()) + } + + val char = byte.toChar + + if (('0' to '9').contains(char)) { + output.append(char) + readDigits(output) + } else { + char + } + } + + private class Exited(val output: String) extends Exception {} + private class Error(val output: String) extends Exception {} +} +// scalastyle:on println http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/repl/src/main/scala/org/apache/livy/repl/package.scala ---------------------------------------------------------------------- diff --git a/repl/src/main/scala/org/apache/livy/repl/package.scala b/repl/src/main/scala/org/apache/livy/repl/package.scala new file mode 100644 index 0000000..7bc9fe9 --- /dev/null +++ b/repl/src/main/scala/org/apache/livy/repl/package.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.livy + +import org.json4s.JField + +package object repl { + type MimeTypeMap = List[JField] + + val APPLICATION_JSON = "application/json" + val APPLICATION_LIVY_TABLE_JSON = "application/vnd.livy.table.v1+json" + val IMAGE_PNG = "image/png" + val TEXT_PLAIN = "text/plain" +} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/repl/src/test/scala/com/cloudera/livy/repl/BaseInterpreterSpec.scala ---------------------------------------------------------------------- diff --git a/repl/src/test/scala/com/cloudera/livy/repl/BaseInterpreterSpec.scala b/repl/src/test/scala/com/cloudera/livy/repl/BaseInterpreterSpec.scala deleted file mode 100644 index 0c410d2..0000000 --- a/repl/src/test/scala/com/cloudera/livy/repl/BaseInterpreterSpec.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.cloudera.livy.repl - -import org.scalatest.{FlatSpec, Matchers} - -import com.cloudera.livy.LivyBaseUnitTestSuite - -abstract class BaseInterpreterSpec extends FlatSpec with Matchers with LivyBaseUnitTestSuite { - - def createInterpreter(): Interpreter - - def withInterpreter(testCode: Interpreter => Any): Unit = { - val interpreter = createInterpreter() - try { - interpreter.start() - testCode(interpreter) - } finally { - interpreter.close() - } - } -} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/repl/src/test/scala/com/cloudera/livy/repl/BaseSessionSpec.scala ---------------------------------------------------------------------- diff --git a/repl/src/test/scala/com/cloudera/livy/repl/BaseSessionSpec.scala b/repl/src/test/scala/com/cloudera/livy/repl/BaseSessionSpec.scala deleted file mode 100644 index 59db4b3..0000000 --- a/repl/src/test/scala/com/cloudera/livy/repl/BaseSessionSpec.scala +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.cloudera.livy.repl - -import java.util.Properties -import java.util.concurrent.atomic.AtomicInteger - -import scala.concurrent.Await -import scala.concurrent.duration._ -import scala.language.postfixOps - -import org.json4s._ -import org.scalatest.{FlatSpec, Matchers} -import org.scalatest.concurrent.Eventually._ - -import com.cloudera.livy.LivyBaseUnitTestSuite -import com.cloudera.livy.rsc.RSCConf -import com.cloudera.livy.rsc.driver.{Statement, StatementState} -import com.cloudera.livy.sessions.SessionState - -abstract class BaseSessionSpec extends FlatSpec with Matchers with LivyBaseUnitTestSuite { - - implicit val formats = DefaultFormats - - private val rscConf = new RSCConf(new Properties()) - - protected def execute(session: Session)(code: String): Statement = { - val id = session.execute(code) - eventually(timeout(30 seconds), interval(100 millis)) { - val s = session.statements(id) - s.state.get() shouldBe StatementState.Available - s - } - } - - protected def withSession(testCode: Session => Any): Unit = { - val stateChangedCalled = new AtomicInteger() - val session = - new Session(rscConf, createInterpreter(), { _ => stateChangedCalled.incrementAndGet() }) - try { - // Session's constructor should fire an initial state change event. - stateChangedCalled.intValue() shouldBe 1 - Await.ready(session.start(), 30 seconds) - assert(session.state === SessionState.Idle()) - // There should be at least 1 state change event fired when session transits to idle. - stateChangedCalled.intValue() should (be > 1) - testCode(session) - } finally { - session.close() - } - } - - protected def createInterpreter(): Interpreter - - it should "start in the starting or idle state" in { - val session = new Session(rscConf, createInterpreter()) - val future = session.start() - try { - eventually(timeout(30 seconds), interval(100 millis)) { - session.state should (equal (SessionState.Starting()) or equal (SessionState.Idle())) - } - Await.ready(future, 60 seconds) - } finally { - session.close() - } - } - - it should "eventually become the idle state" in withSession { session => - session.state should equal (SessionState.Idle()) - } - -} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/repl/src/test/scala/com/cloudera/livy/repl/PythonInterpreterSpec.scala ---------------------------------------------------------------------- diff --git a/repl/src/test/scala/com/cloudera/livy/repl/PythonInterpreterSpec.scala b/repl/src/test/scala/com/cloudera/livy/repl/PythonInterpreterSpec.scala deleted file mode 100644 index 7917b90..0000000 --- a/repl/src/test/scala/com/cloudera/livy/repl/PythonInterpreterSpec.scala +++ /dev/null @@ -1,284 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.cloudera.livy.repl - -import org.apache.spark.SparkConf -import org.json4s.{DefaultFormats, JNull, JValue} -import org.json4s.JsonDSL._ -import org.scalatest._ - -import com.cloudera.livy.rsc.RSCConf -import com.cloudera.livy.sessions._ - -abstract class PythonBaseInterpreterSpec extends BaseInterpreterSpec { - - it should "execute `1 + 2` == 3" in withInterpreter { interpreter => - val response = interpreter.execute("1 + 2") - response should equal (Interpreter.ExecuteSuccess( - TEXT_PLAIN -> "3" - )) - } - - it should "execute multiple statements" in withInterpreter { interpreter => - var response = interpreter.execute("x = 1") - response should equal (Interpreter.ExecuteSuccess( - TEXT_PLAIN -> "" - )) - - response = interpreter.execute("y = 2") - response should equal (Interpreter.ExecuteSuccess( - TEXT_PLAIN -> "" - )) - - response = interpreter.execute("x + y") - response should equal (Interpreter.ExecuteSuccess( - TEXT_PLAIN -> "3" - )) - } - - it should "execute multiple statements in one block" in withInterpreter { interpreter => - val response = interpreter.execute( - """ - |x = 1 - | - |y = 2 - | - |x + y - """.stripMargin) - response should equal(Interpreter.ExecuteSuccess( - TEXT_PLAIN -> "3" - )) - } - - it should "parse a class" in withInterpreter { interpreter => - val response = interpreter.execute( - """ - |class Counter(object): - | def __init__(self): - | self.count = 0 - | - | def add_one(self): - | self.count += 1 - | - | def add_two(self): - | self.count += 2 - | - |counter = Counter() - |counter.add_one() - |counter.add_two() - |counter.count - """.stripMargin) - response should equal(Interpreter.ExecuteSuccess( - TEXT_PLAIN -> "3" - )) - } - - it should "do json magic" in withInterpreter { interpreter => - val response = interpreter.execute( - """x = [[1, 'a'], [3, 'b']] - |%json x - """.stripMargin) - - response should equal(Interpreter.ExecuteSuccess( - APPLICATION_JSON -> List[JValue]( - List[JValue](1, "a"), - List[JValue](3, "b") - ) - )) - } - - it should "do table magic" in withInterpreter { interpreter => - val response = interpreter.execute( - """x = [[1, 'a'], [3, 'b']] - |%table x - """.stripMargin) - - response should equal(Interpreter.ExecuteSuccess( - APPLICATION_LIVY_TABLE_JSON -> ( - ("headers" -> List( - ("type" -> "INT_TYPE") ~ ("name" -> "0"), - ("type" -> "STRING_TYPE") ~ ("name" -> "1") - )) ~ - ("data" -> List( - List[JValue](1, "a"), - List[JValue](3, "b") - )) - ) - )) - } - - it should "do table magic with None type value" in withInterpreter { interpreter => - val response = interpreter.execute( - """x = [{"a":"1", "b":None}, {"a":"2", "b":2}] - |%table x - """.stripMargin) - - response should equal(Interpreter.ExecuteSuccess( - APPLICATION_LIVY_TABLE_JSON -> ( - ("headers" -> List( - ("type" -> "STRING_TYPE") ~ ("name" -> "a"), - ("type" -> "INT_TYPE") ~ ("name" -> "b") - )) ~ - ("data" -> List( - List[JValue]("1", JNull), - List[JValue]("2", 2) - )) - ) - )) - } - - it should "do table magic with None type Row" in withInterpreter { interpreter => - val response = interpreter.execute( - """x = [{"a":None, "b":None}, {"a":"2", "b":2}] - |%table x - """.stripMargin) - - response should equal(Interpreter.ExecuteSuccess( - APPLICATION_LIVY_TABLE_JSON -> ( - ("headers" -> List( - ("type" -> "STRING_TYPE") ~ ("name" -> "a"), - ("type" -> "INT_TYPE") ~ ("name" -> "b") - )) ~ - ("data" -> List( - List[JValue](JNull, JNull), - List[JValue]("2", 2) - )) - ) - )) - } - - it should "allow magic inside statements" in withInterpreter { interpreter => - val response = interpreter.execute( - """x = [[1, 'a'], [3, 'b']] - |%table x - |1 + 2 - """.stripMargin) - - response should equal(Interpreter.ExecuteSuccess( - TEXT_PLAIN -> "3" - )) - } - - it should "capture stdout" in withInterpreter { interpreter => - val response = interpreter.execute("print('Hello World')") - response should equal(Interpreter.ExecuteSuccess( - TEXT_PLAIN -> "Hello World" - )) - } - - it should "report an error if accessing an unknown variable" in withInterpreter { interpreter => - val response = interpreter.execute("x") - response should equal(Interpreter.ExecuteError( - "NameError", - "name 'x' is not defined", - List( - "Traceback (most recent call last):\n", - "NameError: name 'x' is not defined\n" - ) - )) - } - - it should "report an error if empty magic command" in withInterpreter { interpreter => - val response = interpreter.execute("%") - response should equal(Interpreter.ExecuteError( - "UnknownMagic", - "magic command not specified", - List("UnknownMagic: magic command not specified\n") - )) - } - - it should "report an error if unknown magic command" in withInterpreter { interpreter => - val response = interpreter.execute("%foo") - response should equal(Interpreter.ExecuteError( - "UnknownMagic", - "unknown magic command 'foo'", - List("UnknownMagic: unknown magic command 'foo'\n") - )) - } - - it should "not execute part of the block if there is a syntax error" in withInterpreter { intp => - var response = intp.execute( - """x = 1 - |' - """.stripMargin) - - response should equal(Interpreter.ExecuteError( - "SyntaxError", - "EOL while scanning string literal (<stdin>, line 2)", - List( - " File \"<stdin>\", line 2\n", - " '\n", - " ^\n", - "SyntaxError: EOL while scanning string literal\n" - ) - )) - - response = intp.execute("x") - response should equal(Interpreter.ExecuteError( - "NameError", - "name 'x' is not defined", - List( - "Traceback (most recent call last):\n", - "NameError: name 'x' is not defined\n" - ) - )) - } -} - -class Python2InterpreterSpec extends PythonBaseInterpreterSpec { - - implicit val formats = DefaultFormats - - override def createInterpreter(): Interpreter = PythonInterpreter(new SparkConf(), PySpark()) - - // Scalastyle is treating unicode escape as non ascii characters. Turn off the check. - // scalastyle:off non.ascii.character.disallowed - it should "print unicode correctly" in withInterpreter { intp => - intp.execute("print(u\"\u263A\")") should equal(Interpreter.ExecuteSuccess( - TEXT_PLAIN -> "\u263A" - )) - intp.execute("""print(u"\u263A")""") should equal(Interpreter.ExecuteSuccess( - TEXT_PLAIN -> "\u263A" - )) - intp.execute("""print("\xE2\x98\xBA")""") should equal(Interpreter.ExecuteSuccess( - TEXT_PLAIN -> "\u263A" - )) - } - // scalastyle:on non.ascii.character.disallowed -} - -class Python3InterpreterSpec extends PythonBaseInterpreterSpec { - - implicit val formats = DefaultFormats - - override protected def withFixture(test: NoArgTest): Outcome = { - assume(!sys.props.getOrElse("skipPySpark3Tests", "false").toBoolean, "Skipping PySpark3 tests.") - test() - } - - override def createInterpreter(): Interpreter = PythonInterpreter(new SparkConf(), PySpark3()) - - it should "check python version is 3.x" in withInterpreter { interpreter => - val response = interpreter.execute("""import sys - |sys.version >= '3' - """.stripMargin) - response should equal (Interpreter.ExecuteSuccess( - TEXT_PLAIN -> "True" - )) - } -} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/412ccc8f/repl/src/test/scala/com/cloudera/livy/repl/PythonSessionSpec.scala ---------------------------------------------------------------------- diff --git a/repl/src/test/scala/com/cloudera/livy/repl/PythonSessionSpec.scala b/repl/src/test/scala/com/cloudera/livy/repl/PythonSessionSpec.scala deleted file mode 100644 index 6bafdf7..0000000 --- a/repl/src/test/scala/com/cloudera/livy/repl/PythonSessionSpec.scala +++ /dev/null @@ -1,206 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.cloudera.livy.repl - -import org.apache.spark.SparkConf -import org.json4s.Extraction -import org.json4s.jackson.JsonMethods.parse -import org.scalatest._ - -import com.cloudera.livy.rsc.RSCConf -import com.cloudera.livy.sessions._ - -abstract class PythonSessionSpec extends BaseSessionSpec { - - it should "execute `1 + 2` == 3" in withSession { session => - val statement = execute(session)("1 + 2") - statement.id should equal (0) - - val result = parse(statement.output) - val expectedResult = Extraction.decompose(Map( - "status" -> "ok", - "execution_count" -> 0, - "data" -> Map( - "text/plain" -> "3" - ) - )) - - result should equal (expectedResult) - } - - it should "execute `x = 1`, then `y = 2`, then `x + y`" in withSession { session => - val executeWithSession = execute(session)(_) - var statement = executeWithSession("x = 1") - statement.id should equal (0) - - var result = parse(statement.output) - var expectedResult = Extraction.decompose(Map( - "status" -> "ok", - "execution_count" -> 0, - "data" -> Map( - "text/plain" -> "" - ) - )) - - result should equal (expectedResult) - - statement = executeWithSession("y = 2") - statement.id should equal (1) - - result = parse(statement.output) - expectedResult = Extraction.decompose(Map( - "status" -> "ok", - "execution_count" -> 1, - "data" -> Map( - "text/plain" -> "" - ) - )) - - result should equal (expectedResult) - - statement = executeWithSession("x + y") - statement.id should equal (2) - - result = parse(statement.output) - expectedResult = Extraction.decompose(Map( - "status" -> "ok", - "execution_count" -> 2, - "data" -> Map( - "text/plain" -> "3" - ) - )) - - result should equal (expectedResult) - } - - it should "do table magic" in withSession { session => - val statement = execute(session)("x = [[1, 'a'], [3, 'b']]\n%table x") - statement.id should equal (0) - - val result = parse(statement.output) - val expectedResult = Extraction.decompose(Map( - "status" -> "ok", - "execution_count" -> 0, - "data" -> Map( - "application/vnd.livy.table.v1+json" -> Map( - "headers" -> List( - Map("type" -> "INT_TYPE", "name" -> "0"), - Map("type" -> "STRING_TYPE", "name" -> "1")), - "data" -> List(List(1, "a"), List(3, "b")) - ) - ) - )) - - result should equal (expectedResult) - } - - it should "capture stdout" in withSession { session => - val statement = execute(session)("""print('Hello World')""") - statement.id should equal (0) - - val result = parse(statement.output) - val expectedResult = Extraction.decompose(Map( - "status" -> "ok", - "execution_count" -> 0, - "data" -> Map( - "text/plain" -> "Hello World" - ) - )) - - result should equal (expectedResult) - } - - it should "report an error if accessing an unknown variable" in withSession { session => - val statement = execute(session)("""x""") - statement.id should equal (0) - - val result = parse(statement.output) - val expectedResult = Extraction.decompose(Map( - "status" -> "error", - "execution_count" -> 0, - "traceback" -> List( - "Traceback (most recent call last):\n", - "NameError: name 'x' is not defined\n" - ), - "ename" -> "NameError", - "evalue" -> "name 'x' is not defined" - )) - - result should equal (expectedResult) - } - - it should "report an error if exception is thrown" in withSession { session => - val statement = execute(session)( - """def func1(): - | raise Exception("message") - |def func2(): - | func1() - |func2() - """.stripMargin) - statement.id should equal (0) - - val result = parse(statement.output) - val expectedResult = Extraction.decompose(Map( - "status" -> "error", - "execution_count" -> 0, - "traceback" -> List( - "Traceback (most recent call last):\n", - " File \"<stdin>\", line 4, in func2\n", - " File \"<stdin>\", line 2, in func1\n", - "Exception: message\n" - ), - "ename" -> "Exception", - "evalue" -> "message" - )) - - result should equal (expectedResult) - } -} - -class Python2SessionSpec extends PythonSessionSpec { - override def createInterpreter(): Interpreter = PythonInterpreter(new SparkConf(), PySpark()) -} - -class Python3SessionSpec extends PythonSessionSpec { - - override protected def withFixture(test: NoArgTest): Outcome = { - assume(!sys.props.getOrElse("skipPySpark3Tests", "false").toBoolean, "Skipping PySpark3 tests.") - test() - } - - override def createInterpreter(): Interpreter = PythonInterpreter(new SparkConf(), PySpark3()) - - it should "check python version is 3.x" in withSession { session => - val statement = execute(session)( - """import sys - |sys.version >= '3' - """.stripMargin) - statement.id should equal (0) - - val result = parse(statement.output) - val expectedResult = Extraction.decompose(Map( - "status" -> "ok", - "execution_count" -> 0, - "data" -> Map( - "text/plain" -> "True" - ) - )) - - result should equal (expectedResult) - } -}