Repository: spark
Updated Branches:
  refs/heads/master d4d84936f -> 92b70576e


[SPARK-13761][ML] Deprecate validateParams

## What changes were proposed in this pull request?

Deprecate validateParams() method here: 
https://github.com/apache/spark/blob/035d3acdf3c1be5b309a861d5c5beb803b946b5e/mllib/src/main/scala/org/apache/spark/ml/param/params.scala#L553
Move all functionality in overridden methods to transformSchema().
Check docs to make sure they indicate complex Param interaction checks should 
be done in transformSchema.

## How was this patch tested?

unit tests

Author: Yuhao Yang <hhb...@gmail.com>

Closes #11620 from hhbyyh/depreValid.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/92b70576
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/92b70576
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/92b70576

Branch: refs/heads/master
Commit: 92b70576eabf8ff94ac476e2b3c66f8b3d28e79e
Parents: d4d8493
Author: Yuhao Yang <hhb...@gmail.com>
Authored: Wed Mar 16 17:31:55 2016 -0700
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Wed Mar 16 17:31:55 2016 -0700

----------------------------------------------------------------------
 .../src/main/scala/org/apache/spark/ml/Pipeline.scala | 14 --------------
 .../main/scala/org/apache/spark/ml/Predictor.scala    |  1 -
 .../main/scala/org/apache/spark/ml/Transformer.scala  |  1 -
 .../scala/org/apache/spark/ml/clustering/KMeans.scala |  1 -
 .../scala/org/apache/spark/ml/clustering/LDA.scala    |  9 ++-------
 .../org/apache/spark/ml/feature/Bucketizer.scala      |  1 -
 .../org/apache/spark/ml/feature/ChiSqSelector.scala   |  2 --
 .../org/apache/spark/ml/feature/CountVectorizer.scala |  1 -
 .../scala/org/apache/spark/ml/feature/HashingTF.scala |  1 -
 .../main/scala/org/apache/spark/ml/feature/IDF.scala  |  1 -
 .../org/apache/spark/ml/feature/Interaction.scala     | 13 ++++---------
 .../org/apache/spark/ml/feature/MaxAbsScaler.scala    |  1 -
 .../org/apache/spark/ml/feature/MinMaxScaler.scala    |  5 +----
 .../org/apache/spark/ml/feature/OneHotEncoder.scala   |  1 -
 .../main/scala/org/apache/spark/ml/feature/PCA.scala  |  2 --
 .../apache/spark/ml/feature/QuantileDiscretizer.scala |  1 -
 .../scala/org/apache/spark/ml/feature/RFormula.scala  |  4 ----
 .../org/apache/spark/ml/feature/SQLTransformer.scala  |  1 -
 .../org/apache/spark/ml/feature/StandardScaler.scala  |  2 --
 .../apache/spark/ml/feature/StopWordsRemover.scala    |  1 -
 .../org/apache/spark/ml/feature/StringIndexer.scala   |  2 --
 .../org/apache/spark/ml/feature/VectorAssembler.scala |  1 -
 .../org/apache/spark/ml/feature/VectorIndexer.scala   |  2 --
 .../org/apache/spark/ml/feature/VectorSlicer.scala    |  8 ++------
 .../scala/org/apache/spark/ml/feature/Word2Vec.scala  |  1 -
 .../main/scala/org/apache/spark/ml/param/params.scala |  7 ++++---
 .../org/apache/spark/ml/recommendation/ALS.scala      |  2 --
 .../spark/ml/regression/AFTSurvivalRegression.scala   |  1 -
 .../ml/regression/GeneralizedLinearRegression.scala   |  7 ++++++-
 .../spark/ml/regression/IsotonicRegression.scala      |  1 -
 .../org/apache/spark/ml/clustering/LDASuite.scala     | 12 +++++++-----
 .../apache/spark/ml/feature/MinMaxScalerSuite.scala   | 10 ++++++----
 .../apache/spark/ml/feature/VectorSlicerSuite.scala   |  8 ++++----
 33 files changed, 36 insertions(+), 89 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/92b70576/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala 
