[ 
https://issues.apache.org/jira/browse/SPARK-23674?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16718526#comment-16718526
 ] 

ASF GitHub Bot commented on SPARK-23674:
----------------------------------------

felixcheung closed pull request #23263: [SPARK-23674][ML] Adds Spark ML Events
URL: https://github.com/apache/spark/pull/23263
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala 
b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
index 1247882d6c1bd..a3c4db06862f8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
@@ -65,7 +65,19 @@ abstract class Estimator[M <: Model[M]] extends 
PipelineStage {
    * Fits a model to the input data.
    */
   @Since("2.0.0")
-  def fit(dataset: Dataset[_]): M
+  def fit(dataset: Dataset[_]): M = MLEvents.withFitEvent(this, dataset) {
+    fitImpl(dataset)
+  }
+
+  /**
+   * `fit()` handles events and then calls this method. Subclasses should 
override this
+   * method to implement the actual fiting a model to the input data.
+   */
+  @Since("3.0.0")
+  protected def fitImpl(dataset: Dataset[_]): M = {
+    // Keep this default body for backward compatibility.
+    throw new UnsupportedOperationException("fitImpl is not implemented.")
+  }
 
   /**
    * Fits multiple models to the input data with multiple sets of parameters.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala 
b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index 103082b7b9766..1c781faff129e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -132,7 +132,8 @@ class Pipeline @Since("1.4.0") (
    * @return fitted pipeline
    */
   @Since("2.0.0")
-  override def fit(dataset: Dataset[_]): PipelineModel = {
+  override def fit(dataset: Dataset[_]): PipelineModel = super.fit(dataset)
+  override protected def fitImpl(dataset: Dataset[_]): PipelineModel = {
     transformSchema(dataset.schema, logging = true)
     val theStages = $(stages)
     // Search for the last estimator.
@@ -210,7 +211,7 @@ object Pipeline extends MLReadable[Pipeline] {
     /** Checked against metadata when loading model */
     private val className = classOf[Pipeline].getName
 
-    override def load(path: String): Pipeline = {
+    override protected def loadImpl(path: String): Pipeline = {
       val (uid: String, stages: Array[PipelineStage]) = 
SharedReadWrite.load(className, sc, path)
       new Pipeline(uid).setStages(stages)
     }
@@ -301,7 +302,8 @@ class PipelineModel private[ml] (
   }
 
   @Since("2.0.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = 
super.transform(dataset)
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     stages.foldLeft(dataset.toDF)((cur, transformer) => 
transformer.transform(cur))
   }
@@ -344,7 +346,7 @@ object PipelineModel extends MLReadable[PipelineModel] {
     /** Checked against metadata when loading model */
     private val className = classOf[PipelineModel].getName
 
-    override def load(path: String): PipelineModel = {
+    override protected def loadImpl(path: String): PipelineModel = {
       val (uid: String, stages: Array[PipelineStage]) = 
SharedReadWrite.load(className, sc, path)
       val transformers = stages map {
         case stage: Transformer => stage
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala 
b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
index d8f3dfa874439..3731ddae0160c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -94,7 +94,10 @@ abstract class Predictor[
   /** @group setParam */
   def setPredictionCol(value: String): Learner = set(predictionCol, 
value).asInstanceOf[Learner]
 
-  override def fit(dataset: Dataset[_]): M = {
+  // Explictly call parent's load. Otherwise, MiMa complains.
+  override def fit(dataset: Dataset[_]): M = super.fit(dataset)
+
+  override protected def fitImpl(dataset: Dataset[_]): M = {
     // This handles a few items such as schema validation.
     // Developers only need to implement train().
     transformSchema(dataset.schema, logging = true)
@@ -199,7 +202,8 @@ abstract class PredictionModel[FeaturesType, M <: 
PredictionModel[FeaturesType,
    * @param dataset input dataset
    * @return transformed dataset with [[predictionCol]] of type `Double`
    */
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(
+      dataset: Dataset[_]): DataFrame = MLEvents.withTransformEvent(this, 
dataset) {
     transformSchema(dataset.schema, logging = true)
     if ($(predictionCol).nonEmpty) {
       transformImpl(dataset)
@@ -210,7 +214,7 @@ abstract class PredictionModel[FeaturesType, M <: 
PredictionModel[FeaturesType,
     }
   }
 
-  protected def transformImpl(dataset: Dataset[_]): DataFrame = {
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     val predictUDF = udf { (features: Any) =>
       predict(features.asInstanceOf[FeaturesType])
     }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
index a3a2b55adc25d..2467383227a66 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -68,7 +68,19 @@ abstract class Transformer extends PipelineStage {
    * Transforms the input dataset.
    */
   @Since("2.0.0")
-  def transform(dataset: Dataset[_]): DataFrame
+  def transform(dataset: Dataset[_]): DataFrame = 
MLEvents.withTransformEvent(this, dataset) {
+    transformImpl(dataset)
+  }
+
+  /**
+   * `transform()` handles events and then calls this method. Subclasses 
should override this
+   * method to implement the actual transformation.
+   */
+  @Since("3.0.0")
+  protected def transformImpl(dataset: Dataset[_]): DataFrame = {
+    // Keep this default body for backward compatibility.
+    throw new UnsupportedOperationException("transformImpl is not 
implemented.")
+  }
 
   override def copy(extra: ParamMap): Transformer
 }
@@ -116,7 +128,7 @@ abstract class UnaryTransformer[IN, OUT, T <: 
UnaryTransformer[IN, OUT, T]]
     StructType(outputFields)
   }
 
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     val transformUDF = udf(this.createTransformFunc, outputDataType)
     dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol))))
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala 
b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index 7e5790ab70ee9..4de6ae7990490 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.classification
 
 import org.apache.spark.SparkException
 import org.apache.spark.annotation.{DeveloperApi, Since}
-import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
+import org.apache.spark.ml.{MLEvents, PredictionModel, Predictor, 
PredictorParams}
 import org.apache.spark.ml.feature.LabeledPoint
 import org.apache.spark.ml.linalg.{Vector, VectorUDT}
 import org.apache.spark.ml.param.shared.HasRawPredictionCol
@@ -156,7 +156,8 @@ abstract class ClassificationModel[FeaturesType, M <: 
ClassificationModel[Featur
    * @param dataset input dataset
    * @return transformed dataset
    */
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(
+      dataset: Dataset[_]): DataFrame = MLEvents.withTransformEvent(this, 
dataset) {
     transformSchema(dataset.schema, logging = true)
 
     // Output selected columns only.
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index d9292a5476767..ab91b942a06ce 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -275,7 +275,7 @@ object DecisionTreeClassificationModel extends 
MLReadable[DecisionTreeClassifica
     /** Checked against metadata when loading model */
     private val className = classOf[DecisionTreeClassificationModel].getName
 
-    override def load(path: String): DecisionTreeClassificationModel = {
+    override protected def loadImpl(path: String): 
DecisionTreeClassificationModel = {
       implicit val format = DefaultFormats
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
       val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala 
b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index abe2d1febfdf8..5f4bac04f16bb 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -410,7 +410,7 @@ object GBTClassificationModel extends 
MLReadable[GBTClassificationModel] {
     private val className = classOf[GBTClassificationModel].getName
     private val treeClassName = classOf[DecisionTreeRegressionModel].getName
 
-    override def load(path: String): GBTClassificationModel = {
+    override protected def loadImpl(path: String): GBTClassificationModel = {
       implicit val format = DefaultFormats
       val (metadata: Metadata, treesData: Array[(Metadata, Node)], 
treeWeights: Array[Double]) =
         EnsembleModelReadWrite.loadImpl(path, sparkSession, className, 
treeClassName)
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala 
b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
index ff801abef9a94..b6e9b8833022d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
@@ -373,7 +373,7 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] {
     /** Checked against metadata when loading model */
     private val className = classOf[LinearSVCModel].getName
 
-    override def load(path: String): LinearSVCModel = {
+    override protected def loadImpl(path: String): LinearSVCModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
       val dataPath = new Path(path, "data").toString
       val data = sparkSession.read.format("parquet").load(dataPath)
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 27a7db0b2f5d4..1367253aa4f9f 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -1248,7 +1248,7 @@ object LogisticRegressionModel extends 
MLReadable[LogisticRegressionModel] {
     /** Checked against metadata when loading model */
     private val className = classOf[LogisticRegressionModel].getName
 
-    override def load(path: String): LogisticRegressionModel = {
+    override protected def loadImpl(path: String): LogisticRegressionModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
       val (major, minor) = 
VersionUtils.majorMinorVersion(metadata.sparkVersion)
 
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
index 47b8a8df637b9..a912ad268ba7e 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
@@ -359,7 +359,7 @@ object MultilayerPerceptronClassificationModel
     /** Checked against metadata when loading model */
     private val className = 
classOf[MultilayerPerceptronClassificationModel].getName
 
-    override def load(path: String): MultilayerPerceptronClassificationModel = 
{
+    override protected def loadImpl(path: String): 
MultilayerPerceptronClassificationModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
 
       val dataPath = new Path(path, "data").toString
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala 
b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 1a7a5e7a52344..04bb1265962c2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -399,7 +399,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
     /** Checked against metadata when loading model */
     private val className = classOf[NaiveBayesModel].getName
 
-    override def load(path: String): NaiveBayesModel = {
+    override protected def loadImpl(path: String): NaiveBayesModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
 
       val dataPath = new Path(path, "data").toString
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala 
b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index e1fceb1fc96a4..6b0a4d0832007 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -165,7 +165,8 @@ final class OneVsRestModel private[ml] (
   }
 
   @Since("2.0.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = 
super.transform(dataset)
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     // Check schema
     transformSchema(dataset.schema, logging = true)
 
@@ -289,7 +290,7 @@ object OneVsRestModel extends MLReadable[OneVsRestModel] {
     /** Checked against metadata when loading model */
     private val className = classOf[OneVsRestModel].getName
 
-    override def load(path: String): OneVsRestModel = {
+    override protected def loadImpl(path: String): OneVsRestModel = {
       implicit val format = DefaultFormats
       val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, 
className)
       val labelMetadata = Metadata.fromJson((metadata.metadata \ 
"labelMetadata").extract[String])
@@ -372,7 +373,8 @@ final class OneVsRest @Since("1.4.0") (
   }
 
   @Since("2.0.0")
-  override def fit(dataset: Dataset[_]): OneVsRestModel = instrumented { instr 
=>
+  override def fit(dataset: Dataset[_]): OneVsRestModel = super.fit(dataset)
+  override protected def fitImpl(dataset: Dataset[_]): OneVsRestModel = 
instrumented { instr =>
     transformSchema(dataset.schema)
 
     instr.logPipelineStage(this)
@@ -492,7 +494,7 @@ object OneVsRest extends MLReadable[OneVsRest] {
     /** Checked against metadata when loading model */
     private val className = classOf[OneVsRest].getName
 
-    override def load(path: String): OneVsRest = {
+    override protected def loadImpl(path: String): OneVsRest = {
       val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, 
className)
       val ovr = new OneVsRest(metadata.uid)
       metadata.getAndSetParams(ovr)
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
index 730fcab333e11..24781005682dd 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.ml.classification
 
 import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.ml.MLEvents
 import org.apache.spark.ml.linalg.{DenseVector, Vector, VectorUDT}
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util.SchemaUtils
@@ -100,7 +101,8 @@ abstract class ProbabilisticClassificationModel[
    * @param dataset input dataset
    * @return transformed dataset
    */
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(
+      dataset: Dataset[_]): DataFrame = MLEvents.withTransformEvent(this, 
dataset) {
     transformSchema(dataset.schema, logging = true)
     if (isDefined(thresholds)) {
       require($(thresholds).length == numClasses, this.getClass.getSimpleName +
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 0a3bfd1f85e08..145b73e1816ec 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -310,7 +310,7 @@ object RandomForestClassificationModel extends 
MLReadable[RandomForestClassifica
     private val className = classOf[RandomForestClassificationModel].getName
     private val treeClassName = 
classOf[DecisionTreeClassificationModel].getName
 
-    override def load(path: String): RandomForestClassificationModel = {
+    override protected def loadImpl(path: String): 
RandomForestClassificationModel = {
       implicit val format = DefaultFormats
       val (metadata: Metadata, treesData: Array[(Metadata, Node)], _) =
         EnsembleModelReadWrite.loadImpl(path, sparkSession, className, 
treeClassName)
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
index 1a94aefa3f563..b4c1246e780f2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -105,7 +105,8 @@ class BisectingKMeansModel private[ml] (
   def setPredictionCol(value: String): this.type = set(predictionCol, value)
 
   @Since("2.0.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = 
super.transform(dataset)
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     val predictUDF = udf((vector: Vector) => predict(vector))
     dataset.withColumn($(predictionCol),
@@ -191,7 +192,7 @@ object BisectingKMeansModel extends 
MLReadable[BisectingKMeansModel] {
     /** Checked against metadata when loading model */
     private val className = classOf[BisectingKMeansModel].getName
 
-    override def load(path: String): BisectingKMeansModel = {
+    override protected def loadImpl(path: String): BisectingKMeansModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
       val dataPath = new Path(path, "data").toString
       val mllibModel = MLlibBisectingKMeansModel.load(sc, dataPath)
@@ -261,7 +262,9 @@ class BisectingKMeans @Since("2.0.0") (
   def setDistanceMeasure(value: String): this.type = set(distanceMeasure, 
value)
 
   @Since("2.0.0")
-  override def fit(dataset: Dataset[_]): BisectingKMeansModel = instrumented { 
instr =>
+  override def fit(dataset: Dataset[_]): BisectingKMeansModel = 
super.fit(dataset)
+  override protected def fitImpl(
+      dataset: Dataset[_]): BisectingKMeansModel = instrumented { instr =>
     transformSchema(dataset.schema, logging = true)
     val rdd = DatasetUtils.columnToOldVector(dataset, getFeaturesCol)
 
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index 88abc1605d69f..4908f03574d14 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -106,7 +106,8 @@ class GaussianMixtureModel private[ml] (
   }
 
   @Since("2.0.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = 
super.transform(dataset)
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     val predUDF = udf((vector: Vector) => predict(vector))
     val probUDF = udf((vector: Vector) => predictProbability(vector))
@@ -218,7 +219,7 @@ object GaussianMixtureModel extends 
MLReadable[GaussianMixtureModel] {
     /** Checked against metadata when loading model */
     private val className = classOf[GaussianMixtureModel].getName
 
-    override def load(path: String): GaussianMixtureModel = {
+    override protected def loadImpl(path: String): GaussianMixtureModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
 
       val dataPath = new Path(path, "data").toString
@@ -336,7 +337,9 @@ class GaussianMixture @Since("2.0.0") (
   private val numSamples = 5
 
   @Since("2.0.0")
-  override def fit(dataset: Dataset[_]): GaussianMixtureModel = instrumented { 
instr =>
+  override def fit(dataset: Dataset[_]): GaussianMixtureModel = 
super.fit(dataset)
+  override protected def fitImpl(
+      dataset: Dataset[_]): GaussianMixtureModel = instrumented { instr =>
     transformSchema(dataset.schema, logging = true)
 
     val sc = dataset.sparkSession.sparkContext
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 2eed84d51782a..1c8a17c04f99e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -124,7 +124,8 @@ class KMeansModel private[ml] (
   def setPredictionCol(value: String): this.type = set(predictionCol, value)
 
   @Since("2.0.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = 
super.transform(dataset)
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
 
     val predictUDF = udf((vector: Vector) => predict(vector))
@@ -238,7 +239,7 @@ object KMeansModel extends MLReadable[KMeansModel] {
     /** Checked against metadata when loading model */
     private val className = classOf[KMeansModel].getName
 
-    override def load(path: String): KMeansModel = {
+    override protected def loadImpl(path: String): KMeansModel = {
       // Import implicits for Dataset Encoder
       val sparkSession = super.sparkSession
       import sparkSession.implicits._
@@ -321,7 +322,8 @@ class KMeans @Since("1.5.0") (
   def setSeed(value: Long): this.type = set(seed, value)
 
   @Since("2.0.0")
-  override def fit(dataset: Dataset[_]): KMeansModel = instrumented { instr =>
+  override def fit(dataset: Dataset[_]): KMeansModel = super.fit(dataset)
+  override protected def fitImpl(dataset: Dataset[_]): KMeansModel = 
instrumented { instr =>
     transformSchema(dataset.schema, logging = true)
 
     val handlePersistence = dataset.storageLevel == StorageLevel.NONE
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index 84e73dc19a392..fde97d57fdb75 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -455,7 +455,8 @@ abstract class LDAModel private[ml] (
    *          This implementation may be changed in the future.
    */
   @Since("2.0.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = 
super.transform(dataset)
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     if ($(topicDistributionCol).nonEmpty) {
 
       // TODO: Make the transformer natively in ml framework to avoid extra 
conversion.
@@ -619,7 +620,7 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] {
 
     private val className = classOf[LocalLDAModel].getName
 
-    override def load(path: String): LocalLDAModel = {
+    override protected def loadImpl(path: String): LocalLDAModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
       val dataPath = new Path(path, "data").toString
       val data = sparkSession.read.parquet(dataPath)
@@ -772,7 +773,7 @@ object DistributedLDAModel extends 
MLReadable[DistributedLDAModel] {
 
     private val className = classOf[DistributedLDAModel].getName
 
-    override def load(path: String): DistributedLDAModel = {
+    override protected def loadImpl(path: String): DistributedLDAModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
       val modelPath = new Path(path, "oldModel").toString
       val oldModel = OldDistributedLDAModel.load(sc, modelPath)
@@ -895,7 +896,8 @@ class LDA @Since("1.6.0") (
   override def copy(extra: ParamMap): LDA = defaultCopy(extra)
 
   @Since("2.0.0")
-  override def fit(dataset: Dataset[_]): LDAModel = instrumented { instr =>
+  override def fit(dataset: Dataset[_]): LDAModel = super.fit(dataset)
+  override protected def fitImpl(dataset: Dataset[_]): LDAModel = instrumented 
{ instr =>
     transformSchema(dataset.schema, logging = true)
 
     instr.logPipelineStage(this)
@@ -952,7 +954,7 @@ object LDA extends MLReadable[LDA] {
 
     private val className = classOf[LDA].getName
 
-    override def load(path: String): LDA = {
+    override protected def loadImpl(path: String): LDA = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
       val model = new LDA(metadata.uid)
       LDAParams.getAndSetParams(model, metadata)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/events.scala 
b/mllib/src/main/scala/org/apache/spark/ml/events.scala
new file mode 100644
index 0000000000000..fe009394ca081
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/events.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import org.apache.spark.SparkContext
+import org.apache.spark.annotation.Unstable
+import org.apache.spark.ml.util.{MLReader, MLWriter}
+import org.apache.spark.scheduler.SparkListenerEvent
+import org.apache.spark.sql.{DataFrame, Dataset}
+
+/**
+ * Event emitted by ML operations. Events are either fired before and/or
+ * after each operation (the event should document this).
+ */
+@Unstable
+sealed trait MLEvent extends SparkListenerEvent
+
+/**
+ * Event fired before `Transformer.transform`.
+ */
+@Unstable
+case class TransformStart(transformer: Transformer, input: Dataset[_]) extends 
MLEvent
+/**
+ * Event fired after `Transformer.transform`.
+ */
+@Unstable
+case class TransformEnd(transformer: Transformer, output: Dataset[_]) extends 
MLEvent
+
+/**
+ * Event fired before `Estimator.fit`.
+ */
+@Unstable
+case class FitStart[M <: Model[M]](estimator: Estimator[M], dataset: 
Dataset[_]) extends MLEvent
+/**
+ * Event fired after `Estimator.fit`.
+ */
+@Unstable
+case class FitEnd[M <: Model[M]](estimator: Estimator[M], model: M) extends 
MLEvent
+
+/**
+ * Event fired before `MLReader.load`.
+ */
+@Unstable
+case class LoadInstanceStart[T](reader: MLReader[T], path: String) extends 
MLEvent
+/**
+ * Event fired after `MLReader.load`.
+ */
+@Unstable
+case class LoadInstanceEnd[T](reader: MLReader[T], instance: T) extends MLEvent
+
+/**
+ * Event fired before `MLWriter.save`.
+ */
+@Unstable
+case class SaveInstanceStart(writer: MLWriter, path: String) extends MLEvent
+/**
+ * Event fired after `MLWriter.save`.
+ */
+@Unstable
+case class SaveInstanceEnd(writer: MLWriter, path: String) extends MLEvent
+
+
+private[ml] object MLEvents {
+  private def listenerBus = SparkContext.getOrCreate().listenerBus
+
+  def withFitEvent[M <: Model[M]](
+      estimator: Estimator[M], dataset: Dataset[_])(func: => M): M = {
+    listenerBus.post(FitStart(estimator, dataset))
+    val model: M = func
+    listenerBus.post(FitEnd(estimator, model))
+    model
+  }
+
+  def withTransformEvent(
+      transformer: Transformer, input: Dataset[_])(func: => DataFrame): 
DataFrame = {
+    listenerBus.post(TransformStart(transformer, input))
+    val output: DataFrame = func
+    listenerBus.post(TransformEnd(transformer, output))
+    output
+  }
+
+  def withLoadInstanceEvent[T](reader: MLReader[T], path: String)(func: => T): 
T = {
+    listenerBus.post(LoadInstanceStart(reader, path))
+    val instance: T = func
+    listenerBus.post(LoadInstanceEnd(reader, instance))
+    instance
+  }
+
+  def withSaveInstanceEvent(writer: MLWriter, path: String)(func: => Unit): 
Unit = {
+    listenerBus.post(SaveInstanceStart(writer, path))
+    func
+    listenerBus.post(SaveInstanceEnd(writer, path))
+  }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
index 2b0862c60fdf7..3a02b640e359a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
@@ -70,7 +70,8 @@ final class Binarizer @Since("1.4.0") (@Since("1.4.0") 
override val uid: String)
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
   @Since("2.0.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = 
super.transform(dataset)
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     val outputSchema = transformSchema(dataset.schema, logging = true)
     val schema = dataset.schema
     val inputType = schema($(inputCol)).dataType
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
index 0554455a66d7f..f705418432a48 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
@@ -225,7 +225,7 @@ object BucketedRandomProjectionLSHModel extends 
MLReadable[BucketedRandomProject
     /** Checked against metadata when loading model */
     private val className = classOf[BucketedRandomProjectionLSHModel].getName
 
-    override def load(path: String): BucketedRandomProjectionLSHModel = {
+    override protected def loadImpl(path: String): 
BucketedRandomProjectionLSHModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
 
       val dataPath = new Path(path, "data").toString
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index 0b989b0d7d253..eaa9ae8d81c45 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -143,7 +143,8 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") 
override val uid: String
   def setOutputCols(value: Array[String]): this.type = set(outputCols, value)
 
   @Since("2.0.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = 
super.transform(dataset)
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     val transformedSchema = transformSchema(dataset.schema)
 
     val (inputColumns, outputColumns) = if (isSet(inputCols)) {
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
index dbfb199ccd58f..ba8482d50a9cf 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
@@ -197,7 +197,8 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") 
override val uid: Str
   def setLabelCol(value: String): this.type = set(labelCol, value)
 
   @Since("2.0.0")
-  override def fit(dataset: Dataset[_]): ChiSqSelectorModel = {
+  override def fit(dataset: Dataset[_]): ChiSqSelectorModel = 
super.fit(dataset)
+  override protected def fitImpl(dataset: Dataset[_]): ChiSqSelectorModel = {
     transformSchema(dataset.schema, logging = true)
     val input: RDD[OldLabeledPoint] =
       dataset.select(col($(labelCol)).cast(DoubleType), 
col($(featuresCol))).rdd.map {
@@ -263,7 +264,8 @@ final class ChiSqSelectorModel private[ml] (
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
   @Since("2.0.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = 
super.transform(dataset)
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     val transformedSchema = transformSchema(dataset.schema, logging = true)
     val newField = transformedSchema.last
 
@@ -327,7 +329,7 @@ object ChiSqSelectorModel extends 
MLReadable[ChiSqSelectorModel] {
 
     private val className = classOf[ChiSqSelectorModel].getName
 
-    override def load(path: String): ChiSqSelectorModel = {
+    override protected def loadImpl(path: String): ChiSqSelectorModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
       val dataPath = new Path(path, "data").toString
       val data = 
sparkSession.read.parquet(dataPath).select("selectedFeatures").head()
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
index dc8eb8261dbe2..a9ea509ada077 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
@@ -181,7 +181,8 @@ class CountVectorizer @Since("1.5.0") (@Since("1.5.0") 
override val uid: String)
   def setBinary(value: Boolean): this.type = set(binary, value)
 
   @Since("2.0.0")
-  override def fit(dataset: Dataset[_]): CountVectorizerModel = {
+  override def fit(dataset: Dataset[_]): CountVectorizerModel = 
super.fit(dataset)
+  override protected def fitImpl(dataset: Dataset[_]): CountVectorizerModel = {
     transformSchema(dataset.schema, logging = true)
     val vocSize = $(vocabSize)
     val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0))
@@ -291,7 +292,8 @@ class CountVectorizerModel(
   private var broadcastDict: Option[Broadcast[Map[String, Int]]] = None
 
   @Since("2.0.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = 
super.transform(dataset)
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     if (broadcastDict.isEmpty) {
       val dict = vocabulary.zipWithIndex.toMap
@@ -358,7 +360,7 @@ object CountVectorizerModel extends 
MLReadable[CountVectorizerModel] {
 
     private val className = classOf[CountVectorizerModel].getName
 
-    override def load(path: String): CountVectorizerModel = {
+    override protected def loadImpl(path: String): CountVectorizerModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
       val dataPath = new Path(path, "data").toString
       val data = sparkSession.read.parquet(dataPath)
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala
index dc18e1d34880a..a58a1ff8e053c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala
@@ -140,7 +140,8 @@ class FeatureHasher(@Since("2.3.0") override val uid: 
String) extends Transforme
   def setCategoricalCols(value: Array[String]): this.type = 
set(categoricalCols, value)
 
   @Since("2.3.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = 
super.transform(dataset)
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     val hashFunc: Any => Int = FeatureHasher.murmur3Hash
     val n = $(numFeatures)
     val localInputCols = $(inputCols)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
index dbda5b8d8fd4a..57b6bb1dbe432 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -91,7 +91,8 @@ class HashingTF @Since("1.4.0") (@Since("1.4.0") override val 
uid: String)
   def setBinary(value: Boolean): this.type = set(binary, value)
 
   @Since("2.0.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = 
super.transform(dataset)
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     val outputSchema = transformSchema(dataset.schema)
     val hashingTF = new feature.HashingTF($(numFeatures)).setBinary($(binary))
     // TODO: Make the hashingTF.transform natively in ml framework to avoid 
extra conversion.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
index 58897cca4e5c6..9c91cc0417e48 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
@@ -84,7 +84,8 @@ final class IDF @Since("1.4.0") (@Since("1.4.0") override val 
uid: String)
   def setMinDocFreq(value: Int): this.type = set(minDocFreq, value)
 
   @Since("2.0.0")
-  override def fit(dataset: Dataset[_]): IDFModel = {
+  override def fit(dataset: Dataset[_]): IDFModel = super.fit(dataset)
+  override protected def fitImpl(dataset: Dataset[_]): IDFModel = {
     transformSchema(dataset.schema, logging = true)
     val input: RDD[OldVector] = dataset.select($(inputCol)).rdd.map {
       case Row(v: Vector) => OldVectors.fromML(v)
@@ -129,7 +130,8 @@ class IDFModel private[ml] (
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
   @Since("2.0.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = 
super.transform(dataset)
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     // TODO: Make the idfModel.transform natively in ml framework to avoid 
extra conversion.
     val idf = udf { vec: Vector => 
idfModel.transform(OldVectors.fromML(vec)).asML }
@@ -174,7 +176,7 @@ object IDFModel extends MLReadable[IDFModel] {
 
     private val className = classOf[IDFModel].getName
 
-    override def load(path: String): IDFModel = {
+    override protected def loadImpl(path: String): IDFModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
       val dataPath = new Path(path, "data").toString
       val data = sparkSession.read.parquet(dataPath)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
index 1c074e204ad99..761947a7c394c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
@@ -120,7 +120,10 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override 
val uid: String)
 
   setDefault(strategy -> Imputer.mean, missingValue -> Double.NaN)
 
-  override def fit(dataset: Dataset[_]): ImputerModel = {
+  // Explicitly call parent's load. Otherwise, MiMa complains.
+  override def fit(dataset: Dataset[_]): ImputerModel = super.fit(dataset)
+
+  override protected def fitImpl(dataset: Dataset[_]): ImputerModel = {
     transformSchema(dataset.schema, logging = true)
     val spark = dataset.sparkSession
 
@@ -211,7 +214,7 @@ class ImputerModel private[ml] (
   /** @group setParam */
   def setOutputCols(value: Array[String]): this.type = set(outputCols, value)
 
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     val surrogates = surrogateDF.select($(inputCols).map(col): _*).head().toSeq
 
@@ -257,7 +260,7 @@ object ImputerModel extends MLReadable[ImputerModel] {
 
     private val className = classOf[ImputerModel].getName
 
-    override def load(path: String): ImputerModel = {
+    override protected def loadImpl(path: String): ImputerModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
       val dataPath = new Path(path, "data").toString
       val surrogateDF = sqlContext.read.parquet(dataPath)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
index 611f1b691b782..319927e9dd228 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
@@ -67,7 +67,8 @@ class Interaction @Since("1.6.0") (@Since("1.6.0") override 
val uid: String) ext
   }
 
   @Since("2.0.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = 
super.transform(dataset)
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     val inputFeatures = $(inputCols).map(c => dataset.schema(c))
     val featureEncoders = getFeatureEncoders(inputFeatures)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
index b20852383a6ff..b7cc016ebcaf1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
@@ -95,7 +95,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
    */
   protected[ml] def hashDistance(x: Seq[Vector], y: Seq[Vector]): Double
 
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     val transformUDF = udf(hashFunction(_: Vector), 
DataTypes.createArrayType(new VectorUDT))
     dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol))))
@@ -323,7 +323,7 @@ private[ml] abstract class LSH[T <: LSHModel[T]]
    */
   protected[this] def createRawLSHModel(inputDim: Int): T
 
-  override def fit(dataset: Dataset[_]): T = {
+  override protected def fitImpl(dataset: Dataset[_]): T = {
     transformSchema(dataset.schema, logging = true)
     val inputDim = 
dataset.select(col($(inputCol))).head().get(0).asInstanceOf[Vector].size
     val model = createRawLSHModel(inputDim).setParent(this)
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
index 90eceb0d61b40..295d3d0ccab3d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
@@ -68,7 +68,8 @@ class MaxAbsScaler @Since("2.0.0") (@Since("2.0.0") override 
val uid: String)
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
   @Since("2.0.0")
-  override def fit(dataset: Dataset[_]): MaxAbsScalerModel = {
+  override def fit(dataset: Dataset[_]): MaxAbsScalerModel = super.fit(dataset)
+  override protected def fitImpl(dataset: Dataset[_]): MaxAbsScalerModel = {
     transformSchema(dataset.schema, logging = true)
     val input: RDD[OldVector] = dataset.select($(inputCol)).rdd.map {
       case Row(v: Vector) => OldVectors.fromML(v)
@@ -119,7 +120,8 @@ class MaxAbsScalerModel private[ml] (
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
   @Since("2.0.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = 
super.transform(dataset)
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     // TODO: this looks hack, we may have to handle sparse and dense vectors 
separately.
     val maxAbsUnzero = Vectors.dense(maxAbs.toArray.map(x => if (x == 0) 1 
else x))
@@ -165,7 +167,7 @@ object MaxAbsScalerModel extends 
MLReadable[MaxAbsScalerModel] {
 
     private val className = classOf[MaxAbsScalerModel].getName
 
-    override def load(path: String): MaxAbsScalerModel = {
+    override protected def loadImpl(path: String): MaxAbsScalerModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
       val dataPath = new Path(path, "data").toString
       val Row(maxAbs: Vector) = sparkSession.read.parquet(dataPath)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
index 21cde66d8db6b..168accb1a1722 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
@@ -194,7 +194,7 @@ object MinHashLSHModel extends MLReadable[MinHashLSHModel] {
     /** Checked against metadata when loading model */
     private val className = classOf[MinHashLSHModel].getName
 
-    override def load(path: String): MinHashLSHModel = {
+    override protected def loadImpl(path: String): MinHashLSHModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
 
       val dataPath = new Path(path, "data").toString
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
index 2e0ae4af66f06..d76fc5ec05c52 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
@@ -115,7 +115,8 @@ class MinMaxScaler @Since("1.5.0") (@Since("1.5.0") 
override val uid: String)
   def setMax(value: Double): this.type = set(max, value)
 
   @Since("2.0.0")
-  override def fit(dataset: Dataset[_]): MinMaxScalerModel = {
+  override def fit(dataset: Dataset[_]): MinMaxScalerModel = super.fit(dataset)
+  override protected def fitImpl(dataset: Dataset[_]): MinMaxScalerModel = {
     transformSchema(dataset.schema, logging = true)
     val input: RDD[OldVector] = dataset.select($(inputCol)).rdd.map {
       case Row(v: Vector) => OldVectors.fromML(v)
@@ -174,7 +175,8 @@ class MinMaxScalerModel private[ml] (
   def setMax(value: Double): this.type = set(max, value)
 
   @Since("2.0.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = 
super.transform(dataset)
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     val originalRange = (originalMax.asBreeze - originalMin.asBreeze).toArray
     val minArray = originalMin.toArray
@@ -234,7 +236,7 @@ object MinMaxScalerModel extends 
MLReadable[MinMaxScalerModel] {
 
     private val className = classOf[MinMaxScalerModel].getName
 
-    override def load(path: String): MinMaxScalerModel = {
+    override protected def loadImpl(path: String): MinMaxScalerModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
       val dataPath = new Path(path, "data").toString
       val data = sparkSession.read.parquet(dataPath)
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
index ec9792cbbda8f..4e4a259c95cce 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
@@ -147,7 +147,8 @@ class OneHotEncoder @Since("3.0.0") (@Since("3.0.0") 
override val uid: String)
   }
 
   @Since("3.0.0")
-  override def fit(dataset: Dataset[_]): OneHotEncoderModel = {
+  override def fit(dataset: Dataset[_]): OneHotEncoderModel = 
super.fit(dataset)
+  override protected def fitImpl(dataset: Dataset[_]): OneHotEncoderModel = {
     transformSchema(dataset.schema)
 
     // Compute the plain number of categories without `handleInvalid` and
@@ -324,7 +325,8 @@ class OneHotEncoderModel private[ml] (
   }
 
   @Since("3.0.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = 
super.transform(dataset)
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     val transformedSchema = transformSchema(dataset.schema, logging = true)
     val keepInvalid = $(handleInvalid) == OneHotEncoder.KEEP_INVALID
 
@@ -378,7 +380,7 @@ object OneHotEncoderModel extends 
MLReadable[OneHotEncoderModel] {
 
     private val className = classOf[OneHotEncoderModel].getName
 
-    override def load(path: String): OneHotEncoderModel = {
+    override protected def loadImpl(path: String): OneHotEncoderModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
       val dataPath = new Path(path, "data").toString
       val data = sparkSession.read.parquet(dataPath)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
index 8172491a517d1..0a79b49b984c3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
@@ -90,7 +90,8 @@ class PCA @Since("1.5.0") (
    * Computes a [[PCAModel]] that contains the principal components of the 
input vectors.
    */
   @Since("2.0.0")
-  override def fit(dataset: Dataset[_]): PCAModel = {
+  override def fit(dataset: Dataset[_]): PCAModel = super.fit(dataset)
+  override protected def fitImpl(dataset: Dataset[_]): PCAModel = {
     transformSchema(dataset.schema, logging = true)
     val input: RDD[OldVector] = dataset.select($(inputCol)).rdd.map {
       case Row(v: Vector) => OldVectors.fromML(v)
@@ -147,7 +148,8 @@ class PCAModel private[ml] (
    * to `PCA.fit()`.
    */
   @Since("2.0.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = 
super.transform(dataset)
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     val pcaModel = new feature.PCAModel($(k),
       OldMatrices.fromML(pc).asInstanceOf[OldDenseMatrix],
@@ -203,7 +205,7 @@ object PCAModel extends MLReadable[PCAModel] {
      * @param path path to serialized model data
      * @return a [[PCAModel]]
      */
-    override def load(path: String): PCAModel = {
+    override protected def loadImpl(path: String): PCAModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
 
       val dataPath = new Path(path, "data").toString
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index 5bfaa3b7f3f52..0f29ae9687c02 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -199,7 +199,8 @@ final class QuantileDiscretizer @Since("1.6.0") 
(@Since("1.6.0") override val ui
   }
 
   @Since("2.0.0")
-  override def fit(dataset: Dataset[_]): Bucketizer = {
+  override def fit(dataset: Dataset[_]): Bucketizer = super.fit(dataset)
+  override protected def fitImpl(dataset: Dataset[_]): Bucketizer = {
     transformSchema(dataset.schema, logging = true)
     val bucketizer = new Bucketizer(uid).setHandleInvalid($(handleInvalid))
     if (isSet(inputCols)) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index d7eb13772aa64..f5132f09f95f4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -193,7 +193,8 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override 
val uid: String)
   }
 
   @Since("2.0.0")
-  override def fit(dataset: Dataset[_]): RFormulaModel = {
+  override def fit(dataset: Dataset[_]): RFormulaModel = super.fit(dataset)
+  override protected def fitImpl(dataset: Dataset[_]): RFormulaModel = {
     transformSchema(dataset.schema, logging = true)
     require(isDefined(formula), "Formula must be defined first.")
     val parsedFormula = RFormulaParser.parse($(formula))
@@ -338,7 +339,8 @@ class RFormulaModel private[feature](
   extends Model[RFormulaModel] with RFormulaBase with MLWritable {
 
   @Since("2.0.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = 
super.transform(dataset)
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     checkCanTransform(dataset.schema)
     transformLabel(pipelineModel.transform(dataset))
   }
@@ -431,7 +433,7 @@ object RFormulaModel extends MLReadable[RFormulaModel] {
     /** Checked against metadata when loading model */
     private val className = classOf[RFormulaModel].getName
 
-    override def load(path: String): RFormulaModel = {
+    override protected def loadImpl(path: String): RFormulaModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
 
       val dataPath = new Path(path, "data").toString
@@ -462,7 +464,7 @@ private class ColumnPruner(override val uid: String, val 
columnsToPrune: Set[Str
   def this(columnsToPrune: Set[String]) =
     this(Identifiable.randomUID("columnPruner"), columnsToPrune)
 
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     val columnsToKeep = dataset.columns.filter(!columnsToPrune.contains(_))
     dataset.select(columnsToKeep.map(dataset.col): _*)
   }
@@ -502,7 +504,7 @@ private object ColumnPruner extends 
MLReadable[ColumnPruner] {
     /** Checked against metadata when loading model */
     private val className = classOf[ColumnPruner].getName
 
-    override def load(path: String): ColumnPruner = {
+    override protected def loadImpl(path: String): ColumnPruner = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
 
       val dataPath = new Path(path, "data").toString
@@ -535,7 +537,7 @@ private class VectorAttributeRewriter(
   def this(vectorCol: String, prefixesToRewrite: Map[String, String]) =
     this(Identifiable.randomUID("vectorAttrRewriter"), vectorCol, 
prefixesToRewrite)
 
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     val metadata = {
       val group = AttributeGroup.fromStructField(dataset.schema(vectorCol))
       val attrs = group.attributes.get.map { attr =>
@@ -593,7 +595,7 @@ private object VectorAttributeRewriter extends 
MLReadable[VectorAttributeRewrite
     /** Checked against metadata when loading model */
     private val className = classOf[VectorAttributeRewriter].getName
 
-    override def load(path: String): VectorAttributeRewriter = {
+    override protected def loadImpl(path: String): VectorAttributeRewriter = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
 
       val dataPath = new Path(path, "data").toString
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
index 0fb1d8c5dc579..d453ab3723c0c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
@@ -64,7 +64,8 @@ class SQLTransformer @Since("1.6.0") (@Since("1.6.0") 
override val uid: String)
   private val tableIdentifier: String = "__THIS__"
 
   @Since("2.0.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = 
super.transform(dataset)
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     val tableName = Identifiable.randomUID(uid)
     dataset.createOrReplaceTempView(tableName)
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 91b0707dec3f3..c96726faab4f5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -108,7 +108,8 @@ class StandardScaler @Since("1.4.0") (
   def setWithStd(value: Boolean): this.type = set(withStd, value)
 
   @Since("2.0.0")
-  override def fit(dataset: Dataset[_]): StandardScalerModel = {
+  override def fit(dataset: Dataset[_]): StandardScalerModel = 
super.fit(dataset)
+  override protected def fitImpl(dataset: Dataset[_]): StandardScalerModel = {
     transformSchema(dataset.schema, logging = true)
     val input: RDD[OldVector] = dataset.select($(inputCol)).rdd.map {
       case Row(v: Vector) => OldVectors.fromML(v)
@@ -158,7 +159,8 @@ class StandardScalerModel private[ml] (
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
   @Since("2.0.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = 
super.transform(dataset)
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     val scaler = new feature.StandardScalerModel(std, mean, $(withStd), 
$(withMean))
 
@@ -204,7 +206,7 @@ object StandardScalerModel extends 
MLReadable[StandardScalerModel] {
 
     private val className = classOf[StandardScalerModel].getName
 
-    override def load(path: String): StandardScalerModel = {
+    override protected def loadImpl(path: String): StandardScalerModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
       val dataPath = new Path(path, "data").toString
       val data = sparkSession.read.parquet(dataPath)
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
index 6669d402cd996..8110cccf27167 100755
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
@@ -109,7 +109,8 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") 
override val uid: String
     caseSensitive -> false, locale -> Locale.getDefault.toString)
 
   @Since("2.0.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = 
super.transform(dataset)
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     val outputSchema = transformSchema(dataset.schema)
     val t = if ($(caseSensitive)) {
       val stopWordsSet = $(stopWords).toSet
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index a833d8b270cf1..74843bac1ea82 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -131,7 +131,8 @@ class StringIndexer @Since("1.4.0") (
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
   @Since("2.0.0")
-  override def fit(dataset: Dataset[_]): StringIndexerModel = {
+  override def fit(dataset: Dataset[_]): StringIndexerModel = 
super.fit(dataset)
+  override protected def fitImpl(dataset: Dataset[_]): StringIndexerModel = {
     transformSchema(dataset.schema, logging = true)
     val values = dataset.na.drop(Array($(inputCol)))
       .select(col($(inputCol)).cast(StringType))
@@ -218,7 +219,8 @@ class StringIndexerModel (
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
   @Since("2.0.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = 
super.transform(dataset)
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     if (!dataset.schema.fieldNames.contains($(inputCol))) {
       logInfo(s"Input column ${$(inputCol)} does not exist during 
transformation. " +
         "Skip StringIndexerModel.")
@@ -307,7 +309,7 @@ object StringIndexerModel extends 
MLReadable[StringIndexerModel] {
 
     private val className = classOf[StringIndexerModel].getName
 
-    override def load(path: String): StringIndexerModel = {
+    override protected def loadImpl(path: String): StringIndexerModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
       val dataPath = new Path(path, "data").toString
       val data = sparkSession.read.parquet(dataPath)
@@ -386,7 +388,8 @@ class IndexToString @Since("2.2.0") (@Since("1.5.0") 
override val uid: String)
   }
 
   @Since("2.0.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = 
super.transform(dataset)
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     val inputColSchema = dataset.schema($(inputCol))
     // If the labels array is empty use column metadata
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index 57e23d5072b88..6add5f8d8c4e6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -82,7 +82,8 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") 
override val uid: String)
   setDefault(handleInvalid, VectorAssembler.ERROR_INVALID)
 
   @Since("2.0.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = 
super.transform(dataset)
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     // Schema transformation.
     val schema = dataset.schema
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index 0e7396a621dbd..fd6e79c4e21fb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -140,7 +140,8 @@ class VectorIndexer @Since("1.4.0") (
   def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
 
   @Since("2.0.0")
-  override def fit(dataset: Dataset[_]): VectorIndexerModel = {
+  override def fit(dataset: Dataset[_]): VectorIndexerModel = 
super.fit(dataset)
+  override protected def fitImpl(dataset: Dataset[_]): VectorIndexerModel = {
     transformSchema(dataset.schema, logging = true)
     val firstRow = dataset.select($(inputCol)).take(1)
     require(firstRow.length == 1, s"VectorIndexer cannot be fit on an empty 
dataset.")
@@ -425,7 +426,8 @@ class VectorIndexerModel private[ml] (
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
   @Since("2.0.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = 
super.transform(dataset)
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     val newField = prepOutputField(dataset.schema)
     val transformUDF = udf { (vector: Vector) => transformFunc(vector) }
@@ -528,7 +530,7 @@ object VectorIndexerModel extends 
MLReadable[VectorIndexerModel] {
 
     private val className = classOf[VectorIndexerModel].getName
 
-    override def load(path: String): VectorIndexerModel = {
+    override protected def loadImpl(path: String): VectorIndexerModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
       val dataPath = new Path(path, "data").toString
       val data = sparkSession.read.parquet(dataPath)
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala
index f5947d61fe349..854fd1b4d3776 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala
@@ -96,7 +96,8 @@ class VectorSizeHint @Since("2.3.0") (@Since("2.3.0") 
override val uid: String)
   setDefault(handleInvalid, VectorSizeHint.ERROR_INVALID)
 
   @Since("2.3.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = 
super.transform(dataset)
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     val localInputCol = getInputCol
     val localSize = getSize
     val localHandleInvalid = getHandleInvalid
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
index e3e462d07e10c..cd1d49c1a2cb0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
@@ -98,7 +98,8 @@ final class VectorSlicer @Since("1.5.0") (@Since("1.5.0") 
override val uid: Stri
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
   @Since("2.0.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = 
super.transform(dataset)
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     // Validity checks
     transformSchema(dataset.schema)
     val inputAttr = AttributeGroup.fromStructField(dataset.schema($(inputCol)))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index fc9996d69ba72..31e5b00765c68 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -171,7 +171,8 @@ final class Word2Vec @Since("1.4.0") (
   def setMaxSentenceLength(value: Int): this.type = set(maxSentenceLength, 
value)
 
   @Since("2.0.0")
-  override def fit(dataset: Dataset[_]): Word2VecModel = {
+  override def fit(dataset: Dataset[_]): Word2VecModel = super.fit(dataset)
+  override protected def fitImpl(dataset: Dataset[_]): Word2VecModel = {
     transformSchema(dataset.schema, logging = true)
     val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0))
     val wordVectors = new feature.Word2Vec()
@@ -286,7 +287,8 @@ class Word2VecModel private[ml] (
    * is performed by averaging all word vectors it contains.
    */
   @Since("2.0.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = 
super.transform(dataset)
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     val vectors = wordVectors.getVectors
       .mapValues(vv => Vectors.dense(vv.map(_.toDouble)))
@@ -385,7 +387,7 @@ object Word2VecModel extends MLReadable[Word2VecModel] {
 
     private val className = classOf[Word2VecModel].getName
 
-    override def load(path: String): Word2VecModel = {
+    override protected def loadImpl(path: String): Word2VecModel = {
       val spark = sparkSession
       import spark.implicits._
 
diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala 
b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
index 7322815c12ab8..d2ca4bc53957f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
@@ -157,7 +157,8 @@ class FPGrowth @Since("2.2.0") (
   def setPredictionCol(value: String): this.type = set(predictionCol, value)
 
   @Since("2.2.0")
-  override def fit(dataset: Dataset[_]): FPGrowthModel = {
+  override def fit(dataset: Dataset[_]): FPGrowthModel = super.fit(dataset)
+  override protected def fitImpl(dataset: Dataset[_]): FPGrowthModel = {
     transformSchema(dataset.schema, logging = true)
     genericFit(dataset)
   }
@@ -277,7 +278,8 @@ class FPGrowthModel private[ml] (
    * efficiency. This may bring pressure to driver memory for large set of 
association rules.
    */
   @Since("2.2.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = 
super.transform(dataset)
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     genericTransform(dataset)
   }
@@ -342,7 +344,7 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
     /** Checked against metadata when loading model */
     private val className = classOf[FPGrowthModel].getName
 
-    override def load(path: String): FPGrowthModel = {
+    override protected def loadImpl(path: String): FPGrowthModel = {
       implicit val format = DefaultFormats
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
       val (major, minor) = 
VersionUtils.majorMinorVersion(metadata.sparkVersion)
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala 
b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
index 1b5f77a9ae897..d45a5fa14df90 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
@@ -136,7 +136,7 @@ private[r] object AFTSurvivalRegressionWrapper extends 
MLReadable[AFTSurvivalReg
 
   class AFTSurvivalRegressionWrapperReader extends 
MLReader[AFTSurvivalRegressionWrapper] {
 
-    override def load(path: String): AFTSurvivalRegressionWrapper = {
+    override protected def loadImpl(path: String): 
AFTSurvivalRegressionWrapper = {
       implicit val format = DefaultFormats
       val rMetadataPath = new Path(path, "rMetadata").toString
       val pipelinePath = new Path(path, "pipeline").toString
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/ALSWrapper.scala 
b/mllib/src/main/scala/org/apache/spark/ml/r/ALSWrapper.scala
index ad13cced4667b..8d12d80ab4734 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/ALSWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/ALSWrapper.scala
@@ -102,7 +102,7 @@ private[r] object ALSWrapper extends MLReadable[ALSWrapper] 
{
 
   class ALSWrapperReader extends MLReader[ALSWrapper] {
 
-    override def load(path: String): ALSWrapper = {
+    override protected def loadImpl(path: String): ALSWrapper = {
       implicit val format = DefaultFormats
       val rMetadataPath = new Path(path, "rMetadata").toString
       val modelPath = new Path(path, "model").toString
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/r/BisectingKMeansWrapper.scala 
b/mllib/src/main/scala/org/apache/spark/ml/r/BisectingKMeansWrapper.scala
index 71712c1c5eec5..8478c7b6e8c77 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/BisectingKMeansWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/BisectingKMeansWrapper.scala
@@ -126,7 +126,7 @@ private[r] object BisectingKMeansWrapper extends 
MLReadable[BisectingKMeansWrapp
 
   class BisectingKMeansWrapperReader extends MLReader[BisectingKMeansWrapper] {
 
-    override def load(path: String): BisectingKMeansWrapper = {
+    override protected def loadImpl(path: String): BisectingKMeansWrapper = {
       implicit val format = DefaultFormats
       val rMetadataPath = new Path(path, "rMetadata").toString
       val pipelinePath = new Path(path, "pipeline").toString
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala
index a90cae5869b2a..ad9476c1c000c 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala
@@ -137,7 +137,7 @@ private[r] object DecisionTreeClassifierWrapper extends 
MLReadable[DecisionTreeC
 
   class DecisionTreeClassifierWrapperReader extends 
MLReader[DecisionTreeClassifierWrapper] {
 
-    override def load(path: String): DecisionTreeClassifierWrapper = {
+    override protected def loadImpl(path: String): 
DecisionTreeClassifierWrapper = {
       implicit val format = DefaultFormats
       val rMetadataPath = new Path(path, "rMetadata").toString
       val pipelinePath = new Path(path, "pipeline").toString
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeRegressionWrapper.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeRegressionWrapper.scala
index de712d67e6df5..efe7884d5a170 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeRegressionWrapper.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeRegressionWrapper.scala
@@ -120,7 +120,7 @@ private[r] object DecisionTreeRegressorWrapper extends 
MLReadable[DecisionTreeRe
 
   class DecisionTreeRegressorWrapperReader extends 
MLReader[DecisionTreeRegressorWrapper] {
 
-    override def load(path: String): DecisionTreeRegressorWrapper = {
+    override protected def loadImpl(path: String): 
DecisionTreeRegressorWrapper = {
       implicit val format = DefaultFormats
       val rMetadataPath = new Path(path, "rMetadata").toString
       val pipelinePath = new Path(path, "pipeline").toString
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/FPGrowthWrapper.scala 
b/mllib/src/main/scala/org/apache/spark/ml/r/FPGrowthWrapper.scala
index b8151d8d90702..27a859678372f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/FPGrowthWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/FPGrowthWrapper.scala
@@ -61,7 +61,7 @@ private[r] object FPGrowthWrapper extends 
MLReadable[FPGrowthWrapper] {
   override def read: MLReader[FPGrowthWrapper] = new FPGrowthWrapperReader
 
   class FPGrowthWrapperReader extends MLReader[FPGrowthWrapper] {
-    override def load(path: String): FPGrowthWrapper = {
+    override protected def loadImpl(path: String): FPGrowthWrapper = {
       val modelPath = new Path(path, "model").toString
       val fPGrowthModel = FPGrowthModel.load(modelPath)
 
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala 
b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala
index ecaeec5a7791a..ff55d32019030 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala
@@ -144,7 +144,7 @@ private[r] object GBTClassifierWrapper extends 
MLReadable[GBTClassifierWrapper]
 
   class GBTClassifierWrapperReader extends MLReader[GBTClassifierWrapper] {
 
-    override def load(path: String): GBTClassifierWrapper = {
+    override protected def loadImpl(path: String): GBTClassifierWrapper = {
       implicit val format = DefaultFormats
       val rMetadataPath = new Path(path, "rMetadata").toString
       val pipelinePath = new Path(path, "pipeline").toString
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala 
b/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala
index b568d7859221f..fef2c5755e8f9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala
@@ -128,7 +128,7 @@ private[r] object GBTRegressorWrapper extends 
MLReadable[GBTRegressorWrapper] {
 
   class GBTRegressorWrapperReader extends MLReader[GBTRegressorWrapper] {
 
-    override def load(path: String): GBTRegressorWrapper = {
+    override protected def loadImpl(path: String): GBTRegressorWrapper = {
       implicit val format = DefaultFormats
       val rMetadataPath = new Path(path, "rMetadata").toString
       val pipelinePath = new Path(path, "pipeline").toString
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala 
b/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala
index 9a98a8b18b141..656a0062cd14f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala
@@ -120,7 +120,7 @@ private[r] object GaussianMixtureWrapper extends 
MLReadable[GaussianMixtureWrapp
 
   class GaussianMixtureWrapperReader extends MLReader[GaussianMixtureWrapper] {
 
-    override def load(path: String): GaussianMixtureWrapper = {
+    override protected def loadImpl(path: String): GaussianMixtureWrapper = {
       implicit val format = DefaultFormats
       val rMetadataPath = new Path(path, "rMetadata").toString
       val pipelinePath = new Path(path, "pipeline").toString
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
index 64575b0cb0cb5..c974519211288 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
@@ -179,7 +179,7 @@ private[r] object GeneralizedLinearRegressionWrapper
   class GeneralizedLinearRegressionWrapperReader
     extends MLReader[GeneralizedLinearRegressionWrapper] {
 
-    override def load(path: String): GeneralizedLinearRegressionWrapper = {
+    override protected def loadImpl(path: String): 
GeneralizedLinearRegressionWrapper = {
       implicit val format = DefaultFormats
       val rMetadataPath = new Path(path, "rMetadata").toString
       val pipelinePath = new Path(path, "pipeline").toString
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala 
b/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala
index d31ebb46afb97..dc1f3e62056f6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala
@@ -106,7 +106,7 @@ private[r] object IsotonicRegressionWrapper
 
   class IsotonicRegressionWrapperReader extends 
MLReader[IsotonicRegressionWrapper] {
 
-    override def load(path: String): IsotonicRegressionWrapper = {
+    override protected def loadImpl(path: String): IsotonicRegressionWrapper = 
{
       implicit val format = DefaultFormats
       val rMetadataPath = new Path(path, "rMetadata").toString
       val pipelinePath = new Path(path, "pipeline").toString
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala 
b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
index 8d596863b459e..2b239d7a786ce 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
@@ -129,7 +129,7 @@ private[r] object KMeansWrapper extends 
MLReadable[KMeansWrapper] {
 
   class KMeansWrapperReader extends MLReader[KMeansWrapper] {
 
-    override def load(path: String): KMeansWrapper = {
+    override protected def loadImpl(path: String): KMeansWrapper = {
       implicit val format = DefaultFormats
       val rMetadataPath = new Path(path, "rMetadata").toString
       val pipelinePath = new Path(path, "pipeline").toString
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala 
b/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala
index e096bf1f29f3e..210baec3cb33c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala
@@ -206,7 +206,7 @@ private[r] object LDAWrapper extends MLReadable[LDAWrapper] 
{
 
   class LDAWrapperReader extends MLReader[LDAWrapper] {
 
-    override def load(path: String): LDAWrapper = {
+    override protected def loadImpl(path: String): LDAWrapper = {
       implicit val format = DefaultFormats
       val rMetadataPath = new Path(path, "rMetadata").toString
       val pipelinePath = new Path(path, "pipeline").toString
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala 
b/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala
index 7a22a71c3a819..7ec2717f515de 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala
@@ -144,7 +144,7 @@ private[r] object LinearSVCWrapper
 
   class LinearSVCWrapperReader extends MLReader[LinearSVCWrapper] {
 
-    override def load(path: String): LinearSVCWrapper = {
+    override protected def loadImpl(path: String): LinearSVCWrapper = {
       implicit val format = DefaultFormats
       val rMetadataPath = new Path(path, "rMetadata").toString
       val pipelinePath = new Path(path, "pipeline").toString
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala 
b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala
index 18acf7d21656f..a7dfcb3953e19 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala
@@ -199,7 +199,7 @@ private[r] object LogisticRegressionWrapper
 
   class LogisticRegressionWrapperReader extends 
MLReader[LogisticRegressionWrapper] {
 
-    override def load(path: String): LogisticRegressionWrapper = {
+    override protected def loadImpl(path: String): LogisticRegressionWrapper = 
{
       implicit val format = DefaultFormats
       val rMetadataPath = new Path(path, "rMetadata").toString
       val pipelinePath = new Path(path, "pipeline").toString
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala
index 62f642142701b..6ac25957b9c81 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala
@@ -124,7 +124,7 @@ private[r] object MultilayerPerceptronClassifierWrapper
   class MultilayerPerceptronClassifierWrapperReader
     extends MLReader[MultilayerPerceptronClassifierWrapper]{
 
-    override def load(path: String): MultilayerPerceptronClassifierWrapper = {
+    override protected def loadImpl(path: String): 
MultilayerPerceptronClassifierWrapper = {
       implicit val format = DefaultFormats
       val pipelinePath = new Path(path, "pipeline").toString
 
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala 
b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
index fbf9f462ff5f6..e84b1fc297e5e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
@@ -109,7 +109,7 @@ private[r] object NaiveBayesWrapper extends 
MLReadable[NaiveBayesWrapper] {
 
   class NaiveBayesWrapperReader extends MLReader[NaiveBayesWrapper] {
 
-    override def load(path: String): NaiveBayesWrapper = {
+    override protected def loadImpl(path: String): NaiveBayesWrapper = {
       implicit val format = DefaultFormats
       val rMetadataPath = new Path(path, "rMetadata").toString
       val pipelinePath = new Path(path, "pipeline").toString
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala 
b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
index ba6445a730306..7dd93f1420f6f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
@@ -30,7 +30,7 @@ import org.apache.spark.ml.util.MLReader
  */
 private[r] object RWrappers extends MLReader[Object] {
 
-  override def load(path: String): Object = {
+  override protected def loadImpl(path: String): Object = {
     implicit val format = DefaultFormats
     val rMetadataPath = new Path(path, "rMetadata").toString
     val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
index 132345fb9a6d9..c905319ba0ec2 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
@@ -145,7 +145,7 @@ private[r] object RandomForestClassifierWrapper extends 
MLReadable[RandomForestC
 
   class RandomForestClassifierWrapperReader extends 
MLReader[RandomForestClassifierWrapper] {
 
-    override def load(path: String): RandomForestClassifierWrapper = {
+    override protected def loadImpl(path: String): 
RandomForestClassifierWrapper = {
       implicit val format = DefaultFormats
       val rMetadataPath = new Path(path, "rMetadata").toString
       val pipelinePath = new Path(path, "pipeline").toString
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala
index 038bd79c7022b..1f432da570560 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala
@@ -128,7 +128,7 @@ private[r] object RandomForestRegressorWrapper extends 
MLReadable[RandomForestRe
 
   class RandomForestRegressorWrapperReader extends 
MLReader[RandomForestRegressorWrapper] {
 
-    override def load(path: String): RandomForestRegressorWrapper = {
+    override protected def loadImpl(path: String): 
RandomForestRegressorWrapper = {
       implicit val format = DefaultFormats
       val rMetadataPath = new Path(path, "rMetadata").toString
       val pipelinePath = new Path(path, "pipeline").toString
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala 
b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 50ef4330ddc80..fbcf39dcbe3ba 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -303,7 +303,8 @@ class ALSModel private[ml] (
   }
 
   @Since("2.0.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = 
super.transform(dataset)
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema)
     // create a new column named map(predictionCol) by running the predict UDF.
     val predictions = dataset
@@ -519,7 +520,7 @@ object ALSModel extends MLReadable[ALSModel] {
     /** Checked against metadata when loading model */
     private val className = classOf[ALSModel].getName
 
-    override def load(path: String): ALSModel = {
+    override protected def loadImpl(path: String): ALSModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
       implicit val format = DefaultFormats
       val rank = (metadata.metadata \ "rank").extract[Int]
@@ -655,7 +656,8 @@ class ALS(@Since("1.4.0") override val uid: String) extends 
Estimator[ALSModel]
   }
 
   @Since("2.0.0")
-  override def fit(dataset: Dataset[_]): ALSModel = instrumented { instr =>
+  override def fit(dataset: Dataset[_]): ALSModel = super.fit(dataset)
+  override protected def fitImpl(dataset: Dataset[_]): ALSModel = instrumented 
{ instr =>
     transformSchema(dataset.schema)
     import dataset.sparkSession.implicits._
 
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index 8d6e36697d2cc..ebda26e8ffc8e 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -211,7 +211,9 @@ class AFTSurvivalRegression @Since("1.6.0") 
(@Since("1.6.0") override val uid: S
   }
 
   @Since("2.0.0")
-  override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = 
instrumented { instr =>
+  override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = 
super.fit(dataset)
+  override protected def fitImpl(
+      dataset: Dataset[_]): AFTSurvivalRegressionModel = instrumented { instr 
=>
     transformSchema(dataset.schema, logging = true)
     val instances = extractAFTPoints(dataset)
     val handlePersistence = dataset.storageLevel == StorageLevel.NONE
@@ -353,7 +355,8 @@ class AFTSurvivalRegressionModel private[ml] (
   }
 
   @Since("2.0.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = 
super.transform(dataset)
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     val predictUDF = udf { features: Vector => predict(features) }
     val predictQuantilesUDF = udf { features: Vector => 
predictQuantiles(features)}
@@ -412,7 +415,7 @@ object AFTSurvivalRegressionModel extends 
MLReadable[AFTSurvivalRegressionModel]
     /** Checked against metadata when loading model */
     private val className = classOf[AFTSurvivalRegressionModel].getName
 
-    override def load(path: String): AFTSurvivalRegressionModel = {
+    override protected def loadImpl(path: String): AFTSurvivalRegressionModel 
= {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
 
       val dataPath = new Path(path, "data").toString
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index faadc4d7b4ccc..155c979e28250 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -22,7 +22,7 @@ import org.json4s.{DefaultFormats, JObject}
 import org.json4s.JsonDSL._
 
 import org.apache.spark.annotation.Since
-import org.apache.spark.ml.{PredictionModel, Predictor}
+import org.apache.spark.ml.{MLEvents, PredictionModel, Predictor}
 import org.apache.spark.ml.feature.LabeledPoint
 import org.apache.spark.ml.linalg.Vector
 import org.apache.spark.ml.param.ParamMap
@@ -188,7 +188,8 @@ class DecisionTreeRegressionModel private[ml] (
   }
 
   @Since("2.0.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(
+      dataset: Dataset[_]): DataFrame = MLEvents.withTransformEvent(this, 
dataset) {
     transformSchema(dataset.schema, logging = true)
     transformImpl(dataset)
   }
@@ -275,7 +276,7 @@ object DecisionTreeRegressionModel extends 
MLReadable[DecisionTreeRegressionMode
     /** Checked against metadata when loading model */
     private val className = classOf[DecisionTreeRegressionModel].getName
 
-    override def load(path: String): DecisionTreeRegressionModel = {
+    override protected def loadImpl(path: String): DecisionTreeRegressionModel 
= {
       implicit val format = DefaultFormats
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
       val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala 
b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 9a5b7d59e9aef..befef2737b9bf 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -337,7 +337,7 @@ object GBTRegressionModel extends 
MLReadable[GBTRegressionModel] {
     private val className = classOf[GBTRegressionModel].getName
     private val treeClassName = classOf[DecisionTreeRegressionModel].getName
 
-    override def load(path: String): GBTRegressionModel = {
+    override protected def loadImpl(path: String): GBTRegressionModel = {
       implicit val format = DefaultFormats
       val (metadata: Metadata, treesData: Array[(Metadata, Node)], 
treeWeights: Array[Double]) =
         EnsembleModelReadWrite.loadImpl(path, sparkSession, className, 
treeClassName)
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index abb60ea205751..648167fe3e012 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -26,7 +26,7 @@ import org.apache.hadoop.fs.Path
 import org.apache.spark.SparkException
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.internal.Logging
-import org.apache.spark.ml.PredictorParams
+import org.apache.spark.ml.{MLEvents, PredictorParams}
 import org.apache.spark.ml.attribute.AttributeGroup
 import org.apache.spark.ml.feature.{Instance, OffsetInstance}
 import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors}
@@ -1034,7 +1034,8 @@ class GeneralizedLinearRegressionModel private[ml] (
     BLAS.dot(features, coefficients) + intercept + offset
   }
 
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(
+      dataset: Dataset[_]): DataFrame = MLEvents.withTransformEvent(this, 
dataset) {
     transformSchema(dataset.schema)
     transformImpl(dataset)
   }
@@ -1140,7 +1141,7 @@ object GeneralizedLinearRegressionModel extends 
MLReadable[GeneralizedLinearRegr
     /** Checked against metadata when loading model */
     private val className = classOf[GeneralizedLinearRegressionModel].getName
 
-    override def load(path: String): GeneralizedLinearRegressionModel = {
+    override protected def loadImpl(path: String): 
GeneralizedLinearRegressionModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
 
       val dataPath = new Path(path, "data").toString
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala 
b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
index 8b9233dcdc4d1..422eaf8d4b9b3 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
@@ -162,7 +162,9 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") 
override val uid: Stri
   override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra)
 
   @Since("2.0.0")
-  override def fit(dataset: Dataset[_]): IsotonicRegressionModel = 
instrumented { instr =>
+  override def fit(dataset: Dataset[_]): IsotonicRegressionModel = 
super.fit(dataset)
+  override protected def fitImpl(
+      dataset: Dataset[_]): IsotonicRegressionModel = instrumented { instr =>
     transformSchema(dataset.schema, logging = true)
     // Extract columns from data.  If dataset is persisted, do not persist 
oldDataset.
     val instances = extractWeightedLabeledPoints(dataset)
@@ -239,7 +241,8 @@ class IsotonicRegressionModel private[ml] (
   }
 
   @Since("2.0.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = 
super.transform(dataset)
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     val predict = dataset.schema($(featuresCol)).dataType match {
       case DoubleType =>
@@ -296,7 +299,7 @@ object IsotonicRegressionModel extends 
MLReadable[IsotonicRegressionModel] {
     /** Checked against metadata when loading model */
     private val className = classOf[IsotonicRegressionModel].getName
 
-    override def load(path: String): IsotonicRegressionModel = {
+    override protected def loadImpl(path: String): IsotonicRegressionModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
 
       val dataPath = new Path(path, "data").toString
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala 
b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index ce6c12cc368dd..395ebd60c1728 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -782,7 +782,7 @@ object LinearRegressionModel extends 
MLReadable[LinearRegressionModel] {
     /** Checked against metadata when loading model */
     private val className = classOf[LinearRegressionModel].getName
 
-    override def load(path: String): LinearRegressionModel = {
+    override protected def loadImpl(path: String): LinearRegressionModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
 
       val dataPath = new Path(path, "data").toString
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index afa9a646412b3..d1132f09d1975 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -268,7 +268,7 @@ object RandomForestRegressionModel extends 
MLReadable[RandomForestRegressionMode
     private val className = classOf[RandomForestRegressionModel].getName
     private val treeClassName = classOf[DecisionTreeRegressionModel].getName
 
-    override def load(path: String): RandomForestRegressionModel = {
+    override protected def loadImpl(path: String): RandomForestRegressionModel 
= {
       implicit val format = DefaultFormats
       val (metadata: Metadata, treesData: Array[(Metadata, Node)], 
treeWeights: Array[Double]) =
         EnsembleModelReadWrite.loadImpl(path, sparkSession, className, 
treeClassName)
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index e60a14f976a5c..dd5124a9fd584 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -119,7 +119,8 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") 
override val uid: String)
   def setCollectSubModels(value: Boolean): this.type = set(collectSubModels, 
value)
 
   @Since("2.0.0")
-  override def fit(dataset: Dataset[_]): CrossValidatorModel = instrumented { 
instr =>
+  override def fit(dataset: Dataset[_]): CrossValidatorModel = 
super.fit(dataset)
+  override protected def fitImpl(dataset: Dataset[_]): CrossValidatorModel = 
instrumented { instr =>
     val schema = dataset.schema
     transformSchema(schema, logging = true)
     val sparkSession = dataset.sparkSession
@@ -226,7 +227,7 @@ object CrossValidator extends MLReadable[CrossValidator] {
     /** Checked against metadata when loading model */
     private val className = classOf[CrossValidator].getName
 
-    override def load(path: String): CrossValidator = {
+    override protected def loadImpl(path: String): CrossValidator = {
       implicit val format = DefaultFormats
 
       val (metadata, estimator, evaluator, estimatorParamMaps) =
@@ -299,7 +300,8 @@ class CrossValidatorModel private[ml] (
   def hasSubModels: Boolean = _subModels.isDefined
 
   @Since("2.0.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = 
super.transform(dataset)
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     bestModel.transform(dataset)
   }
@@ -392,7 +394,10 @@ object CrossValidatorModel extends 
MLReadable[CrossValidatorModel] {
     /** Checked against metadata when loading model */
     private val className = classOf[CrossValidatorModel].getName
 
-    override def load(path: String): CrossValidatorModel = {
+    // Explicitly call parent's load. Otherwise, MiMa complains.
+    override def load(path: String): CrossValidatorModel = super.load(path)
+
+    override protected def loadImpl(path: String): CrossValidatorModel = {
       implicit val format = DefaultFormats
 
       val (metadata, estimator, evaluator, estimatorParamMaps) =
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index 8b251197afbef..7fc662de85a77 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -118,7 +118,9 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") 
override val uid: St
   def setCollectSubModels(value: Boolean): this.type = set(collectSubModels, 
value)
 
   @Since("2.0.0")
-  override def fit(dataset: Dataset[_]): TrainValidationSplitModel = 
instrumented { instr =>
+  override def fit(dataset: Dataset[_]): TrainValidationSplitModel = 
super.fit(dataset)
+  override protected def fitImpl(
+      dataset: Dataset[_]): TrainValidationSplitModel = instrumented { instr =>
     val schema = dataset.schema
     transformSchema(schema, logging = true)
     val est = $(estimator)
@@ -220,7 +222,7 @@ object TrainValidationSplit extends 
MLReadable[TrainValidationSplit] {
     /** Checked against metadata when loading model */
     private val className = classOf[TrainValidationSplit].getName
 
-    override def load(path: String): TrainValidationSplit = {
+    override protected def loadImpl(path: String): TrainValidationSplit = {
       implicit val format = DefaultFormats
 
       val (metadata, estimator, evaluator, estimatorParamMaps) =
@@ -290,7 +292,8 @@ class TrainValidationSplitModel private[ml] (
   def hasSubModels: Boolean = _subModels.isDefined
 
   @Since("2.0.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
+  override def transform(dataset: Dataset[_]): DataFrame = 
super.transform(dataset)
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     bestModel.transform(dataset)
   }
@@ -380,7 +383,10 @@ object TrainValidationSplitModel extends 
MLReadable[TrainValidationSplitModel] {
     /** Checked against metadata when loading model */
     private val className = classOf[TrainValidationSplitModel].getName
 
-    override def load(path: String): TrainValidationSplitModel = {
+    // Explicitly call parent's load. Otherwise, MiMa complains.
+    override def load(path: String): TrainValidationSplitModel = 
super.load(path)
+
+    override protected def loadImpl(path: String): TrainValidationSplitModel = 
{
       implicit val format = DefaultFormats
 
       val (metadata, estimator, evaluator, estimatorParamMaps) =
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala 
b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index fbc7be25a5640..0a45fa8e267ed 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -163,7 +163,7 @@ abstract class MLWriter extends BaseReadWrite with Logging {
    */
   @Since("1.6.0")
   @throws[IOException]("If the input path already exists but overwrite is not 
enabled.")
-  def save(path: String): Unit = {
+  def save(path: String): Unit = MLEvents.withSaveInstanceEvent(this, path) {
     new FileSystemOverwrite().handleOverwrite(path, shouldOverwrite, sc)
     saveImpl(path)
   }
@@ -329,7 +329,19 @@ abstract class MLReader[T] extends BaseReadWrite {
    * Loads the ML component from the input path.
    */
   @Since("1.6.0")
-  def load(path: String): T
+  def load(path: String): T = MLEvents.withLoadInstanceEvent(this, path) {
+    loadImpl(path)
+  }
+
+  /**
+   * `load()` handles events and then calls this method. Subclasses should 
override this
+   * method to implement the actual loading of the instance.
+   */
+  @Since("3.0.0")
+  protected def loadImpl(path: String): T = {
+    // Keep this default body for backward compatibility.
+    throw new UnsupportedOperationException("loadImpl is not implemented.")
+  }
 
   // override for Java compatibility
   override def session(sparkSession: SparkSession): this.type = 
super.session(sparkSession)
@@ -467,7 +479,7 @@ private[ml] object DefaultParamsWriter {
  */
 private[ml] class DefaultParamsReader[T] extends MLReader[T] {
 
-  override def load(path: String): T = {
+  override protected def loadImpl(path: String): T = {
     val metadata = DefaultParamsReader.loadMetadata(path, sc)
     val cls = Utils.classForName(metadata.className)
     val instance =
diff --git a/mllib/src/test/scala/org/apache/spark/ml/MLEventsSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/MLEventsSuite.scala
new file mode 100644
index 0000000000000..e63eef870d985
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/MLEventsSuite.scala
@@ -0,0 +1,183 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import scala.collection.mutable
+import scala.concurrent.duration._
+import scala.language.postfixOps
+
+import org.apache.hadoop.fs.Path
+import org.mockito.Matchers.{any, eq => meq}
+import org.mockito.Mockito.when
+import org.scalatest.BeforeAndAfterEach
+import org.scalatest.concurrent.Eventually
+import org.scalatest.mockito.MockitoSugar.mock
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent}
+import org.apache.spark.sql._
+
+
+class MLEventsSuite
+    extends SparkFunSuite
+    with BeforeAndAfterEach
+    with MLlibTestSparkContext
+    with Eventually {
+
+  private val dirName: String = "pipeline"
+  private val events = mutable.ArrayBuffer.empty[MLEvent]
+  private val listener: SparkListener = new SparkListener {
+    override def onOtherEvent(event: SparkListenerEvent): Unit = event match {
+      case e: FitStart[_] => events.append(e)
+      case e: FitEnd[_] => events.append(e)
+      case e: TransformStart => events.append(e)
+      case e: TransformEnd => events.append(e)
+      case e: SaveInstanceStart if e.path.endsWith(dirName) => events.append(e)
+      case e: SaveInstanceEnd if e.path.endsWith(dirName) => events.append(e)
+      case _ =>
+    }
+  }
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    spark.sparkContext.addSparkListener(listener)
+  }
+
+  override def afterEach(): Unit = {
+    try {
+      events.clear()
+    } finally {
+      super.afterEach()
+    }
+  }
+
+  override def afterAll(): Unit = {
+    try {
+      if (spark != null) {
+        spark.sparkContext.removeSparkListener(listener)
+      }
+    } finally {
+      super.afterAll()
+    }
+  }
+
+  abstract class MyModel extends Model[MyModel]
+
+  test("pipeline fit events") {
+    val estimator0 = mock[Estimator[MyModel]]
+    val model0 = mock[MyModel]
+    val transformer1 = mock[Transformer]
+    val estimator2 = mock[Estimator[MyModel]]
+    val model2 = mock[MyModel]
+    val transformer3 = mock[Transformer]
+
+    when(estimator0.copy(any[ParamMap])).thenReturn(estimator0)
+    when(model0.copy(any[ParamMap])).thenReturn(model0)
+    when(transformer1.copy(any[ParamMap])).thenReturn(transformer1)
+    when(estimator2.copy(any[ParamMap])).thenReturn(estimator2)
+    when(model2.copy(any[ParamMap])).thenReturn(model2)
+    when(transformer3.copy(any[ParamMap])).thenReturn(transformer3)
+
+    val dataset0 = mock[DataFrame]
+    val dataset1 = mock[DataFrame]
+    val dataset2 = mock[DataFrame]
+    val dataset3 = mock[DataFrame]
+    val dataset4 = mock[DataFrame]
+
+    when(dataset0.toDF).thenReturn(dataset0)
+    when(dataset1.toDF).thenReturn(dataset1)
+    when(dataset2.toDF).thenReturn(dataset2)
+    when(dataset3.toDF).thenReturn(dataset3)
+    when(dataset4.toDF).thenReturn(dataset4)
+
+    when(estimator0.fit(meq(dataset0))).thenReturn(model0)
+    when(model0.transform(meq(dataset0))).thenReturn(dataset1)
+    when(model0.parent).thenReturn(estimator0)
+    when(transformer1.transform(meq(dataset1))).thenReturn(dataset2)
+    when(estimator2.fit(meq(dataset2))).thenReturn(model2)
+    when(model2.transform(meq(dataset2))).thenReturn(dataset3)
+    when(model2.parent).thenReturn(estimator2)
+    when(transformer3.transform(meq(dataset3))).thenReturn(dataset4)
+
+    val pipeline = new Pipeline()
+      .setStages(Array(estimator0, transformer1, estimator2, transformer3))
+    val pipelineModel = pipeline.fit(dataset0)
+
+    val expected =
+      FitStart(pipeline, dataset0) ::
+      FitEnd(pipeline, pipelineModel) :: Nil
+    eventually(timeout(10 seconds), interval(1 second)) {
+      assert(expected === events)
+    }
+  }
+
+  test("pipeline model transform events") {
+    val dataset = mock[DataFrame]
+    when(dataset.toDF).thenReturn(dataset)
+    val transform1 = mock[Transformer]
+    val model = mock[MyModel]
+    val transform2 = mock[Transformer]
+    val stages = Array(transform1, model, transform2)
+    val newPipelineModel = new PipelineModel("pipeline0", stages)
+    val output = newPipelineModel.transform(dataset)
+
+    val expected =
+      TransformStart(newPipelineModel, dataset) ::
+      TransformEnd(newPipelineModel, output) :: Nil
+    eventually(timeout(10 seconds), interval(1 second)) {
+      assert(expected === events)
+    }
+  }
+
+  test("pipeline read/write events") {
+    withTempDir { dir =>
+      val path = new Path(dir.getCanonicalPath, dirName).toUri.toString
+      val writableStage = new WritableStage("writableStage")
+      val newPipeline = new Pipeline().setStages(Array(writableStage))
+      val pipelineWriter = newPipeline.write
+      pipelineWriter.save(path)
+
+      val expected =
+        SaveInstanceStart(pipelineWriter, path) ::
+        SaveInstanceEnd(pipelineWriter, path) :: Nil
+      eventually(timeout(10 seconds), interval(1 second)) {
+        assert(expected === events)
+      }
+    }
+  }
+
+  test("pipeline model read/write events") {
+    withTempDir { dir =>
+      val path = new Path(dir.getCanonicalPath, dirName).toUri.toString
+      val writableStage = new WritableStage("writableStage")
+      val pipelineModel =
+        new PipelineModel("pipeline_89329329", 
Array(writableStage.asInstanceOf[Transformer]))
+      val pipelineWriter = pipelineModel.write
+      pipelineWriter.save(path)
+
+      val expected =
+        SaveInstanceStart(pipelineWriter, path) ::
+        SaveInstanceEnd(pipelineWriter, path) :: Nil
+      eventually(timeout(10 seconds), interval(1 second)) {
+        assert(expected === events)
+      }
+    }
+  }
+}


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


> Add Spark ML Listener for Tracking ML Pipeline Status
> -----------------------------------------------------
>
>                 Key: SPARK-23674
>                 URL: https://issues.apache.org/jira/browse/SPARK-23674
>             Project: Spark
>          Issue Type: Improvement
>          Components: ML
>    Affects Versions: 2.3.0
>            Reporter: Mingjie Tang
>            Priority: Major
>
> Currently, Spark provides status monitoring for different components of 
> Spark, like spark history server, streaming listener, sql listener and etc. 
> The use case would be (1) front UI to track the status of training coverage 
> rate during iteration, then DS can understand how the job converge when 
> training, like K-means, Logistic and other linear regression model.  (2) 
> tracking the data lineage for the input and output of training data.  
> In this proposal, we hope to provide Spark ML pipeline listener to track the 
> status of Spark ML pipeline status includes: 
>  # ML pipeline create and saved 
>  # ML pipeline model created, saved and load  
>  # ML model training status monitoring  



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org
For additional commands, e-mail: issues-h...@spark.apache.org

Reply via email to