Repository: spark
Updated Branches:
  refs/heads/branch-1.6 737f07172 -> 3f63f08f9


[SPARK-7013][ML][TEST] Add unit test for spark.ml StandardScaler

I have added unit test for ML's StandardScaler By comparing with R's output, 
please review  for me.
Thx.

Author: RoyGaoVLIS <[email protected]>

Closes #6665 from RoyGao/7013.

(cherry picked from commit 67a5132c21bc8338adbae80b33b85e8fa0ddda34)
Signed-off-by: Xiangrui Meng <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3f63f08f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3f63f08f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3f63f08f

Branch: refs/heads/branch-1.6
Commit: 3f63f08f9db6073ef9b6318ba20ebfbd1bbd263a
Parents: 737f071
Author: RoyGaoVLIS <[email protected]>
Authored: Tue Nov 17 23:00:49 2015 -0800
Committer: Xiangrui Meng <[email protected]>
Committed: Tue Nov 17 23:01:03 2015 -0800

----------------------------------------------------------------------
 .../spark/ml/feature/StandardScalerSuite.scala  | 108 +++++++++++++++++++
 1 file changed, 108 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3f63f08f/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
new file mode 100644
index 0000000..879a3ae
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, 
Vectors}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+
+class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext{
+
+  @transient var data: Array[Vector] = _
+  @transient var resWithStd: Array[Vector] = _
+  @transient var resWithMean: Array[Vector] = _
+  @transient var resWithBoth: Array[Vector] = _
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+
+    data = Array(
+      Vectors.dense(-2.0, 2.3, 0.0),
+      Vectors.dense(0.0, -5.1, 1.0),
+      Vectors.dense(1.7, -0.6, 3.3)
+    )
+    resWithMean = Array(
+      Vectors.dense(-1.9, 3.433333333333, -1.433333333333),
+      Vectors.dense(0.1, -3.966666666667, -0.433333333333),
+      Vectors.dense(1.8, 0.533333333333, 1.866666666667)
+    )
+    resWithStd = Array(
+      Vectors.dense(-1.079898494312, 0.616834091415, 0.0),
+      Vectors.dense(0.0, -1.367762550529, 0.590968109266),
+      Vectors.dense(0.917913720165, -0.160913241239, 1.950194760579)
+    )
+    resWithBoth = Array(
+      Vectors.dense(-1.0259035695965, 0.920781324866, -0.8470542899497),
+      Vectors.dense(0.0539949247156, -1.063815317078, -0.256086180682),
+      Vectors.dense(0.9719086448809, 0.143033992212, 1.103140470631)
+    )
+  }
+
+  def assertResult(dataframe: DataFrame): Unit = {
+    dataframe.select("standarded_features", "expected").collect().foreach {
+      case Row(vector1: Vector, vector2: Vector) =>
+        assert(vector1 ~== vector2 absTol 1E-5,
+          "The vector value is not correct after standardization.")
+    }
+  }
+
+  test("Standardization with default parameter") {
+    val df0 = 
sqlContext.createDataFrame(data.zip(resWithStd)).toDF("features", "expected")
+
+    val standardscaler0 = new StandardScaler()
+      .setInputCol("features")
+      .setOutputCol("standarded_features")
+      .fit(df0)
+
+    assertResult(standardscaler0.transform(df0))
+  }
+
+  test("Standardization with setter") {
+    val df1 = 
sqlContext.createDataFrame(data.zip(resWithBoth)).toDF("features", "expected")
+    val df2 = 
sqlContext.createDataFrame(data.zip(resWithMean)).toDF("features", "expected")
+    val df3 = sqlContext.createDataFrame(data.zip(data)).toDF("features", 
"expected")
+
+    val standardscaler1 = new StandardScaler()
+      .setInputCol("features")
+      .setOutputCol("standarded_features")
+      .setWithMean(true)
+      .setWithStd(true)
+      .fit(df1)
+
+    val standardscaler2 = new StandardScaler()
+      .setInputCol("features")
+      .setOutputCol("standarded_features")
+      .setWithMean(true)
+      .setWithStd(false)
+      .fit(df2)
+
+    val standardscaler3 = new StandardScaler()
+      .setInputCol("features")
+      .setOutputCol("standarded_features")
+      .setWithMean(false)
+      .setWithStd(false)
+      .fit(df3)
+
+    assertResult(standardscaler1.transform(df1))
+    assertResult(standardscaler2.transform(df2))
+    assertResult(standardscaler3.transform(df3))
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to