b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index cbac7bb..f4c6214 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -110,12 +110,6 @@ class Pipeline @Since("1.4.0") (
   @Since("1.2.0")
   def getStages: Array[PipelineStage] = $(stages).clone()
 
-  @Since("1.4.0")
-  override def validateParams(): Unit = {
-    super.validateParams()
-    $(stages).foreach(_.validateParams())
-  }
-
   /**
    * Fits the pipeline to the input dataset with additional parameters. If a 
stage is an
    * [[Estimator]], its [[Estimator#fit]] method will be called on the input 
dataset to fit a model.
@@ -175,7 +169,6 @@ class Pipeline @Since("1.4.0") (
 
   @Since("1.2.0")
   override def transformSchema(schema: StructType): StructType = {
-    validateParams()
     val theStages = $(stages)
     require(theStages.toSet.size == theStages.length,
       "Cannot have duplicate components in a pipeline.")
@@ -297,12 +290,6 @@ class PipelineModel private[ml] (
     this(uid, stages.asScala.toArray)
   }
 
-  @Since("1.4.0")
-  override def validateParams(): Unit = {
-    super.validateParams()
-    stages.foreach(_.validateParams())
-  }
-
   @Since("1.2.0")
   override def transform(dataset: DataFrame): DataFrame = {
     transformSchema(dataset.schema, logging = true)
@@ -311,7 +298,6 @@ class PipelineModel private[ml] (
 
   @Since("1.2.0")
   override def transformSchema(schema: StructType): StructType = {
-    validateParams()
     stages.foldLeft(schema)((cur, transformer) => 
transformer.transformSchema(cur))
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/92b70576/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala 
b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
index 4b27ee6..ebe4870 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -46,7 +46,6 @@ private[ml] trait PredictorParams extends Params
       schema: StructType,
       fitting: Boolean,
       featuresDataType: DataType): StructType = {
-    validateParams()
     // TODO: Support casting Array[Double] and Array[Float] to Vector when 
FeaturesType = Vector
     SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType)
     if (fitting) {

http://git-wip-us.apache.org/repos/asf/spark/blob/92b70576/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
index fdce273..1f3325a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -103,7 +103,6 @@ abstract class UnaryTransformer[IN, OUT, T <: 
UnaryTransformer[IN, OUT, T]]
   protected def validateInputType(inputType: DataType): Unit = {}
 
   override def transformSchema(schema: StructType): StructType = {
-    validateParams()
     val inputType = schema($(inputCol)).dataType
     validateInputType(inputType)
     if (schema.fieldNames.contains($(outputCol))) {

http://git-wip-us.apache.org/repos/asf/spark/blob/92b70576/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 79332b0..ab00127 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -81,7 +81,6 @@ private[clustering] trait KMeansParams extends Params with 
HasMaxIter with HasFe
    * @return output schema
    */
   protected def validateAndTransformSchema(schema: StructType): StructType = {
-    validateParams()
     SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
     SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/92b70576/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index 6304b20..0de82b4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -263,13 +263,6 @@ private[clustering] trait LDAParams extends Params with 
HasFeaturesCol with HasM
    * @return output schema
    */
   protected def validateAndTransformSchema(schema: StructType): StructType = {
-    validateParams()
-    SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
-    SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT)
-  }
-
-  @Since("1.6.0")
-  override def validateParams(): Unit = {
     if (isSet(docConcentration)) {
       if (getDocConcentration.length != 1) {
         require(getDocConcentration.length == getK, s"LDA docConcentration was 
of length" +
@@ -297,6 +290,8 @@ private[clustering] trait LDAParams extends Params with 
HasFeaturesCol with HasM
             s" must be >= 1.  Found value: $getTopicConcentration")
       }
     }
+    SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
+    SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT)
   }
 
   private[clustering] def getOldOptimizer: OldLDAOptimizer = getOptimizer 
