Repository: mahout
Updated Branches:
  refs/heads/master 9cf90546d -> 5083f5835


MAHOUT-1574 - Add sparse handling to rows and columns of DiagonalMatrix

Project: http://git-wip-us.apache.org/repos/asf/mahout/repo
Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/dd78ed94
Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/dd78ed94
Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/dd78ed94

Branch: refs/heads/master
Commit: dd78ed9479559cd222f24fa0be57655cf2e3075b
Parents: 9cf9054
Author: Ted Dunning <[email protected]>
Authored: Fri Jun 6 19:19:03 2014 -0700
Committer: Ted Dunning <[email protected]>
Committed: Fri Jun 6 19:19:03 2014 -0700

----------------------------------------------------------------------
 .../org/apache/mahout/math/DiagonalMatrix.java  | 206 ++++++++++++++++++-
 .../apache/mahout/math/DiagonalMatrixTest.java  |  43 ++++
 2 files changed, 244 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/mahout/blob/dd78ed94/math/src/main/java/org/apache/mahout/math/DiagonalMatrix.java
----------------------------------------------------------------------
diff --git a/math/src/main/java/org/apache/mahout/math/DiagonalMatrix.java 
b/math/src/main/java/org/apache/mahout/math/DiagonalMatrix.java
index 2a027f7..3e20a4a 100644
--- a/math/src/main/java/org/apache/mahout/math/DiagonalMatrix.java
+++ b/math/src/main/java/org/apache/mahout/math/DiagonalMatrix.java
@@ -17,6 +17,9 @@
 
 package org.apache.mahout.math;
 
