LIVY-355. Refactor statement progress tracker to fix binary compatible issue (#323)
* Refactor statement progress tracker to fix binary compatible issue Change-Id: Ie91fd77472aeebe138bd6711a0baa82269a6b247 * refactor again to simplify the code Change-Id: I9380bcb8dd2b594250783633a3c68e290ac7ea28 * isolate statementId to job group logic Change-Id: If554aee2c0b3d96b54804f94cbb8df9af7843ab4 Project: http://git-wip-us.apache.org/repos/asf/incubator-livy/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-livy/commit/0ddcaf68 Tree: http://git-wip-us.apache.org/repos/asf/incubator-livy/tree/0ddcaf68 Diff: http://git-wip-us.apache.org/repos/asf/incubator-livy/diff/0ddcaf68 Branch: refs/heads/master Commit: 0ddcaf68d28a5120b1b74692bd92dde2301f5170 Parents: f5ef489 Author: Saisai Shao <sai.sai.s...@gmail.com> Authored: Wed May 10 04:25:57 2017 +0800 Committer: Jeff Zhang <zjf...@gmail.com> Committed: Tue May 9 13:25:57 2017 -0700 ---------------------------------------------------------------------- .../cloudera/livy/repl/SparkInterpreter.scala | 4 +- .../livy/repl/SparkInterpreterSpec.scala | 2 +- .../cloudera/livy/repl/SparkInterpreter.scala | 4 +- .../livy/repl/SparkInterpreterSpec.scala | 2 +- .../com/cloudera/livy/repl/Interpreter.scala | 10 - .../cloudera/livy/repl/ProcessInterpreter.scala | 7 +- .../cloudera/livy/repl/PythonInterpreter.scala | 9 +- .../com/cloudera/livy/repl/ReplDriver.scala | 10 +- .../scala/com/cloudera/livy/repl/Session.scala | 40 +++- .../cloudera/livy/repl/SparkRInterpreter.scala | 10 +- .../livy/repl/StatementProgressListener.scala | 162 ------------- .../livy/repl/PythonInterpreterSpec.scala | 6 +- .../cloudera/livy/repl/PythonSessionSpec.scala | 6 +- .../livy/repl/ScalaInterpreterSpec.scala | 2 +- .../livy/repl/SparkRInterpreterSpec.scala | 3 +- .../cloudera/livy/repl/SparkRSessionSpec.scala | 3 +- .../cloudera/livy/repl/SparkSessionSpec.scala | 35 ++- .../repl/StatementProgressListenerSpec.scala | 227 ------------------- 18 files changed, 93 insertions(+), 449 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/0ddcaf68/repl/scala-2.10/src/main/scala/com/cloudera/livy/repl/SparkInterpreter.scala ---------------------------------------------------------------------- diff --git a/repl/scala-2.10/src/main/scala/com/cloudera/livy/repl/SparkInterpreter.scala b/repl/scala-2.10/src/main/scala/com/cloudera/livy/repl/SparkInterpreter.scala index 5ef5491..d736125 100644 --- a/repl/scala-2.10/src/main/scala/com/cloudera/livy/repl/SparkInterpreter.scala +++ b/repl/scala-2.10/src/main/scala/com/cloudera/livy/repl/SparkInterpreter.scala @@ -33,8 +33,7 @@ import org.apache.spark.repl.SparkIMain /** * This represents a Spark interpreter. It is not thread safe. */ -class SparkInterpreter(conf: SparkConf, - override val statementProgressListener: StatementProgressListener) +class SparkInterpreter(conf: SparkConf) extends AbstractSparkInterpreter with SparkContextInitializer { private var sparkIMain: SparkIMain = _ @@ -108,7 +107,6 @@ class SparkInterpreter(conf: SparkConf, createSparkContext(conf) } - sparkContext.addSparkListener(statementProgressListener) sparkContext } http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/0ddcaf68/repl/scala-2.10/src/test/scala/com/cloudera/livy/repl/SparkInterpreterSpec.scala ---------------------------------------------------------------------- diff --git a/repl/scala-2.10/src/test/scala/com/cloudera/livy/repl/SparkInterpreterSpec.scala b/repl/scala-2.10/src/test/scala/com/cloudera/livy/repl/SparkInterpreterSpec.scala index 3df35b5..e2b783a 100644 --- a/repl/scala-2.10/src/test/scala/com/cloudera/livy/repl/SparkInterpreterSpec.scala +++ b/repl/scala-2.10/src/test/scala/com/cloudera/livy/repl/SparkInterpreterSpec.scala @@ -24,7 +24,7 @@ import com.cloudera.livy.LivyBaseUnitTestSuite class SparkInterpreterSpec extends FunSpec with Matchers with LivyBaseUnitTestSuite { describe("SparkInterpreter") { - val interpreter = new SparkInterpreter(null, null) + val interpreter = new SparkInterpreter(null) it("should parse Scala compile error.") { // Regression test for LIVY-260. http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/0ddcaf68/repl/scala-2.11/src/main/scala/com/cloudera/livy/repl/SparkInterpreter.scala ---------------------------------------------------------------------- diff --git a/repl/scala-2.11/src/main/scala/com/cloudera/livy/repl/SparkInterpreter.scala b/repl/scala-2.11/src/main/scala/com/cloudera/livy/repl/SparkInterpreter.scala index 6735b3a..f08a46e 100644 --- a/repl/scala-2.11/src/main/scala/com/cloudera/livy/repl/SparkInterpreter.scala +++ b/repl/scala-2.11/src/main/scala/com/cloudera/livy/repl/SparkInterpreter.scala @@ -33,8 +33,7 @@ import org.apache.spark.repl.SparkILoop /** * Scala 2.11 version of SparkInterpreter */ -class SparkInterpreter(conf: SparkConf, - override val statementProgressListener: StatementProgressListener) +class SparkInterpreter(conf: SparkConf) extends AbstractSparkInterpreter with SparkContextInitializer { protected var sparkContext: SparkContext = _ @@ -94,7 +93,6 @@ class SparkInterpreter(conf: SparkConf, createSparkContext(conf) } - sparkContext.addSparkListener(statementProgressListener) sparkContext } http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/0ddcaf68/repl/scala-2.11/src/test/scala/com/cloudera/livy/repl/SparkInterpreterSpec.scala ---------------------------------------------------------------------- diff --git a/repl/scala-2.11/src/test/scala/com/cloudera/livy/repl/SparkInterpreterSpec.scala b/repl/scala-2.11/src/test/scala/com/cloudera/livy/repl/SparkInterpreterSpec.scala index 56656d7..5cb88e3 100644 --- a/repl/scala-2.11/src/test/scala/com/cloudera/livy/repl/SparkInterpreterSpec.scala +++ b/repl/scala-2.11/src/test/scala/com/cloudera/livy/repl/SparkInterpreterSpec.scala @@ -24,7 +24,7 @@ import com.cloudera.livy.LivyBaseUnitTestSuite class SparkInterpreterSpec extends FunSpec with Matchers with LivyBaseUnitTestSuite { describe("SparkInterpreter") { - val interpreter = new SparkInterpreter(null, null) + val interpreter = new SparkInterpreter(null) it("should parse Scala compile error.") { // Regression test for LIVY-. http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/0ddcaf68/repl/src/main/scala/com/cloudera/livy/repl/Interpreter.scala ---------------------------------------------------------------------- diff --git a/repl/src/main/scala/com/cloudera/livy/repl/Interpreter.scala b/repl/src/main/scala/com/cloudera/livy/repl/Interpreter.scala index 59ad878..fa3b640 100644 --- a/repl/src/main/scala/com/cloudera/livy/repl/Interpreter.scala +++ b/repl/src/main/scala/com/cloudera/livy/repl/Interpreter.scala @@ -37,8 +37,6 @@ trait Interpreter { def kind: String - def statementProgressListener: StatementProgressListener - /** * Start the Interpreter. * @@ -47,14 +45,6 @@ trait Interpreter { def start(): SparkContext /** - * Execute the code and return the result. - */ - def execute(statementId: Int, code: String): ExecuteResponse = { - statementProgressListener.setCurrentStatementId(statementId) - execute(code) - } - - /** * Execute the code and return the result, it may * take some time to execute. */ http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/0ddcaf68/repl/src/main/scala/com/cloudera/livy/repl/ProcessInterpreter.scala ---------------------------------------------------------------------- diff --git a/repl/src/main/scala/com/cloudera/livy/repl/ProcessInterpreter.scala b/repl/src/main/scala/com/cloudera/livy/repl/ProcessInterpreter.scala index 0414bbb..fe10697 100644 --- a/repl/src/main/scala/com/cloudera/livy/repl/ProcessInterpreter.scala +++ b/repl/src/main/scala/com/cloudera/livy/repl/ProcessInterpreter.scala @@ -41,8 +41,7 @@ private case class ShutdownRequest(promise: Promise[Unit]) extends Request * * @param process */ -abstract class ProcessInterpreter(process: Process, - override val statementProgressListener: StatementProgressListener) +abstract class ProcessInterpreter(process: Process) extends Interpreter with Logging { protected[this] val stdin = new PrintWriter(process.getOutputStream) protected[this] val stdout = new BufferedReader(new InputStreamReader(process.getInputStream), 1) @@ -53,9 +52,7 @@ abstract class ProcessInterpreter(process: Process, if (ClientConf.TEST_MODE) { null.asInstanceOf[SparkContext] } else { - val sc = SparkContext.getOrCreate() - sc.addSparkListener(statementProgressListener) - sc + SparkContext.getOrCreate() } } http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/0ddcaf68/repl/src/main/scala/com/cloudera/livy/repl/PythonInterpreter.scala ---------------------------------------------------------------------- diff --git a/repl/src/main/scala/com/cloudera/livy/repl/PythonInterpreter.scala b/repl/src/main/scala/com/cloudera/livy/repl/PythonInterpreter.scala index 6e80c09..a04dfef 100644 --- a/repl/src/main/scala/com/cloudera/livy/repl/PythonInterpreter.scala +++ b/repl/src/main/scala/com/cloudera/livy/repl/PythonInterpreter.scala @@ -45,7 +45,7 @@ import com.cloudera.livy.sessions._ // scalastyle:off println object PythonInterpreter extends Logging { - def apply(conf: SparkConf, kind: Kind, listener: StatementProgressListener): Interpreter = { + def apply(conf: SparkConf, kind: Kind): Interpreter = { val pythonExec = kind match { case PySpark() => sys.env.getOrElse("PYSPARK_PYTHON", "python") case PySpark3() => sys.env.getOrElse("PYSPARK3_PYTHON", "python3") @@ -72,7 +72,7 @@ object PythonInterpreter extends Logging { env.put("LIVY_SPARK_MAJOR_VERSION", conf.get("spark.livy.spark_major_version", "1")) builder.redirectError(Redirect.PIPE) val process = builder.start() - new PythonInterpreter(process, gatewayServer, kind.toString, listener) + new PythonInterpreter(process, gatewayServer, kind.toString) } private def findPySparkArchives(): Seq[String] = { @@ -190,9 +190,8 @@ object PythonInterpreter extends Logging { private class PythonInterpreter( process: Process, gatewayServer: GatewayServer, - pyKind: String, - listener: StatementProgressListener) - extends ProcessInterpreter(process, listener) + pyKind: String) + extends ProcessInterpreter(process) with Logging { implicit val formats = DefaultFormats http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/0ddcaf68/repl/src/main/scala/com/cloudera/livy/repl/ReplDriver.scala ---------------------------------------------------------------------- diff --git a/repl/src/main/scala/com/cloudera/livy/repl/ReplDriver.scala b/repl/src/main/scala/com/cloudera/livy/repl/ReplDriver.scala index d368c6a..695a9d0 100644 --- a/repl/src/main/scala/com/cloudera/livy/repl/ReplDriver.scala +++ b/repl/src/main/scala/com/cloudera/livy/repl/ReplDriver.scala @@ -44,11 +44,11 @@ class ReplDriver(conf: SparkConf, livyConf: RSCConf) override protected def initializeContext(): JavaSparkContext = { interpreter = kind match { - case PySpark() => PythonInterpreter(conf, PySpark(), new StatementProgressListener(livyConf)) + case PySpark() => PythonInterpreter(conf, PySpark()) case PySpark3() => - PythonInterpreter(conf, PySpark3(), new StatementProgressListener(livyConf)) - case Spark() => new SparkInterpreter(conf, new StatementProgressListener(livyConf)) - case SparkR() => SparkRInterpreter(conf, new StatementProgressListener(livyConf)) + PythonInterpreter(conf, PySpark3()) + case Spark() => new SparkInterpreter(conf) + case SparkR() => SparkRInterpreter(conf) } session = new Session(livyConf, interpreter, { s => broadcast(new ReplState(s.toString)) }) @@ -94,7 +94,7 @@ class ReplDriver(conf: SparkConf, livyConf: RSCConf) // Update progress of statements when queried statements.foreach { s => - s.updateProgress(interpreter.statementProgressListener.progressOfStatement(s.id)) + s.updateProgress(session.progressOfStatement(s.id)) } new ReplJobResults(statements.sortBy(_.id)) http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/0ddcaf68/repl/src/main/scala/com/cloudera/livy/repl/Session.scala ---------------------------------------------------------------------- diff --git a/repl/src/main/scala/com/cloudera/livy/repl/Session.scala b/repl/src/main/scala/com/cloudera/livy/repl/Session.scala index 54056a3..31e520c 100644 --- a/repl/src/main/scala/com/cloudera/livy/repl/Session.scala +++ b/repl/src/main/scala/com/cloudera/livy/repl/Session.scala @@ -170,6 +170,29 @@ class Session( interpreter.close() } + /** + * Get the current progress of given statement id. + */ + def progressOfStatement(stmtId: Int): Double = { + val jobGroup = statementIdToJobGroup(stmtId) + + _sc.map { sc => + val jobIds = sc.statusTracker.getJobIdsForGroup(jobGroup) + val jobs = jobIds.flatMap { id => sc.statusTracker.getJobInfo(id) } + val stages = jobs.flatMap { job => + job.stageIds().flatMap(sc.statusTracker.getStageInfo) + } + + val taskCount = stages.map(_.numTasks).sum + val completedTaskCount = stages.map(_.numCompletedTasks).sum + if (taskCount == 0) { + 0.0 + } else { + completedTaskCount.toDouble / taskCount + } + }.getOrElse(0.0) + } + private def changeState(newState: SessionState): Unit = { synchronized { _state = newState @@ -188,7 +211,7 @@ class Session( } val resultInJson = try { - interpreter.execute(executionCount, code) match { + interpreter.execute(code) match { case Interpreter.ExecuteSuccess(data) => transitToIdle() @@ -240,23 +263,28 @@ class Session( } private def setJobGroup(statementId: Int): String = { + val jobGroup = statementIdToJobGroup(statementId) val cmd = Kind(interpreter.kind) match { case Spark() => // A dummy value to avoid automatic value binding in scala REPL. - s"""val _livyJobGroup$statementId = sc.setJobGroup("$statementId",""" + - s""""Job group for statement $statementId")""" + s"""val _livyJobGroup$jobGroup = sc.setJobGroup("$jobGroup",""" + + s""""Job group for statement $jobGroup")""" case PySpark() | PySpark3() => - s"""sc.setJobGroup("$statementId", "Job group for statement $statementId")""" + s"""sc.setJobGroup("$jobGroup", "Job group for statement $jobGroup")""" case SparkR() => interpreter.asInstanceOf[SparkRInterpreter].sparkMajorVersion match { case "1" => - s"""setJobGroup(sc, "$statementId", "Job group for statement $statementId", """ + + s"""setJobGroup(sc, "$jobGroup", "Job group for statement $jobGroup", """ + "FALSE)" case "2" => - s"""setJobGroup("$statementId", "Job group for statement $statementId", FALSE)""" + s"""setJobGroup("$jobGroup", "Job group for statement $jobGroup", FALSE)""" } } // Set the job group executeCode(statementId, cmd) } + + private def statementIdToJobGroup(statementId: Int): String = { + statementId.toString + } } http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/0ddcaf68/repl/src/main/scala/com/cloudera/livy/repl/SparkRInterpreter.scala ---------------------------------------------------------------------- diff --git a/repl/src/main/scala/com/cloudera/livy/repl/SparkRInterpreter.scala b/repl/src/main/scala/com/cloudera/livy/repl/SparkRInterpreter.scala index 7318b1e..469d0a5 100644 --- a/repl/src/main/scala/com/cloudera/livy/repl/SparkRInterpreter.scala +++ b/repl/src/main/scala/com/cloudera/livy/repl/SparkRInterpreter.scala @@ -68,7 +68,7 @@ object SparkRInterpreter { ")" ).r.unanchored - def apply(conf: SparkConf, listener: StatementProgressListener): SparkRInterpreter = { + def apply(conf: SparkConf): SparkRInterpreter = { val backendTimeout = sys.env.getOrElse("SPARKR_BACKEND_TIMEOUT", "120").toInt val mirror = universe.runtimeMirror(getClass.getClassLoader) val sparkRBackendClass = mirror.classLoader.loadClass("org.apache.spark.api.r.RBackend") @@ -121,8 +121,7 @@ object SparkRInterpreter { val process = builder.start() new SparkRInterpreter(process, backendInstance, backendThread, conf.get("spark.livy.spark_major_version", "1"), - conf.getBoolean("spark.repl.enableHiveContext", false), - listener) + conf.getBoolean("spark.repl.enableHiveContext", false)) } catch { case e: Exception => if (backendThread != null) { @@ -137,9 +136,8 @@ class SparkRInterpreter(process: Process, backendInstance: Any, backendThread: Thread, val sparkMajorVersion: String, - hiveEnabled: Boolean, - statementProgressListener: StatementProgressListener) - extends ProcessInterpreter(process, statementProgressListener) { + hiveEnabled: Boolean) + extends ProcessInterpreter(process) { import SparkRInterpreter._ implicit val formats = DefaultFormats http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/0ddcaf68/repl/src/main/scala/com/cloudera/livy/repl/StatementProgressListener.scala ---------------------------------------------------------------------- diff --git a/repl/src/main/scala/com/cloudera/livy/repl/StatementProgressListener.scala b/repl/src/main/scala/com/cloudera/livy/repl/StatementProgressListener.scala deleted file mode 100644 index ae2147b..0000000 --- a/repl/src/main/scala/com/cloudera/livy/repl/StatementProgressListener.scala +++ /dev/null @@ -1,162 +0,0 @@ -/* - * Licensed to Cloudera, Inc. under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. Cloudera, Inc. licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.cloudera.livy.repl - -import scala.collection.mutable - -import com.google.common.annotations.VisibleForTesting -import org.apache.spark.Success -import org.apache.spark.scheduler._ - -import com.cloudera.livy.rsc.RSCConf - -/** - * [[StatementProgressListener]] is an implementation of SparkListener, used to track the progress - * of submitted statement, this class builds a mapping relation between statement, jobs, stages - * and tasks, and uses the finished task number to calculate the statement progress. - * - * By default 100 latest statement progresses will be kept, users could also configure - * livy.rsc.retained_statements to change the cached number. - * - * This statement progress can only reflect the statement in which has Spark jobs, if - * the statement submitted doesn't generate any Spark job, the progress will always return 0.0 - * until completed. - * - * Also if the statement includes several Spark jobs, the progress will be flipped because we - * don't know the actual number of Spark jobs/tasks generated before the statement executed. - */ -class StatementProgressListener(conf: RSCConf) extends SparkListener { - - case class TaskCount(var currFinishedTasks: Int, var totalTasks: Int) - case class JobState(jobId: Int, var isCompleted: Boolean) - - private val retainedStatements = conf.getInt(RSCConf.Entry.RETAINED_STATEMENT_NUMBER) - - /** Statement id to list of jobs map */ - @VisibleForTesting - private[repl] val statementToJobs = new mutable.LinkedHashMap[Int, Seq[JobState]]() - @VisibleForTesting - private[repl] val jobIdToStatement = new mutable.HashMap[Int, Int]() - /** Job id to list of stage ids map */ - @VisibleForTesting - private[repl] val jobIdToStages = new mutable.HashMap[Int, Seq[Int]]() - /** Stage id to number of finished/total tasks map */ - @VisibleForTesting - private[repl] val stageIdToTaskCount = new mutable.HashMap[Int, TaskCount]() - - @transient private var currentStatementId: Int = _ - - /** - * Set current statement id, onJobStart() will use current statement id to build the mapping - * relations. - */ - def setCurrentStatementId(stmtId: Int): Unit = { - currentStatementId = stmtId - } - - /** - * Get the current progress of given statement id. - */ - def progressOfStatement(stmtId: Int): Double = synchronized { - var finishedTasks = 0 - var totalTasks = 0 - - for { - job <- statementToJobs.getOrElse(stmtId, Seq.empty) - stageId <- jobIdToStages.getOrElse(job.jobId, Seq.empty) - taskCount <- stageIdToTaskCount.get(stageId) - } yield { - finishedTasks += taskCount.currFinishedTasks - totalTasks += taskCount.totalTasks - } - - if (totalTasks == 0) { - 0.0 - } else { - finishedTasks.toDouble / totalTasks - } - } - - /** - * Get the active job ids of the given statement id. - */ - def activeJobsOfStatement(stmtId: Int): Seq[Int] = synchronized { - statementToJobs.getOrElse(stmtId, Seq.empty).filter(!_.isCompleted).map(_.jobId) - } - - override def onJobStart(jobStart: SparkListenerJobStart): Unit = synchronized { - val jobs = statementToJobs.getOrElseUpdate(currentStatementId, Seq.empty) :+ - JobState(jobStart.jobId, isCompleted = false) - statementToJobs.put(currentStatementId, jobs) - jobIdToStatement(jobStart.jobId) = currentStatementId - - jobIdToStages(jobStart.jobId) = jobStart.stageInfos.map(_.stageId) - jobStart.stageInfos.foreach { s => stageIdToTaskCount(s.stageId) = TaskCount(0, s.numTasks) } - } - - override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { - taskEnd.reason match { - case Success => - stageIdToTaskCount.get(taskEnd.stageId).foreach { t => t.currFinishedTasks += 1 } - case _ => - // If task is failed, it will run again, so don't count it. - } - } - - override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = synchronized { - // If stage is resubmitted, we should reset the task count of this stage. - stageIdToTaskCount.get(stageSubmitted.stageInfo.stageId).foreach { t => - t.currFinishedTasks = 0 - t.totalTasks = stageSubmitted.stageInfo.numTasks - } - } - - override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = synchronized { - stageIdToTaskCount.get(stageCompleted.stageInfo.stageId).foreach { t => - t.currFinishedTasks = t.totalTasks - } - } - - override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = synchronized { - jobIdToStatement.get(jobEnd.jobId).foreach { stmtId => - statementToJobs.get(stmtId).foreach { jobs => - jobs.filter(_.jobId == jobEnd.jobId).foreach(_.isCompleted = true) - } - } - - // Try to clean the old data when job is finished. This will trigger data cleaning in LRU - // policy. - cleanOldMetadata() - } - - private def cleanOldMetadata(): Unit = { - if (statementToJobs.size > retainedStatements) { - val toRemove = statementToJobs.size - retainedStatements - statementToJobs.take(toRemove).foreach { case (_, jobs) => - jobs.foreach { job => - jobIdToStatement.remove(job.jobId) - jobIdToStages.remove(job.jobId).foreach { stages => - stages.foreach(s => stageIdToTaskCount.remove(s)) - } - } - } - (0 until toRemove).foreach(_ => statementToJobs.remove(statementToJobs.head._1)) - } - } -} http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/0ddcaf68/repl/src/test/scala/com/cloudera/livy/repl/PythonInterpreterSpec.scala ---------------------------------------------------------------------- diff --git a/repl/src/test/scala/com/cloudera/livy/repl/PythonInterpreterSpec.scala b/repl/src/test/scala/com/cloudera/livy/repl/PythonInterpreterSpec.scala index c67d580..a4a40af 100644 --- a/repl/src/test/scala/com/cloudera/livy/repl/PythonInterpreterSpec.scala +++ b/repl/src/test/scala/com/cloudera/livy/repl/PythonInterpreterSpec.scala @@ -245,8 +245,7 @@ class Python2InterpreterSpec extends PythonBaseInterpreterSpec { implicit val formats = DefaultFormats - override def createInterpreter(): Interpreter = - PythonInterpreter(new SparkConf(), PySpark(), new StatementProgressListener(new RSCConf())) + override def createInterpreter(): Interpreter = PythonInterpreter(new SparkConf(), PySpark()) // Scalastyle is treating unicode escape as non ascii characters. Turn off the check. // scalastyle:off non.ascii.character.disallowed @@ -273,8 +272,7 @@ class Python3InterpreterSpec extends PythonBaseInterpreterSpec { test() } - override def createInterpreter(): Interpreter = - PythonInterpreter(new SparkConf(), PySpark3(), new StatementProgressListener(new RSCConf())) + override def createInterpreter(): Interpreter = PythonInterpreter(new SparkConf(), PySpark3()) it should "check python version is 3.x" in withInterpreter { interpreter => val response = interpreter.execute("""import sys http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/0ddcaf68/repl/src/test/scala/com/cloudera/livy/repl/PythonSessionSpec.scala ---------------------------------------------------------------------- diff --git a/repl/src/test/scala/com/cloudera/livy/repl/PythonSessionSpec.scala b/repl/src/test/scala/com/cloudera/livy/repl/PythonSessionSpec.scala index 28f457f..1e5958d 100644 --- a/repl/src/test/scala/com/cloudera/livy/repl/PythonSessionSpec.scala +++ b/repl/src/test/scala/com/cloudera/livy/repl/PythonSessionSpec.scala @@ -174,8 +174,7 @@ abstract class PythonSessionSpec extends BaseSessionSpec { } class Python2SessionSpec extends PythonSessionSpec { - override def createInterpreter(): Interpreter = - PythonInterpreter(new SparkConf(), PySpark(), new StatementProgressListener(new RSCConf())) + override def createInterpreter(): Interpreter = PythonInterpreter(new SparkConf(), PySpark()) } class Python3SessionSpec extends PythonSessionSpec { @@ -185,8 +184,7 @@ class Python3SessionSpec extends PythonSessionSpec { test() } - override def createInterpreter(): Interpreter = - PythonInterpreter(new SparkConf(), PySpark3(), new StatementProgressListener(new RSCConf())) + override def createInterpreter(): Interpreter = PythonInterpreter(new SparkConf(), PySpark3()) it should "check python version is 3.x" in withSession { session => val statement = execute(session)( http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/0ddcaf68/repl/src/test/scala/com/cloudera/livy/repl/ScalaInterpreterSpec.scala ---------------------------------------------------------------------- diff --git a/repl/src/test/scala/com/cloudera/livy/repl/ScalaInterpreterSpec.scala b/repl/src/test/scala/com/cloudera/livy/repl/ScalaInterpreterSpec.scala index a9e1e8b..0126796 100644 --- a/repl/src/test/scala/com/cloudera/livy/repl/ScalaInterpreterSpec.scala +++ b/repl/src/test/scala/com/cloudera/livy/repl/ScalaInterpreterSpec.scala @@ -29,7 +29,7 @@ class ScalaInterpreterSpec extends BaseInterpreterSpec { implicit val formats = DefaultFormats override def createInterpreter(): Interpreter = - new SparkInterpreter(new SparkConf(), new StatementProgressListener(new RSCConf())) + new SparkInterpreter(new SparkConf()) it should "execute `1 + 2` == 3" in withInterpreter { interpreter => val response = interpreter.execute("1 + 2") http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/0ddcaf68/repl/src/test/scala/com/cloudera/livy/repl/SparkRInterpreterSpec.scala ---------------------------------------------------------------------- diff --git a/repl/src/test/scala/com/cloudera/livy/repl/SparkRInterpreterSpec.scala b/repl/src/test/scala/com/cloudera/livy/repl/SparkRInterpreterSpec.scala index a513867..61f1a36 100644 --- a/repl/src/test/scala/com/cloudera/livy/repl/SparkRInterpreterSpec.scala +++ b/repl/src/test/scala/com/cloudera/livy/repl/SparkRInterpreterSpec.scala @@ -34,8 +34,7 @@ class SparkRInterpreterSpec extends BaseInterpreterSpec { super.withFixture(test) } - override def createInterpreter(): Interpreter = - SparkRInterpreter(new SparkConf(), new StatementProgressListener(new RSCConf())) + override def createInterpreter(): Interpreter = SparkRInterpreter(new SparkConf()) it should "execute `1 + 2` == 3" in withInterpreter { interpreter => val response = interpreter.execute("1 + 2") http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/0ddcaf68/repl/src/test/scala/com/cloudera/livy/repl/SparkRSessionSpec.scala ---------------------------------------------------------------------- diff --git a/repl/src/test/scala/com/cloudera/livy/repl/SparkRSessionSpec.scala b/repl/src/test/scala/com/cloudera/livy/repl/SparkRSessionSpec.scala index a6091d0..c604205 100644 --- a/repl/src/test/scala/com/cloudera/livy/repl/SparkRSessionSpec.scala +++ b/repl/src/test/scala/com/cloudera/livy/repl/SparkRSessionSpec.scala @@ -31,8 +31,7 @@ class SparkRSessionSpec extends BaseSessionSpec { super.withFixture(test) } - override def createInterpreter(): Interpreter = - SparkRInterpreter(new SparkConf(), new StatementProgressListener(new RSCConf())) + override def createInterpreter(): Interpreter = SparkRInterpreter(new SparkConf()) it should "execute `1 + 2` == 3" in withSession { session => val statement = execute(session)("1 + 2") http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/0ddcaf68/repl/src/test/scala/com/cloudera/livy/repl/SparkSessionSpec.scala ---------------------------------------------------------------------- diff --git a/repl/src/test/scala/com/cloudera/livy/repl/SparkSessionSpec.scala b/repl/src/test/scala/com/cloudera/livy/repl/SparkSessionSpec.scala index a051513..52b6b42 100644 --- a/repl/src/test/scala/com/cloudera/livy/repl/SparkSessionSpec.scala +++ b/repl/src/test/scala/com/cloudera/livy/repl/SparkSessionSpec.scala @@ -32,8 +32,7 @@ import com.cloudera.livy.rsc.driver.StatementState class SparkSessionSpec extends BaseSessionSpec { - override def createInterpreter(): Interpreter = - new SparkInterpreter(new SparkConf(), new StatementProgressListener(new RSCConf())) + override def createInterpreter(): Interpreter = new SparkInterpreter(new SparkConf()) it should "execute `1 + 2` == 3" in withSession { session => val statement = execute(session)("1 + 2") @@ -240,4 +239,36 @@ class SparkSessionSpec extends BaseSessionSpec { "Job 0 cancelled part of cancelled job group 0") } } + + it should "correctly calculate progress" in withSession { session => + val executeCode = + """ + |sc.parallelize(1 to 2, 2).map(i => (i, 1)).collect() + """.stripMargin + + val stmtId = session.execute(executeCode) + eventually(timeout(30 seconds), interval(100 millis)) { + session.progressOfStatement(stmtId) should be(1.0) + } + } + + it should "not generate Spark jobs for plain Scala code" in withSession { session => + val executeCode = """1 + 1""" + + val stmtId = session.execute(executeCode) + session.progressOfStatement(stmtId) should be (0.0) + } + + it should "handle multiple jobs in one statement" in withSession { session => + val executeCode = + """ + |sc.parallelize(1 to 2, 2).map(i => (i, 1)).collect() + |sc.parallelize(1 to 2, 2).map(i => (i, 1)).collect() + """.stripMargin + + val stmtId = session.execute(executeCode) + eventually(timeout(30 seconds), interval(100 millis)) { + session.progressOfStatement(stmtId) should be(1.0) + } + } } http://git-wip-us.apache.org/repos/asf/incubator-livy/blob/0ddcaf68/repl/src/test/scala/com/cloudera/livy/repl/StatementProgressListenerSpec.scala ---------------------------------------------------------------------- diff --git a/repl/src/test/scala/com/cloudera/livy/repl/StatementProgressListenerSpec.scala b/repl/src/test/scala/com/cloudera/livy/repl/StatementProgressListenerSpec.scala deleted file mode 100644 index 2acee4c..0000000 --- a/repl/src/test/scala/com/cloudera/livy/repl/StatementProgressListenerSpec.scala +++ /dev/null @@ -1,227 +0,0 @@ -/* - * Licensed to Cloudera, Inc. under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. Cloudera, Inc. licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.cloudera.livy.repl - -import java.util.concurrent.atomic.AtomicInteger - -import scala.collection.mutable.ArrayBuffer -import scala.concurrent.duration._ -import scala.language.{postfixOps, reflectiveCalls} - -import org.apache.spark.SparkConf -import org.apache.spark.scheduler._ -import org.scalatest._ -import org.scalatest.concurrent.Eventually._ - -import com.cloudera.livy.LivyBaseUnitTestSuite -import com.cloudera.livy.rsc.RSCConf - -class StatementProgressListenerSpec extends FlatSpec - with Matchers - with BeforeAndAfterAll - with BeforeAndAfter - with LivyBaseUnitTestSuite { - private val rscConf = new RSCConf() - .set(RSCConf.Entry.RETAINED_STATEMENT_NUMBER, 2) - - private val testListener = new StatementProgressListener(rscConf) { - var onJobStartedCallback: Option[() => Unit] = None - var onJobEndCallback: Option[() => Unit] = None - var onStageEndCallback: Option[() => Unit] = None - var onTaskEndCallback: Option[() => Unit] = None - - override def onJobStart(jobStart: SparkListenerJobStart): Unit = { - super.onJobStart(jobStart) - onJobStartedCallback.foreach(f => f()) - } - - override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { - super.onJobEnd(jobEnd) - onJobEndCallback.foreach(f => f()) - } - - override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { - super.onStageCompleted(stageCompleted) - onStageEndCallback.foreach(f => f()) - } - - override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { - super.onTaskEnd(taskEnd) - onTaskEndCallback.foreach(f => f()) - } - } - - private val statementId = new AtomicInteger(0) - - private def getStatementId = statementId.getAndIncrement() - - private var sparkInterpreter: SparkInterpreter = _ - - override def beforeAll(): Unit = { - super.beforeAll() - sparkInterpreter = new SparkInterpreter(new SparkConf(), testListener) - sparkInterpreter.start() - } - - override def afterAll(): Unit = { - sparkInterpreter.close() - super.afterAll() - } - - after { - testListener.onJobStartedCallback = None - testListener.onJobEndCallback = None - testListener.onStageEndCallback = None - testListener.onTaskEndCallback = None - } - - it should "correctly calculate progress" in { - val executeCode = - """ - |sc.parallelize(1 to 2, 2).map(i => (i, 1)).collect() - """.stripMargin - val stmtId = getStatementId - - def verifyJobs(): Unit = { - testListener.statementToJobs.get(stmtId) should not be (None) - - // One job will be submitted - testListener.statementToJobs(stmtId).size should be (1) - val jobId = testListener.statementToJobs(stmtId).head.jobId - testListener.jobIdToStatement(jobId) should be (stmtId) - - // 1 stage will be generated - testListener.jobIdToStages(jobId).size should be (1) - val stageIds = testListener.jobIdToStages(jobId) - - // 2 tasks per stage will be generated - stageIds.foreach { id => - testListener.stageIdToTaskCount(id).currFinishedTasks should be (0) - testListener.stageIdToTaskCount(id).totalTasks should be (2) - } - } - - var taskEndCalls = 0 - def verifyTasks(): Unit = { - taskEndCalls += 1 - testListener.progressOfStatement(stmtId) should be (taskEndCalls.toDouble / 2) - } - - var stageEndCalls = 0 - def verifyStages(): Unit = { - stageEndCalls += 1 - testListener.progressOfStatement(stmtId) should be (stageEndCalls.toDouble / 1) - } - - testListener.onJobStartedCallback = Some(verifyJobs) - testListener.onTaskEndCallback = Some(verifyTasks) - testListener.onStageEndCallback = Some(verifyStages) - sparkInterpreter.execute(stmtId, executeCode) - - eventually(timeout(30 seconds), interval(100 millis)) { - testListener.progressOfStatement(stmtId) should be(1.0) - } - } - - it should "not generate Spark jobs for plain Scala code" in { - val executeCode = """1 + 1""" - val stmtId = getStatementId - - def verifyJobs(): Unit = { - fail("No job will be submitted") - } - - testListener.onJobStartedCallback = Some(verifyJobs) - testListener.progressOfStatement(stmtId) should be (0.0) - sparkInterpreter.execute(stmtId, executeCode) - testListener.progressOfStatement(stmtId) should be (0.0) - } - - it should "handle multiple jobs in one statement" in { - val executeCode = - """ - |sc.parallelize(1 to 2, 2).map(i => (i, 1)).collect() - |sc.parallelize(1 to 2, 2).map(i => (i, 1)).collect() - """.stripMargin - val stmtId = getStatementId - - var jobs = 0 - def verifyJobs(): Unit = { - jobs += 1 - - testListener.statementToJobs.get(stmtId) should not be (None) - // One job will be submitted - testListener.statementToJobs(stmtId).size should be (jobs) - val jobId = testListener.statementToJobs(stmtId)(jobs - 1).jobId - testListener.jobIdToStatement(jobId) should be (stmtId) - - // 1 stages will be generated - testListener.jobIdToStages(jobId).size should be (1) - val stageIds = testListener.jobIdToStages(jobId) - - // 2 tasks per stage will be generated - stageIds.foreach { id => - testListener.stageIdToTaskCount(id).currFinishedTasks should be (0) - testListener.stageIdToTaskCount(id).totalTasks should be (2) - } - } - - val taskProgress = ArrayBuffer[Double]() - def verifyTasks(): Unit = { - taskProgress += testListener.progressOfStatement(stmtId) - } - - val stageProgress = ArrayBuffer[Double]() - def verifyStages(): Unit = { - stageProgress += testListener.progressOfStatement(stmtId) - } - - testListener.onJobStartedCallback = Some(verifyJobs) - testListener.onTaskEndCallback = Some(verifyTasks) - testListener.onStageEndCallback = Some(verifyStages) - sparkInterpreter.execute(stmtId, executeCode) - - taskProgress.toArray should be (Array(0.5, 1.0, 0.75, 1.0)) - stageProgress.toArray should be (Array(1.0, 1.0)) - - eventually(timeout(30 seconds), interval(100 millis)) { - testListener.progressOfStatement(stmtId) should be(1.0) - } - } - - it should "remove old statement progress" in { - val executeCode = - """ - |sc.parallelize(1 to 2, 2).map(i => (i, 1)).collect() - """.stripMargin - val stmtId = getStatementId - - def onJobEnd(): Unit = { - testListener.statementToJobs(stmtId).size should be (1) - testListener.statementToJobs(stmtId).head.isCompleted should be (true) - - testListener.statementToJobs.size should be (2) - testListener.statementToJobs.get(0) should be (None) - testListener.jobIdToStatement.filter(_._2 == 0) should be (Map.empty) - } - - testListener.onJobEndCallback = Some(onJobEnd) - sparkInterpreter.execute(stmtId, executeCode) - } -}