Repository: spark Updated Branches: refs/heads/master ccd298eb6 -> 5855e0057
[SPARK-15668][ML] ml.feature: update check schema to avoid confusion when user use MLlib.vector as input type ## What changes were proposed in this pull request? ml.feature: update check schema to avoid confusion when user use MLlib.vector as input type ## How was this patch tested? existing ut Author: Yuhao Yang <yuhao.y...@intel.com> Closes #13411 from hhbyyh/schemaCheck. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5855e005 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5855e005 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5855e005 Branch: refs/heads/master Commit: 5855e0057defeab8006ca4f7b0196003bbc9e899 Parents: ccd298e Author: Yuhao Yang <yuhao.y...@intel.com> Authored: Thu Jun 2 16:37:01 2016 -0700 Committer: Nick Pentreath <ni...@za.ibm.com> Committed: Thu Jun 2 16:37:01 2016 -0700 ---------------------------------------------------------------------- .../apache/spark/ml/feature/MaxAbsScaler.scala | 4 +-- .../apache/spark/ml/feature/MinMaxScaler.scala | 4 +-- .../scala/org/apache/spark/ml/feature/PCA.scala | 28 +++++++++----------- .../spark/ml/feature/StandardScaler.scala | 25 ++++++++--------- 4 files changed, 25 insertions(+), 36 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/5855e005/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala index 1b51599..7298a18 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala @@ -39,9 +39,7 @@ private[feature] trait MaxAbsScalerParams extends Params with HasInputCol with H /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { - val inputType = schema($(inputCol)).dataType - require(inputType.isInstanceOf[VectorUDT], - s"Input column ${$(inputCol)} must be a vector column") + SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) require(!schema.fieldNames.contains($(outputCol)), s"Output column ${$(outputCol)} already exists.") val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) http://git-wip-us.apache.org/repos/asf/spark/blob/5855e005/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index d15f1b8..a27bed5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -63,9 +63,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { require($(min) < $(max), s"The specified min(${$(min)}) is larger or equal to max(${$(max)})") - val inputType = schema($(inputCol)).dataType - require(inputType.isInstanceOf[VectorUDT], - s"Input column ${$(inputCol)} must be a vector column") + SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) require(!schema.fieldNames.contains($(outputCol)), s"Output column ${$(outputCol)} already exists.") val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) http://git-wip-us.apache.org/repos/asf/spark/blob/5855e005/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index dbbaa5a..2f667af 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -49,8 +49,16 @@ private[feature] trait PCAParams extends Params with HasInputCol with HasOutputC /** @group getParam */ def getK: Int = $(k) -} + /** Validates and transforms the input schema. */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) + require(!schema.fieldNames.contains($(outputCol)), + s"Output column ${$(outputCol)} already exists.") + val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) + StructType(outputFields) + } +} /** * :: Experimental :: * PCA trains a model to project vectors to a lower dimensional space of the top [[PCA!.k]] @@ -86,13 +94,7 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams } override def transformSchema(schema: StructType): StructType = { - val inputType = schema($(inputCol)).dataType - require(inputType.isInstanceOf[VectorUDT], - s"Input column ${$(inputCol)} must be a vector column") - require(!schema.fieldNames.contains($(outputCol)), - s"Output column ${$(outputCol)} already exists.") - val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) - StructType(outputFields) + validateAndTransformSchema(schema) } override def copy(extra: ParamMap): PCA = defaultCopy(extra) @@ -148,13 +150,7 @@ class PCAModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { - val inputType = schema($(inputCol)).dataType - require(inputType.isInstanceOf[VectorUDT], - s"Input column ${$(inputCol)} must be a vector column") - require(!schema.fieldNames.contains($(outputCol)), - s"Output column ${$(outputCol)} already exists.") - val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) - StructType(outputFields) + validateAndTransformSchema(schema) } override def copy(extra: ParamMap): PCAModel = { @@ -201,7 +197,7 @@ object PCAModel extends MLReadable[PCAModel] { val versionRegex = "([0-9]+)\\.([0-9]+).*".r val hasExplainedVariance = metadata.sparkVersion match { case versionRegex(major, minor) => - (major.toInt >= 2 || (major.toInt == 1 && minor.toInt > 6)) + major.toInt >= 2 || (major.toInt == 1 && minor.toInt > 6) case _ => false } http://git-wip-us.apache.org/repos/asf/spark/blob/5855e005/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 9d084b5..7cec369 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -62,6 +62,15 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with /** @group getParam */ def getWithStd: Boolean = $(withStd) + /** Validates and transforms the input schema. */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) + require(!schema.fieldNames.contains($(outputCol)), + s"Output column ${$(outputCol)} already exists.") + val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) + StructType(outputFields) + } + setDefault(withMean -> false, withStd -> true) } @@ -105,13 +114,7 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM } override def transformSchema(schema: StructType): StructType = { - val inputType = schema($(inputCol)).dataType - require(inputType.isInstanceOf[VectorUDT], - s"Input column ${$(inputCol)} must be a vector column") - require(!schema.fieldNames.contains($(outputCol)), - s"Output column ${$(outputCol)} already exists.") - val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) - StructType(outputFields) + validateAndTransformSchema(schema) } override def copy(extra: ParamMap): StandardScaler = defaultCopy(extra) @@ -159,13 +162,7 @@ class StandardScalerModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { - val inputType = schema($(inputCol)).dataType - require(inputType.isInstanceOf[VectorUDT], - s"Input column ${$(inputCol)} must be a vector column") - require(!schema.fieldNames.contains($(outputCol)), - s"Output column ${$(outputCol)} already exists.") - val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) - StructType(outputFields) + validateAndTransformSchema(schema) } override def copy(extra: ParamMap): StandardScalerModel = { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org