This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d2ff10c  [SPARK-23674][ML] Adds Spark ML Events to Instrumentation
d2ff10c is described below

commit d2ff10cbe1c22f919a7b1999fe54db13f4178979
Author: Hyukjin Kwon <gurwls...@apache.org>
AuthorDate: Fri Jan 25 10:11:49 2019 +0800

    [SPARK-23674][ML] Adds Spark ML Events to Instrumentation
    
    ## What changes were proposed in this pull request?
    
    This PR proposes to add ML events to Instrumentation, and use it in 
Pipeline so that other developers can track and add some actions for them.
    
    ## Introduction
    
    ML events (like SQL events) can be quite useful when people want to track 
and make some actions for corresponding ML operations. For instance, I have 
been working on integrating
    Apache Spark with [Apache Atlas](https://atlas.apache.org/QuickStart.html). 
With some custom changes with this PR, I can visualise ML pipeline as below:
    
    
![spark_ml_streaming_lineage](https://user-images.githubusercontent.com/6477701/49682779-394bca80-faf5-11e8-85b8-5fae28b784b3.png)
    
    Another good thing that might have to be considered is, that we can 
interact this with other SQL/Streaming events. For instance, where the input 
`Dataset` is originated. For instance, with current Apache Spark, I can 
visualise SQL operations as below:
    
    ![screen shot 2018-12-10 at 9 41 36 
am](https://user-images.githubusercontent.com/6477701/49706269-d9bdfe00-fc5f-11e8-943a-3309d1856ba5.png)
    
    I think we can combine those existing lineages together to easily 
understand where the data comes and goes. Currently, ML side is a hole so the 
lineages can't be connected for the current Apache Spark ..
    
    To add up, I think it's not to mention how useful it is to track the 
SQL/Streaming operations. Likewise, I would like to propose ML events as well 
(as lowest stability `Unstable` APIs for now - no guarantee about stability).
    
    ## Implementation Details
    
    ### Sends event (but not expose ML specific listener)
    
    **`mllib/src/main/scala/org/apache/spark/ml/events.scala`**
    
    ```scala
    Unstable
    case class ...StartEvent(caller, input)
    Unstable
    case class ...EndEvent(caller, output)
    
    trait MLEvents {
      // Wrappers to send events:
      // def with...Event(body) = {
      //   body()
      //   SparkContext.getOrCreate().listenerBus.post(event)
      // }
    }
    ```
    
    This trait is used by `Instrumentation`.
    
    ```scala
    class Instrumentation ... with MLEvents {
    ```
    
    and used as below:
    
    ```scala
    instrumented { instr =>
      instr.with...Event(...) {
        ...
      }
    }
    ```
    
    This way mimics both:
    
    **1. Catalog events (see 
`org/apache/spark/sql/catalyst/catalog/events.scala`)**
    
    - This allows a Catalog specific listener to be added 
`ExternalCatalogEventListener`
    
    - It's implemented in a way of wrapping whole `ExternalCatalog` named 
`ExternalCatalogWithListener`
    which delegates the operations to `ExternalCatalog`
    
    This is not quite possible in this case because most of instances (like 
`Pipeline`) will be directly created in most of cases. We might be able to do 
that via extending `ListenerBus` for all possible instances but IMHO it's too 
invasive. Also, exposing another ML specific listener sounds a bit too much at 
this stage. Therefore, I simply borrowed file name and structures here
    
    **2. SQL execution events (see 
`org/apache/spark/sql/execution/SQLExecution.scala`)**
    
    - Add an object that wraps a body to send events
    
    Current apporach is rather close to this. It has a `with...` wrapper to 
send events. I borrowed this approach to be consistent.
    
    ## Usage
    
    It needs a custom implementation for a query listener. For instance,
    
    with the custom listener below:
    
    ```scala
    class CustomMLListener extends SparkListener
      def onOtherEvents(e) = e match {
        case e: MLEvent => // do something
        case _ => // pass
      }
    }
    ```
    
    There are two (existing) ways to use this.
    
    ```scala
    spark.sparkContext.addSparkListener(new CustomMLListener)
    ```
    
    ```bash
    spark-submit ...\
      --conf spark.extraListeners=CustomMLListener\
      ...
    ```
    
    It's also similar with other existing implementation in SQL side.
    
    ## Target users
    
    1. I think someone in general would likely utilise this feature like other 
event listeners. At least, I can see some interests going on outside.
    
        - SQL Listener
          - 
https://stackoverflow.com/questions/46409339/spark-listener-to-an-sql-query
          - 
http://apache-spark-user-list.1001560.n3.nabble.com/spark-sql-Custom-Query-Execution-listener-via-conf-properties-td30979.html
    
        - Streaming Query Listener
          - https://jhui.github.io/2017/01/15/Apache-Spark-Streaming/
          -  
http://apache-spark-developers-list.1001551.n3.nabble.com/Structured-Streaming-with-Watermark-td25413.html#a25416
    
    2. Someone would likely run this via Atlas. The plugin mirror intentionally 
is exposed at 
[spark-atlas-connector](https://github.com/hortonworks-spark/spark-atlas-connector)
 so that anyone could do something about lineage and governance in Atlas. I'm 
trying to show integrated lineages in Apache Spark but this is a missing hole.
    
    ## How was this patch tested?
    
    Manually tested and unit tests were added.
    
    Closes #23263 from HyukjinKwon/SPARK-23674-1.
    
    Authored-by: Hyukjin Kwon <gurwls...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../main/scala/org/apache/spark/ml/Pipeline.scala  |  48 ++--
 .../main/scala/org/apache/spark/ml/events.scala    | 137 +++++++++++
 .../org/apache/spark/ml/util/Instrumentation.scala |   9 +-
 .../scala/org/apache/spark/ml/util/ReadWrite.scala |  11 +-
 .../scala/org/apache/spark/ml/MLEventsSuite.scala  | 255 +++++++++++++++++++++
 5 files changed, 436 insertions(+), 24 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala 
b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index 103082b..69a4dbe 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -31,6 +31,7 @@ import org.apache.spark.annotation.{DeveloperApi, Since}
 import org.apache.spark.internal.Logging
 import org.apache.spark.ml.param.{Param, ParamMap, Params}
 import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.Instrumentation.instrumented
 import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.types.StructType
 
@@ -132,7 +133,8 @@ class Pipeline @Since("1.4.0") (
    * @return fitted pipeline
    */
   @Since("2.0.0")
-  override def fit(dataset: Dataset[_]): PipelineModel = {
+  override def fit(dataset: Dataset[_]): PipelineModel = instrumented(
+      instr => instr.withFitEvent(this, dataset) {
     transformSchema(dataset.schema, logging = true)
     val theStages = $(stages)
     // Search for the last estimator.
@@ -150,7 +152,7 @@ class Pipeline @Since("1.4.0") (
       if (index <= indexOfLastEstimator) {
         val transformer = stage match {
           case estimator: Estimator[_] =>
-            estimator.fit(curDataset)
+            instr.withFitEvent(estimator, 
curDataset)(estimator.fit(curDataset))
           case t: Transformer =>
             t
           case _ =>
@@ -158,7 +160,8 @@ class Pipeline @Since("1.4.0") (
               s"Does not support stage $stage of type ${stage.getClass}")
         }
         if (index < indexOfLastEstimator) {
-          curDataset = transformer.transform(curDataset)
+          curDataset = instr.withTransformEvent(
+            transformer, curDataset)(transformer.transform(curDataset))
         }
         transformers += transformer
       } else {
@@ -167,7 +170,7 @@ class Pipeline @Since("1.4.0") (
     }
 
     new PipelineModel(uid, transformers.toArray).setParent(this)
-  }
+  })
 
   @Since("1.4.0")
   override def copy(extra: ParamMap): Pipeline = {
@@ -197,10 +200,12 @@ object Pipeline extends MLReadable[Pipeline] {
   @Since("1.6.0")
   override def load(path: String): Pipeline = super.load(path)
 
-  private[Pipeline] class PipelineWriter(instance: Pipeline) extends MLWriter {
+  private[Pipeline] class PipelineWriter(val instance: Pipeline) extends 
MLWriter {
 
     SharedReadWrite.validateStages(instance.getStages)
 
+    override def save(path: String): Unit =
+      instrumented(_.withSaveInstanceEvent(this, path)(super.save(path)))
     override protected def saveImpl(path: String): Unit =
       SharedReadWrite.saveImpl(instance, instance.getStages, sc, path)
   }
@@ -210,10 +215,10 @@ object Pipeline extends MLReadable[Pipeline] {
     /** Checked against metadata when loading model */
     private val className = classOf[Pipeline].getName
 
-    override def load(path: String): Pipeline = {
+    override def load(path: String): Pipeline = 
instrumented(_.withLoadInstanceEvent(this, path) {
       val (uid: String, stages: Array[PipelineStage]) = 
SharedReadWrite.load(className, sc, path)
       new Pipeline(uid).setStages(stages)
-    }
+    })
   }
 
   /**
@@ -243,7 +248,7 @@ object Pipeline extends MLReadable[Pipeline] {
         instance: Params,
         stages: Array[PipelineStage],
         sc: SparkContext,
-        path: String): Unit = {
+        path: String): Unit = instrumented { instr =>
       val stageUids = stages.map(_.uid)
       val jsonParams = List("stageUids" -> 
parse(compact(render(stageUids.toSeq))))
       DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = 
Some(jsonParams))
@@ -251,8 +256,9 @@ object Pipeline extends MLReadable[Pipeline] {
       // Save stages
       val stagesDir = new Path(path, "stages").toString
       stages.zipWithIndex.foreach { case (stage, idx) =>
-        stage.asInstanceOf[MLWritable].write.save(
-          getStagePath(stage.uid, idx, stages.length, stagesDir))
+        val writer = stage.asInstanceOf[MLWritable].write
+        val stagePath = getStagePath(stage.uid, idx, stages.length, stagesDir)
+        instr.withSaveInstanceEvent(writer, stagePath)(writer.save(stagePath))
       }
     }
 
@@ -263,7 +269,7 @@ object Pipeline extends MLReadable[Pipeline] {
     def load(
         expectedClassName: String,
         sc: SparkContext,
-        path: String): (String, Array[PipelineStage]) = {
+        path: String): (String, Array[PipelineStage]) = instrumented { instr =>
       val metadata = DefaultParamsReader.loadMetadata(path, sc, 
expectedClassName)
 
       implicit val format = DefaultFormats
@@ -271,7 +277,8 @@ object Pipeline extends MLReadable[Pipeline] {
       val stageUids: Array[String] = (metadata.params \ 
"stageUids").extract[Seq[String]].toArray
       val stages: Array[PipelineStage] = stageUids.zipWithIndex.map { case 
(stageUid, idx) =>
         val stagePath = SharedReadWrite.getStagePath(stageUid, idx, 
stageUids.length, stagesDir)
-        DefaultParamsReader.loadParamsInstance[PipelineStage](stagePath, sc)
+        val reader = 
DefaultParamsReader.loadParamsInstanceReader[PipelineStage](stagePath, sc)
+        instr.withLoadInstanceEvent(reader, stagePath)(reader.load(stagePath))
       }
       (metadata.uid, stages)
     }
@@ -301,10 +308,12 @@ class PipelineModel private[ml] (
   }
 
   @Since("2.0.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = instrumented(instr 
=>
+      instr.withTransformEvent(this, dataset) {
     transformSchema(dataset.schema, logging = true)
-    stages.foldLeft(dataset.toDF)((cur, transformer) => 
transformer.transform(cur))
-  }
+    stages.foldLeft(dataset.toDF)((cur, transformer) =>
+      instr.withTransformEvent(transformer, cur)(transformer.transform(cur)))
+  })
 
   @Since("1.2.0")
   override def transformSchema(schema: StructType): StructType = {
@@ -331,10 +340,12 @@ object PipelineModel extends MLReadable[PipelineModel] {
   @Since("1.6.0")
   override def load(path: String): PipelineModel = super.load(path)
 
-  private[PipelineModel] class PipelineModelWriter(instance: PipelineModel) 
extends MLWriter {
+  private[PipelineModel] class PipelineModelWriter(val instance: 
PipelineModel) extends MLWriter {
 
     
SharedReadWrite.validateStages(instance.stages.asInstanceOf[Array[PipelineStage]])
 
+    override def save(path: String): Unit =
+      instrumented(_.withSaveInstanceEvent(this, path)(super.save(path)))
     override protected def saveImpl(path: String): Unit = 
SharedReadWrite.saveImpl(instance,
       instance.stages.asInstanceOf[Array[PipelineStage]], sc, path)
   }
@@ -344,7 +355,8 @@ object PipelineModel extends MLReadable[PipelineModel] {
     /** Checked against metadata when loading model */
     private val className = classOf[PipelineModel].getName
 
-    override def load(path: String): PipelineModel = {
+    override def load(path: String): PipelineModel = 
instrumented(_.withLoadInstanceEvent(
+        this, path) {
       val (uid: String, stages: Array[PipelineStage]) = 
SharedReadWrite.load(className, sc, path)
       val transformers = stages map {
         case stage: Transformer => stage
@@ -352,6 +364,6 @@ object PipelineModel extends MLReadable[PipelineModel] {
           s" was not a Transformer.  Bad stage ${other.uid} of type 
${other.getClass}")
       }
       new PipelineModel(uid, transformers)
-    }
+    })
   }
 }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/events.scala 
b/mllib/src/main/scala/org/apache/spark/ml/events.scala
new file mode 100644
index 0000000..c51600f
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/events.scala
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import org.apache.spark.SparkContext
+import org.apache.spark.annotation.Unstable
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.util.{MLReader, MLWriter}
+import org.apache.spark.scheduler.SparkListenerEvent
+import org.apache.spark.sql.{DataFrame, Dataset}
+
+/**
+ * Event emitted by ML operations. Events are either fired before and/or
+ * after each operation (the event should document this).
+ *
+ * @note This is supported via [[Pipeline]] and [[PipelineModel]].
+ */
+@Unstable
+sealed trait MLEvent extends SparkListenerEvent
+
+/**
+ * Event fired before `Transformer.transform`.
+ */
+@Unstable
+case class TransformStart(transformer: Transformer, input: Dataset[_]) extends 
MLEvent
+/**
+ * Event fired after `Transformer.transform`.
+ */
+@Unstable
+case class TransformEnd(transformer: Transformer, output: Dataset[_]) extends 
MLEvent
+
+/**
+ * Event fired before `Estimator.fit`.
+ */
+@Unstable
+case class FitStart[M <: Model[M]](estimator: Estimator[M], dataset: 
Dataset[_]) extends MLEvent
+/**
+ * Event fired after `Estimator.fit`.
+ */
+@Unstable
+case class FitEnd[M <: Model[M]](estimator: Estimator[M], model: M) extends 
MLEvent
+
+/**
+ * Event fired before `MLReader.load`.
+ */
+@Unstable
+case class LoadInstanceStart[T](reader: MLReader[T], path: String) extends 
MLEvent
+/**
+ * Event fired after `MLReader.load`.
+ */
+@Unstable
+case class LoadInstanceEnd[T](reader: MLReader[T], instance: T) extends MLEvent
+
+/**
+ * Event fired before `MLWriter.save`.
+ */
+@Unstable
+case class SaveInstanceStart(writer: MLWriter, path: String) extends MLEvent
+/**
+ * Event fired after `MLWriter.save`.
+ */
+@Unstable
+case class SaveInstanceEnd(writer: MLWriter, path: String) extends MLEvent
+
+/**
+ * A small trait that defines some methods to send 
[[org.apache.spark.ml.MLEvent]].
+ */
+private[ml] trait MLEvents extends Logging {
+
+  private def listenerBus = SparkContext.getOrCreate().listenerBus
+
+  /**
+   * Log [[MLEvent]] to send. By default, it emits a debug-level log.
+   */
+  def logEvent(event: MLEvent): Unit = logDebug(s"Sending an MLEvent: $event")
+
+  def withFitEvent[M <: Model[M]](
+      estimator: Estimator[M], dataset: Dataset[_])(func: => M): M = {
+    val startEvent = FitStart(estimator, dataset)
+    logEvent(startEvent)
+    listenerBus.post(startEvent)
+    val model: M = func
+    val endEvent = FitEnd(estimator, model)
+    logEvent(endEvent)
+    listenerBus.post(endEvent)
+    model
+  }
+
+  def withTransformEvent(
+      transformer: Transformer, input: Dataset[_])(func: => DataFrame): 
DataFrame = {
+    val startEvent = TransformStart(transformer, input)
+    logEvent(startEvent)
+    listenerBus.post(startEvent)
+    val output: DataFrame = func
+    val endEvent = TransformEnd(transformer, output)
+    logEvent(endEvent)
+    listenerBus.post(endEvent)
+    output
+  }
+
+  def withLoadInstanceEvent[T](reader: MLReader[T], path: String)(func: => T): 
T = {
+    val startEvent = LoadInstanceStart(reader, path)
+    logEvent(startEvent)
+    listenerBus.post(startEvent)
+    val instance: T = func
+    val endEvent = LoadInstanceEnd(reader, instance)
+    logEvent(endEvent)
+    listenerBus.post(endEvent)
+    instance
+  }
+
+  def withSaveInstanceEvent(writer: MLWriter, path: String)(func: => Unit): 
Unit = {
+    listenerBus.post(SaveInstanceEnd(writer, path))
+    val startEvent = SaveInstanceStart(writer, path)
+    logEvent(startEvent)
+    listenerBus.post(startEvent)
+    func
+    val endEvent = SaveInstanceEnd(writer, path)
+    logEvent(endEvent)
+    listenerBus.post(endEvent)
+  }
+}
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala 
b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
index 4965491..780650d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
@@ -27,17 +27,18 @@ import org.json4s.JsonDSL._
 import org.json4s.jackson.JsonMethods._
 
 import org.apache.spark.internal.Logging
-import org.apache.spark.ml.PipelineStage
+import org.apache.spark.ml.{MLEvents, PipelineStage}
 import org.apache.spark.ml.param.{Param, Params}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.Dataset
 import org.apache.spark.util.Utils
 
 /**
- * A small wrapper that defines a training session for an estimator, and some 
methods to log
- * useful information during this session.
+ * A small wrapper that defines a training session for an estimator, some 
methods to log
+ * useful information during this session, and some methods to send
+ * [[org.apache.spark.ml.MLEvent]].
  */
-private[spark] class Instrumentation private () extends Logging {
+private[spark] class Instrumentation private () extends Logging with MLEvents {
 
   private val id = UUID.randomUUID()
   private val shortId = id.toString.take(8)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala 
b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index fbc7be2..ce8f346 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -624,10 +624,17 @@ private[ml] object DefaultParamsReader {
    * Load a `Params` instance from the given path, and return it.
    * This assumes the instance implements [[MLReadable]].
    */
-  def loadParamsInstance[T](path: String, sc: SparkContext): T = {
+  def loadParamsInstance[T](path: String, sc: SparkContext): T =
+    loadParamsInstanceReader(path, sc).load(path)
+
+  /**
+   * Load a `Params` instance reader from the given path, and return it.
+   * This assumes the instance implements [[MLReadable]].
+   */
+  def loadParamsInstanceReader[T](path: String, sc: SparkContext): MLReader[T] 
= {
     val metadata = DefaultParamsReader.loadMetadata(path, sc)
     val cls = Utils.classForName(metadata.className)
-    cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path)
+    cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]]
   }
 }
 
diff --git a/mllib/src/test/scala/org/apache/spark/ml/MLEventsSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/MLEventsSuite.scala
new file mode 100644
index 0000000..0a87328
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/MLEventsSuite.scala
@@ -0,0 +1,255 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import scala.collection.mutable
+import scala.concurrent.duration._
+import scala.language.postfixOps
+
+import org.apache.hadoop.fs.Path
+import org.mockito.ArgumentMatchers.{any, eq => meq}
+import org.mockito.Mockito.when
+import org.scalatest.BeforeAndAfterEach
+import org.scalatest.concurrent.Eventually
+import org.scalatest.mockito.MockitoSugar.mock
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter, 
MLWriter}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent}
+import org.apache.spark.sql._
+
+
+class MLEventsSuite
+  extends SparkFunSuite with BeforeAndAfterEach with MLlibTestSparkContext 
with Eventually {
+
+  private val events = mutable.ArrayBuffer.empty[MLEvent]
+  private val listener: SparkListener = new SparkListener {
+    override def onOtherEvent(event: SparkListenerEvent): Unit = event match {
+      case e: MLEvent => events.append(e)
+      case _ =>
+    }
+  }
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    spark.sparkContext.addSparkListener(listener)
+  }
+
+  override def afterEach(): Unit = {
+    try {
+      events.clear()
+    } finally {
+      super.afterEach()
+    }
+  }
+
+  override def afterAll(): Unit = {
+    try {
+      if (spark != null) {
+        spark.sparkContext.removeSparkListener(listener)
+      }
+    } finally {
+      super.afterAll()
+    }
+  }
+
+  abstract class MyModel extends Model[MyModel]
+
+  test("pipeline fit events") {
+    val estimator1 = mock[Estimator[MyModel]]
+    val model1 = mock[MyModel]
+    val transformer1 = mock[Transformer]
+    val estimator2 = mock[Estimator[MyModel]]
+    val model2 = mock[MyModel]
+
+    when(estimator1.copy(any[ParamMap])).thenReturn(estimator1)
+    when(model1.copy(any[ParamMap])).thenReturn(model1)
+    when(transformer1.copy(any[ParamMap])).thenReturn(transformer1)
+    when(estimator2.copy(any[ParamMap])).thenReturn(estimator2)
+    when(model2.copy(any[ParamMap])).thenReturn(model2)
+
+    val dataset1 = mock[DataFrame]
+    val dataset2 = mock[DataFrame]
+    val dataset3 = mock[DataFrame]
+    val dataset4 = mock[DataFrame]
+    val dataset5 = mock[DataFrame]
+
+    when(dataset1.toDF).thenReturn(dataset1)
+    when(dataset2.toDF).thenReturn(dataset2)
+    when(dataset3.toDF).thenReturn(dataset3)
+    when(dataset4.toDF).thenReturn(dataset4)
+    when(dataset5.toDF).thenReturn(dataset5)
+
+    when(estimator1.fit(meq(dataset1))).thenReturn(model1)
+    when(model1.transform(meq(dataset1))).thenReturn(dataset2)
+    when(model1.parent).thenReturn(estimator1)
+    when(transformer1.transform(meq(dataset2))).thenReturn(dataset3)
+    when(estimator2.fit(meq(dataset3))).thenReturn(model2)
+
+    val pipeline = new Pipeline()
+      .setStages(Array(estimator1, transformer1, estimator2))
+    assert(events.isEmpty)
+    val pipelineModel = pipeline.fit(dataset1)
+    val expected =
+      FitStart(pipeline, dataset1) ::
+      FitStart(estimator1, dataset1) ::
+      FitEnd(estimator1, model1) ::
+      TransformStart(model1, dataset1) ::
+      TransformEnd(model1, dataset2) ::
+      TransformStart(transformer1, dataset2) ::
+      TransformEnd(transformer1, dataset3) ::
+      FitStart(estimator2, dataset3) ::
+      FitEnd(estimator2, model2) ::
+      FitEnd(pipeline, pipelineModel) :: Nil
+    eventually(timeout(10 seconds), interval(1 second)) {
+      assert(events === expected)
+    }
+  }
+
+  test("pipeline model transform events") {
+    val dataset1 = mock[DataFrame]
+    val dataset2 = mock[DataFrame]
+    val dataset3 = mock[DataFrame]
+    val dataset4 = mock[DataFrame]
+    when(dataset1.toDF).thenReturn(dataset1)
+    when(dataset2.toDF).thenReturn(dataset2)
+    when(dataset3.toDF).thenReturn(dataset3)
+    when(dataset4.toDF).thenReturn(dataset4)
+
+    val transformer1 = mock[Transformer]
+    val model = mock[MyModel]
+    val transformer2 = mock[Transformer]
+    when(transformer1.transform(meq(dataset1))).thenReturn(dataset2)
+    when(model.transform(meq(dataset2))).thenReturn(dataset3)
+    when(transformer2.transform(meq(dataset3))).thenReturn(dataset4)
+
+    val newPipelineModel = new PipelineModel(
+      "pipeline0", Array(transformer1, model, transformer2))
+    assert(events.isEmpty)
+    val output = newPipelineModel.transform(dataset1)
+    val expected =
+      TransformStart(newPipelineModel, dataset1) ::
+      TransformStart(transformer1, dataset1) ::
+      TransformEnd(transformer1, dataset2) ::
+      TransformStart(model, dataset2) ::
+      TransformEnd(model, dataset3) ::
+      TransformStart(transformer2, dataset3) ::
+      TransformEnd(transformer2, dataset4) ::
+      TransformEnd(newPipelineModel, output) :: Nil
+    eventually(timeout(10 seconds), interval(1 second)) {
+      assert(events === expected)
+    }
+  }
+
+  test("pipeline read/write events") {
+    def getInstance(w: MLWriter): AnyRef =
+      w.getClass.getDeclaredMethod("instance").invoke(w)
+
+    withTempDir { dir =>
+      val path = new Path(dir.getCanonicalPath, "pipeline").toUri.toString
+      val writableStage = new WritableStage("writableStage")
+      val newPipeline = new Pipeline().setStages(Array(writableStage))
+      val pipelineWriter = newPipeline.write
+      assert(events.isEmpty)
+      pipelineWriter.save(path)
+      eventually(timeout(10 seconds), interval(1 second)) {
+        events.foreach {
+          case e: SaveInstanceStart if 
e.writer.isInstanceOf[DefaultParamsWriter] =>
+            assert(e.path.endsWith("writableStage"))
+          case e: SaveInstanceEnd if 
e.writer.isInstanceOf[DefaultParamsWriter] =>
+            assert(e.path.endsWith("writableStage"))
+          case e: SaveInstanceStart if 
getInstance(e.writer).isInstanceOf[Pipeline] =>
+            assert(getInstance(e.writer).asInstanceOf[Pipeline].uid === 
newPipeline.uid)
+          case e: SaveInstanceEnd if 
getInstance(e.writer).isInstanceOf[Pipeline] =>
+            assert(getInstance(e.writer).asInstanceOf[Pipeline].uid === 
newPipeline.uid)
+          case e => fail(s"Unexpected event thrown: $e")
+        }
+      }
+
+      events.clear()
+      val pipelineReader = Pipeline.read
+      assert(events.isEmpty)
+      pipelineReader.load(path)
+      eventually(timeout(10 seconds), interval(1 second)) {
+        events.foreach {
+          case e: LoadInstanceStart[PipelineStage]
+              if e.reader.isInstanceOf[DefaultParamsReader[PipelineStage]] =>
+            assert(e.path.endsWith("writableStage"))
+          case e: LoadInstanceEnd[PipelineStage]
+              if e.reader.isInstanceOf[DefaultParamsReader[PipelineStage]] =>
+            assert(e.instance.isInstanceOf[PipelineStage])
+          case e: LoadInstanceStart[Pipeline] =>
+            assert(e.reader === pipelineReader)
+          case e: LoadInstanceEnd[Pipeline] =>
+            assert(e.instance.uid === newPipeline.uid)
+          case e => fail(s"Unexpected event thrown: $e")
+        }
+      }
+    }
+  }
+
+  test("pipeline model read/write events") {
+    def getInstance(w: MLWriter): AnyRef =
+      w.getClass.getDeclaredMethod("instance").invoke(w)
+
+    withTempDir { dir =>
+      val path = new Path(dir.getCanonicalPath, "pipeline").toUri.toString
+      val writableStage = new WritableStage("writableStage")
+      val pipelineModel =
+        new PipelineModel("pipeline_89329329", 
Array(writableStage.asInstanceOf[Transformer]))
+      val pipelineWriter = pipelineModel.write
+      assert(events.isEmpty)
+      pipelineWriter.save(path)
+      eventually(timeout(10 seconds), interval(1 second)) {
+        events.foreach {
+          case e: SaveInstanceStart if 
e.writer.isInstanceOf[DefaultParamsWriter] =>
+            assert(e.path.endsWith("writableStage"))
+          case e: SaveInstanceEnd if 
e.writer.isInstanceOf[DefaultParamsWriter] =>
+            assert(e.path.endsWith("writableStage"))
+          case e: SaveInstanceStart if 
getInstance(e.writer).isInstanceOf[PipelineModel] =>
+            assert(getInstance(e.writer).asInstanceOf[PipelineModel].uid === 
pipelineModel.uid)
+          case e: SaveInstanceEnd if 
getInstance(e.writer).isInstanceOf[PipelineModel] =>
+            assert(getInstance(e.writer).asInstanceOf[PipelineModel].uid === 
pipelineModel.uid)
+          case e => fail(s"Unexpected event thrown: $e")
+        }
+      }
+
+      events.clear()
+      val pipelineModelReader = PipelineModel.read
+      assert(events.isEmpty)
+      pipelineModelReader.load(path)
+      eventually(timeout(10 seconds), interval(1 second)) {
+        events.foreach {
+          case e: LoadInstanceStart[PipelineStage]
+            if e.reader.isInstanceOf[DefaultParamsReader[PipelineStage]] =>
+            assert(e.path.endsWith("writableStage"))
+          case e: LoadInstanceEnd[PipelineStage]
+            if e.reader.isInstanceOf[DefaultParamsReader[PipelineStage]] =>
+            assert(e.instance.isInstanceOf[PipelineStage])
+          case e: LoadInstanceStart[PipelineModel] =>
+            assert(e.reader === pipelineModelReader)
+          case e: LoadInstanceEnd[PipelineModel] =>
+            assert(e.instance.uid === pipelineModel.uid)
+          case e => fail(s"Unexpected event thrown: $e")
+        }
+      }
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to