[ https://issues.apache.org/jira/browse/SPARK-23674?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16718526#comment-16718526 ]
ASF GitHub Bot commented on SPARK-23674: ---------------------------------------- felixcheung closed pull request #23263: [SPARK-23674][ML] Adds Spark ML Events URL: https://github.com/apache/spark/pull/23263 This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala index 1247882d6c1bd..a3c4db06862f8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -65,7 +65,19 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage { * Fits a model to the input data. */ @Since("2.0.0") - def fit(dataset: Dataset[_]): M + def fit(dataset: Dataset[_]): M = MLEvents.withFitEvent(this, dataset) { + fitImpl(dataset) + } + + /** + * `fit()` handles events and then calls this method. Subclasses should override this + * method to implement the actual fiting a model to the input data. + */ + @Since("3.0.0") + protected def fitImpl(dataset: Dataset[_]): M = { + // Keep this default body for backward compatibility. + throw new UnsupportedOperationException("fitImpl is not implemented.") + } /** * Fits multiple models to the input data with multiple sets of parameters. diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 103082b7b9766..1c781faff129e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -132,7 +132,8 @@ class Pipeline @Since("1.4.0") ( * @return fitted pipeline */ @Since("2.0.0") - override def fit(dataset: Dataset[_]): PipelineModel = { + override def fit(dataset: Dataset[_]): PipelineModel = super.fit(dataset) + override protected def fitImpl(dataset: Dataset[_]): PipelineModel = { transformSchema(dataset.schema, logging = true) val theStages = $(stages) // Search for the last estimator. @@ -210,7 +211,7 @@ object Pipeline extends MLReadable[Pipeline] { /** Checked against metadata when loading model */ private val className = classOf[Pipeline].getName - override def load(path: String): Pipeline = { + override protected def loadImpl(path: String): Pipeline = { val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path) new Pipeline(uid).setStages(stages) } @@ -301,7 +302,8 @@ class PipelineModel private[ml] ( } @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = super.transform(dataset) + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) stages.foldLeft(dataset.toDF)((cur, transformer) => transformer.transform(cur)) } @@ -344,7 +346,7 @@ object PipelineModel extends MLReadable[PipelineModel] { /** Checked against metadata when loading model */ private val className = classOf[PipelineModel].getName - override def load(path: String): PipelineModel = { + override protected def loadImpl(path: String): PipelineModel = { val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path) val transformers = stages map { case stage: Transformer => stage diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index d8f3dfa874439..3731ddae0160c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -94,7 +94,10 @@ abstract class Predictor[ /** @group setParam */ def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner] - override def fit(dataset: Dataset[_]): M = { + // Explictly call parent's load. Otherwise, MiMa complains. + override def fit(dataset: Dataset[_]): M = super.fit(dataset) + + override protected def fitImpl(dataset: Dataset[_]): M = { // This handles a few items such as schema validation. // Developers only need to implement train(). transformSchema(dataset.schema, logging = true) @@ -199,7 +202,8 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, * @param dataset input dataset * @return transformed dataset with [[predictionCol]] of type `Double` */ - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform( + dataset: Dataset[_]): DataFrame = MLEvents.withTransformEvent(this, dataset) { transformSchema(dataset.schema, logging = true) if ($(predictionCol).nonEmpty) { transformImpl(dataset) @@ -210,7 +214,7 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, } } - protected def transformImpl(dataset: Dataset[_]): DataFrame = { + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { val predictUDF = udf { (features: Any) => predict(features.asInstanceOf[FeaturesType]) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index a3a2b55adc25d..2467383227a66 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -68,7 +68,19 @@ abstract class Transformer extends PipelineStage { * Transforms the input dataset. */ @Since("2.0.0") - def transform(dataset: Dataset[_]): DataFrame + def transform(dataset: Dataset[_]): DataFrame = MLEvents.withTransformEvent(this, dataset) { + transformImpl(dataset) + } + + /** + * `transform()` handles events and then calls this method. Subclasses should override this + * method to implement the actual transformation. + */ + @Since("3.0.0") + protected def transformImpl(dataset: Dataset[_]): DataFrame = { + // Keep this default body for backward compatibility. + throw new UnsupportedOperationException("transformImpl is not implemented.") + } override def copy(extra: ParamMap): Transformer } @@ -116,7 +128,7 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] StructType(outputFields) } - override def transform(dataset: Dataset[_]): DataFrame = { + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val transformUDF = udf(this.createTransformFunc, outputDataType) dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol)))) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 7e5790ab70ee9..4de6ae7990490 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkException import org.apache.spark.annotation.{DeveloperApi, Since} -import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams} +import org.apache.spark.ml.{MLEvents, PredictionModel, Predictor, PredictorParams} import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param.shared.HasRawPredictionCol @@ -156,7 +156,8 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur * @param dataset input dataset * @return transformed dataset */ - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform( + dataset: Dataset[_]): DataFrame = MLEvents.withTransformEvent(this, dataset) { transformSchema(dataset.schema, logging = true) // Output selected columns only. diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index d9292a5476767..ab91b942a06ce 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -275,7 +275,7 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica /** Checked against metadata when loading model */ private val className = classOf[DecisionTreeClassificationModel].getName - override def load(path: String): DecisionTreeClassificationModel = { + override protected def loadImpl(path: String): DecisionTreeClassificationModel = { implicit val format = DefaultFormats val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index abe2d1febfdf8..5f4bac04f16bb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -410,7 +410,7 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { private val className = classOf[GBTClassificationModel].getName private val treeClassName = classOf[DecisionTreeRegressionModel].getName - override def load(path: String): GBTClassificationModel = { + override protected def loadImpl(path: String): GBTClassificationModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index ff801abef9a94..b6e9b8833022d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -373,7 +373,7 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] { /** Checked against metadata when loading model */ private val className = classOf[LinearSVCModel].getName - override def load(path: String): LinearSVCModel = { + override protected def loadImpl(path: String): LinearSVCModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.format("parquet").load(dataPath) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 27a7db0b2f5d4..1367253aa4f9f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1248,7 +1248,7 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { /** Checked against metadata when loading model */ private val className = classOf[LogisticRegressionModel].getName - override def load(path: String): LogisticRegressionModel = { + override protected def loadImpl(path: String): LogisticRegressionModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 47b8a8df637b9..a912ad268ba7e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -359,7 +359,7 @@ object MultilayerPerceptronClassificationModel /** Checked against metadata when loading model */ private val className = classOf[MultilayerPerceptronClassificationModel].getName - override def load(path: String): MultilayerPerceptronClassificationModel = { + override protected def loadImpl(path: String): MultilayerPerceptronClassificationModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 1a7a5e7a52344..04bb1265962c2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -399,7 +399,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] { /** Checked against metadata when loading model */ private val className = classOf[NaiveBayesModel].getName - override def load(path: String): NaiveBayesModel = { + override protected def loadImpl(path: String): NaiveBayesModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index e1fceb1fc96a4..6b0a4d0832007 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -165,7 +165,8 @@ final class OneVsRestModel private[ml] ( } @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = super.transform(dataset) + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { // Check schema transformSchema(dataset.schema, logging = true) @@ -289,7 +290,7 @@ object OneVsRestModel extends MLReadable[OneVsRestModel] { /** Checked against metadata when loading model */ private val className = classOf[OneVsRestModel].getName - override def load(path: String): OneVsRestModel = { + override protected def loadImpl(path: String): OneVsRestModel = { implicit val format = DefaultFormats val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className) val labelMetadata = Metadata.fromJson((metadata.metadata \ "labelMetadata").extract[String]) @@ -372,7 +373,8 @@ final class OneVsRest @Since("1.4.0") ( } @Since("2.0.0") - override def fit(dataset: Dataset[_]): OneVsRestModel = instrumented { instr => + override def fit(dataset: Dataset[_]): OneVsRestModel = super.fit(dataset) + override protected def fitImpl(dataset: Dataset[_]): OneVsRestModel = instrumented { instr => transformSchema(dataset.schema) instr.logPipelineStage(this) @@ -492,7 +494,7 @@ object OneVsRest extends MLReadable[OneVsRest] { /** Checked against metadata when loading model */ private val className = classOf[OneVsRest].getName - override def load(path: String): OneVsRest = { + override protected def loadImpl(path: String): OneVsRest = { val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className) val ovr = new OneVsRest(metadata.uid) metadata.getAndSetParams(ovr) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 730fcab333e11..24781005682dd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.MLEvents import org.apache.spark.ml.linalg.{DenseVector, Vector, VectorUDT} import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils @@ -100,7 +101,8 @@ abstract class ProbabilisticClassificationModel[ * @param dataset input dataset * @return transformed dataset */ - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform( + dataset: Dataset[_]): DataFrame = MLEvents.withTransformEvent(this, dataset) { transformSchema(dataset.schema, logging = true) if (isDefined(thresholds)) { require($(thresholds).length == numClasses, this.getClass.getSimpleName + diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 0a3bfd1f85e08..145b73e1816ec 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -310,7 +310,7 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica private val className = classOf[RandomForestClassificationModel].getName private val treeClassName = classOf[DecisionTreeClassificationModel].getName - override def load(path: String): RandomForestClassificationModel = { + override protected def loadImpl(path: String): RandomForestClassificationModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], _) = EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 1a94aefa3f563..b4c1246e780f2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -105,7 +105,8 @@ class BisectingKMeansModel private[ml] ( def setPredictionCol(value: String): this.type = set(predictionCol, value) @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = super.transform(dataset) + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val predictUDF = udf((vector: Vector) => predict(vector)) dataset.withColumn($(predictionCol), @@ -191,7 +192,7 @@ object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] { /** Checked against metadata when loading model */ private val className = classOf[BisectingKMeansModel].getName - override def load(path: String): BisectingKMeansModel = { + override protected def loadImpl(path: String): BisectingKMeansModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString val mllibModel = MLlibBisectingKMeansModel.load(sc, dataPath) @@ -261,7 +262,9 @@ class BisectingKMeans @Since("2.0.0") ( def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value) @Since("2.0.0") - override def fit(dataset: Dataset[_]): BisectingKMeansModel = instrumented { instr => + override def fit(dataset: Dataset[_]): BisectingKMeansModel = super.fit(dataset) + override protected def fitImpl( + dataset: Dataset[_]): BisectingKMeansModel = instrumented { instr => transformSchema(dataset.schema, logging = true) val rdd = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 88abc1605d69f..4908f03574d14 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -106,7 +106,8 @@ class GaussianMixtureModel private[ml] ( } @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = super.transform(dataset) + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val predUDF = udf((vector: Vector) => predict(vector)) val probUDF = udf((vector: Vector) => predictProbability(vector)) @@ -218,7 +219,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] { /** Checked against metadata when loading model */ private val className = classOf[GaussianMixtureModel].getName - override def load(path: String): GaussianMixtureModel = { + override protected def loadImpl(path: String): GaussianMixtureModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString @@ -336,7 +337,9 @@ class GaussianMixture @Since("2.0.0") ( private val numSamples = 5 @Since("2.0.0") - override def fit(dataset: Dataset[_]): GaussianMixtureModel = instrumented { instr => + override def fit(dataset: Dataset[_]): GaussianMixtureModel = super.fit(dataset) + override protected def fitImpl( + dataset: Dataset[_]): GaussianMixtureModel = instrumented { instr => transformSchema(dataset.schema, logging = true) val sc = dataset.sparkSession.sparkContext diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 2eed84d51782a..1c8a17c04f99e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -124,7 +124,8 @@ class KMeansModel private[ml] ( def setPredictionCol(value: String): this.type = set(predictionCol, value) @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = super.transform(dataset) + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val predictUDF = udf((vector: Vector) => predict(vector)) @@ -238,7 +239,7 @@ object KMeansModel extends MLReadable[KMeansModel] { /** Checked against metadata when loading model */ private val className = classOf[KMeansModel].getName - override def load(path: String): KMeansModel = { + override protected def loadImpl(path: String): KMeansModel = { // Import implicits for Dataset Encoder val sparkSession = super.sparkSession import sparkSession.implicits._ @@ -321,7 +322,8 @@ class KMeans @Since("1.5.0") ( def setSeed(value: Long): this.type = set(seed, value) @Since("2.0.0") - override def fit(dataset: Dataset[_]): KMeansModel = instrumented { instr => + override def fit(dataset: Dataset[_]): KMeansModel = super.fit(dataset) + override protected def fitImpl(dataset: Dataset[_]): KMeansModel = instrumented { instr => transformSchema(dataset.schema, logging = true) val handlePersistence = dataset.storageLevel == StorageLevel.NONE diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 84e73dc19a392..fde97d57fdb75 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -455,7 +455,8 @@ abstract class LDAModel private[ml] ( * This implementation may be changed in the future. */ @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = super.transform(dataset) + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { if ($(topicDistributionCol).nonEmpty) { // TODO: Make the transformer natively in ml framework to avoid extra conversion. @@ -619,7 +620,7 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] { private val className = classOf[LocalLDAModel].getName - override def load(path: String): LocalLDAModel = { + override protected def loadImpl(path: String): LocalLDAModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath) @@ -772,7 +773,7 @@ object DistributedLDAModel extends MLReadable[DistributedLDAModel] { private val className = classOf[DistributedLDAModel].getName - override def load(path: String): DistributedLDAModel = { + override protected def loadImpl(path: String): DistributedLDAModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val modelPath = new Path(path, "oldModel").toString val oldModel = OldDistributedLDAModel.load(sc, modelPath) @@ -895,7 +896,8 @@ class LDA @Since("1.6.0") ( override def copy(extra: ParamMap): LDA = defaultCopy(extra) @Since("2.0.0") - override def fit(dataset: Dataset[_]): LDAModel = instrumented { instr => + override def fit(dataset: Dataset[_]): LDAModel = super.fit(dataset) + override protected def fitImpl(dataset: Dataset[_]): LDAModel = instrumented { instr => transformSchema(dataset.schema, logging = true) instr.logPipelineStage(this) @@ -952,7 +954,7 @@ object LDA extends MLReadable[LDA] { private val className = classOf[LDA].getName - override def load(path: String): LDA = { + override protected def loadImpl(path: String): LDA = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val model = new LDA(metadata.uid) LDAParams.getAndSetParams(model, metadata) diff --git a/mllib/src/main/scala/org/apache/spark/ml/events.scala b/mllib/src/main/scala/org/apache/spark/ml/events.scala new file mode 100644 index 0000000000000..fe009394ca081 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/events.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.Unstable +import org.apache.spark.ml.util.{MLReader, MLWriter} +import org.apache.spark.scheduler.SparkListenerEvent +import org.apache.spark.sql.{DataFrame, Dataset} + +/** + * Event emitted by ML operations. Events are either fired before and/or + * after each operation (the event should document this). + */ +@Unstable +sealed trait MLEvent extends SparkListenerEvent + +/** + * Event fired before `Transformer.transform`. + */ +@Unstable +case class TransformStart(transformer: Transformer, input: Dataset[_]) extends MLEvent +/** + * Event fired after `Transformer.transform`. + */ +@Unstable +case class TransformEnd(transformer: Transformer, output: Dataset[_]) extends MLEvent + +/** + * Event fired before `Estimator.fit`. + */ +@Unstable +case class FitStart[M <: Model[M]](estimator: Estimator[M], dataset: Dataset[_]) extends MLEvent +/** + * Event fired after `Estimator.fit`. + */ +@Unstable +case class FitEnd[M <: Model[M]](estimator: Estimator[M], model: M) extends MLEvent + +/** + * Event fired before `MLReader.load`. + */ +@Unstable +case class LoadInstanceStart[T](reader: MLReader[T], path: String) extends MLEvent +/** + * Event fired after `MLReader.load`. + */ +@Unstable +case class LoadInstanceEnd[T](reader: MLReader[T], instance: T) extends MLEvent + +/** + * Event fired before `MLWriter.save`. + */ +@Unstable +case class SaveInstanceStart(writer: MLWriter, path: String) extends MLEvent +/** + * Event fired after `MLWriter.save`. + */ +@Unstable +case class SaveInstanceEnd(writer: MLWriter, path: String) extends MLEvent + + +private[ml] object MLEvents { + private def listenerBus = SparkContext.getOrCreate().listenerBus + + def withFitEvent[M <: Model[M]]( + estimator: Estimator[M], dataset: Dataset[_])(func: => M): M = { + listenerBus.post(FitStart(estimator, dataset)) + val model: M = func + listenerBus.post(FitEnd(estimator, model)) + model + } + + def withTransformEvent( + transformer: Transformer, input: Dataset[_])(func: => DataFrame): DataFrame = { + listenerBus.post(TransformStart(transformer, input)) + val output: DataFrame = func + listenerBus.post(TransformEnd(transformer, output)) + output + } + + def withLoadInstanceEvent[T](reader: MLReader[T], path: String)(func: => T): T = { + listenerBus.post(LoadInstanceStart(reader, path)) + val instance: T = func + listenerBus.post(LoadInstanceEnd(reader, instance)) + instance + } + + def withSaveInstanceEvent(writer: MLWriter, path: String)(func: => Unit): Unit = { + listenerBus.post(SaveInstanceStart(writer, path)) + func + listenerBus.post(SaveInstanceEnd(writer, path)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index 2b0862c60fdf7..3a02b640e359a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -70,7 +70,8 @@ final class Binarizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) def setOutputCol(value: String): this.type = set(outputCol, value) @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = super.transform(dataset) + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { val outputSchema = transformSchema(dataset.schema, logging = true) val schema = dataset.schema val inputType = schema($(inputCol)).dataType diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala index 0554455a66d7f..f705418432a48 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala @@ -225,7 +225,7 @@ object BucketedRandomProjectionLSHModel extends MLReadable[BucketedRandomProject /** Checked against metadata when loading model */ private val className = classOf[BucketedRandomProjectionLSHModel].getName - override def load(path: String): BucketedRandomProjectionLSHModel = { + override protected def loadImpl(path: String): BucketedRandomProjectionLSHModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 0b989b0d7d253..eaa9ae8d81c45 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -143,7 +143,8 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String def setOutputCols(value: Array[String]): this.type = set(outputCols, value) @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = super.transform(dataset) + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { val transformedSchema = transformSchema(dataset.schema) val (inputColumns, outputColumns) = if (isSet(inputCols)) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index dbfb199ccd58f..ba8482d50a9cf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -197,7 +197,8 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str def setLabelCol(value: String): this.type = set(labelCol, value) @Since("2.0.0") - override def fit(dataset: Dataset[_]): ChiSqSelectorModel = { + override def fit(dataset: Dataset[_]): ChiSqSelectorModel = super.fit(dataset) + override protected def fitImpl(dataset: Dataset[_]): ChiSqSelectorModel = { transformSchema(dataset.schema, logging = true) val input: RDD[OldLabeledPoint] = dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { @@ -263,7 +264,8 @@ final class ChiSqSelectorModel private[ml] ( def setOutputCol(value: String): this.type = set(outputCol, value) @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = super.transform(dataset) + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { val transformedSchema = transformSchema(dataset.schema, logging = true) val newField = transformedSchema.last @@ -327,7 +329,7 @@ object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] { private val className = classOf[ChiSqSelectorModel].getName - override def load(path: String): ChiSqSelectorModel = { + override protected def loadImpl(path: String): ChiSqSelectorModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath).select("selectedFeatures").head() diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index dc8eb8261dbe2..a9ea509ada077 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -181,7 +181,8 @@ class CountVectorizer @Since("1.5.0") (@Since("1.5.0") override val uid: String) def setBinary(value: Boolean): this.type = set(binary, value) @Since("2.0.0") - override def fit(dataset: Dataset[_]): CountVectorizerModel = { + override def fit(dataset: Dataset[_]): CountVectorizerModel = super.fit(dataset) + override protected def fitImpl(dataset: Dataset[_]): CountVectorizerModel = { transformSchema(dataset.schema, logging = true) val vocSize = $(vocabSize) val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0)) @@ -291,7 +292,8 @@ class CountVectorizerModel( private var broadcastDict: Option[Broadcast[Map[String, Int]]] = None @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = super.transform(dataset) + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) if (broadcastDict.isEmpty) { val dict = vocabulary.zipWithIndex.toMap @@ -358,7 +360,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] { private val className = classOf[CountVectorizerModel].getName - override def load(path: String): CountVectorizerModel = { + override protected def loadImpl(path: String): CountVectorizerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala index dc18e1d34880a..a58a1ff8e053c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala @@ -140,7 +140,8 @@ class FeatureHasher(@Since("2.3.0") override val uid: String) extends Transforme def setCategoricalCols(value: Array[String]): this.type = set(categoricalCols, value) @Since("2.3.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = super.transform(dataset) + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { val hashFunc: Any => Int = FeatureHasher.murmur3Hash val n = $(numFeatures) val localInputCols = $(inputCols) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index dbda5b8d8fd4a..57b6bb1dbe432 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -91,7 +91,8 @@ class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String) def setBinary(value: Boolean): this.type = set(binary, value) @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = super.transform(dataset) + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { val outputSchema = transformSchema(dataset.schema) val hashingTF = new feature.HashingTF($(numFeatures)).setBinary($(binary)) // TODO: Make the hashingTF.transform natively in ml framework to avoid extra conversion. diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 58897cca4e5c6..9c91cc0417e48 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -84,7 +84,8 @@ final class IDF @Since("1.4.0") (@Since("1.4.0") override val uid: String) def setMinDocFreq(value: Int): this.type = set(minDocFreq, value) @Since("2.0.0") - override def fit(dataset: Dataset[_]): IDFModel = { + override def fit(dataset: Dataset[_]): IDFModel = super.fit(dataset) + override protected def fitImpl(dataset: Dataset[_]): IDFModel = { transformSchema(dataset.schema, logging = true) val input: RDD[OldVector] = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => OldVectors.fromML(v) @@ -129,7 +130,8 @@ class IDFModel private[ml] ( def setOutputCol(value: String): this.type = set(outputCol, value) @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = super.transform(dataset) + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) // TODO: Make the idfModel.transform natively in ml framework to avoid extra conversion. val idf = udf { vec: Vector => idfModel.transform(OldVectors.fromML(vec)).asML } @@ -174,7 +176,7 @@ object IDFModel extends MLReadable[IDFModel] { private val className = classOf[IDFModel].getName - override def load(path: String): IDFModel = { + override protected def loadImpl(path: String): IDFModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index 1c074e204ad99..761947a7c394c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -120,7 +120,10 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String) setDefault(strategy -> Imputer.mean, missingValue -> Double.NaN) - override def fit(dataset: Dataset[_]): ImputerModel = { + // Explicitly call parent's load. Otherwise, MiMa complains. + override def fit(dataset: Dataset[_]): ImputerModel = super.fit(dataset) + + override protected def fitImpl(dataset: Dataset[_]): ImputerModel = { transformSchema(dataset.schema, logging = true) val spark = dataset.sparkSession @@ -211,7 +214,7 @@ class ImputerModel private[ml] ( /** @group setParam */ def setOutputCols(value: Array[String]): this.type = set(outputCols, value) - override def transform(dataset: Dataset[_]): DataFrame = { + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val surrogates = surrogateDF.select($(inputCols).map(col): _*).head().toSeq @@ -257,7 +260,7 @@ object ImputerModel extends MLReadable[ImputerModel] { private val className = classOf[ImputerModel].getName - override def load(path: String): ImputerModel = { + override protected def loadImpl(path: String): ImputerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString val surrogateDF = sqlContext.read.parquet(dataPath) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index 611f1b691b782..319927e9dd228 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -67,7 +67,8 @@ class Interaction @Since("1.6.0") (@Since("1.6.0") override val uid: String) ext } @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = super.transform(dataset) + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val inputFeatures = $(inputCols).map(c => dataset.schema(c)) val featureEncoders = getFeatureEncoders(inputFeatures) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala index b20852383a6ff..b7cc016ebcaf1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala @@ -95,7 +95,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] */ protected[ml] def hashDistance(x: Seq[Vector], y: Seq[Vector]): Double - override def transform(dataset: Dataset[_]): DataFrame = { + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val transformUDF = udf(hashFunction(_: Vector), DataTypes.createArrayType(new VectorUDT)) dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol)))) @@ -323,7 +323,7 @@ private[ml] abstract class LSH[T <: LSHModel[T]] */ protected[this] def createRawLSHModel(inputDim: Int): T - override def fit(dataset: Dataset[_]): T = { + override protected def fitImpl(dataset: Dataset[_]): T = { transformSchema(dataset.schema, logging = true) val inputDim = dataset.select(col($(inputCol))).head().get(0).asInstanceOf[Vector].size val model = createRawLSHModel(inputDim).setParent(this) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala index 90eceb0d61b40..295d3d0ccab3d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala @@ -68,7 +68,8 @@ class MaxAbsScaler @Since("2.0.0") (@Since("2.0.0") override val uid: String) def setOutputCol(value: String): this.type = set(outputCol, value) @Since("2.0.0") - override def fit(dataset: Dataset[_]): MaxAbsScalerModel = { + override def fit(dataset: Dataset[_]): MaxAbsScalerModel = super.fit(dataset) + override protected def fitImpl(dataset: Dataset[_]): MaxAbsScalerModel = { transformSchema(dataset.schema, logging = true) val input: RDD[OldVector] = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => OldVectors.fromML(v) @@ -119,7 +120,8 @@ class MaxAbsScalerModel private[ml] ( def setOutputCol(value: String): this.type = set(outputCol, value) @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = super.transform(dataset) + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) // TODO: this looks hack, we may have to handle sparse and dense vectors separately. val maxAbsUnzero = Vectors.dense(maxAbs.toArray.map(x => if (x == 0) 1 else x)) @@ -165,7 +167,7 @@ object MaxAbsScalerModel extends MLReadable[MaxAbsScalerModel] { private val className = classOf[MaxAbsScalerModel].getName - override def load(path: String): MaxAbsScalerModel = { + override protected def loadImpl(path: String): MaxAbsScalerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString val Row(maxAbs: Vector) = sparkSession.read.parquet(dataPath) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala index 21cde66d8db6b..168accb1a1722 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala @@ -194,7 +194,7 @@ object MinHashLSHModel extends MLReadable[MinHashLSHModel] { /** Checked against metadata when loading model */ private val className = classOf[MinHashLSHModel].getName - override def load(path: String): MinHashLSHModel = { + override protected def loadImpl(path: String): MinHashLSHModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index 2e0ae4af66f06..d76fc5ec05c52 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -115,7 +115,8 @@ class MinMaxScaler @Since("1.5.0") (@Since("1.5.0") override val uid: String) def setMax(value: Double): this.type = set(max, value) @Since("2.0.0") - override def fit(dataset: Dataset[_]): MinMaxScalerModel = { + override def fit(dataset: Dataset[_]): MinMaxScalerModel = super.fit(dataset) + override protected def fitImpl(dataset: Dataset[_]): MinMaxScalerModel = { transformSchema(dataset.schema, logging = true) val input: RDD[OldVector] = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => OldVectors.fromML(v) @@ -174,7 +175,8 @@ class MinMaxScalerModel private[ml] ( def setMax(value: Double): this.type = set(max, value) @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = super.transform(dataset) + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val originalRange = (originalMax.asBreeze - originalMin.asBreeze).toArray val minArray = originalMin.toArray @@ -234,7 +236,7 @@ object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] { private val className = classOf[MinMaxScalerModel].getName - override def load(path: String): MinMaxScalerModel = { + override protected def loadImpl(path: String): MinMaxScalerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index ec9792cbbda8f..4e4a259c95cce 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -147,7 +147,8 @@ class OneHotEncoder @Since("3.0.0") (@Since("3.0.0") override val uid: String) } @Since("3.0.0") - override def fit(dataset: Dataset[_]): OneHotEncoderModel = { + override def fit(dataset: Dataset[_]): OneHotEncoderModel = super.fit(dataset) + override protected def fitImpl(dataset: Dataset[_]): OneHotEncoderModel = { transformSchema(dataset.schema) // Compute the plain number of categories without `handleInvalid` and @@ -324,7 +325,8 @@ class OneHotEncoderModel private[ml] ( } @Since("3.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = super.transform(dataset) + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { val transformedSchema = transformSchema(dataset.schema, logging = true) val keepInvalid = $(handleInvalid) == OneHotEncoder.KEEP_INVALID @@ -378,7 +380,7 @@ object OneHotEncoderModel extends MLReadable[OneHotEncoderModel] { private val className = classOf[OneHotEncoderModel].getName - override def load(path: String): OneHotEncoderModel = { + override protected def loadImpl(path: String): OneHotEncoderModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 8172491a517d1..0a79b49b984c3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -90,7 +90,8 @@ class PCA @Since("1.5.0") ( * Computes a [[PCAModel]] that contains the principal components of the input vectors. */ @Since("2.0.0") - override def fit(dataset: Dataset[_]): PCAModel = { + override def fit(dataset: Dataset[_]): PCAModel = super.fit(dataset) + override protected def fitImpl(dataset: Dataset[_]): PCAModel = { transformSchema(dataset.schema, logging = true) val input: RDD[OldVector] = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => OldVectors.fromML(v) @@ -147,7 +148,8 @@ class PCAModel private[ml] ( * to `PCA.fit()`. */ @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = super.transform(dataset) + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val pcaModel = new feature.PCAModel($(k), OldMatrices.fromML(pc).asInstanceOf[OldDenseMatrix], @@ -203,7 +205,7 @@ object PCAModel extends MLReadable[PCAModel] { * @param path path to serialized model data * @return a [[PCAModel]] */ - override def load(path: String): PCAModel = { + override protected def loadImpl(path: String): PCAModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 5bfaa3b7f3f52..0f29ae9687c02 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -199,7 +199,8 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui } @Since("2.0.0") - override def fit(dataset: Dataset[_]): Bucketizer = { + override def fit(dataset: Dataset[_]): Bucketizer = super.fit(dataset) + override protected def fitImpl(dataset: Dataset[_]): Bucketizer = { transformSchema(dataset.schema, logging = true) val bucketizer = new Bucketizer(uid).setHandleInvalid($(handleInvalid)) if (isSet(inputCols)) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index d7eb13772aa64..f5132f09f95f4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -193,7 +193,8 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) } @Since("2.0.0") - override def fit(dataset: Dataset[_]): RFormulaModel = { + override def fit(dataset: Dataset[_]): RFormulaModel = super.fit(dataset) + override protected def fitImpl(dataset: Dataset[_]): RFormulaModel = { transformSchema(dataset.schema, logging = true) require(isDefined(formula), "Formula must be defined first.") val parsedFormula = RFormulaParser.parse($(formula)) @@ -338,7 +339,8 @@ class RFormulaModel private[feature]( extends Model[RFormulaModel] with RFormulaBase with MLWritable { @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = super.transform(dataset) + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { checkCanTransform(dataset.schema) transformLabel(pipelineModel.transform(dataset)) } @@ -431,7 +433,7 @@ object RFormulaModel extends MLReadable[RFormulaModel] { /** Checked against metadata when loading model */ private val className = classOf[RFormulaModel].getName - override def load(path: String): RFormulaModel = { + override protected def loadImpl(path: String): RFormulaModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString @@ -462,7 +464,7 @@ private class ColumnPruner(override val uid: String, val columnsToPrune: Set[Str def this(columnsToPrune: Set[String]) = this(Identifiable.randomUID("columnPruner"), columnsToPrune) - override def transform(dataset: Dataset[_]): DataFrame = { + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { val columnsToKeep = dataset.columns.filter(!columnsToPrune.contains(_)) dataset.select(columnsToKeep.map(dataset.col): _*) } @@ -502,7 +504,7 @@ private object ColumnPruner extends MLReadable[ColumnPruner] { /** Checked against metadata when loading model */ private val className = classOf[ColumnPruner].getName - override def load(path: String): ColumnPruner = { + override protected def loadImpl(path: String): ColumnPruner = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString @@ -535,7 +537,7 @@ private class VectorAttributeRewriter( def this(vectorCol: String, prefixesToRewrite: Map[String, String]) = this(Identifiable.randomUID("vectorAttrRewriter"), vectorCol, prefixesToRewrite) - override def transform(dataset: Dataset[_]): DataFrame = { + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { val metadata = { val group = AttributeGroup.fromStructField(dataset.schema(vectorCol)) val attrs = group.attributes.get.map { attr => @@ -593,7 +595,7 @@ private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewrite /** Checked against metadata when loading model */ private val className = classOf[VectorAttributeRewriter].getName - override def load(path: String): VectorAttributeRewriter = { + override protected def loadImpl(path: String): VectorAttributeRewriter = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index 0fb1d8c5dc579..d453ab3723c0c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -64,7 +64,8 @@ class SQLTransformer @Since("1.6.0") (@Since("1.6.0") override val uid: String) private val tableIdentifier: String = "__THIS__" @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = super.transform(dataset) + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val tableName = Identifiable.randomUID(uid) dataset.createOrReplaceTempView(tableName) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 91b0707dec3f3..c96726faab4f5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -108,7 +108,8 @@ class StandardScaler @Since("1.4.0") ( def setWithStd(value: Boolean): this.type = set(withStd, value) @Since("2.0.0") - override def fit(dataset: Dataset[_]): StandardScalerModel = { + override def fit(dataset: Dataset[_]): StandardScalerModel = super.fit(dataset) + override protected def fitImpl(dataset: Dataset[_]): StandardScalerModel = { transformSchema(dataset.schema, logging = true) val input: RDD[OldVector] = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => OldVectors.fromML(v) @@ -158,7 +159,8 @@ class StandardScalerModel private[ml] ( def setOutputCol(value: String): this.type = set(outputCol, value) @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = super.transform(dataset) + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val scaler = new feature.StandardScalerModel(std, mean, $(withStd), $(withMean)) @@ -204,7 +206,7 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] { private val className = classOf[StandardScalerModel].getName - override def load(path: String): StandardScalerModel = { + override protected def loadImpl(path: String): StandardScalerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 6669d402cd996..8110cccf27167 100755 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -109,7 +109,8 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String caseSensitive -> false, locale -> Locale.getDefault.toString) @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = super.transform(dataset) + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { val outputSchema = transformSchema(dataset.schema) val t = if ($(caseSensitive)) { val stopWordsSet = $(stopWords).toSet diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index a833d8b270cf1..74843bac1ea82 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -131,7 +131,8 @@ class StringIndexer @Since("1.4.0") ( def setOutputCol(value: String): this.type = set(outputCol, value) @Since("2.0.0") - override def fit(dataset: Dataset[_]): StringIndexerModel = { + override def fit(dataset: Dataset[_]): StringIndexerModel = super.fit(dataset) + override protected def fitImpl(dataset: Dataset[_]): StringIndexerModel = { transformSchema(dataset.schema, logging = true) val values = dataset.na.drop(Array($(inputCol))) .select(col($(inputCol)).cast(StringType)) @@ -218,7 +219,8 @@ class StringIndexerModel ( def setOutputCol(value: String): this.type = set(outputCol, value) @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = super.transform(dataset) + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { if (!dataset.schema.fieldNames.contains($(inputCol))) { logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " + "Skip StringIndexerModel.") @@ -307,7 +309,7 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] { private val className = classOf[StringIndexerModel].getName - override def load(path: String): StringIndexerModel = { + override protected def loadImpl(path: String): StringIndexerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath) @@ -386,7 +388,8 @@ class IndexToString @Since("2.2.0") (@Since("1.5.0") override val uid: String) } @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = super.transform(dataset) + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val inputColSchema = dataset.schema($(inputCol)) // If the labels array is empty use column metadata diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 57e23d5072b88..6add5f8d8c4e6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -82,7 +82,8 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) setDefault(handleInvalid, VectorAssembler.ERROR_INVALID) @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = super.transform(dataset) + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) // Schema transformation. val schema = dataset.schema diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 0e7396a621dbd..fd6e79c4e21fb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -140,7 +140,8 @@ class VectorIndexer @Since("1.4.0") ( def setHandleInvalid(value: String): this.type = set(handleInvalid, value) @Since("2.0.0") - override def fit(dataset: Dataset[_]): VectorIndexerModel = { + override def fit(dataset: Dataset[_]): VectorIndexerModel = super.fit(dataset) + override protected def fitImpl(dataset: Dataset[_]): VectorIndexerModel = { transformSchema(dataset.schema, logging = true) val firstRow = dataset.select($(inputCol)).take(1) require(firstRow.length == 1, s"VectorIndexer cannot be fit on an empty dataset.") @@ -425,7 +426,8 @@ class VectorIndexerModel private[ml] ( def setOutputCol(value: String): this.type = set(outputCol, value) @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = super.transform(dataset) + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val newField = prepOutputField(dataset.schema) val transformUDF = udf { (vector: Vector) => transformFunc(vector) } @@ -528,7 +530,7 @@ object VectorIndexerModel extends MLReadable[VectorIndexerModel] { private val className = classOf[VectorIndexerModel].getName - override def load(path: String): VectorIndexerModel = { + override protected def loadImpl(path: String): VectorIndexerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala index f5947d61fe349..854fd1b4d3776 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala @@ -96,7 +96,8 @@ class VectorSizeHint @Since("2.3.0") (@Since("2.3.0") override val uid: String) setDefault(handleInvalid, VectorSizeHint.ERROR_INVALID) @Since("2.3.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = super.transform(dataset) + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { val localInputCol = getInputCol val localSize = getSize val localHandleInvalid = getHandleInvalid diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala index e3e462d07e10c..cd1d49c1a2cb0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala @@ -98,7 +98,8 @@ final class VectorSlicer @Since("1.5.0") (@Since("1.5.0") override val uid: Stri def setOutputCol(value: String): this.type = set(outputCol, value) @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = super.transform(dataset) + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { // Validity checks transformSchema(dataset.schema) val inputAttr = AttributeGroup.fromStructField(dataset.schema($(inputCol))) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index fc9996d69ba72..31e5b00765c68 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -171,7 +171,8 @@ final class Word2Vec @Since("1.4.0") ( def setMaxSentenceLength(value: Int): this.type = set(maxSentenceLength, value) @Since("2.0.0") - override def fit(dataset: Dataset[_]): Word2VecModel = { + override def fit(dataset: Dataset[_]): Word2VecModel = super.fit(dataset) + override protected def fitImpl(dataset: Dataset[_]): Word2VecModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0)) val wordVectors = new feature.Word2Vec() @@ -286,7 +287,8 @@ class Word2VecModel private[ml] ( * is performed by averaging all word vectors it contains. */ @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = super.transform(dataset) + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val vectors = wordVectors.getVectors .mapValues(vv => Vectors.dense(vv.map(_.toDouble))) @@ -385,7 +387,7 @@ object Word2VecModel extends MLReadable[Word2VecModel] { private val className = classOf[Word2VecModel].getName - override def load(path: String): Word2VecModel = { + override protected def loadImpl(path: String): Word2VecModel = { val spark = sparkSession import spark.implicits._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 7322815c12ab8..d2ca4bc53957f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -157,7 +157,8 @@ class FPGrowth @Since("2.2.0") ( def setPredictionCol(value: String): this.type = set(predictionCol, value) @Since("2.2.0") - override def fit(dataset: Dataset[_]): FPGrowthModel = { + override def fit(dataset: Dataset[_]): FPGrowthModel = super.fit(dataset) + override protected def fitImpl(dataset: Dataset[_]): FPGrowthModel = { transformSchema(dataset.schema, logging = true) genericFit(dataset) } @@ -277,7 +278,8 @@ class FPGrowthModel private[ml] ( * efficiency. This may bring pressure to driver memory for large set of association rules. */ @Since("2.2.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = super.transform(dataset) + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) genericTransform(dataset) } @@ -342,7 +344,7 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] { /** Checked against metadata when loading model */ private val className = classOf[FPGrowthModel].getName - override def load(path: String): FPGrowthModel = { + override protected def loadImpl(path: String): FPGrowthModel = { implicit val format = DefaultFormats val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala index 1b5f77a9ae897..d45a5fa14df90 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala @@ -136,7 +136,7 @@ private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalReg class AFTSurvivalRegressionWrapperReader extends MLReader[AFTSurvivalRegressionWrapper] { - override def load(path: String): AFTSurvivalRegressionWrapper = { + override protected def loadImpl(path: String): AFTSurvivalRegressionWrapper = { implicit val format = DefaultFormats val rMetadataPath = new Path(path, "rMetadata").toString val pipelinePath = new Path(path, "pipeline").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/ALSWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/ALSWrapper.scala index ad13cced4667b..8d12d80ab4734 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/ALSWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/ALSWrapper.scala @@ -102,7 +102,7 @@ private[r] object ALSWrapper extends MLReadable[ALSWrapper] { class ALSWrapperReader extends MLReader[ALSWrapper] { - override def load(path: String): ALSWrapper = { + override protected def loadImpl(path: String): ALSWrapper = { implicit val format = DefaultFormats val rMetadataPath = new Path(path, "rMetadata").toString val modelPath = new Path(path, "model").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/BisectingKMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/BisectingKMeansWrapper.scala index 71712c1c5eec5..8478c7b6e8c77 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/BisectingKMeansWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/BisectingKMeansWrapper.scala @@ -126,7 +126,7 @@ private[r] object BisectingKMeansWrapper extends MLReadable[BisectingKMeansWrapp class BisectingKMeansWrapperReader extends MLReader[BisectingKMeansWrapper] { - override def load(path: String): BisectingKMeansWrapper = { + override protected def loadImpl(path: String): BisectingKMeansWrapper = { implicit val format = DefaultFormats val rMetadataPath = new Path(path, "rMetadata").toString val pipelinePath = new Path(path, "pipeline").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala index a90cae5869b2a..ad9476c1c000c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala @@ -137,7 +137,7 @@ private[r] object DecisionTreeClassifierWrapper extends MLReadable[DecisionTreeC class DecisionTreeClassifierWrapperReader extends MLReader[DecisionTreeClassifierWrapper] { - override def load(path: String): DecisionTreeClassifierWrapper = { + override protected def loadImpl(path: String): DecisionTreeClassifierWrapper = { implicit val format = DefaultFormats val rMetadataPath = new Path(path, "rMetadata").toString val pipelinePath = new Path(path, "pipeline").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeRegressionWrapper.scala index de712d67e6df5..efe7884d5a170 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeRegressionWrapper.scala @@ -120,7 +120,7 @@ private[r] object DecisionTreeRegressorWrapper extends MLReadable[DecisionTreeRe class DecisionTreeRegressorWrapperReader extends MLReader[DecisionTreeRegressorWrapper] { - override def load(path: String): DecisionTreeRegressorWrapper = { + override protected def loadImpl(path: String): DecisionTreeRegressorWrapper = { implicit val format = DefaultFormats val rMetadataPath = new Path(path, "rMetadata").toString val pipelinePath = new Path(path, "pipeline").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/FPGrowthWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/FPGrowthWrapper.scala index b8151d8d90702..27a859678372f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/FPGrowthWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/FPGrowthWrapper.scala @@ -61,7 +61,7 @@ private[r] object FPGrowthWrapper extends MLReadable[FPGrowthWrapper] { override def read: MLReader[FPGrowthWrapper] = new FPGrowthWrapperReader class FPGrowthWrapperReader extends MLReader[FPGrowthWrapper] { - override def load(path: String): FPGrowthWrapper = { + override protected def loadImpl(path: String): FPGrowthWrapper = { val modelPath = new Path(path, "model").toString val fPGrowthModel = FPGrowthModel.load(modelPath) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala index ecaeec5a7791a..ff55d32019030 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala @@ -144,7 +144,7 @@ private[r] object GBTClassifierWrapper extends MLReadable[GBTClassifierWrapper] class GBTClassifierWrapperReader extends MLReader[GBTClassifierWrapper] { - override def load(path: String): GBTClassifierWrapper = { + override protected def loadImpl(path: String): GBTClassifierWrapper = { implicit val format = DefaultFormats val rMetadataPath = new Path(path, "rMetadata").toString val pipelinePath = new Path(path, "pipeline").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala index b568d7859221f..fef2c5755e8f9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala @@ -128,7 +128,7 @@ private[r] object GBTRegressorWrapper extends MLReadable[GBTRegressorWrapper] { class GBTRegressorWrapperReader extends MLReader[GBTRegressorWrapper] { - override def load(path: String): GBTRegressorWrapper = { + override protected def loadImpl(path: String): GBTRegressorWrapper = { implicit val format = DefaultFormats val rMetadataPath = new Path(path, "rMetadata").toString val pipelinePath = new Path(path, "pipeline").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala index 9a98a8b18b141..656a0062cd14f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala @@ -120,7 +120,7 @@ private[r] object GaussianMixtureWrapper extends MLReadable[GaussianMixtureWrapp class GaussianMixtureWrapperReader extends MLReader[GaussianMixtureWrapper] { - override def load(path: String): GaussianMixtureWrapper = { + override protected def loadImpl(path: String): GaussianMixtureWrapper = { implicit val format = DefaultFormats val rMetadataPath = new Path(path, "rMetadata").toString val pipelinePath = new Path(path, "pipeline").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala index 64575b0cb0cb5..c974519211288 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -179,7 +179,7 @@ private[r] object GeneralizedLinearRegressionWrapper class GeneralizedLinearRegressionWrapperReader extends MLReader[GeneralizedLinearRegressionWrapper] { - override def load(path: String): GeneralizedLinearRegressionWrapper = { + override protected def loadImpl(path: String): GeneralizedLinearRegressionWrapper = { implicit val format = DefaultFormats val rMetadataPath = new Path(path, "rMetadata").toString val pipelinePath = new Path(path, "pipeline").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala index d31ebb46afb97..dc1f3e62056f6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala @@ -106,7 +106,7 @@ private[r] object IsotonicRegressionWrapper class IsotonicRegressionWrapperReader extends MLReader[IsotonicRegressionWrapper] { - override def load(path: String): IsotonicRegressionWrapper = { + override protected def loadImpl(path: String): IsotonicRegressionWrapper = { implicit val format = DefaultFormats val rMetadataPath = new Path(path, "rMetadata").toString val pipelinePath = new Path(path, "pipeline").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala index 8d596863b459e..2b239d7a786ce 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala @@ -129,7 +129,7 @@ private[r] object KMeansWrapper extends MLReadable[KMeansWrapper] { class KMeansWrapperReader extends MLReader[KMeansWrapper] { - override def load(path: String): KMeansWrapper = { + override protected def loadImpl(path: String): KMeansWrapper = { implicit val format = DefaultFormats val rMetadataPath = new Path(path, "rMetadata").toString val pipelinePath = new Path(path, "pipeline").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala index e096bf1f29f3e..210baec3cb33c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala @@ -206,7 +206,7 @@ private[r] object LDAWrapper extends MLReadable[LDAWrapper] { class LDAWrapperReader extends MLReader[LDAWrapper] { - override def load(path: String): LDAWrapper = { + override protected def loadImpl(path: String): LDAWrapper = { implicit val format = DefaultFormats val rMetadataPath = new Path(path, "rMetadata").toString val pipelinePath = new Path(path, "pipeline").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala index 7a22a71c3a819..7ec2717f515de 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala @@ -144,7 +144,7 @@ private[r] object LinearSVCWrapper class LinearSVCWrapperReader extends MLReader[LinearSVCWrapper] { - override def load(path: String): LinearSVCWrapper = { + override protected def loadImpl(path: String): LinearSVCWrapper = { implicit val format = DefaultFormats val rMetadataPath = new Path(path, "rMetadata").toString val pipelinePath = new Path(path, "pipeline").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala index 18acf7d21656f..a7dfcb3953e19 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala @@ -199,7 +199,7 @@ private[r] object LogisticRegressionWrapper class LogisticRegressionWrapperReader extends MLReader[LogisticRegressionWrapper] { - override def load(path: String): LogisticRegressionWrapper = { + override protected def loadImpl(path: String): LogisticRegressionWrapper = { implicit val format = DefaultFormats val rMetadataPath = new Path(path, "rMetadata").toString val pipelinePath = new Path(path, "pipeline").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala index 62f642142701b..6ac25957b9c81 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala @@ -124,7 +124,7 @@ private[r] object MultilayerPerceptronClassifierWrapper class MultilayerPerceptronClassifierWrapperReader extends MLReader[MultilayerPerceptronClassifierWrapper]{ - override def load(path: String): MultilayerPerceptronClassifierWrapper = { + override protected def loadImpl(path: String): MultilayerPerceptronClassifierWrapper = { implicit val format = DefaultFormats val pipelinePath = new Path(path, "pipeline").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala index fbf9f462ff5f6..e84b1fc297e5e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala @@ -109,7 +109,7 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] { class NaiveBayesWrapperReader extends MLReader[NaiveBayesWrapper] { - override def load(path: String): NaiveBayesWrapper = { + override protected def loadImpl(path: String): NaiveBayesWrapper = { implicit val format = DefaultFormats val rMetadataPath = new Path(path, "rMetadata").toString val pipelinePath = new Path(path, "pipeline").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala index ba6445a730306..7dd93f1420f6f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala @@ -30,7 +30,7 @@ import org.apache.spark.ml.util.MLReader */ private[r] object RWrappers extends MLReader[Object] { - override def load(path: String): Object = { + override protected def loadImpl(path: String): Object = { implicit val format = DefaultFormats val rMetadataPath = new Path(path, "rMetadata").toString val rMetadataStr = sc.textFile(rMetadataPath, 1).first() diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala index 132345fb9a6d9..c905319ba0ec2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala @@ -145,7 +145,7 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC class RandomForestClassifierWrapperReader extends MLReader[RandomForestClassifierWrapper] { - override def load(path: String): RandomForestClassifierWrapper = { + override protected def loadImpl(path: String): RandomForestClassifierWrapper = { implicit val format = DefaultFormats val rMetadataPath = new Path(path, "rMetadata").toString val pipelinePath = new Path(path, "pipeline").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala index 038bd79c7022b..1f432da570560 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala @@ -128,7 +128,7 @@ private[r] object RandomForestRegressorWrapper extends MLReadable[RandomForestRe class RandomForestRegressorWrapperReader extends MLReader[RandomForestRegressorWrapper] { - override def load(path: String): RandomForestRegressorWrapper = { + override protected def loadImpl(path: String): RandomForestRegressorWrapper = { implicit val format = DefaultFormats val rMetadataPath = new Path(path, "rMetadata").toString val pipelinePath = new Path(path, "pipeline").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 50ef4330ddc80..fbcf39dcbe3ba 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -303,7 +303,8 @@ class ALSModel private[ml] ( } @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = super.transform(dataset) + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema) // create a new column named map(predictionCol) by running the predict UDF. val predictions = dataset @@ -519,7 +520,7 @@ object ALSModel extends MLReadable[ALSModel] { /** Checked against metadata when loading model */ private val className = classOf[ALSModel].getName - override def load(path: String): ALSModel = { + override protected def loadImpl(path: String): ALSModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) implicit val format = DefaultFormats val rank = (metadata.metadata \ "rank").extract[Int] @@ -655,7 +656,8 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] } @Since("2.0.0") - override def fit(dataset: Dataset[_]): ALSModel = instrumented { instr => + override def fit(dataset: Dataset[_]): ALSModel = super.fit(dataset) + override protected def fitImpl(dataset: Dataset[_]): ALSModel = instrumented { instr => transformSchema(dataset.schema) import dataset.sparkSession.implicits._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 8d6e36697d2cc..ebda26e8ffc8e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -211,7 +211,9 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S } @Since("2.0.0") - override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = instrumented { instr => + override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = super.fit(dataset) + override protected def fitImpl( + dataset: Dataset[_]): AFTSurvivalRegressionModel = instrumented { instr => transformSchema(dataset.schema, logging = true) val instances = extractAFTPoints(dataset) val handlePersistence = dataset.storageLevel == StorageLevel.NONE @@ -353,7 +355,8 @@ class AFTSurvivalRegressionModel private[ml] ( } @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = super.transform(dataset) + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val predictUDF = udf { features: Vector => predict(features) } val predictQuantilesUDF = udf { features: Vector => predictQuantiles(features)} @@ -412,7 +415,7 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] /** Checked against metadata when loading model */ private val className = classOf[AFTSurvivalRegressionModel].getName - override def load(path: String): AFTSurvivalRegressionModel = { + override protected def loadImpl(path: String): AFTSurvivalRegressionModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index faadc4d7b4ccc..155c979e28250 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -22,7 +22,7 @@ import org.json4s.{DefaultFormats, JObject} import org.json4s.JsonDSL._ import org.apache.spark.annotation.Since -import org.apache.spark.ml.{PredictionModel, Predictor} +import org.apache.spark.ml.{MLEvents, PredictionModel, Predictor} import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param.ParamMap @@ -188,7 +188,8 @@ class DecisionTreeRegressionModel private[ml] ( } @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform( + dataset: Dataset[_]): DataFrame = MLEvents.withTransformEvent(this, dataset) { transformSchema(dataset.schema, logging = true) transformImpl(dataset) } @@ -275,7 +276,7 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode /** Checked against metadata when loading model */ private val className = classOf[DecisionTreeRegressionModel].getName - override def load(path: String): DecisionTreeRegressionModel = { + override protected def loadImpl(path: String): DecisionTreeRegressionModel = { implicit val format = DefaultFormats val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 9a5b7d59e9aef..befef2737b9bf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -337,7 +337,7 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] { private val className = classOf[GBTRegressionModel].getName private val treeClassName = classOf[DecisionTreeRegressionModel].getName - override def load(path: String): GBTRegressionModel = { + override protected def loadImpl(path: String): GBTRegressionModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index abb60ea205751..648167fe3e012 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging -import org.apache.spark.ml.PredictorParams +import org.apache.spark.ml.{MLEvents, PredictorParams} import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.feature.{Instance, OffsetInstance} import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} @@ -1034,7 +1034,8 @@ class GeneralizedLinearRegressionModel private[ml] ( BLAS.dot(features, coefficients) + intercept + offset } - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform( + dataset: Dataset[_]): DataFrame = MLEvents.withTransformEvent(this, dataset) { transformSchema(dataset.schema) transformImpl(dataset) } @@ -1140,7 +1141,7 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr /** Checked against metadata when loading model */ private val className = classOf[GeneralizedLinearRegressionModel].getName - override def load(path: String): GeneralizedLinearRegressionModel = { + override protected def loadImpl(path: String): GeneralizedLinearRegressionModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index 8b9233dcdc4d1..422eaf8d4b9b3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -162,7 +162,9 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra) @Since("2.0.0") - override def fit(dataset: Dataset[_]): IsotonicRegressionModel = instrumented { instr => + override def fit(dataset: Dataset[_]): IsotonicRegressionModel = super.fit(dataset) + override protected def fitImpl( + dataset: Dataset[_]): IsotonicRegressionModel = instrumented { instr => transformSchema(dataset.schema, logging = true) // Extract columns from data. If dataset is persisted, do not persist oldDataset. val instances = extractWeightedLabeledPoints(dataset) @@ -239,7 +241,8 @@ class IsotonicRegressionModel private[ml] ( } @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = super.transform(dataset) + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val predict = dataset.schema($(featuresCol)).dataType match { case DoubleType => @@ -296,7 +299,7 @@ object IsotonicRegressionModel extends MLReadable[IsotonicRegressionModel] { /** Checked against metadata when loading model */ private val className = classOf[IsotonicRegressionModel].getName - override def load(path: String): IsotonicRegressionModel = { + override protected def loadImpl(path: String): IsotonicRegressionModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index ce6c12cc368dd..395ebd60c1728 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -782,7 +782,7 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { /** Checked against metadata when loading model */ private val className = classOf[LinearRegressionModel].getName - override def load(path: String): LinearRegressionModel = { + override protected def loadImpl(path: String): LinearRegressionModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index afa9a646412b3..d1132f09d1975 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -268,7 +268,7 @@ object RandomForestRegressionModel extends MLReadable[RandomForestRegressionMode private val className = classOf[RandomForestRegressionModel].getName private val treeClassName = classOf[DecisionTreeRegressionModel].getName - override def load(path: String): RandomForestRegressionModel = { + override protected def loadImpl(path: String): RandomForestRegressionModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index e60a14f976a5c..dd5124a9fd584 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -119,7 +119,8 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) def setCollectSubModels(value: Boolean): this.type = set(collectSubModels, value) @Since("2.0.0") - override def fit(dataset: Dataset[_]): CrossValidatorModel = instrumented { instr => + override def fit(dataset: Dataset[_]): CrossValidatorModel = super.fit(dataset) + override protected def fitImpl(dataset: Dataset[_]): CrossValidatorModel = instrumented { instr => val schema = dataset.schema transformSchema(schema, logging = true) val sparkSession = dataset.sparkSession @@ -226,7 +227,7 @@ object CrossValidator extends MLReadable[CrossValidator] { /** Checked against metadata when loading model */ private val className = classOf[CrossValidator].getName - override def load(path: String): CrossValidator = { + override protected def loadImpl(path: String): CrossValidator = { implicit val format = DefaultFormats val (metadata, estimator, evaluator, estimatorParamMaps) = @@ -299,7 +300,8 @@ class CrossValidatorModel private[ml] ( def hasSubModels: Boolean = _subModels.isDefined @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = super.transform(dataset) + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) bestModel.transform(dataset) } @@ -392,7 +394,10 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { /** Checked against metadata when loading model */ private val className = classOf[CrossValidatorModel].getName - override def load(path: String): CrossValidatorModel = { + // Explicitly call parent's load. Otherwise, MiMa complains. + override def load(path: String): CrossValidatorModel = super.load(path) + + override protected def loadImpl(path: String): CrossValidatorModel = { implicit val format = DefaultFormats val (metadata, estimator, evaluator, estimatorParamMaps) = diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 8b251197afbef..7fc662de85a77 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -118,7 +118,9 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St def setCollectSubModels(value: Boolean): this.type = set(collectSubModels, value) @Since("2.0.0") - override def fit(dataset: Dataset[_]): TrainValidationSplitModel = instrumented { instr => + override def fit(dataset: Dataset[_]): TrainValidationSplitModel = super.fit(dataset) + override protected def fitImpl( + dataset: Dataset[_]): TrainValidationSplitModel = instrumented { instr => val schema = dataset.schema transformSchema(schema, logging = true) val est = $(estimator) @@ -220,7 +222,7 @@ object TrainValidationSplit extends MLReadable[TrainValidationSplit] { /** Checked against metadata when loading model */ private val className = classOf[TrainValidationSplit].getName - override def load(path: String): TrainValidationSplit = { + override protected def loadImpl(path: String): TrainValidationSplit = { implicit val format = DefaultFormats val (metadata, estimator, evaluator, estimatorParamMaps) = @@ -290,7 +292,8 @@ class TrainValidationSplitModel private[ml] ( def hasSubModels: Boolean = _subModels.isDefined @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = super.transform(dataset) + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) bestModel.transform(dataset) } @@ -380,7 +383,10 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { /** Checked against metadata when loading model */ private val className = classOf[TrainValidationSplitModel].getName - override def load(path: String): TrainValidationSplitModel = { + // Explicitly call parent's load. Otherwise, MiMa complains. + override def load(path: String): TrainValidationSplitModel = super.load(path) + + override protected def loadImpl(path: String): TrainValidationSplitModel = { implicit val format = DefaultFormats val (metadata, estimator, evaluator, estimatorParamMaps) = diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index fbc7be25a5640..0a45fa8e267ed 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -163,7 +163,7 @@ abstract class MLWriter extends BaseReadWrite with Logging { */ @Since("1.6.0") @throws[IOException]("If the input path already exists but overwrite is not enabled.") - def save(path: String): Unit = { + def save(path: String): Unit = MLEvents.withSaveInstanceEvent(this, path) { new FileSystemOverwrite().handleOverwrite(path, shouldOverwrite, sc) saveImpl(path) } @@ -329,7 +329,19 @@ abstract class MLReader[T] extends BaseReadWrite { * Loads the ML component from the input path. */ @Since("1.6.0") - def load(path: String): T + def load(path: String): T = MLEvents.withLoadInstanceEvent(this, path) { + loadImpl(path) + } + + /** + * `load()` handles events and then calls this method. Subclasses should override this + * method to implement the actual loading of the instance. + */ + @Since("3.0.0") + protected def loadImpl(path: String): T = { + // Keep this default body for backward compatibility. + throw new UnsupportedOperationException("loadImpl is not implemented.") + } // override for Java compatibility override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) @@ -467,7 +479,7 @@ private[ml] object DefaultParamsWriter { */ private[ml] class DefaultParamsReader[T] extends MLReader[T] { - override def load(path: String): T = { + override protected def loadImpl(path: String): T = { val metadata = DefaultParamsReader.loadMetadata(path, sc) val cls = Utils.classForName(metadata.className) val instance = diff --git a/mllib/src/test/scala/org/apache/spark/ml/MLEventsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/MLEventsSuite.scala new file mode 100644 index 0000000000000..e63eef870d985 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/MLEventsSuite.scala @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import scala.collection.mutable +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.apache.hadoop.fs.Path +import org.mockito.Matchers.{any, eq => meq} +import org.mockito.Mockito.when +import org.scalatest.BeforeAndAfterEach +import org.scalatest.concurrent.Eventually +import org.scalatest.mockito.MockitoSugar.mock + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent} +import org.apache.spark.sql._ + + +class MLEventsSuite + extends SparkFunSuite + with BeforeAndAfterEach + with MLlibTestSparkContext + with Eventually { + + private val dirName: String = "pipeline" + private val events = mutable.ArrayBuffer.empty[MLEvent] + private val listener: SparkListener = new SparkListener { + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case e: FitStart[_] => events.append(e) + case e: FitEnd[_] => events.append(e) + case e: TransformStart => events.append(e) + case e: TransformEnd => events.append(e) + case e: SaveInstanceStart if e.path.endsWith(dirName) => events.append(e) + case e: SaveInstanceEnd if e.path.endsWith(dirName) => events.append(e) + case _ => + } + } + + override def beforeAll(): Unit = { + super.beforeAll() + spark.sparkContext.addSparkListener(listener) + } + + override def afterEach(): Unit = { + try { + events.clear() + } finally { + super.afterEach() + } + } + + override def afterAll(): Unit = { + try { + if (spark != null) { + spark.sparkContext.removeSparkListener(listener) + } + } finally { + super.afterAll() + } + } + + abstract class MyModel extends Model[MyModel] + + test("pipeline fit events") { + val estimator0 = mock[Estimator[MyModel]] + val model0 = mock[MyModel] + val transformer1 = mock[Transformer] + val estimator2 = mock[Estimator[MyModel]] + val model2 = mock[MyModel] + val transformer3 = mock[Transformer] + + when(estimator0.copy(any[ParamMap])).thenReturn(estimator0) + when(model0.copy(any[ParamMap])).thenReturn(model0) + when(transformer1.copy(any[ParamMap])).thenReturn(transformer1) + when(estimator2.copy(any[ParamMap])).thenReturn(estimator2) + when(model2.copy(any[ParamMap])).thenReturn(model2) + when(transformer3.copy(any[ParamMap])).thenReturn(transformer3) + + val dataset0 = mock[DataFrame] + val dataset1 = mock[DataFrame] + val dataset2 = mock[DataFrame] + val dataset3 = mock[DataFrame] + val dataset4 = mock[DataFrame] + + when(dataset0.toDF).thenReturn(dataset0) + when(dataset1.toDF).thenReturn(dataset1) + when(dataset2.toDF).thenReturn(dataset2) + when(dataset3.toDF).thenReturn(dataset3) + when(dataset4.toDF).thenReturn(dataset4) + + when(estimator0.fit(meq(dataset0))).thenReturn(model0) + when(model0.transform(meq(dataset0))).thenReturn(dataset1) + when(model0.parent).thenReturn(estimator0) + when(transformer1.transform(meq(dataset1))).thenReturn(dataset2) + when(estimator2.fit(meq(dataset2))).thenReturn(model2) + when(model2.transform(meq(dataset2))).thenReturn(dataset3) + when(model2.parent).thenReturn(estimator2) + when(transformer3.transform(meq(dataset3))).thenReturn(dataset4) + + val pipeline = new Pipeline() + .setStages(Array(estimator0, transformer1, estimator2, transformer3)) + val pipelineModel = pipeline.fit(dataset0) + + val expected = + FitStart(pipeline, dataset0) :: + FitEnd(pipeline, pipelineModel) :: Nil + eventually(timeout(10 seconds), interval(1 second)) { + assert(expected === events) + } + } + + test("pipeline model transform events") { + val dataset = mock[DataFrame] + when(dataset.toDF).thenReturn(dataset) + val transform1 = mock[Transformer] + val model = mock[MyModel] + val transform2 = mock[Transformer] + val stages = Array(transform1, model, transform2) + val newPipelineModel = new PipelineModel("pipeline0", stages) + val output = newPipelineModel.transform(dataset) + + val expected = + TransformStart(newPipelineModel, dataset) :: + TransformEnd(newPipelineModel, output) :: Nil + eventually(timeout(10 seconds), interval(1 second)) { + assert(expected === events) + } + } + + test("pipeline read/write events") { + withTempDir { dir => + val path = new Path(dir.getCanonicalPath, dirName).toUri.toString + val writableStage = new WritableStage("writableStage") + val newPipeline = new Pipeline().setStages(Array(writableStage)) + val pipelineWriter = newPipeline.write + pipelineWriter.save(path) + + val expected = + SaveInstanceStart(pipelineWriter, path) :: + SaveInstanceEnd(pipelineWriter, path) :: Nil + eventually(timeout(10 seconds), interval(1 second)) { + assert(expected === events) + } + } + } + + test("pipeline model read/write events") { + withTempDir { dir => + val path = new Path(dir.getCanonicalPath, dirName).toUri.toString + val writableStage = new WritableStage("writableStage") + val pipelineModel = + new PipelineModel("pipeline_89329329", Array(writableStage.asInstanceOf[Transformer])) + val pipelineWriter = pipelineModel.write + pipelineWriter.save(path) + + val expected = + SaveInstanceStart(pipelineWriter, path) :: + SaveInstanceEnd(pipelineWriter, path) :: Nil + eventually(timeout(10 seconds), interval(1 second)) { + assert(expected === events) + } + } + } +} ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org > Add Spark ML Listener for Tracking ML Pipeline Status > ----------------------------------------------------- > > Key: SPARK-23674 > URL: https://issues.apache.org/jira/browse/SPARK-23674 > Project: Spark > Issue Type: Improvement > Components: ML > Affects Versions: 2.3.0 > Reporter: Mingjie Tang > Priority: Major > > Currently, Spark provides status monitoring for different components of > Spark, like spark history server, streaming listener, sql listener and etc. > The use case would be (1) front UI to track the status of training coverage > rate during iteration, then DS can understand how the job converge when > training, like K-means, Logistic and other linear regression model. (2) > tracking the data lineage for the input and output of training data. > In this proposal, we hope to provide Spark ML pipeline listener to track the > status of Spark ML pipeline status includes: > # ML pipeline create and saved > # ML pipeline model created, saved and load > # ML model training status monitoring -- This message was sent by Atlassian JIRA (v7.6.3#76005) --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org For additional commands, e-mail: issues-h...@spark.apache.org