Repository: spark Updated Branches: refs/heads/master 446738e51 -> 67a5132c2
[SPARK-7013][ML][TEST] Add unit test for spark.ml StandardScaler I have added unit test for ML's StandardScaler By comparing with R's output, please review for me. Thx. Author: RoyGaoVLIS <[email protected]> Closes #6665 from RoyGao/7013. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/67a5132c Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/67a5132c Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/67a5132c Branch: refs/heads/master Commit: 67a5132c21bc8338adbae80b33b85e8fa0ddda34 Parents: 446738e Author: RoyGaoVLIS <[email protected]> Authored: Tue Nov 17 23:00:49 2015 -0800 Committer: Xiangrui Meng <[email protected]> Committed: Tue Nov 17 23:00:49 2015 -0800 ---------------------------------------------------------------------- .../spark/ml/feature/StandardScalerSuite.scala | 108 +++++++++++++++++++ 1 file changed, 108 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/67a5132c/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala new file mode 100644 index 0000000..879a3ae --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + +class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext{ + + @transient var data: Array[Vector] = _ + @transient var resWithStd: Array[Vector] = _ + @transient var resWithMean: Array[Vector] = _ + @transient var resWithBoth: Array[Vector] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + data = Array( + Vectors.dense(-2.0, 2.3, 0.0), + Vectors.dense(0.0, -5.1, 1.0), + Vectors.dense(1.7, -0.6, 3.3) + ) + resWithMean = Array( + Vectors.dense(-1.9, 3.433333333333, -1.433333333333), + Vectors.dense(0.1, -3.966666666667, -0.433333333333), + Vectors.dense(1.8, 0.533333333333, 1.866666666667) + ) + resWithStd = Array( + Vectors.dense(-1.079898494312, 0.616834091415, 0.0), + Vectors.dense(0.0, -1.367762550529, 0.590968109266), + Vectors.dense(0.917913720165, -0.160913241239, 1.950194760579) + ) + resWithBoth = Array( + Vectors.dense(-1.0259035695965, 0.920781324866, -0.8470542899497), + Vectors.dense(0.0539949247156, -1.063815317078, -0.256086180682), + Vectors.dense(0.9719086448809, 0.143033992212, 1.103140470631) + ) + } + + def assertResult(dataframe: DataFrame): Unit = { + dataframe.select("standarded_features", "expected").collect().foreach { + case Row(vector1: Vector, vector2: Vector) => + assert(vector1 ~== vector2 absTol 1E-5, + "The vector value is not correct after standardization.") + } + } + + test("Standardization with default parameter") { + val df0 = sqlContext.createDataFrame(data.zip(resWithStd)).toDF("features", "expected") + + val standardscaler0 = new StandardScaler() + .setInputCol("features") + .setOutputCol("standarded_features") + .fit(df0) + + assertResult(standardscaler0.transform(df0)) + } + + test("Standardization with setter") { + val df1 = sqlContext.createDataFrame(data.zip(resWithBoth)).toDF("features", "expected") + val df2 = sqlContext.createDataFrame(data.zip(resWithMean)).toDF("features", "expected") + val df3 = sqlContext.createDataFrame(data.zip(data)).toDF("features", "expected") + + val standardscaler1 = new StandardScaler() + .setInputCol("features") + .setOutputCol("standarded_features") + .setWithMean(true) + .setWithStd(true) + .fit(df1) + + val standardscaler2 = new StandardScaler() + .setInputCol("features") + .setOutputCol("standarded_features") + .setWithMean(true) + .setWithStd(false) + .fit(df2) + + val standardscaler3 = new StandardScaler() + .setInputCol("features") + .setOutputCol("standarded_features") + .setWithMean(false) + .setWithStd(false) + .fit(df3) + + assertResult(standardscaler1.transform(df1)) + assertResult(standardscaler2.transform(df2)) + assertResult(standardscaler3.transform(df3)) + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
