Repository: spark
Updated Branches:
  refs/heads/master 0171b71e9 -> ba5f81859


[SPARK-11259][ML] Params.validateParams() should be called automatically

See JIRA: https://issues.apache.org/jira/browse/SPARK-11259

Author: Yanbo Liang <yblia...@gmail.com>

Closes #9224 from yanboliang/spark-11259.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ba5f8185
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ba5f8185
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ba5f8185

Branch: refs/heads/master
Commit: ba5f81859d6ba37a228a1c43d26c47e64c0382cd
Parents: 0171b71
Author: Yanbo Liang <yblia...@gmail.com>
Authored: Mon Jan 4 13:30:17 2016 -0800
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Mon Jan 4 13:30:17 2016 -0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/ml/Pipeline.scala    |  2 ++
 .../scala/org/apache/spark/ml/Predictor.scala   |  1 +
 .../scala/org/apache/spark/ml/Transformer.scala |  1 +
 .../org/apache/spark/ml/clustering/KMeans.scala |  1 +
 .../org/apache/spark/ml/clustering/LDA.scala    |  1 +
 .../org/apache/spark/ml/feature/Binarizer.scala |  1 +
 .../apache/spark/ml/feature/Bucketizer.scala    |  1 +
 .../apache/spark/ml/feature/ChiSqSelector.scala |  2 ++
 .../spark/ml/feature/CountVectorizer.scala      |  1 +
 .../org/apache/spark/ml/feature/HashingTF.scala |  1 +
 .../scala/org/apache/spark/ml/feature/IDF.scala |  1 +
 .../apache/spark/ml/feature/MinMaxScaler.scala  |  1 +
 .../apache/spark/ml/feature/OneHotEncoder.scala |  1 +
 .../scala/org/apache/spark/ml/feature/PCA.scala |  2 ++
 .../spark/ml/feature/QuantileDiscretizer.scala  |  1 +
 .../org/apache/spark/ml/feature/RFormula.scala  |  4 ++++
 .../spark/ml/feature/SQLTransformer.scala       |  1 +
 .../spark/ml/feature/StandardScaler.scala       |  2 ++
 .../spark/ml/feature/StopWordsRemover.scala     |  1 +
 .../apache/spark/ml/feature/StringIndexer.scala |  2 ++
 .../spark/ml/feature/VectorAssembler.scala      |  1 +
 .../apache/spark/ml/feature/VectorIndexer.scala |  2 ++
 .../apache/spark/ml/feature/VectorSlicer.scala  |  1 +
 .../org/apache/spark/ml/feature/Word2Vec.scala  |  1 +
 .../apache/spark/ml/recommendation/ALS.scala    |  2 ++
 .../ml/regression/AFTSurvivalRegression.scala   |  1 +
 .../ml/regression/IsotonicRegression.scala      |  1 +
 .../apache/spark/ml/tuning/CrossValidator.scala |  2 ++
 .../spark/ml/tuning/TrainValidationSplit.scala  |  2 ++
 .../org/apache/spark/ml/PipelineSuite.scala     | 23 +++++++++++++++++++-
 30 files changed, 63 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala 
