Repository: mahout Updated Branches: refs/heads/master 60bb75192 -> a70a8733c
MAHOUT-1930 Add Test for Standard Scaler closes apache/mahout#280 Project: http://git-wip-us.apache.org/repos/asf/mahout/repo Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/a70a8733 Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/a70a8733 Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/a70a8733 Branch: refs/heads/master Commit: a70a8733c6db0e5dcf02384f8dd474469c42e7c5 Parents: 60bb751 Author: rawkintrevo <[email protected]> Authored: Mon Feb 20 07:38:33 2017 -0600 Committer: rawkintrevo <[email protected]> Committed: Mon Feb 20 07:38:33 2017 -0600 ---------------------------------------------------------------------- .../preprocessing/StandardScaler.scala | 17 ++++++++-- .../math/algorithms/PreprocessorSuiteBase.scala | 33 +++++++++++++++++++- 2 files changed, 46 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/mahout/blob/a70a8733/math-scala/src/main/scala/org/apache/mahout/math/algorithms/preprocessing/StandardScaler.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/main/scala/org/apache/mahout/math/algorithms/preprocessing/StandardScaler.scala b/math-scala/src/main/scala/org/apache/mahout/math/algorithms/preprocessing/StandardScaler.scala index 98d0be1..5863330 100644 --- a/math-scala/src/main/scala/org/apache/mahout/math/algorithms/preprocessing/StandardScaler.scala +++ b/math-scala/src/main/scala/org/apache/mahout/math/algorithms/preprocessing/StandardScaler.scala @@ -29,6 +29,18 @@ import org.apache.mahout.math.{Vector => MahoutVector, Matrix} /** * Scales columns to mean 0 and unit variance + * + * An important note- The equivelent call in R would be something like + * ```r + * N <- nrow(x) + * scale(x, scale= apply(x, 2, sd) * sqrt(N-1/N)) + * ``` + * + * This is because R uses degrees of freedom = 1 to calculate standard deviation. + * Multiplying the standard deviation by sqrt(N-1/N) 'undoes' this correction. + * + * The StandardScaler of sklearn uses degrees of freedom = 0 for its calculation, so results + * should be similar. */ class StandardScaler extends PreprocessorFitter { @@ -40,15 +52,14 @@ class StandardScaler extends PreprocessorFitter { } -class StandardScalerModel(meanVec: MahoutVector, - stdev: MahoutVector +class StandardScalerModel(val meanVec: MahoutVector, + val stdev: MahoutVector ) extends PreprocessorModel { def transform[K](input: DrmLike[K]): DrmLike[K] = { implicit val ctx = input.context - // Some mapBlock() calls need it // implicit val ktag = input.keyClassTag http://git-wip-us.apache.org/repos/asf/mahout/blob/a70a8733/math-scala/src/test/scala/org/apache/mahout/math/algorithms/PreprocessorSuiteBase.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/test/scala/org/apache/mahout/math/algorithms/PreprocessorSuiteBase.scala b/math-scala/src/test/scala/org/apache/mahout/math/algorithms/PreprocessorSuiteBase.scala index 9e8f029..ec76c11 100644 --- a/math-scala/src/test/scala/org/apache/mahout/math/algorithms/PreprocessorSuiteBase.scala +++ b/math-scala/src/test/scala/org/apache/mahout/math/algorithms/PreprocessorSuiteBase.scala @@ -19,7 +19,7 @@ package org.apache.mahout.math.algorithms -import org.apache.mahout.math.algorithms.preprocessing.{AsFactor, AsFactorModel} +import org.apache.mahout.math.algorithms.preprocessing._ import org.apache.mahout.math.drm.drmParallelize import org.apache.mahout.math.scalabindings.{dense, sparse, svec} import org.apache.mahout.math.scalabindings.RLikeOps._ @@ -56,4 +56,35 @@ trait PreprocessorSuiteBase extends DistributedMahoutSuite with Matchers { (myAnswer.norm - correctAnswer.norm) should be <= epsilon } + + test("standard scaler test") { + /** + * R Prototype + * x <- matrix( c(1,2,3,1,5,9,5,-15,-2), nrow=3) + * scale(x, scale= apply(x, 2, sd) * sqrt(2/3)) + * # ^^ note: R uses degress of freedom = 1 for standard deviation calculations. + * # we don't (and neither does sklearn) + * # the *sqrt(N-1/N) 'undoes' the degrees of freedom = 1 + */ + + val A = drmParallelize(dense( + (1, 1, 5), + (2, 5, -15), + (3, 9, -2)), numPartitions = 2) + + val scaler: StandardScalerModel = new StandardScaler().fit(A) + + val correctAnswer = dense( + (-1.224745, -1.224745, -1.224745), + (0.000000, 0.000000, 1.224745), + (1.224745, 1.224745, 0.000000)) + + val myAnswer = scaler.transform(A).collect + println(scaler.meanVec) + println(scaler.stdev) + + val epsilon = 1E-6 + (myAnswer.norm - correctAnswer.norm) should be <= epsilon + + } }
