Repository: mahout Updated Branches: refs/heads/master c1ca30872 -> c20eee89c
MAHOUT-1464 fixed bug counting only positive column elements, now counts all non-zero (pat) closes apache/mahout#18 Project: http://git-wip-us.apache.org/repos/asf/mahout/repo Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/c20eee89 Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/c20eee89 Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/c20eee89 Branch: refs/heads/master Commit: c20eee89c6cc669494cf7edbb80255a83e194a15 Parents: c1ca308 Author: pferrel <[email protected]> Authored: Sat Jun 14 09:24:30 2014 -0700 Committer: pferrel <[email protected]> Committed: Sat Jun 14 09:24:30 2014 -0700 ---------------------------------------------------------------------- .../apache/mahout/math/scalabindings/MatrixOps.scala | 7 ++++--- .../mahout/math/scalabindings/MatrixOpsSuite.scala | 12 ++++++++++++ .../org/apache/mahout/math/function/Functions.java | 15 +++++++++++++-- .../apache/mahout/sparkbindings/SparkEngine.scala | 6 ++---- .../apache/mahout/cf/CooccurrenceAnalysisSuite.scala | 4 ++-- .../mahout/sparkbindings/drm/RLikeDrmOpsSuite.scala | 5 +++-- 6 files changed, 36 insertions(+), 13 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/mahout/blob/c20eee89/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/MatrixOps.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/MatrixOps.scala b/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/MatrixOps.scala index 149feca..28acc5a 100644 --- a/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/MatrixOps.scala +++ b/math-scala/src/main/scala/org/apache/mahout/math/scalabindings/MatrixOps.scala @@ -176,7 +176,7 @@ class MatrixOps(val m: Matrix) { def rowMeans() = if (m.ncol == 0) rowSums() else rowSums() /= m.ncol - def numNonZeroElementsPerColumn() = m.aggregateColumns(vectorCountFunc) + def numNonZeroElementsPerColumn() = m.aggregateColumns(vectorCountNonZeroElementsFunc) } object MatrixOps { @@ -188,8 +188,9 @@ object MatrixOps { def apply(f: Vector): Double = f.sum } - private def vectorCountFunc = new VectorFunction { - def apply(f: Vector): Double = f.aggregate(Functions.PLUS, Functions.greater(0)) + private def vectorCountNonZeroElementsFunc = new VectorFunction { + //def apply(f: Vector): Double = f.aggregate(Functions.PLUS, Functions.notEqual(0)) + def apply(f: Vector): Double = f.getNumNonZeroElements().toDouble } } \ No newline at end of file http://git-wip-us.apache.org/repos/asf/mahout/blob/c20eee89/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/MatrixOpsSuite.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/MatrixOpsSuite.scala b/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/MatrixOpsSuite.scala index d59d3a5..8374a9b 100644 --- a/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/MatrixOpsSuite.scala +++ b/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/MatrixOpsSuite.scala @@ -123,4 +123,16 @@ class MatrixOpsSuite extends FunSuite with MahoutSuite { } + test("numNonZeroElementsPerColumn") { + val a = dense( + (2, 3, 4), + (3, 4, 5), + (-5, 0, -1), + (0, 0, 1) + ) + + a.numNonZeroElementsPerColumn() should equal(dvec(3,2,4)) + + } + } \ No newline at end of file http://git-wip-us.apache.org/repos/asf/mahout/blob/c20eee89/math/src/main/java/org/apache/mahout/math/function/Functions.java ---------------------------------------------------------------------- diff --git a/math/src/main/java/org/apache/mahout/math/function/Functions.java b/math/src/main/java/org/apache/mahout/math/function/Functions.java index 7a3db98..64315ce 100644 --- a/math/src/main/java/org/apache/mahout/math/function/Functions.java +++ b/math/src/main/java/org/apache/mahout/math/function/Functions.java @@ -27,11 +27,11 @@ It is provided "as is" without expressed or implied warranty. package org.apache.mahout.math.function; -import java.util.Date; - import com.google.common.base.Preconditions; import org.apache.mahout.math.jet.random.engine.MersenneTwister; +import java.util.Date; + /** * Function objects to be passed to generic methods. Contains the functions of {@link java.lang.Math} as function @@ -1393,6 +1393,17 @@ public final class Functions { }; } + /** Constructs a function that returns <tt>a != b ? 1 : 0</tt>. <tt>a</tt> is a variable, <tt>b</tt> is fixed. */ + public static DoubleFunction notEqual(final double b) { + return new DoubleFunction() { + + @Override + public double apply(double a) { + return a != b ? 1 : 0; + } + }; + } + /** Constructs a function that returns <tt>a > b ? 1 : 0</tt>. <tt>a</tt> is a variable, <tt>b</tt> is fixed. */ public static DoubleFunction greater(final double b) { return new DoubleFunction() { http://git-wip-us.apache.org/repos/asf/mahout/blob/c20eee89/spark/src/main/scala/org/apache/mahout/sparkbindings/SparkEngine.scala ---------------------------------------------------------------------- diff --git a/spark/src/main/scala/org/apache/mahout/sparkbindings/SparkEngine.scala b/spark/src/main/scala/org/apache/mahout/sparkbindings/SparkEngine.scala index 7a1fb2d..a4eef9d 100644 --- a/spark/src/main/scala/org/apache/mahout/sparkbindings/SparkEngine.scala +++ b/spark/src/main/scala/org/apache/mahout/sparkbindings/SparkEngine.scala @@ -60,10 +60,8 @@ object SparkEngine extends DistributedEngine { .map(_._2) // Fold() doesn't work with kryo still. So work around it. .mapPartitions(iter => { - val acc = ((new DenseVector(n): Vector) /: iter){(acc, v) => - v.nonZeroes().foreach { elem => - if (elem.get() > 0) acc(elem.index) += 1 - } + val acc = ((new DenseVector(n): Vector) /: iter) { (acc, v) => + v.nonZeroes().foreach { elem => acc(elem.index) += 1} acc } Iterator(acc) http://git-wip-us.apache.org/repos/asf/mahout/blob/c20eee89/spark/src/test/scala/org/apache/mahout/cf/CooccurrenceAnalysisSuite.scala ---------------------------------------------------------------------- diff --git a/spark/src/test/scala/org/apache/mahout/cf/CooccurrenceAnalysisSuite.scala b/spark/src/test/scala/org/apache/mahout/cf/CooccurrenceAnalysisSuite.scala index 3c05a42..2db5f50 100644 --- a/spark/src/test/scala/org/apache/mahout/cf/CooccurrenceAnalysisSuite.scala +++ b/spark/src/test/scala/org/apache/mahout/cf/CooccurrenceAnalysisSuite.scala @@ -118,8 +118,8 @@ class CooccurrenceAnalysisSuite extends FunSuite with MahoutSuite with MahoutLoc } test("cooccurrence [A'A], [B'A] integer data using LLR") { - val a = dense((1000, 10, 0, 0, 0), (0, 0, 10000, 10, 0), (0, 0, 0, 0, 100), (10000, 0, 0, 1000, 0)) - val b = dense((100, 1000, 10000, 10000, 0), (10000, 1000, 100, 10, 0), (0, 0, 10, 0, 100), (10, 100, 0, 1000, 0)) + val a = dense((1000, 10, 0, 0, 0), (0, 0, -10000, 10, 0), (0, 0, 0, 0, 100), (10000, 0, 0, 1000, 0)) + val b = dense((100, 1000, -10000, 10000, 0), (10000, 1000, 100, 10, 0), (0, 0, 10, 0, -100), (10, 100, 0, 1000, 0)) val drmA = drmParallelize(m = a, numPartitions = 2) val drmB = drmParallelize(m = b, numPartitions = 2) http://git-wip-us.apache.org/repos/asf/mahout/blob/c20eee89/spark/src/test/scala/org/apache/mahout/sparkbindings/drm/RLikeDrmOpsSuite.scala ---------------------------------------------------------------------- diff --git a/spark/src/test/scala/org/apache/mahout/sparkbindings/drm/RLikeDrmOpsSuite.scala b/spark/src/test/scala/org/apache/mahout/sparkbindings/drm/RLikeDrmOpsSuite.scala index 30a602b..3cd49cd 100644 --- a/spark/src/test/scala/org/apache/mahout/sparkbindings/drm/RLikeDrmOpsSuite.scala +++ b/spark/src/test/scala/org/apache/mahout/sparkbindings/drm/RLikeDrmOpsSuite.scala @@ -466,8 +466,9 @@ class RLikeDrmOpsSuite extends FunSuite with Matchers with MahoutLocalContext { test("numNonZeroElementsPerColumn") { val inCoreA = dense( (0, 2), - (3, 4), - (0, 30) + (3, 0), + (0, -30) + ) val drmA = drmParallelize(inCoreA, numPartitions = 2)
