Repository: spark Updated Branches: refs/heads/master 6fe690d5a -> 1be207078
[SPARK-5924] Add the ability to specify withMean or withStd parameters with StandarScaler The current implementation call the default constructor of mllib.feature.StandarScaler without the possibility to specify withMean or withStd options. Author: jrabary <jaon...@gmail.com> Closes #4704 from jrabary/master and squashes the following commits: fae8568 [jrabary] style fix 8896b0e [jrabary] Comments fix ef96d73 [jrabary] style fix 8e52607 [jrabary] style fix edd9d48 [jrabary] Fix default param initialization 17e1a76 [jrabary] Fix default param initialization 298f405 [jrabary] Typo fix 45ed914 [jrabary] Add withMean and withStd params to StandarScaler Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1be20707 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1be20707 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1be20707 Branch: refs/heads/master Commit: 1be207078cef48c5935595969bf9f6b1ec1334ca Parents: 6fe690d Author: jrabary <jaon...@gmail.com> Authored: Mon Apr 20 09:47:56 2015 -0700 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Mon Apr 20 09:47:56 2015 -0700 ---------------------------------------------------------------------- .../spark/ml/feature/StandardScaler.scala | 32 +++++++++++++++++--- 1 file changed, 28 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/1be20707/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 1b10261..447851e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -30,7 +30,22 @@ import org.apache.spark.sql.types.{StructField, StructType} /** * Params for [[StandardScaler]] and [[StandardScalerModel]]. */ -private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol +private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol { + + /** + * False by default. Centers the data with mean before scaling. + * It will build a dense output, so this does not work on sparse input + * and will raise an exception. + * @group param + */ + val withMean: BooleanParam = new BooleanParam(this, "withMean", "Center data with mean") + + /** + * True by default. Scales the data to unit standard deviation. + * @group param + */ + val withStd: BooleanParam = new BooleanParam(this, "withStd", "Scale to unit standard deviation") +} /** * :: AlphaComponent :: @@ -40,18 +55,27 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with @AlphaComponent class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerParams { + setDefault(withMean -> false, withStd -> true) + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - + + /** @group setParam */ + def setWithMean(value: Boolean): this.type = set(withMean, value) + + /** @group setParam */ + def setWithStd(value: Boolean): this.type = set(withStd, value) + override def fit(dataset: DataFrame, paramMap: ParamMap): StandardScalerModel = { transformSchema(dataset.schema, paramMap, logging = true) val map = extractParamMap(paramMap) val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v } - val scaler = new feature.StandardScaler().fit(input) - val model = new StandardScalerModel(this, map, scaler) + val scaler = new feature.StandardScaler(withMean = map(withMean), withStd = map(withStd)) + val scalerModel = scaler.fit(input) + val model = new StandardScalerModel(this, map, scalerModel) Params.inheritValues(map, this, model) model } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org