+import java.util.Iterator;
+import java.util.NoSuchElementException;
+
 public class DiagonalMatrix extends AbstractMatrix implements MatrixTimesOps {
   private final Vector diagonal;
 
@@ -60,6 +63,195 @@ public class DiagonalMatrix extends AbstractMatrix 
implements MatrixTimesOps {
     throw new UnsupportedOperationException("Can't assign a row to a diagonal 
matrix");
   }
 
+  @Override
+  public Vector viewRow(int row) {
+    return new SingleElementVector(row);
+  }
+
+  @Override
+  public Vector viewColumn(int row) {
+    return new SingleElementVector(row);
+  }
+
+  /**
+   * Special class to implement views of rows and columns of a diagonal matrix.
+   */
+  public class SingleElementVector extends AbstractVector {
+    private int index;
+
+    public SingleElementVector(int index) {
+      super(diagonal.size());
+      this.index = index;
+    }
+
+    @Override
+    public double getQuick(int index) {
+      if (index == this.index) {
+        return diagonal.get(index);
+      } else {
+        return 0;
+      }
+    }
+
+    @Override
+    public void set(int index, double value) {
+      if (index == this.index) {
+        diagonal.set(index, value);
+      } else {
+        throw new IllegalArgumentException("Can't set off-diagonal element of 
diagonal matrix");
+      }
+    }
+
+    @Override
+    protected Iterator<Element> iterateNonZero() {
+      return new Iterator<Element>() {
+        boolean more = true;
+
+        @Override
+        public boolean hasNext() {
+          return more;
+        }
+
+        @Override
+        public Element next() {
+          if (more) {
+            more = false;
+            return new Element() {
+              @Override
+              public double get() {
+                return diagonal.get(index);
+              }
+
+              @Override
+              public int index() {
+                return index;
+              }
+
+              @Override
+              public void set(double value) {
+                diagonal.set(index, value);
+              }
+            };
+          } else {
+            throw new NoSuchElementException("Only one non-zero element in a 
row or column of a diagonal matrix");
+          }
+        }
+
+        @Override
+        public void remove() {
+          throw new UnsupportedOperationException("Can't remove from vector 
view");
+        }
+      };
+    }
+
+    @Override
+    protected Iterator<Element> iterator() {
+      return new Iterator<Element>() {
+        int i = 0;
+
+        Element r = new Element() {
+          @Override
+          public double get() {
+            if (i == index) {
+              return diagonal.get(index);
+            } else {
+              return 0;
+            }
+          }
+
+          @Override
+          public int index() {
+            return i;
+          }
+
+          @Override
+          public void set(double value) {
+            if (i == index) {
+              diagonal.set(index, value);
+            } else {
+              throw new IllegalArgumentException("Can't set any element but 
diagonal");
+            }
+          }
+        };
+
+        @Override
+        public boolean hasNext() {
+          return i < diagonal.size() - 1;
+        }
+
+        @Override
+        public Element next() {
+          if (i < SingleElementVector.this.size() - 1) {
+            i++;
+            return r;
+          } else {
+            throw new NoSuchElementException("Attempted to access passed last 
element of vector");
+          }
+        }
+
+
+        @Override
+        public void remove() {
+          throw new UnsupportedOperationException("Default operation");
+        }
+      };
+    }
+
+    @Override
+    protected Matrix matrixLike(int rows, int columns) {
+      return new DiagonalMatrix(rows, columns);
+    }
+
+    @Override
+    public boolean isDense() {
+      return false;
+    }
+
+    @Override
+    public boolean isSequentialAccess() {
+      return true;
+    }
+
+    @Override
+    public void mergeUpdates(OrderedIntDoubleMapping updates) {
+      throw new UnsupportedOperationException("Default operation");
+    }
+
+    @Override
+    public Vector like() {
+      return new DenseVector(size());
+    }
+
+    @Override
+    public void setQuick(int index, double value) {
+      if (index == this.index) {
+        diagonal.set(this.index, value);
+      } else {
+        throw new IllegalArgumentException("Can't set off-diagonal element of 
DiagonalMatrix");
+      }
+    }
+
+    @Override
+    public int getNumNondefaultElements() {
+      return 1;
+    }
+
+    @Override
+    public double getLookupCost() {
+      return 0;
+    }
+
+    @Override
+    public double getIteratorAdvanceCost() {
+      return 1;
+    }
+
+    @Override
+    public boolean isAddConstantTime() {
+      return false;
+    }
+  }
+
   /**
    * Provides a view of the diagonal of a matrix.
    */
@@ -147,22 +339,26 @@ public class DiagonalMatrix extends AbstractMatrix 
implements MatrixTimesOps {
 
   @Override
   public Matrix timesRight(Matrix that) {
-    if (that.numRows() != diagonal.size())
+    if (that.numRows() != diagonal.size()) {
       throw new IllegalArgumentException("Incompatible number of rows in the 
right operand of matrix multiplication.");
+    }
     Matrix m = that.like();
-    for (int row = 0; row < diagonal.size(); row++)
+    for (int row = 0; row < diagonal.size(); row++) {
       m.assignRow(row, that.viewRow(row).times(diagonal.getQuick(row)));
+    }
     return m;
   }
 
   @Override
   public Matrix timesLeft(Matrix that) {
-    if (that.numCols() != diagonal.size())
+    if (that.numCols() != diagonal.size()) {
       throw new IllegalArgumentException(
-          "Incompatible number of rows in the left operand of matrix-matrix 
multiplication.");
+        "Incompatible number of rows in the left operand of matrix-matrix 
multiplication.");
+    }
     Matrix m = that.like();
-    for (int col = 0; col < diagonal.size(); col++)
+    for (int col = 0; col < diagonal.size(); col++) {
       m.assignColumn(col, that.viewColumn(col).times(diagonal.getQuick(col)));
+    }
     return m;
   }
 }

http://git-wip-us.apache.org/repos/asf/mahout/blob/dd78ed94/math/src/test/java/org/apache/mahout/math/DiagonalMatrixTest.java
----------------------------------------------------------------------
diff --git a/math/src/test/java/org/apache/mahout/math/DiagonalMatrixTest.java 
b/math/src/test/java/org/apache/mahout/math/DiagonalMatrixTest.java
index 5b3a278..2ca7be0 100644
--- a/math/src/test/java/org/apache/mahout/math/DiagonalMatrixTest.java
+++ b/math/src/test/java/org/apache/mahout/math/DiagonalMatrixTest.java
@@ -18,8 +18,11 @@
 package org.apache.mahout.math;
 
 import org.apache.mahout.math.function.Functions;
+import org.junit.Assert;
 import org.junit.Test;
 
+import java.util.Iterator;
+
 public class DiagonalMatrixTest extends MahoutTestCase {
   @Test
   public void testBasics() {
@@ -46,4 +49,44 @@ public class DiagonalMatrixTest extends MahoutTestCase {
     assertEquals(100, a.times(m.transpose()).aggregate(Functions.PLUS, 
Functions.ABS), 1.0e-10);
   }
 
+  @Test
+  public void testSparsity() {
+    Vector d = new DenseVector(10);
+    for (int i = 0; i < 10; i++) {
+      d.set(i, i * i);
+    }
+    DiagonalMatrix m = new DiagonalMatrix(d);
+
+    Assert.assertFalse(m.viewRow(0).isDense());
+    Assert.assertFalse(m.viewColumn(0).isDense());
+
+    for (int i = 0; i < 10; i++) {
+      assertEquals(i * i, m.viewRow(i).zSum(), 0);
+      assertEquals(i * i, m.viewRow(i).get(i), 0);
+
+      assertEquals(i * i, m.viewColumn(i).zSum(), 0);
+      assertEquals(i * i, m.viewColumn(i).get(i), 0);
+    }
+
+    Iterator<Vector.Element> ix = m.viewRow(7).nonZeroes().iterator();
+    assertTrue(ix.hasNext());
+    Vector.Element r = ix.next();
+    assertEquals(7, r.index());
+    assertEquals(49, r.get(), 0);
+    assertFalse(ix.hasNext());
+
+    assertEquals(0, m.viewRow(5).get(3), 0);
+    assertEquals(0, m.viewColumn(8).get(3), 0);
+
+    m.viewRow(3).set(3, 1);
+    assertEquals(1, m.get(3, 3), 0);
+
+    for (Vector.Element element : m.viewRow(6).all()) {
+      if (element.index() == 6) {
+        assertEquals(36, element.get(), 0);
+      }                                    else {
+        assertEquals(0, element.get(), 0);
+      }
+    }
+  }
 }

Reply via email to