Repository: mahout Updated Branches: refs/heads/master d9b32f308 -> 800a9ed6d
MAHOUT-2019 SparkRow Matrix Speedup and fixing change to scala 2.11 made by build script Project: http://git-wip-us.apache.org/repos/asf/mahout/repo Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/800a9ed6 Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/800a9ed6 Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/800a9ed6 Branch: refs/heads/master Commit: 800a9ed6d7e015aa82b9eb7624bb441b71a8f397 Parents: d9b32f3 Author: pferrel <[email protected]> Authored: Sat Nov 18 12:29:06 2017 -0800 Committer: pferrel <[email protected]> Committed: Sat Nov 18 12:34:07 2017 -0800 ---------------------------------------------------------------------- .../org/apache/mahout/math/SparseRowMatrix.java | 53 ++++++++++++++++++++ 1 file changed, 53 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/mahout/blob/800a9ed6/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java ---------------------------------------------------------------------- diff --git a/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java b/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java index 6e06769..ee54ad0 100644 --- a/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java +++ b/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java @@ -19,7 +19,12 @@ package org.apache.mahout.math; import org.apache.mahout.math.flavor.MatrixFlavor; import org.apache.mahout.math.flavor.TraversingStructureEnum; +import org.apache.mahout.math.function.DoubleDoubleFunction; import org.apache.mahout.math.function.Functions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Iterator; /** * sparse matrix with general element values whose rows are accessible quickly. Implemented as a row @@ -30,6 +35,8 @@ public class SparseRowMatrix extends AbstractMatrix { private final boolean randomAccessRows; + private static final Logger log = LoggerFactory.getLogger(SparseRowMatrix.class); + /** * Construct a sparse matrix starting with the provided row vectors. * @@ -133,6 +140,52 @@ public class SparseRowMatrix extends AbstractMatrix { } @Override + public Matrix assign(Matrix other, DoubleDoubleFunction function) { + int rows = rowSize(); + if (rows != other.rowSize()) { + throw new CardinalityException(rows, other.rowSize()); + } + int columns = columnSize(); + if (columns != other.columnSize()) { + throw new CardinalityException(columns, other.columnSize()); + } + for (int row = 0; row < rows; row++) { + try { + Iterator<Vector.Element> sparseRowIterator = ((SequentialAccessSparseVector) this.rowVectors[row]) + .iterateNonZero(); + if (function.isLikeMult()) { // TODO: is this a sufficient test? + // TODO: this may cause an exception if the row type is not compatible but it is currently guaranteed to be + // a SequentialAccessSparseVector, should "try" here just in case and Warn + // TODO: can we use iterateNonZero on both rows until the index is the same to get better speedup? + + // TODO: SASVs have an iterateNonZero that returns zeros, this should not hurt but is far from optimal + // this might perform much better if SparseRowMatrix were backed by RandomAccessSparseVectors, which + // are backed by fastutil hashmaps and the iterateNonZero actually does only return nonZeros. + while (sparseRowIterator.hasNext()) { + Vector.Element element = sparseRowIterator.next(); + int col = element.index(); + setQuick(row, col, function.apply(element.get(), other.getQuick(row, col))); + } + } else { + for (int col = 0; col < columns; col++) { + setQuick(row, col, function.apply(getQuick(row, col), other.getQuick(row, col))); + } + } + + } catch (ClassCastException e) { + // Warn and use default implementation + log.warn("Error casting the row to SequentialAccessSparseVector, this should never happen because" + + "SparseRomMatrix is always made of SequentialAccessSparseVectors. Proceeding with non-optimzed" + + "implementation."); + for (int col = 0; col < columns; col++) { + setQuick(row, col, function.apply(getQuick(row, col), other.getQuick(row, col))); + } + } + } + return this; + } + + @Override public Matrix assignColumn(int column, Vector other) { if (rowSize() != other.size()) { throw new CardinalityException(rowSize(), other.size());
