Repository: mahout Updated Branches: refs/heads/flink-binding 854a2893c -> 5a6d14059
MAHOUT-1781 Dense matrix view multiplication is 4x slower than non-view one Project: http://git-wip-us.apache.org/repos/asf/mahout/repo Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/5a6d1405 Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/5a6d1405 Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/5a6d1405 Branch: refs/heads/flink-binding Commit: 5a6d140599024cd8a6bb2a4b2187376fbad40a7d Parents: 854a289 Author: smarthi <[email protected]> Authored: Tue Oct 27 01:33:15 2015 -0400 Committer: smarthi <[email protected]> Committed: Tue Oct 27 01:37:25 2015 -0400 ---------------------------------------------------------------------- .../scalabindings/RLikeMatrixOpsSuite.scala | 10 ++ .../scalabindings/RLikeVectorOpsSuite.scala | 40 +++++++- .../org/apache/mahout/math/DenseVector.java | 99 +++++++++++++++++++- .../java/org/apache/mahout/math/MatrixView.java | 4 +- .../mahout/math/VectorBinaryAggregate.java | 3 +- .../java/org/apache/mahout/math/VectorView.java | 4 +- .../mahout/sparkbindings/SparkEngine.scala | 2 +- 7 files changed, 154 insertions(+), 8 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/mahout/blob/5a6d1405/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/RLikeMatrixOpsSuite.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/RLikeMatrixOpsSuite.scala b/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/RLikeMatrixOpsSuite.scala index 79d2899..b44e295 100644 --- a/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/RLikeMatrixOpsSuite.scala +++ b/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/RLikeMatrixOpsSuite.scala @@ -133,6 +133,7 @@ class RLikeMatrixOpsSuite extends FunSuite with MahoutSuite { // Dense matrix tests. println(s"Ad %*% Bd: ${getMmulAvgs(mxAd, mxBd, n)}") + println(s"Ad(::,::) %*% Bd: ${getMmulAvgs(mxAd(0 until mxAd.nrow,::), mxBd, n)}") println(s"Ad' %*% Bd: ${getMmulAvgs(mxAd.t, mxBd, n)}") println(s"Ad %*% Bd': ${getMmulAvgs(mxAd, mxBd.t, n)}") println(s"Ad' %*% Bd': ${getMmulAvgs(mxAd.t, mxBd.t, n)}") @@ -353,4 +354,13 @@ class RLikeMatrixOpsSuite extends FunSuite with MahoutSuite { } + test("dense-view-debug") { + val d = 500 + // Dense row-wise + val mxAd = new DenseMatrix(d, d) := Matrices.gaussianView(d, d, 134) + 1 + val mxBd = new DenseMatrix(d, d) := Matrices.gaussianView(d, d, 134) - 1 + + mxAd(0 until mxAd.nrow, ::) %*% mxBd + + } } http://git-wip-us.apache.org/repos/asf/mahout/blob/5a6d1405/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/RLikeVectorOpsSuite.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/RLikeVectorOpsSuite.scala b/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/RLikeVectorOpsSuite.scala index 832937b..72754f8 100644 --- a/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/RLikeVectorOpsSuite.scala +++ b/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/RLikeVectorOpsSuite.scala @@ -17,13 +17,21 @@ package org.apache.mahout.math.scalabindings +import org.apache.log4j.{Level, BasicConfigurator} import org.scalatest.FunSuite -import org.apache.mahout.math.Vector +import org.apache.mahout.math._ +import scalabindings._ import RLikeOps._ import org.apache.mahout.test.MahoutSuite +import org.apache.mahout.logging._ + class RLikeVectorOpsSuite extends FunSuite with MahoutSuite { + BasicConfigurator.configure() + private[scalabindings] final implicit val log = getLog(classOf[RLikeVectorOpsSuite]) + setLogLevel(Level.DEBUG) + test("Hadamard") { val a: Vector = (1, 2, 3) val b = (3, 4, 5) @@ -33,4 +41,34 @@ class RLikeVectorOpsSuite extends FunSuite with MahoutSuite { assert(c ===(3, 8, 15)) } + test("dot-view performance") { + + val dv1 = new DenseVector(500) := Matrices.uniformView(1, 500, 1234)(0, ::) + val dv2 = new DenseVector(500) := Matrices.uniformView(1, 500, 1244)(0, ::) + + val nit = 300000 + + // warm up + dv1 dot dv2 + + val dmsStart = System.currentTimeMillis() + for (i â 0 until nit) + dv1 dot dv2 + val dmsMs = System.currentTimeMillis() - dmsStart + + val (dvv1, dvv2) = dv1(0 until dv1.length) â dv2(0 until dv2.length) + + // Warm up. + dvv1 dot dvv2 + + val dvmsStart = System.currentTimeMillis() + for (i â 0 until nit) + dvv1 dot dvv2 + val dvmsMs = System.currentTimeMillis() - dvmsStart + + debug(f"dense vector dots:${dmsMs}%.2f ms.") + debug(f"dense view dots:${dvmsMs}%.2f ms.") + + } + } http://git-wip-us.apache.org/repos/asf/mahout/blob/5a6d1405/math/src/main/java/org/apache/mahout/math/DenseVector.java ---------------------------------------------------------------------- diff --git a/math/src/main/java/org/apache/mahout/math/DenseVector.java b/math/src/main/java/org/apache/mahout/math/DenseVector.java index 3633e58..3961966 100644 --- a/math/src/main/java/org/apache/mahout/math/DenseVector.java +++ b/math/src/main/java/org/apache/mahout/math/DenseVector.java @@ -203,7 +203,7 @@ public class DenseVector extends AbstractVector { if (offset + length > size()) { throw new IndexException(offset + length, size()); } - return new VectorView(this, offset, length); + return new DenseVectorView(this, offset, length); } @Override @@ -342,4 +342,101 @@ public class DenseVector extends AbstractVector { values[index] = value; } } + + private final class DenseVectorView extends VectorView { + + public DenseVectorView(Vector vector, int offset, int cardinality) { + super(vector, offset, cardinality); + } + + @Override + public double dot(Vector x) { + + // Apply custom dot kernels for pairs of dense vectors or their views to reduce + // view indirection. + if (x instanceof DenseVectorView) { + + if (size() != x.size()) + throw new IllegalArgumentException("Cardinality mismatch during dot(x,y)."); + + DenseVectorView xv = (DenseVectorView) x; + double[] thisValues = ((DenseVector) vector).values; + double[] thatValues = ((DenseVector) xv.vector).values; + int untilOffset = offset + size(); + + int i, j; + double sum = 0.0; + + // Provoking SSE + int until4 = offset + (size() & ~3); + for ( + i = offset, j = xv.offset; + i < until4; + i += 4, j += 4 + ) { + sum += thisValues[i] * thatValues[j] + + thisValues[i + 1] * thatValues[j + 1] + + thisValues[i + 2] * thatValues[j + 2] + + thisValues[i + 3] * thatValues[j + 3]; + } + + // Picking up the slack + for ( + i = offset, j = xv.offset; + i < untilOffset; + ) { + sum += thisValues[i++] * thatValues[j++]; + } + return sum; + + } else if (x instanceof DenseVector ) { + + if (size() != x.size()) + throw new IllegalArgumentException("Cardinality mismatch during dot(x,y)."); + + DenseVector xv = (DenseVector) x; + double[] thisValues = ((DenseVector) vector).values; + double[] thatValues = xv.values; + int untilOffset = offset + size(); + + int i, j; + double sum = 0.0; + + // Provoking SSE + int until4 = offset + (size() & ~3); + for ( + i = offset, j = 0; + i < until4; + i += 4, j += 4 + ) { + sum += thisValues[i] * thatValues[j] + + thisValues[i + 1] * thatValues[j + 1] + + thisValues[i + 2] * thatValues[j + 2] + + thisValues[i + 3] * thatValues[j + 3]; + } + + // Picking up slack + for ( ; + i < untilOffset; + ) { + sum += thisValues[i++] * thatValues[j++]; + } + return sum; + + } else { + return super.dot(x); + } + } + + @Override + public Vector viewPart(int offset, int length) { + if (offset < 0) { + throw new IndexException(offset, size()); + } + if (offset + length > size()) { + throw new IndexException(offset + length, size()); + } + return new DenseVectorView(vector, offset + this.offset, length); + } + } } http://git-wip-us.apache.org/repos/asf/mahout/blob/5a6d1405/math/src/main/java/org/apache/mahout/math/MatrixView.java ---------------------------------------------------------------------- diff --git a/math/src/main/java/org/apache/mahout/math/MatrixView.java b/math/src/main/java/org/apache/mahout/math/MatrixView.java index 86760d5..951515b 100644 --- a/math/src/main/java/org/apache/mahout/math/MatrixView.java +++ b/math/src/main/java/org/apache/mahout/math/MatrixView.java @@ -142,7 +142,7 @@ public class MatrixView extends AbstractMatrix { if (column < 0 || column >= columnSize()) { throw new IndexException(column, columnSize()); } - return new VectorView(matrix.viewColumn(column + offset[COL]), offset[ROW], rowSize()); + return matrix.viewColumn(column + offset[COL]).viewPart(offset[ROW], rowSize()); } @Override @@ -150,7 +150,7 @@ public class MatrixView extends AbstractMatrix { if (row < 0 || row >= rowSize()) { throw new IndexException(row, rowSize()); } - return new VectorView(matrix.viewRow(row + offset[ROW]), offset[COL], columnSize()); + return matrix.viewRow(row + offset[ROW]).viewPart(offset[COL], columnSize()); } @Override http://git-wip-us.apache.org/repos/asf/mahout/blob/5a6d1405/math/src/main/java/org/apache/mahout/math/VectorBinaryAggregate.java ---------------------------------------------------------------------- diff --git a/math/src/main/java/org/apache/mahout/math/VectorBinaryAggregate.java b/math/src/main/java/org/apache/mahout/math/VectorBinaryAggregate.java index 3ec3189..4d3a80f 100644 --- a/math/src/main/java/org/apache/mahout/math/VectorBinaryAggregate.java +++ b/math/src/main/java/org/apache/mahout/math/VectorBinaryAggregate.java @@ -471,7 +471,8 @@ public abstract class VectorBinaryAggregate { @Override public double aggregate(Vector x, Vector y, DoubleDoubleFunction fa, DoubleDoubleFunction fc) { double result = fc.apply(x.getQuick(0), y.getQuick(0)); - for (int i = 1; i < x.size(); ++i) { + int s = x.size(); + for (int i = 1; i < s; ++i) { result = fa.apply(result, fc.apply(x.getQuick(i), y.getQuick(i))); } return result; http://git-wip-us.apache.org/repos/asf/mahout/blob/5a6d1405/math/src/main/java/org/apache/mahout/math/VectorView.java ---------------------------------------------------------------------- diff --git a/math/src/main/java/org/apache/mahout/math/VectorView.java b/math/src/main/java/org/apache/mahout/math/VectorView.java index d61a038..f69a8c7 100644 --- a/math/src/main/java/org/apache/mahout/math/VectorView.java +++ b/math/src/main/java/org/apache/mahout/math/VectorView.java @@ -24,10 +24,10 @@ import com.google.common.collect.AbstractIterator; /** Implements subset view of a Vector */ public class VectorView extends AbstractVector { - private Vector vector; + protected Vector vector; // the offset into the Vector - private int offset; + protected int offset; /** For serialization purposes only */ public VectorView() { http://git-wip-us.apache.org/repos/asf/mahout/blob/5a6d1405/spark/src/main/scala/org/apache/mahout/sparkbindings/SparkEngine.scala ---------------------------------------------------------------------- diff --git a/spark/src/main/scala/org/apache/mahout/sparkbindings/SparkEngine.scala b/spark/src/main/scala/org/apache/mahout/sparkbindings/SparkEngine.scala index d89a8de..4f30d10 100644 --- a/spark/src/main/scala/org/apache/mahout/sparkbindings/SparkEngine.scala +++ b/spark/src/main/scala/org/apache/mahout/sparkbindings/SparkEngine.scala @@ -176,7 +176,7 @@ object SparkEngine extends DistributedEngine { private[sparkbindings] def parallelizeInCore(m: Matrix, numPartitions: Int = 1) (implicit sc: DistributedContext): DrmRdd[Int] = { - val p = (0 until m.nrow).map(i => i â m(i, ::)) + val p = (0 until m.nrow).map(i â i â m(i, ::)) sc.parallelize(p, numPartitions) }