b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index 3acc60d..32570a1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -165,6 +165,7 @@ class Pipeline(override val uid: String) extends 
Estimator[PipelineModel] with M
   }
 
   override def transformSchema(schema: StructType): StructType = {
+    validateParams()
     val theStages = $(stages)
     require(theStages.toSet.size == theStages.length,
       "Cannot have duplicate components in a pipeline.")
@@ -296,6 +297,7 @@ class PipelineModel private[ml] (
   }
 
   override def transformSchema(schema: StructType): StructType = {
+    validateParams()
     stages.foldLeft(schema)((cur, transformer) => 
transformer.transformSchema(cur))
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala 
b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
index 6aacffd..d1388b5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -46,6 +46,7 @@ private[ml] trait PredictorParams extends Params
       schema: StructType,
       fitting: Boolean,
       featuresDataType: DataType): StructType = {
+    validateParams()
     // TODO: Support casting Array[Double] and Array[Float] to Vector when 
FeaturesType = Vector
     SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType)
     if (fitting) {

http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
index 1f3325a..fdce273 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -103,6 +103,7 @@ abstract class UnaryTransformer[IN, OUT, T <: 
UnaryTransformer[IN, OUT, T]]
   protected def validateInputType(inputType: DataType): Unit = {}
 
   override def transformSchema(schema: StructType): StructType = {
+    validateParams()
     val inputType = schema($(inputCol)).dataType
     validateInputType(inputType)
     if (schema.fieldNames.contains($(outputCol))) {

http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 6e5abb2..dc6d5d9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -80,6 +80,7 @@ private[clustering] trait KMeansParams extends Params with 
HasMaxIter with HasFe
    * @return output schema
    */
   protected def validateAndTransformSchema(schema: StructType): StructType = {
+    validateParams()
     SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
     SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index af0b3e1..99383e7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -263,6 +263,7 @@ private[clustering] trait LDAParams extends Params with 
HasFeaturesCol with HasM
    * @return output schema
    */
   protected def validateAndTransformSchema(schema: StructType): StructType = {
+    validateParams()
     SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
     SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
index 5b17d34..544cf05 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
@@ -72,6 +72,7 @@ final class Binarizer(override val uid: String)
   }
 
   override def transformSchema(schema: StructType): StructType = {
+    validateParams()
     SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
 
     val inputFields = schema.fields

http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index 33abc7c..0c75317 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -86,6 +86,7 @@ final class Bucketizer(override val uid: String)
   }
 
   override def transformSchema(schema: StructType): StructType = {
+    validateParams()
     SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
     SchemaUtils.appendColumn(schema, prepOutputField(schema))
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
index dfec038..7b565ef 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
@@ -88,6 +88,7 @@ final class ChiSqSelector(override val uid: String)
   }
 
   override def transformSchema(schema: StructType): StructType = {
+    validateParams()
     SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
     SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
     SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
@@ -135,6 +136,7 @@ final class ChiSqSelectorModel private[ml] (
   }
 
   override def transformSchema(schema: StructType): StructType = {
+    validateParams()
     SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
     val newField = prepOutputField(schema)
     val outputFields = schema.fields :+ newField

http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
index 1268c87..10dcda2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
@@ -70,6 +70,7 @@ private[feature] trait CountVectorizerParams extends Params 
with HasInputCol wit
 
   /** Validates and transforms the input schema. */
   protected def validateAndTransformSchema(schema: StructType): StructType = {
+    validateParams()
     SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, 
true))
     SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
index 61a78d7..8af0058 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -69,6 +69,7 @@ class HashingTF(override val uid: String)
   }
 
   override def transformSchema(schema: StructType): StructType = {
+    validateParams()
     val inputType = schema($(inputCol)).dataType
     require(inputType.isInstanceOf[ArrayType],
       s"The input column must be ArrayType, but got $inputType.")

http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
index f7b0f29..9e7eee4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
@@ -52,6 +52,7 @@ private[feature] trait IDFBase extends Params with 
HasInputCol with HasOutputCol
    * Validate and transform the input schema.
    */
   protected def validateAndTransformSchema(schema: StructType): StructType = {
+    validateParams()
     SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
     SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
index 559a025..ad0458d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
@@ -59,6 +59,7 @@ private[feature] trait MinMaxScalerParams extends Params with 
HasInputCol with H
 
   /** Validates and transforms the input schema. */
   protected def validateAndTransformSchema(schema: StructType): StructType = {
+    validateParams()
     val inputType = schema($(inputCol)).dataType
     require(inputType.isInstanceOf[VectorUDT],
       s"Input column ${$(inputCol)} must be a vector column")

http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
index c01e29a..3425404 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
@@ -66,6 +66,7 @@ class OneHotEncoder(override val uid: String) extends 
Transformer
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
   override def transformSchema(schema: StructType): StructType = {
+    validateParams()
     val inputColName = $(inputCol)
     val outputColName = $(outputCol)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
index f653798..7020397 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
@@ -77,6 +77,7 @@ class PCA (override val uid: String) extends 
Estimator[PCAModel] with PCAParams
   }
 
   override def transformSchema(schema: StructType): StructType = {
+    validateParams()
     val inputType = schema($(inputCol)).dataType
     require(inputType.isInstanceOf[VectorUDT],
       s"Input column ${$(inputCol)} must be a vector column")
@@ -130,6 +131,7 @@ class PCAModel private[ml] (
   }
 
   override def transformSchema(schema: StructType): StructType = {
+    validateParams()
     val inputType = schema($(inputCol)).dataType
     require(inputType.isInstanceOf[VectorUDT],
       s"Input column ${$(inputCol)} must be a vector column")

http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index 39de846..8fd0ce2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -74,6 +74,7 @@ final class QuantileDiscretizer(override val uid: String)
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
   override def transformSchema(schema: StructType): StructType = {
+    validateParams()
     SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
     val inputFields = schema.fields
     require(inputFields.forall(_.name != $(outputCol)),

http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index 2b578c2..f995243 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -146,6 +146,7 @@ class RFormula(override val uid: String) extends 
Estimator[RFormulaModel] with R
 
   // optimistic schema; does not contain any ML attributes
   override def transformSchema(schema: StructType): StructType = {
+    validateParams()
     if (hasLabelCol(schema)) {
       StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, 
true))
     } else {
@@ -178,6 +179,7 @@ class RFormulaModel private[feature](
   }
 
   override def transformSchema(schema: StructType): StructType = {
+    validateParams()
     checkCanTransform(schema)
     val withFeatures = pipelineModel.transformSchema(schema)
     if (hasLabelCol(withFeatures)) {
@@ -240,6 +242,7 @@ private class ColumnPruner(columnsToPrune: Set[String]) 
extends Transformer {
   }
 
   override def transformSchema(schema: StructType): StructType = {
+    validateParams()
     StructType(schema.fields.filter(col => !columnsToPrune.contains(col.name)))
   }
 
@@ -288,6 +291,7 @@ private class VectorAttributeRewriter(
   }
 
   override def transformSchema(schema: StructType): StructType = {
+    validateParams()
     StructType(
       schema.fields.filter(_.name != vectorCol) ++
       schema.fields.filter(_.name == vectorCol))

http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
index e0ca45b..af6494b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
@@ -74,6 +74,7 @@ class SQLTransformer @Since("1.6.0") (override val uid: 
String) extends Transfor
 
   @Since("1.6.0")
   override def transformSchema(schema: StructType): StructType = {
+    validateParams()
     val sc = SparkContext.getOrCreate()
     val sqlContext = SQLContext.getOrCreate(sc)
     val dummyRDD = sc.parallelize(Seq(Row.empty))

http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index d76a9c6..6a0b6c2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -94,6 +94,7 @@ class StandardScaler(override val uid: String) extends 
Estimator[StandardScalerM
   }
 
   override def transformSchema(schema: StructType): StructType = {
+    validateParams()
     val inputType = schema($(inputCol)).dataType
     require(inputType.isInstanceOf[VectorUDT],
       s"Input column ${$(inputCol)} must be a vector column")
@@ -143,6 +144,7 @@ class StandardScalerModel private[ml] (
   }
 
   override def transformSchema(schema: StructType): StructType = {
+    validateParams()
     val inputType = schema($(inputCol)).dataType
     require(inputType.isInstanceOf[VectorUDT],
       s"Input column ${$(inputCol)} must be a vector column")

http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
index 5d6936d..b93c9ed 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
@@ -145,6 +145,7 @@ class StopWordsRemover(override val uid: String)
   }
 
   override def transformSchema(schema: StructType): StructType = {
+    validateParams()
     val inputType = schema($(inputCol)).dataType
     require(inputType.sameType(ArrayType(StringType)),
       s"Input type must be ArrayType(StringType) but got $inputType.")

http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 5c40c35..912bd95 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -39,6 +39,7 @@ private[feature] trait StringIndexerBase extends Params with 
HasInputCol with Ha
 
   /** Validates and transforms the input schema. */
   protected def validateAndTransformSchema(schema: StructType): StructType = {
+    validateParams()
     val inputColName = $(inputCol)
     val inputDataType = schema(inputColName).dataType
     require(inputDataType == StringType || 
inputDataType.isInstanceOf[NumericType],
@@ -272,6 +273,7 @@ class IndexToString private[ml] (override val uid: String)
   final def getLabels: Array[String] = $(labels)
 
   override def transformSchema(schema: StructType): StructType = {
+    validateParams()
     val inputColName = $(inputCol)
     val inputDataType = schema(inputColName).dataType
     require(inputDataType.isInstanceOf[NumericType],

http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index e9d1b57..0b21565 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -106,6 +106,7 @@ class VectorAssembler(override val uid: String)
   }
 
   override def transformSchema(schema: StructType): StructType = {
+    validateParams()
     val inputColNames = $(inputCols)
     val outputColName = $(outputCol)
     val inputDataTypes = inputColNames.map(name => schema(name).dataType)

http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index a637a6f..2a52684 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -126,6 +126,7 @@ class VectorIndexer(override val uid: String) extends 
Estimator[VectorIndexerMod
   }
 
   override def transformSchema(schema: StructType): StructType = {
+    validateParams()
     // We do not transfer feature metadata since we do not know what types of 
features we will
     // produce in transform().
     val dataType = new VectorUDT
@@ -354,6 +355,7 @@ class VectorIndexerModel private[ml] (
   }
 
   override def transformSchema(schema: StructType): StructType = {
+    validateParams()
     val dataType = new VectorUDT
     require(isDefined(inputCol),
       s"VectorIndexerModel requires input column parameter: $inputCol")

http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
index 4813d8a..300d63b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
@@ -139,6 +139,7 @@ final class VectorSlicer(override val uid: String)
   }
 
   override def transformSchema(schema: StructType): StructType = {
+    validateParams()
     SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
 
     if (schema.fieldNames.contains($(outputCol))) {

http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 59c34cd..2b6b3c3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -92,6 +92,7 @@ private[feature] trait Word2VecBase extends Params
    * Validate and transform the input schema.
    */
   protected def validateAndTransformSchema(schema: StructType): StructType = {
+    validateParams()
     SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, 
true))
     SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala 
b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 14a28b8..472c185 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -162,6 +162,7 @@ private[recommendation] trait ALSParams extends 
ALSModelParams with HasMaxIter w
    * @return output schema
    */
   protected def validateAndTransformSchema(schema: StructType): StructType = {
+    validateParams()
     SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
     SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
     val ratingType = schema($(ratingCol)).dataType
@@ -213,6 +214,7 @@ class ALSModel private[ml] (
   }
 
   override def transformSchema(schema: StructType): StructType = {
+    validateParams()
     SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
     SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
     SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)

http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index 3787ca4..e8a1ff2 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -99,6 +99,7 @@ private[regression] trait AFTSurvivalRegressionParams extends 
Params
   protected def validateAndTransformSchema(
       schema: StructType,
       fitting: Boolean): StructType = {
+    validateParams()
     SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
     if (fitting) {
       SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType)

http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala 
b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
index e8d361b..1573bb4 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
@@ -105,6 +105,7 @@ private[regression] trait IsotonicRegressionBase extends 
Params with HasFeatures
   protected[ml] def validateAndTransformSchema(
       schema: StructType,
       fitting: Boolean): StructType = {
+    validateParams()
     if (fitting) {
       SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
       if (hasWeightCol) {

http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 477675c..3eac616 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -131,6 +131,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") 
override val uid: String)
 
   @Since("1.4.0")
   override def transformSchema(schema: StructType): StructType = {
+    validateParams()
     $(estimator).transformSchema(schema)
   }
 
@@ -345,6 +346,7 @@ class CrossValidatorModel private[ml] (
 
   @Since("1.4.0")
   override def transformSchema(schema: StructType): StructType = {
+    validateParams()
     bestModel.transformSchema(schema)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index f346ea6..4f67e8c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -118,6 +118,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") 
override val uid: St
 
   @Since("1.5.0")
   override def transformSchema(schema: StructType): StructType = {
+    validateParams()
     $(estimator).transformSchema(schema)
   }
 
@@ -172,6 +173,7 @@ class TrainValidationSplitModel private[ml] (
 
   @Since("1.5.0")
   override def transformSchema(schema: StructType): StructType = {
+    validateParams()
     bestModel.transformSchema(schema)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index 8c86767..f3321fb 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -26,9 +26,10 @@ import org.scalatest.mock.MockitoSugar.mock
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.Pipeline.SharedReadWrite
-import org.apache.spark.ml.feature.HashingTF
+import org.apache.spark.ml.feature.{HashingTF, MinMaxScaler}
 import org.apache.spark.ml.param.{IntParam, ParamMap}
 import org.apache.spark.ml.util._
+import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.types.StructType
@@ -174,6 +175,26 @@ class PipelineSuite extends SparkFunSuite with 
MLlibTestSparkContext with Defaul
       }
     }
   }
+
+  test("pipeline validateParams") {
+    val df = sqlContext.createDataFrame(
+      Seq(
+        (1, Vectors.dense(0.0, 1.0, 4.0), 1.0),
+        (2, Vectors.dense(1.0, 0.0, 4.0), 2.0),
+        (3, Vectors.dense(1.0, 0.0, 5.0), 3.0),
+        (4, Vectors.dense(0.0, 0.0, 5.0), 4.0))
+    ).toDF("id", "features", "label")
+
+    intercept[IllegalArgumentException] {
+       val scaler = new MinMaxScaler()
+         .setInputCol("features")
+         .setOutputCol("features_scaled")
+         .setMin(10)
+         .setMax(0)
+       val pipeline = new Pipeline().setStages(Array(scaler))
+       pipeline.fit(df)
+    }
+  }
 }
 
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to