Repository: mahout Updated Branches: refs/heads/master 9cf90546d -> 5083f5835
MAHOUT-1574 - Add sparse handling to rows and columns of DiagonalMatrix Project: http://git-wip-us.apache.org/repos/asf/mahout/repo Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/dd78ed94 Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/dd78ed94 Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/dd78ed94 Branch: refs/heads/master Commit: dd78ed9479559cd222f24fa0be57655cf2e3075b Parents: 9cf9054 Author: Ted Dunning <[email protected]> Authored: Fri Jun 6 19:19:03 2014 -0700 Committer: Ted Dunning <[email protected]> Committed: Fri Jun 6 19:19:03 2014 -0700 ---------------------------------------------------------------------- .../org/apache/mahout/math/DiagonalMatrix.java | 206 ++++++++++++++++++- .../apache/mahout/math/DiagonalMatrixTest.java | 43 ++++ 2 files changed, 244 insertions(+), 5 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/mahout/blob/dd78ed94/math/src/main/java/org/apache/mahout/math/DiagonalMatrix.java ---------------------------------------------------------------------- diff --git a/math/src/main/java/org/apache/mahout/math/DiagonalMatrix.java b/math/src/main/java/org/apache/mahout/math/DiagonalMatrix.java index 2a027f7..3e20a4a 100644 --- a/math/src/main/java/org/apache/mahout/math/DiagonalMatrix.java +++ b/math/src/main/java/org/apache/mahout/math/DiagonalMatrix.java @@ -17,6 +17,9 @@ package org.apache.mahout.math; +import java.util.Iterator; +import java.util.NoSuchElementException; + public class DiagonalMatrix extends AbstractMatrix implements MatrixTimesOps { private final Vector diagonal; @@ -60,6 +63,195 @@ public class DiagonalMatrix extends AbstractMatrix implements MatrixTimesOps { throw new UnsupportedOperationException("Can't assign a row to a diagonal matrix"); } + @Override + public Vector viewRow(int row) { + return new SingleElementVector(row); + } + + @Override + public Vector viewColumn(int row) { + return new SingleElementVector(row); + } + + /** + * Special class to implement views of rows and columns of a diagonal matrix. + */ + public class SingleElementVector extends AbstractVector { + private int index; + + public SingleElementVector(int index) { + super(diagonal.size()); + this.index = index; + } + + @Override + public double getQuick(int index) { + if (index == this.index) { + return diagonal.get(index); + } else { + return 0; + } + } + + @Override + public void set(int index, double value) { + if (index == this.index) { + diagonal.set(index, value); + } else { + throw new IllegalArgumentException("Can't set off-diagonal element of diagonal matrix"); + } + } + + @Override + protected Iterator<Element> iterateNonZero() { + return new Iterator<Element>() { + boolean more = true; + + @Override + public boolean hasNext() { + return more; + } + + @Override + public Element next() { + if (more) { + more = false; + return new Element() { + @Override + public double get() { + return diagonal.get(index); + } + + @Override + public int index() { + return index; + } + + @Override + public void set(double value) { + diagonal.set(index, value); + } + }; + } else { + throw new NoSuchElementException("Only one non-zero element in a row or column of a diagonal matrix"); + } + } + + @Override + public void remove() { + throw new UnsupportedOperationException("Can't remove from vector view"); + } + }; + } + + @Override + protected Iterator<Element> iterator() { + return new Iterator<Element>() { + int i = 0; + + Element r = new Element() { + @Override + public double get() { + if (i == index) { + return diagonal.get(index); + } else { + return 0; + } + } + + @Override + public int index() { + return i; + } + + @Override + public void set(double value) { + if (i == index) { + diagonal.set(index, value); + } else { + throw new IllegalArgumentException("Can't set any element but diagonal"); + } + } + }; + + @Override + public boolean hasNext() { + return i < diagonal.size() - 1; + } + + @Override + public Element next() { + if (i < SingleElementVector.this.size() - 1) { + i++; + return r; + } else { + throw new NoSuchElementException("Attempted to access passed last element of vector"); + } + } + + + @Override + public void remove() { + throw new UnsupportedOperationException("Default operation"); + } + }; + } + + @Override + protected Matrix matrixLike(int rows, int columns) { + return new DiagonalMatrix(rows, columns); + } + + @Override + public boolean isDense() { + return false; + } + + @Override + public boolean isSequentialAccess() { + return true; + } + + @Override + public void mergeUpdates(OrderedIntDoubleMapping updates) { + throw new UnsupportedOperationException("Default operation"); + } + + @Override + public Vector like() { + return new DenseVector(size()); + } + + @Override + public void setQuick(int index, double value) { + if (index == this.index) { + diagonal.set(this.index, value); + } else { + throw new IllegalArgumentException("Can't set off-diagonal element of DiagonalMatrix"); + } + } + + @Override + public int getNumNondefaultElements() { + return 1; + } + + @Override + public double getLookupCost() { + return 0; + } + + @Override + public double getIteratorAdvanceCost() { + return 1; + } + + @Override + public boolean isAddConstantTime() { + return false; + } + } + /** * Provides a view of the diagonal of a matrix. */ @@ -147,22 +339,26 @@ public class DiagonalMatrix extends AbstractMatrix implements MatrixTimesOps { @Override public Matrix timesRight(Matrix that) { - if (that.numRows() != diagonal.size()) + if (that.numRows() != diagonal.size()) { throw new IllegalArgumentException("Incompatible number of rows in the right operand of matrix multiplication."); + } Matrix m = that.like(); - for (int row = 0; row < diagonal.size(); row++) + for (int row = 0; row < diagonal.size(); row++) { m.assignRow(row, that.viewRow(row).times(diagonal.getQuick(row))); + } return m; } @Override public Matrix timesLeft(Matrix that) { - if (that.numCols() != diagonal.size()) + if (that.numCols() != diagonal.size()) { throw new IllegalArgumentException( - "Incompatible number of rows in the left operand of matrix-matrix multiplication."); + "Incompatible number of rows in the left operand of matrix-matrix multiplication."); + } Matrix m = that.like(); - for (int col = 0; col < diagonal.size(); col++) + for (int col = 0; col < diagonal.size(); col++) { m.assignColumn(col, that.viewColumn(col).times(diagonal.getQuick(col))); + } return m; } } http://git-wip-us.apache.org/repos/asf/mahout/blob/dd78ed94/math/src/test/java/org/apache/mahout/math/DiagonalMatrixTest.java ---------------------------------------------------------------------- diff --git a/math/src/test/java/org/apache/mahout/math/DiagonalMatrixTest.java b/math/src/test/java/org/apache/mahout/math/DiagonalMatrixTest.java index 5b3a278..2ca7be0 100644 --- a/math/src/test/java/org/apache/mahout/math/DiagonalMatrixTest.java +++ b/math/src/test/java/org/apache/mahout/math/DiagonalMatrixTest.java @@ -18,8 +18,11 @@ package org.apache.mahout.math; import org.apache.mahout.math.function.Functions; +import org.junit.Assert; import org.junit.Test; +import java.util.Iterator; + public class DiagonalMatrixTest extends MahoutTestCase { @Test public void testBasics() { @@ -46,4 +49,44 @@ public class DiagonalMatrixTest extends MahoutTestCase { assertEquals(100, a.times(m.transpose()).aggregate(Functions.PLUS, Functions.ABS), 1.0e-10); } + @Test + public void testSparsity() { + Vector d = new DenseVector(10); + for (int i = 0; i < 10; i++) { + d.set(i, i * i); + } + DiagonalMatrix m = new DiagonalMatrix(d); + + Assert.assertFalse(m.viewRow(0).isDense()); + Assert.assertFalse(m.viewColumn(0).isDense()); + + for (int i = 0; i < 10; i++) { + assertEquals(i * i, m.viewRow(i).zSum(), 0); + assertEquals(i * i, m.viewRow(i).get(i), 0); + + assertEquals(i * i, m.viewColumn(i).zSum(), 0); + assertEquals(i * i, m.viewColumn(i).get(i), 0); + } + + Iterator<Vector.Element> ix = m.viewRow(7).nonZeroes().iterator(); + assertTrue(ix.hasNext()); + Vector.Element r = ix.next(); + assertEquals(7, r.index()); + assertEquals(49, r.get(), 0); + assertFalse(ix.hasNext()); + + assertEquals(0, m.viewRow(5).get(3), 0); + assertEquals(0, m.viewColumn(8).get(3), 0); + + m.viewRow(3).set(3, 1); + assertEquals(1, m.get(3, 3), 0); + + for (Vector.Element element : m.viewRow(6).all()) { + if (element.index() == 6) { + assertEquals(36, element.get(), 0); + } else { + assertEquals(0, element.get(), 0); + } + } + } }
