Repository: spark Updated Branches: refs/heads/master 0171b71e9 -> ba5f81859
[SPARK-11259][ML] Params.validateParams() should be called automatically See JIRA: https://issues.apache.org/jira/browse/SPARK-11259 Author: Yanbo Liang <yblia...@gmail.com> Closes #9224 from yanboliang/spark-11259. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ba5f8185 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ba5f8185 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ba5f8185 Branch: refs/heads/master Commit: ba5f81859d6ba37a228a1c43d26c47e64c0382cd Parents: 0171b71 Author: Yanbo Liang <yblia...@gmail.com> Authored: Mon Jan 4 13:30:17 2016 -0800 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Mon Jan 4 13:30:17 2016 -0800 ---------------------------------------------------------------------- .../scala/org/apache/spark/ml/Pipeline.scala | 2 ++ .../scala/org/apache/spark/ml/Predictor.scala | 1 + .../scala/org/apache/spark/ml/Transformer.scala | 1 + .../org/apache/spark/ml/clustering/KMeans.scala | 1 + .../org/apache/spark/ml/clustering/LDA.scala | 1 + .../org/apache/spark/ml/feature/Binarizer.scala | 1 + .../apache/spark/ml/feature/Bucketizer.scala | 1 + .../apache/spark/ml/feature/ChiSqSelector.scala | 2 ++ .../spark/ml/feature/CountVectorizer.scala | 1 + .../org/apache/spark/ml/feature/HashingTF.scala | 1 + .../scala/org/apache/spark/ml/feature/IDF.scala | 1 + .../apache/spark/ml/feature/MinMaxScaler.scala | 1 + .../apache/spark/ml/feature/OneHotEncoder.scala | 1 + .../scala/org/apache/spark/ml/feature/PCA.scala | 2 ++ .../spark/ml/feature/QuantileDiscretizer.scala | 1 + .../org/apache/spark/ml/feature/RFormula.scala | 4 ++++ .../spark/ml/feature/SQLTransformer.scala | 1 + .../spark/ml/feature/StandardScaler.scala | 2 ++ .../spark/ml/feature/StopWordsRemover.scala | 1 + .../apache/spark/ml/feature/StringIndexer.scala | 2 ++ .../spark/ml/feature/VectorAssembler.scala | 1 + .../apache/spark/ml/feature/VectorIndexer.scala | 2 ++ .../apache/spark/ml/feature/VectorSlicer.scala | 1 + .../org/apache/spark/ml/feature/Word2Vec.scala | 1 + .../apache/spark/ml/recommendation/ALS.scala | 2 ++ .../ml/regression/AFTSurvivalRegression.scala | 1 + .../ml/regression/IsotonicRegression.scala | 1 + .../apache/spark/ml/tuning/CrossValidator.scala | 2 ++ .../spark/ml/tuning/TrainValidationSplit.scala | 2 ++ .../org/apache/spark/ml/PipelineSuite.scala | 23 +++++++++++++++++++- 30 files changed, 63 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 3acc60d..32570a1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -165,6 +165,7 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] with M } override def transformSchema(schema: StructType): StructType = { + validateParams() val theStages = $(stages) require(theStages.toSet.size == theStages.length, "Cannot have duplicate components in a pipeline.") @@ -296,6 +297,7 @@ class PipelineModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { + validateParams() stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur)) } http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index 6aacffd..d1388b5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -46,6 +46,7 @@ private[ml] trait PredictorParams extends Params schema: StructType, fitting: Boolean, featuresDataType: DataType): StructType = { + validateParams() // TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType) if (fitting) { http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index 1f3325a..fdce273 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -103,6 +103,7 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] protected def validateInputType(inputType: DataType): Unit = {} override def transformSchema(schema: StructType): StructType = { + validateParams() val inputType = schema($(inputCol)).dataType validateInputType(inputType) if (schema.fieldNames.contains($(outputCol))) { http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 6e5abb2..dc6d5d9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -80,6 +80,7 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) } http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index af0b3e1..99383e7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -263,6 +263,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT) } http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index 5b17d34..544cf05 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -72,6 +72,7 @@ final class Binarizer(override val uid: String) } override def transformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) val inputFields = schema.fields http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 33abc7c..0c75317 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -86,6 +86,7 @@ final class Bucketizer(override val uid: String) } override def transformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) SchemaUtils.appendColumn(schema, prepOutputField(schema)) } http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index dfec038..7b565ef 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -88,6 +88,7 @@ final class ChiSqSelector(override val uid: String) } override def transformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) @@ -135,6 +136,7 @@ final class ChiSqSelectorModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) val newField = prepOutputField(schema) val outputFields = schema.fields :+ newField http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 1268c87..10dcda2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -70,6 +70,7 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true)) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) } http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 61a78d7..8af0058 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -69,6 +69,7 @@ class HashingTF(override val uid: String) } override def transformSchema(schema: StructType): StructType = { + validateParams() val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[ArrayType], s"The input column must be ArrayType, but got $inputType.") http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index f7b0f29..9e7eee4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -52,6 +52,7 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol * Validate and transform the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) } http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index 559a025..ad0458d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -59,6 +59,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { + validateParams() val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${$(inputCol)} must be a vector column") http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index c01e29a..3425404 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -66,6 +66,7 @@ class OneHotEncoder(override val uid: String) extends Transformer def setOutputCol(value: String): this.type = set(outputCol, value) override def transformSchema(schema: StructType): StructType = { + validateParams() val inputColName = $(inputCol) val outputColName = $(outputCol) http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index f653798..7020397 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -77,6 +77,7 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams } override def transformSchema(schema: StructType): StructType = { + validateParams() val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${$(inputCol)} must be a vector column") @@ -130,6 +131,7 @@ class PCAModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { + validateParams() val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${$(inputCol)} must be a vector column") http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 39de846..8fd0ce2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -74,6 +74,7 @@ final class QuantileDiscretizer(override val uid: String) def setOutputCol(value: String): this.type = set(outputCol, value) override def transformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) val inputFields = schema.fields require(inputFields.forall(_.name != $(outputCol)), http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 2b578c2..f995243 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -146,6 +146,7 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R // optimistic schema; does not contain any ML attributes override def transformSchema(schema: StructType): StructType = { + validateParams() if (hasLabelCol(schema)) { StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true)) } else { @@ -178,6 +179,7 @@ class RFormulaModel private[feature]( } override def transformSchema(schema: StructType): StructType = { + validateParams() checkCanTransform(schema) val withFeatures = pipelineModel.transformSchema(schema) if (hasLabelCol(withFeatures)) { @@ -240,6 +242,7 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer { } override def transformSchema(schema: StructType): StructType = { + validateParams() StructType(schema.fields.filter(col => !columnsToPrune.contains(col.name))) } @@ -288,6 +291,7 @@ private class VectorAttributeRewriter( } override def transformSchema(schema: StructType): StructType = { + validateParams() StructType( schema.fields.filter(_.name != vectorCol) ++ schema.fields.filter(_.name == vectorCol)) http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index e0ca45b..af6494b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -74,6 +74,7 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { + validateParams() val sc = SparkContext.getOrCreate() val sqlContext = SQLContext.getOrCreate(sc) val dummyRDD = sc.parallelize(Seq(Row.empty)) http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index d76a9c6..6a0b6c2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -94,6 +94,7 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM } override def transformSchema(schema: StructType): StructType = { + validateParams() val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${$(inputCol)} must be a vector column") @@ -143,6 +144,7 @@ class StandardScalerModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { + validateParams() val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${$(inputCol)} must be a vector column") http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 5d6936d..b93c9ed 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -145,6 +145,7 @@ class StopWordsRemover(override val uid: String) } override def transformSchema(schema: StructType): StructType = { + validateParams() val inputType = schema($(inputCol)).dataType require(inputType.sameType(ArrayType(StringType)), s"Input type must be ArrayType(StringType) but got $inputType.") http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 5c40c35..912bd95 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -39,6 +39,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { + validateParams() val inputColName = $(inputCol) val inputDataType = schema(inputColName).dataType require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType], @@ -272,6 +273,7 @@ class IndexToString private[ml] (override val uid: String) final def getLabels: Array[String] = $(labels) override def transformSchema(schema: StructType): StructType = { + validateParams() val inputColName = $(inputCol) val inputDataType = schema(inputColName).dataType require(inputDataType.isInstanceOf[NumericType], http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index e9d1b57..0b21565 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -106,6 +106,7 @@ class VectorAssembler(override val uid: String) } override def transformSchema(schema: StructType): StructType = { + validateParams() val inputColNames = $(inputCols) val outputColName = $(outputCol) val inputDataTypes = inputColNames.map(name => schema(name).dataType) http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index a637a6f..2a52684 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -126,6 +126,7 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod } override def transformSchema(schema: StructType): StructType = { + validateParams() // We do not transfer feature metadata since we do not know what types of features we will // produce in transform(). val dataType = new VectorUDT @@ -354,6 +355,7 @@ class VectorIndexerModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { + validateParams() val dataType = new VectorUDT require(isDefined(inputCol), s"VectorIndexerModel requires input column parameter: $inputCol") http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala index 4813d8a..300d63b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala @@ -139,6 +139,7 @@ final class VectorSlicer(override val uid: String) } override def transformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) if (schema.fieldNames.contains($(outputCol))) { http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 59c34cd..2b6b3c3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -92,6 +92,7 @@ private[feature] trait Word2VecBase extends Params * Validate and transform the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true)) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) } http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 14a28b8..472c185 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -162,6 +162,7 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(userCol), IntegerType) SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) val ratingType = schema($(ratingCol)).dataType @@ -213,6 +214,7 @@ class ALSModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(userCol), IntegerType) SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) SchemaUtils.appendColumn(schema, $(predictionCol), FloatType) http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 3787ca4..e8a1ff2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -99,6 +99,7 @@ private[regression] trait AFTSurvivalRegressionParams extends Params protected def validateAndTransformSchema( schema: StructType, fitting: Boolean): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) if (fitting) { SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType) http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index e8d361b..1573bb4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -105,6 +105,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures protected[ml] def validateAndTransformSchema( schema: StructType, fitting: Boolean): StructType = { + validateParams() if (fitting) { SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) if (hasWeightCol) { http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 477675c..3eac616 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -131,6 +131,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { + validateParams() $(estimator).transformSchema(schema) } @@ -345,6 +346,7 @@ class CrossValidatorModel private[ml] ( @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { + validateParams() bestModel.transformSchema(schema) } http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index f346ea6..4f67e8c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -118,6 +118,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { + validateParams() $(estimator).transformSchema(schema) } @@ -172,6 +173,7 @@ class TrainValidationSplitModel private[ml] ( @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { + validateParams() bestModel.transformSchema(schema) } http://git-wip-us.apache.org/repos/asf/spark/blob/ba5f8185/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 8c86767..f3321fb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -26,9 +26,10 @@ import org.scalatest.mock.MockitoSugar.mock import org.apache.spark.SparkFunSuite import org.apache.spark.ml.Pipeline.SharedReadWrite -import org.apache.spark.ml.feature.HashingTF +import org.apache.spark.ml.feature.{HashingTF, MinMaxScaler} import org.apache.spark.ml.param.{IntParam, ParamMap} import org.apache.spark.ml.util._ +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType @@ -174,6 +175,26 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } } } + + test("pipeline validateParams") { + val df = sqlContext.createDataFrame( + Seq( + (1, Vectors.dense(0.0, 1.0, 4.0), 1.0), + (2, Vectors.dense(1.0, 0.0, 4.0), 2.0), + (3, Vectors.dense(1.0, 0.0, 5.0), 3.0), + (4, Vectors.dense(0.0, 0.0, 5.0), 4.0)) + ).toDF("id", "features", "label") + + intercept[IllegalArgumentException] { + val scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("features_scaled") + .setMin(10) + .setMax(0) + val pipeline = new Pipeline().setStages(Array(scaler)) + pipeline.fit(df) + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org