match {

http://git-wip-us.apache.org/repos/asf/spark/blob/92b70576/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index 0c75317..33abc7c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -86,7 +86,6 @@ final class Bucketizer(override val uid: String)
   }
 
   override def transformSchema(schema: StructType): StructType = {
-    validateParams()
     SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
     SchemaUtils.appendColumn(schema, prepOutputField(schema))
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/92b70576/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
index 4abc459..b9e9d56 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
@@ -88,7 +88,6 @@ final class ChiSqSelector(override val uid: String)
   }
 
   override def transformSchema(schema: StructType): StructType = {
-    validateParams()
     SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
     SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
     SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
@@ -136,7 +135,6 @@ final class ChiSqSelectorModel private[ml] (
   }
 
   override def transformSchema(schema: StructType): StructType = {
-    validateParams()
     SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
     val newField = prepOutputField(schema)
     val outputFields = schema.fields :+ newField

http://git-wip-us.apache.org/repos/asf/spark/blob/92b70576/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
index cf15145..f7d08b3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
@@ -70,7 +70,6 @@ private[feature] trait CountVectorizerParams extends Params 
with HasInputCol wit
 
   /** Validates and transforms the input schema. */
   protected def validateAndTransformSchema(schema: StructType): StructType = {
-    validateParams()
     val typeCandidates = List(new ArrayType(StringType, true), new 
ArrayType(StringType, false))
     SchemaUtils.checkColumnTypes(schema, $(inputCol), typeCandidates)
     SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)

http://git-wip-us.apache.org/repos/asf/spark/blob/92b70576/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
index 8af0058..61a78d7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -69,7 +69,6 @@ class HashingTF(override val uid: String)
   }
 
   override def transformSchema(schema: StructType): StructType = {
-    validateParams()
     val inputType = schema($(inputCol)).dataType
     require(inputType.isInstanceOf[ArrayType],
       s"The input column must be ArrayType, but got $inputType.")

http://git-wip-us.apache.org/repos/asf/spark/blob/92b70576/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
index cebbe5c..f36cf50 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
@@ -52,7 +52,6 @@ private[feature] trait IDFBase extends Params with 
HasInputCol with HasOutputCol
    * Validate and transform the input schema.
    */
   protected def validateAndTransformSchema(schema: StructType): StructType = {
-    validateParams()
     SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
     SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/92b70576/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
index 7d2a1da..d3fe6e5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
@@ -61,13 +61,15 @@ class Interaction @Since("1.6.0") (override val uid: 
String) extends Transformer
   // optimistic schema; does not contain any ML attributes
   @Since("1.6.0")
   override def transformSchema(schema: StructType): StructType = {
-    validateParams()
+    require(get(inputCols).isDefined, "Input cols must be defined first.")
+    require(get(outputCol).isDefined, "Output col must be defined first.")
+    require($(inputCols).length > 0, "Input cols must have non-zero length.")
+    require($(inputCols).distinct.length == $(inputCols).length, "Input cols 
must be distinct.")
     StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, 
false))
   }
 
   @Since("1.6.0")
   override def transform(dataset: DataFrame): DataFrame = {
-    validateParams()
     val inputFeatures = $(inputCols).map(c => dataset.schema(c))
     val featureEncoders = getFeatureEncoders(inputFeatures)
     val featureAttrs = getFeatureAttrs(inputFeatures)
@@ -217,13 +219,6 @@ class Interaction @Since("1.6.0") (override val uid: 
String) extends Transformer
   @Since("1.6.0")
   override def copy(extra: ParamMap): Interaction = defaultCopy(extra)
 
-  @Since("1.6.0")
-  override def validateParams(): Unit = {
-    require(get(inputCols).isDefined, "Input cols must be defined first.")
-    require(get(outputCol).isDefined, "Output col must be defined first.")
-    require($(inputCols).length > 0, "Input cols must have non-zero length.")
-    require($(inputCols).distinct.length == $(inputCols).length, "Input cols 
must be distinct.")
-  }
 }
 
 @Since("1.6.0")

http://git-wip-us.apache.org/repos/asf/spark/blob/92b70576/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
index 09fad23..7de5a4d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
@@ -37,7 +37,6 @@ private[feature] trait MaxAbsScalerParams extends Params with 
HasInputCol with H
 
    /** Validates and transforms the input schema. */
   protected def validateAndTransformSchema(schema: StructType): StructType = {
-    validateParams()
     val inputType = schema($(inputCol)).dataType
     require(inputType.isInstanceOf[VectorUDT],
       s"Input column ${$(inputCol)} must be a vector column")

http://git-wip-us.apache.org/repos/asf/spark/blob/92b70576/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
index 3b4209b..b13684a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
@@ -59,7 +59,7 @@ private[feature] trait MinMaxScalerParams extends Params with 
HasInputCol with H
 
   /** Validates and transforms the input schema. */
   protected def validateAndTransformSchema(schema: StructType): StructType = {
-    validateParams()
+    require($(min) < $(max), s"The specified min(${$(min)}) is larger or equal 
to max(${$(max)})")
     val inputType = schema($(inputCol)).dataType
     require(inputType.isInstanceOf[VectorUDT],
       s"Input column ${$(inputCol)} must be a vector column")
@@ -69,9 +69,6 @@ private[feature] trait MinMaxScalerParams extends Params with 
HasInputCol with H
     StructType(outputFields)
   }
 
-  override def validateParams(): Unit = {
-    require($(min) < $(max), s"The specified min(${$(min)}) is larger or equal 
to max(${$(max)})")
-  }
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/92b70576/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
index fa5013d..4f67042 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
@@ -66,7 +66,6 @@ class OneHotEncoder(override val uid: String) extends 
Transformer
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
   override def transformSchema(schema: StructType): StructType = {
-    validateParams()
     val inputColName = $(inputCol)
     val outputColName = $(outputCol)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/92b70576/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
index 80b124f..305c3d1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
@@ -77,7 +77,6 @@ class PCA (override val uid: String) extends 
Estimator[PCAModel] with PCAParams
   }
 
   override def transformSchema(schema: StructType): StructType = {
-    validateParams()
     val inputType = schema($(inputCol)).dataType
     require(inputType.isInstanceOf[VectorUDT],
       s"Input column ${$(inputCol)} must be a vector column")
@@ -133,7 +132,6 @@ class PCAModel private[ml] (
   }
 
   override def transformSchema(schema: StructType): StructType = {
-    validateParams()
     val inputType = schema($(inputCol)).dataType
     require(inputType.isInstanceOf[VectorUDT],
       s"Input column ${$(inputCol)} must be a vector column")

http://git-wip-us.apache.org/repos/asf/spark/blob/92b70576/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index 18896fc..e830d2a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -78,7 +78,6 @@ final class QuantileDiscretizer(override val uid: String)
   def setSeed(value: Long): this.type = set(seed, value)
 
   override def transformSchema(schema: StructType): StructType = {
-    validateParams()
     SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
     val inputFields = schema.fields
     require(inputFields.forall(_.name != $(outputCol)),

http://git-wip-us.apache.org/repos/asf/spark/blob/92b70576/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index c21da21..ab5f4a1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -167,7 +167,6 @@ class RFormula(override val uid: String) extends 
Estimator[RFormulaModel] with R
 
   // optimistic schema; does not contain any ML attributes
   override def transformSchema(schema: StructType): StructType = {
-    validateParams()
     if (hasLabelCol(schema)) {
       StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, 
true))
     } else {
@@ -200,7 +199,6 @@ class RFormulaModel private[feature](
   }
 
   override def transformSchema(schema: StructType): StructType = {
-    validateParams()
     checkCanTransform(schema)
     val withFeatures = pipelineModel.transformSchema(schema)
     if (hasLabelCol(withFeatures)) {
@@ -263,7 +261,6 @@ private class ColumnPruner(columnsToPrune: Set[String]) 
extends Transformer {
   }
 
   override def transformSchema(schema: StructType): StructType = {
-    validateParams()
     StructType(schema.fields.filter(col => !columnsToPrune.contains(col.name)))
   }
 
@@ -312,7 +309,6 @@ private class VectorAttributeRewriter(
   }
 
   override def transformSchema(schema: StructType): StructType = {
-    validateParams()
     StructType(
       schema.fields.filter(_.name != vectorCol) ++
       schema.fields.filter(_.name == vectorCol))

http://git-wip-us.apache.org/repos/asf/spark/blob/92b70576/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
index af6494b..e0ca45b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
@@ -74,7 +74,6 @@ class SQLTransformer @Since("1.6.0") (override val uid: 
String) extends Transfor
 
   @Since("1.6.0")
   override def transformSchema(schema: StructType): StructType = {
-    validateParams()
     val sc = SparkContext.getOrCreate()
     val sqlContext = SQLContext.getOrCreate(sc)
     val dummyRDD = sc.parallelize(Seq(Row.empty))

http://git-wip-us.apache.org/repos/asf/spark/blob/92b70576/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 9952d3b..26ee8e1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -94,7 +94,6 @@ class StandardScaler(override val uid: String) extends 
Estimator[StandardScalerM
   }
 
   override def transformSchema(schema: StructType): StructType = {
-    validateParams()
     val inputType = schema($(inputCol)).dataType
     require(inputType.isInstanceOf[VectorUDT],
       s"Input column ${$(inputCol)} must be a vector column")
@@ -144,7 +143,6 @@ class StandardScalerModel private[ml] (
   }
 
   override def transformSchema(schema: StructType): StructType = {
-    validateParams()
     val inputType = schema($(inputCol)).dataType
     require(inputType.isInstanceOf[VectorUDT],
       s"Input column ${$(inputCol)} must be a vector column")

http://git-wip-us.apache.org/repos/asf/spark/blob/92b70576/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
index 0d4c968..0a0e0b0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
@@ -145,7 +145,6 @@ class StopWordsRemover(override val uid: String)
   }
 
   override def transformSchema(schema: StructType): StructType = {
-    validateParams()
     val inputType = schema($(inputCol)).dataType
     require(inputType.sameType(ArrayType(StringType)),
       s"Input type must be ArrayType(StringType) but got $inputType.")

http://git-wip-us.apache.org/repos/asf/spark/blob/92b70576/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 7dd794b..c579a0d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -39,7 +39,6 @@ private[feature] trait StringIndexerBase extends Params with 
HasInputCol with Ha
 
   /** Validates and transforms the input schema. */
   protected def validateAndTransformSchema(schema: StructType): StructType = {
-    validateParams()
     val inputColName = $(inputCol)
     val inputDataType = schema(inputColName).dataType
     require(inputDataType == StringType || 
inputDataType.isInstanceOf[NumericType],
@@ -275,7 +274,6 @@ class IndexToString private[ml] (override val uid: String)
   final def getLabels: Array[String] = $(labels)
 
   override def transformSchema(schema: StructType): StructType = {
-    validateParams()
     val inputColName = $(inputCol)
     val inputDataType = schema(inputColName).dataType
     require(inputDataType.isInstanceOf[NumericType],

http://git-wip-us.apache.org/repos/asf/spark/blob/92b70576/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index 7ff5ad1..957e8e7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -106,7 +106,6 @@ class VectorAssembler(override val uid: String)
   }
 
   override def transformSchema(schema: StructType): StructType = {
-    validateParams()
     val inputColNames = $(inputCols)
     val outputColName = $(outputCol)
     val inputDataTypes = inputColNames.map(name => schema(name).dataType)

http://git-wip-us.apache.org/repos/asf/spark/blob/92b70576/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index 5c11760..bf4aef2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -126,7 +126,6 @@ class VectorIndexer(override val uid: String) extends 
Estimator[VectorIndexerMod
   }
 
   override def transformSchema(schema: StructType): StructType = {
-    validateParams()
     // We do not transfer feature metadata since we do not know what types of 
features we will
     // produce in transform().
     val dataType = new VectorUDT
@@ -355,7 +354,6 @@ class VectorIndexerModel private[ml] (
   }
 
   override def transformSchema(schema: StructType): StructType = {
-    validateParams()
     val dataType = new VectorUDT
     require(isDefined(inputCol),
       s"VectorIndexerModel requires input column parameter: $inputCol")

http://git-wip-us.apache.org/repos/asf/spark/blob/92b70576/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
index 300d63b..b60e82d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
@@ -89,11 +89,6 @@ final class VectorSlicer(override val uid: String)
   /** @group setParam */
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
-  override def validateParams(): Unit = {
-    require($(indices).length > 0 || $(names).length > 0,
-      s"VectorSlicer requires that at least one feature be selected.")
-  }
-
   override def transform(dataset: DataFrame): DataFrame = {
     // Validity checks
     transformSchema(dataset.schema)
@@ -139,7 +134,8 @@ final class VectorSlicer(override val uid: String)
   }
 
   override def transformSchema(schema: StructType): StructType = {
-    validateParams()
+    require($(indices).length > 0 || $(names).length > 0,
+      s"VectorSlicer requires that at least one feature be selected.")
     SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
 
     if (schema.fieldNames.contains($(outputCol))) {

http://git-wip-us.apache.org/repos/asf/spark/blob/92b70576/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 3d3c7bd..95bae1c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -92,7 +92,6 @@ private[feature] trait Word2VecBase extends Params
    * Validate and transform the input schema.
    */
   protected def validateAndTransformSchema(schema: StructType): StructType = {
-    validateParams()
     SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, 
true))
     SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/92b70576/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala 
b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 42411d2..d7837b6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -58,9 +58,8 @@ class Param[T](val parent: String, val name: String, val doc: 
String, val isVali
   /**
    * Assert that the given value is valid for this parameter.
    *
-   * Note: Parameter checks involving interactions between multiple parameters 
should be
-   *       implemented in [[Params.validateParams()]].  Checks for 
input/output columns should be
-   *       implemented in 
[[org.apache.spark.ml.PipelineStage.transformSchema()]].
+   * Note: Parameter checks involving interactions between multiple parameters 
and input/output
+   * columns should be implemented in 
[[org.apache.spark.ml.PipelineStage.transformSchema()]].
    *
    * DEVELOPERS: This method is only called by [[ParamPair]], which means that 
all parameters
    *             should be specified via [[ParamPair]].
@@ -555,7 +554,9 @@ trait Params extends Identifiable with Serializable {
    * Parameter value checks which do not depend on other parameters are 
handled by
    * [[Param.validate()]].  This method does not handle input/output column 
parameters;
    * those are checked during schema validation.
+   * @deprecated Will be removed in 2.1.0. All the checks should be merged 
into transformSchema
    */
+  @deprecated("Will be removed in 2.1.0. Checks should be merged into 
transformSchema.", "2.0.0")
   def validateParams(): Unit = {
     // Do nothing by default.  Override to handle Param interactions.
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/92b70576/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala 
b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index dacdac9..f3bc9f0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -162,7 +162,6 @@ private[recommendation] trait ALSParams extends 
ALSModelParams with HasMaxIter w
    * @return output schema
    */
   protected def validateAndTransformSchema(schema: StructType): StructType = {
-    validateParams()
     SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
     SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
     val ratingType = schema($(ratingCol)).dataType
@@ -220,7 +219,6 @@ class ALSModel private[ml] (
 
   @Since("1.3.0")
   override def transformSchema(schema: StructType): StructType = {
-    validateParams()
     SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
     SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
     SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)

http://git-wip-us.apache.org/repos/asf/spark/blob/92b70576/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index e4339d6..0901642 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -99,7 +99,6 @@ private[regression] trait AFTSurvivalRegressionParams extends 
Params
   protected def validateAndTransformSchema(
       schema: StructType,
       fitting: Boolean): StructType = {
-    validateParams()
     SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
     if (fitting) {
       SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType)

http://git-wip-us.apache.org/repos/asf/spark/blob/92b70576/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index b4e47c8..46ba558 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -32,6 +32,7 @@ import org.apache.spark.mllib.linalg.{BLAS, Vector}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Row}
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.{DataType, StructType}
 
 /**
  * Params for Generalized Linear Regression.
@@ -77,7 +78,10 @@ private[regression] trait GeneralizedLinearRegressionBase 
extends PredictorParam
   import GeneralizedLinearRegression._
 
   @Since("2.0.0")
-  override def validateParams(): Unit = {
+  override def validateAndTransformSchema(
+      schema: StructType,
+      fitting: Boolean,
+      featuresDataType: DataType): StructType = {
     if ($(solver) == "irls") {
       setDefault(maxIter -> 25)
     }
@@ -86,6 +90,7 @@ private[regression] trait GeneralizedLinearRegressionBase 
extends PredictorParam
         Family.fromName($(family)) -> Link.fromName($(link))), "Generalized 
Linear Regression " +
         s"with ${$(family)} family does not support ${$(link)} link function.")
     }
+    super.validateAndTransformSchema(schema, fitting, featuresDataType)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/92b70576/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala 
b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
index 36b006c..20a0998 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
@@ -105,7 +105,6 @@ private[regression] trait IsotonicRegressionBase extends 
Params with HasFeatures
   protected[ml] def validateAndTransformSchema(
       schema: StructType,
       fitting: Boolean): StructType = {
-    validateParams()
     if (fitting) {
       SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
       if (hasWeightCol) {

http://git-wip-us.apache.org/repos/asf/spark/blob/92b70576/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
index a3a8f65..dd3f4c6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
@@ -138,16 +138,18 @@ class LDASuite extends SparkFunSuite with 
MLlibTestSparkContext with DefaultRead
       new LDA().setTopicConcentration(-1.1)
     }
 
-    // validateParams()
-    lda.validateParams()
+    val dummyDF = sqlContext.createDataFrame(Seq(
+      (1, Vectors.dense(1.0, 2.0)))).toDF("id", "features")
+    // validate parameters
+    lda.transformSchema(dummyDF.schema)
     lda.setDocConcentration(1.1)
-    lda.validateParams()
+    lda.transformSchema(dummyDF.schema)
     lda.setDocConcentration(Range(0, lda.getK).map(_ + 2.0).toArray)
-    lda.validateParams()
+    lda.transformSchema(dummyDF.schema)
     lda.setDocConcentration(Range(0, lda.getK - 1).map(_ + 2.0).toArray)
     withClue("LDA docConcentration validity check failed for bad array 
length") {
       intercept[IllegalArgumentException] {
-        lda.validateParams()
+        lda.transformSchema(dummyDF.schema)
       }
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/92b70576/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
index 035bfc0..87206c7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
@@ -57,13 +57,15 @@ class MinMaxScalerSuite extends SparkFunSuite with 
MLlibTestSparkContext with De
 
   test("MinMaxScaler arguments max must be larger than min") {
     withClue("arguments max must be larger than min") {
+      val dummyDF = sqlContext.createDataFrame(Seq(
+        (1, Vectors.dense(1.0, 2.0)))).toDF("id", "feature")
       intercept[IllegalArgumentException] {
-        val scaler = new MinMaxScaler().setMin(10).setMax(0)
-        scaler.validateParams()
+        val scaler = new 
MinMaxScaler().setMin(10).setMax(0).setInputCol("feature")
+        scaler.transformSchema(dummyDF.schema)
       }
       intercept[IllegalArgumentException] {
-        val scaler = new MinMaxScaler().setMin(0).setMax(0)
-        scaler.validateParams()
+        val scaler = new 
MinMaxScaler().setMin(0).setMax(0).setInputCol("feature")
+        scaler.transformSchema(dummyDF.schema)
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/92b70576/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
index 94191e5..6bb4678 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
@@ -21,21 +21,21 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, 
NumericAttribute}
 import org.apache.spark.ml.param.ParamsSuite
 import org.apache.spark.ml.util.DefaultReadWriteTest
-import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.{DataFrame, Row}
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{StructField, StructType}
 
 class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with 
DefaultReadWriteTest {
 
   test("params") {
-    val slicer = new VectorSlicer
+    val slicer = new VectorSlicer().setInputCol("feature")
     ParamsSuite.checkParams(slicer)
     assert(slicer.getIndices.length === 0)
     assert(slicer.getNames.length === 0)
     withClue("VectorSlicer should not have any features selected by default") {
       intercept[IllegalArgumentException] {
-        slicer.validateParams()
+        slicer.transformSchema(StructType(Seq(StructField("feature", new 
VectorUDT, true))))
       }
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to