Repository: spark Updated Branches: refs/heads/master 74cc16dbc -> 61d7b533d
[SPARK-7514] [MLLIB] Add MinMaxScaler to feature transformation jira: https://issues.apache.org/jira/browse/SPARK-7514 Add a popular scaling method to feature component, which is commonly known as min-max normalization or Rescaling. Core function is, Normalized(x) = (x - min) / (max - min) * scale + newBase where `newBase` and `scale` are parameters (type Double) of the `VectorTransformer`. `newBase` is the new minimum number for the features, and `scale` controls the ranges after transformation. This is a little complicated than the basic MinMax normalization, yet it provides flexibility so that users can control the range more specifically. like [0.1, 0.9] in some NN application. For case that `max == min`, 0.5 is used as the raw value. (0.5 * scale + newBase) I'll add UT once the design got settled ( and this is not considered as too naive) reference: http://en.wikipedia.org/wiki/Feature_scaling http://stn.spotfire.com/spotfire_client_help/index.htm#norm/norm_scale_between_0_and_1.htm Author: Yuhao Yang <[email protected]> Closes #6039 from hhbyyh/minMaxNorm and squashes the following commits: f942e9f [Yuhao Yang] add todo for metadata 8b37bbc [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into minMaxNorm 4894dbc [Yuhao Yang] add copy fa2989f [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into minMaxNorm 29db415 [Yuhao Yang] add clue and minor adjustment 5b8f7cc [Yuhao Yang] style fix 9b133d0 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into minMaxNorm 22f20f2 [Yuhao Yang] style change and bug fix 747c9bb [Yuhao Yang] add ut and remove mllib version a5ba0aa [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into minMaxNorm 585cc07 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into minMaxNorm 1c6dcb1 [Yuhao Yang] minor change 0f1bc80 [Yuhao Yang] add MinMaxScaler to ml 8e7436e [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into minMaxNorm 3663165 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into minMaxNorm 1247c27 [Yuhao Yang] some comments improvement d285a19 [Yuhao Yang] initial checkin for minMaxNorm Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/61d7b533 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/61d7b533 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/61d7b533 Branch: refs/heads/master Commit: 61d7b533dd50bfac2162b4edcea94724bbd8fcb1 Parents: 74cc16d Author: Yuhao Yang <[email protected]> Authored: Tue Jun 30 12:44:43 2015 -0700 Committer: Joseph K. Bradley <[email protected]> Committed: Tue Jun 30 12:44:43 2015 -0700 ---------------------------------------------------------------------- .../apache/spark/ml/feature/MinMaxScaler.scala | 170 +++++++++++++++++++ .../spark/ml/feature/MinMaxScalerSuite.scala | 68 ++++++++ 2 files changed, 238 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/61d7b533/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala new file mode 100644 index 0000000..b30adf3 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.param.{ParamMap, DoubleParam, Params} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.mllib.stat.Statistics +import org.apache.spark.sql._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{StructField, StructType} + +/** + * Params for [[MinMaxScaler]] and [[MinMaxScalerModel]]. + */ +private[feature] trait MinMaxScalerParams extends Params with HasInputCol with HasOutputCol { + + /** + * lower bound after transformation, shared by all features + * Default: 0.0 + * @group param + */ + val min: DoubleParam = new DoubleParam(this, "min", + "lower bound of the output feature range") + + /** + * upper bound after transformation, shared by all features + * Default: 1.0 + * @group param + */ + val max: DoubleParam = new DoubleParam(this, "max", + "upper bound of the output feature range") + + /** Validates and transforms the input schema. */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + val inputType = schema($(inputCol)).dataType + require(inputType.isInstanceOf[VectorUDT], + s"Input column ${$(inputCol)} must be a vector column") + require(!schema.fieldNames.contains($(outputCol)), + s"Output column ${$(outputCol)} already exists.") + val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) + StructType(outputFields) + } + + override def validateParams(): Unit = { + require($(min) < $(max), s"The specified min(${$(min)}) is larger or equal to max(${$(max)})") + } +} + +/** + * :: Experimental :: + * Rescale each feature individually to a common range [min, max] linearly using column summary + * statistics, which is also known as min-max normalization or Rescaling. The rescaled value for + * feature E is calculated as, + * + * Rescaled(e_i) = \frac{e_i - E_{min}}{E_{max} - E_{min}} * (max - min) + min + * + * For the case E_{max} == E_{min}, Rescaled(e_i) = 0.5 * (max + min) + * Note that since zero values will probably be transformed to non-zero values, output of the + * transformer will be DenseVector even for sparse input. + */ +@Experimental +class MinMaxScaler(override val uid: String) + extends Estimator[MinMaxScalerModel] with MinMaxScalerParams { + + def this() = this(Identifiable.randomUID("minMaxScal")) + + setDefault(min -> 0.0, max -> 1.0) + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + def setMin(value: Double): this.type = set(min, value) + + /** @group setParam */ + def setMax(value: Double): this.type = set(max, value) + + override def fit(dataset: DataFrame): MinMaxScalerModel = { + transformSchema(dataset.schema, logging = true) + val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } + val summary = Statistics.colStats(input) + copyValues(new MinMaxScalerModel(uid, summary.min, summary.max).setParent(this)) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + override def copy(extra: ParamMap): MinMaxScaler = defaultCopy(extra) +} + +/** + * :: Experimental :: + * Model fitted by [[MinMaxScaler]]. + * + * TODO: The transformer does not yet set the metadata in the output column (SPARK-8529). + */ +@Experimental +class MinMaxScalerModel private[ml] ( + override val uid: String, + val originalMin: Vector, + val originalMax: Vector) + extends Model[MinMaxScalerModel] with MinMaxScalerParams { + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + def setMin(value: Double): this.type = set(min, value) + + /** @group setParam */ + def setMax(value: Double): this.type = set(max, value) + + + override def transform(dataset: DataFrame): DataFrame = { + val originalRange = (originalMax.toBreeze - originalMin.toBreeze).toArray + val minArray = originalMin.toArray + + val reScale = udf { (vector: Vector) => + val scale = $(max) - $(min) + + // 0 in sparse vector will probably be rescaled to non-zero + val values = vector.toArray + val size = values.size + var i = 0 + while (i < size) { + val raw = if (originalRange(i) != 0) (values(i) - minArray(i)) / originalRange(i) else 0.5 + values(i) = raw * scale + $(min) + i += 1 + } + Vectors.dense(values) + } + + dataset.withColumn($(outputCol), reScale(col($(inputCol)))) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + override def copy(extra: ParamMap): MinMaxScalerModel = { + val copied = new MinMaxScalerModel(uid, originalMin, originalMax) + copyValues(copied, extra) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/61d7b533/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala new file mode 100644 index 0000000..c452054 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{Row, SQLContext} + +class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("MinMaxScaler fit basic case") { + val sqlContext = new SQLContext(sc) + + val data = Array( + Vectors.dense(1, 0, Long.MinValue), + Vectors.dense(2, 0, 0), + Vectors.sparse(3, Array(0, 2), Array(3, Long.MaxValue)), + Vectors.sparse(3, Array(0), Array(1.5))) + + val expected: Array[Vector] = Array( + Vectors.dense(-5, 0, -5), + Vectors.dense(0, 0, 0), + Vectors.sparse(3, Array(0, 2), Array(5, 5)), + Vectors.sparse(3, Array(0), Array(-2.5))) + + val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected") + val scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("scaled") + .setMin(-5) + .setMax(5) + + val model = scaler.fit(df) + model.transform(df).select("expected", "scaled").collect() + .foreach { case Row(vector1: Vector, vector2: Vector) => + assert(vector1.equals(vector2), "Transformed vector is different with expected.") + } + } + + test("MinMaxScaler arguments max must be larger than min") { + withClue("arguments max must be larger than min") { + intercept[IllegalArgumentException] { + val scaler = new MinMaxScaler().setMin(10).setMax(0) + scaler.validateParams() + } + intercept[IllegalArgumentException] { + val scaler = new MinMaxScaler().setMin(0).setMax(0) + scaler.validateParams() + } + } + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
