Author: srowen
Date: Mon Oct  3 07:51:03 2011
New Revision: 1178324

URL: http://svn.apache.org/viewvc?rev=1178324&view=rev
Log:
MAHOUT-812 help make confusion matrix writable

Added:
    
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java
    
mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java
Modified:
    
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
    mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java
    
mahout/trunk/core/src/test/java/org/apache/mahout/math/VectorWritableTest.java

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java?rev=1178324&r1=1178323&r2=1178324&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
 Mon Oct  3 07:51:03 2011
@@ -1,9 +1,9 @@
 /**
- * Licensed to the Apache Software Foundation (ASF) under one or more
  * contributor license agreements.  See the NOTICE file distributed with
  * this work for additional information regarding copyright ownership.
  * The ASF licenses this file to You under the Apache License, Version 2.0
  * (the "License"); you may not use this file except in compliance with
+ * Licensed to the Apache Software Foundation (ASF) under one or more
  * the License.  You may obtain a copy of the License at
  *
  *     http://www.apache.org/licenses/LICENSE-2.0
@@ -17,12 +17,18 @@
 
 package org.apache.mahout.classifier;
 
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.LinkedHashMap;
 import java.util.Map;
 
+import com.google.common.collect.Maps;
 import org.apache.commons.lang.StringUtils;
+import org.apache.mahout.math.CardinalityException;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.Matrix;
 
 import com.google.common.base.Preconditions;
 
@@ -127,6 +133,46 @@ public class ConfusionMatrix {
     return this;
   }
   
+  public Matrix getMatrix() {
+         int length = confusionMatrix.length;
+         Matrix m = new DenseMatrix(length, length);
+         for (int r = 0; r < length; r++) {
+                 for (int c = 0; c < length; c++) {
+                         m.set(r, c, confusionMatrix[r][c]);
+                 }
+         }
+         Map<String,Integer> labels = Maps.newHashMap();
+         for (Map.Entry<String, Integer> entry : labelMap.entrySet()) {
+                 labels.put(entry.getKey(), entry.getValue());
+         }
+         m.setRowLabelBindings(labels);
+         m.setColumnLabelBindings(labels);
+         return m;
+  }
+
+  public void setMatrix(Matrix m) {
+         int length = confusionMatrix.length;
+         if (m.numRows() != m.numCols()) {
+      throw new CardinalityException(m.numRows(), m.numCols());
+    }
+    if (m.numRows() != length) {
+      throw new CardinalityException(m.numRows(), length);
+    }
+         for (int r = 0; r < length; r++) {
+                 for (int c = 0; c < length; c++) {
+                         confusionMatrix[r][c] = (int) Math.round(m.get(r, c));
+                 }
+         }
+         Map<String,Integer> labels = m.getRowLabelBindings();
+         if (labels == null) {
+      labels = m.getColumnLabelBindings();
+    }
+    labelMap.clear();    
+         if (labels != null) {
+      labelMap.putAll(labels);
+         }
+  }
+  
   @Override
   public String toString() {
     StringBuilder returnString = new StringBuilder(200);

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java?rev=1178324&r1=1178323&r2=1178324&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java 
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java 
Mon Oct  3 07:51:03 2011
@@ -18,19 +18,23 @@
 package org.apache.mahout.math;
 
 import com.google.common.base.Preconditions;
+import com.google.common.collect.Maps;
 import org.apache.hadoop.io.Writable;
 
 import java.io.DataInput;
 import java.io.DataOutput;
 import java.io.IOException;
+import java.util.HashMap;
 import java.util.Map;
 
 public class MatrixWritable implements Writable {
 
+  private static final int FLAG_DENSE = 0x01;
+  private static final int FLAG_SEQUENTIAL = 0x02;
+  private static final int FLAG_LABELS = 0x04;
+  private static final int NUM_FLAGS = 3;
+
   private Matrix matrix;
-  private static final int NUM_FLAGS = 2;
-  private static final int FLAG_DENSE = 1;
-  private static final int FLAG_SEQUENTIAL = 2;
 
   public MatrixWritable() {
   }
@@ -103,6 +107,7 @@ public class MatrixWritable implements W
     Preconditions.checkArgument(flags >> NUM_FLAGS == 0, "Unknown flags set: 
%d", Integer.toString(flags, 2));
     boolean dense = (flags & FLAG_DENSE) != 0;
     boolean sequential = (flags & FLAG_SEQUENTIAL) != 0;
+    boolean hasLabels = (flags & FLAG_LABELS) != 0;
 
     int rows = in.readInt();
     int columns = in.readInt();
@@ -118,6 +123,18 @@ public class MatrixWritable implements W
       r.viewRow(row).assign(VectorWritable.readVector(in));
     }
 
+    if (hasLabels) {
+       Map<String,Integer> columnLabelBindings = Maps.newHashMap();
+       Map<String,Integer> rowLabelBindings = Maps.newHashMap();
+       readLabels(in, columnLabelBindings, rowLabelBindings);
+       if (!columnLabelBindings.isEmpty()) {
+               r.setColumnLabelBindings(columnLabelBindings);
+       }
+       if (!rowLabelBindings.isEmpty()) {
+               r.setRowLabelBindings(rowLabelBindings);
+       }
+    }
+
     return r;
   }
 
@@ -131,6 +148,9 @@ public class MatrixWritable implements W
     if (row.isSequentialAccess()) {
       flags |= FLAG_SEQUENTIAL;
     }
+    if (matrix.getRowLabelBindings() != null || 
matrix.getColumnLabelBindings() != null) {
+      flags |= FLAG_LABELS;
+    }
     out.writeInt(flags);
 
     out.writeInt(matrix.rowSize());
@@ -139,5 +159,8 @@ public class MatrixWritable implements W
     for (int i = 0; i < matrix.rowSize(); i++) {
       VectorWritable.writeVector(out, matrix.viewRow(i), false);
     }
+    if ((flags & FLAG_LABELS) != 0) {
+       writeLabelBindings(out, matrix.getColumnLabelBindings(), 
matrix.getRowLabelBindings());
+    }
   }
 }

Added: 
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java?rev=1178324&view=auto
==============================================================================
--- 
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java
 (added)
+++ 
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java
 Mon Oct  3 07:51:03 2011
@@ -0,0 +1,96 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Map;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.Matrix;
+import org.junit.Test;
+
+public final class ConfusionMatrixTest extends MahoutTestCase {
+
+  private static final int[][] VALUES = {{2, 3}, {10, 20}};
+  private static final String[] LABELS = {"Label1", "Label2"};
+  private static final String DEFAULT_LABEL = "other";
+  
+  @Test
+  public void testBuild() {
+    ConfusionMatrix cm = fillCM(VALUES, LABELS, DEFAULT_LABEL);
+    checkValues(cm);
+    checkAccuracy(cm);
+  }
+
+  @Test
+  public void testGetMatrix() {
+           ConfusionMatrix cm = fillCM(VALUES, LABELS, DEFAULT_LABEL);
+           Matrix m = cm.getMatrix();
+           Map<String, Integer> rowLabels = m.getRowLabelBindings();
+           assertEquals(cm.getLabels().size(), m.numCols());
+           assertTrue(rowLabels.keySet().contains(LABELS[0]));
+           assertTrue(rowLabels.keySet().contains(LABELS[1]));
+           assertTrue(rowLabels.keySet().contains(DEFAULT_LABEL));
+           assertEquals(2, cm.getCorrect(LABELS[0]));
+           assertEquals(20, cm.getCorrect(LABELS[1]));
+           assertEquals(0, cm.getCorrect(DEFAULT_LABEL));
+  }
+
+  private static void checkValues(ConfusionMatrix cm) {
+    int[][] counts = cm.getConfusionMatrix();
+    cm.toString();
+    assertEquals(counts.length, counts[0].length);
+    assertEquals(3, counts.length);
+    assertEquals(VALUES[0][0], counts[0][0]);
+    assertEquals(VALUES[0][1], counts[0][1]);
+    assertEquals(VALUES[1][0], counts[1][0]);
+    assertEquals(VALUES[1][1], counts[1][1]);
+    assertTrue(Arrays.equals(new int[3], counts[2])); // zeros
+    assertEquals(0, counts[0][2]);
+    assertEquals(0, counts[1][2]);
+    assertEquals(3, cm.getLabels().size());
+    assertTrue(cm.getLabels().contains(LABELS[0]));
+    assertTrue(cm.getLabels().contains(LABELS[1]));
+    assertTrue(cm.getLabels().contains(DEFAULT_LABEL));
+
+  }
+
+  private static void checkAccuracy(ConfusionMatrix cm) {
+    Collection<String> labelstrs = cm.getLabels();
+    assertEquals(3, labelstrs.size());
+    assertEquals(40.0, cm.getAccuracy("Label1"), EPSILON);
+    assertEquals(66.666666667, cm.getAccuracy("Label2"), EPSILON);
+    assertTrue(Double.isNaN(cm.getAccuracy("other")));
+  }
+  
+  private static ConfusionMatrix fillCM(int[][] values, String[] labels, 
String defaultLabel) {
+    Collection<String> labelList = new ArrayList<String>();
+    labelList.add(labels[0]);
+    labelList.add(labels[1]);
+    ConfusionMatrix cm = new ConfusionMatrix(labelList, defaultLabel);
+    int[][] v = cm.getConfusionMatrix();
+    v[0][0] = values[0][0];
+    v[0][1] = values[0][1];
+    v[1][0] = values[1][0];
+    v[1][1] = values[1][1];
+    return cm;
+  }
+  
+}

Added: 
mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java?rev=1178324&view=auto
==============================================================================
--- 
mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java 
(added)
+++ 
mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java 
Mon Oct  3 07:51:03 2011
@@ -0,0 +1,120 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import com.google.common.io.Closeables;
+import org.apache.hadoop.io.Writable;
+import org.junit.Test;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+public final class MatrixWritableTest extends MahoutTestCase {
+
+       @Test
+       public void testSparseMatrixWritable() throws Exception {
+               Matrix m = new SparseMatrix(5, 5);
+               m.set(1, 2, 3.0);
+               m.set(3, 4, 5.0);
+               Map<String, Integer> bindings = new HashMap<String, Integer>();
+               bindings.put("A", 0);
+               bindings.put("B", 1);
+               bindings.put("C", 2);
+               bindings.put("D", 3);
+               bindings.put("default", 4);
+               m.setRowLabelBindings(bindings);
+               doTestMatrixWritableEquals(m);
+       }
+
+       @Test
+       public void testDenseMatrixWritable() throws Exception {
+               Matrix m = new DenseMatrix(5,5);
+               m.set(1, 2, 3.0);
+               m.set(3, 4, 5.0);
+               Map<String, Integer> bindings = new HashMap<String, Integer>();
+               bindings.put("A", 0);
+               bindings.put("B", 1);
+               bindings.put("C", 2);
+               bindings.put("D", 3);
+               bindings.put("default", 4);
+               m.setColumnLabelBindings(bindings);
+               doTestMatrixWritableEquals(m);
+       }
+
+       private static void doTestMatrixWritableEquals(Matrix m) throws 
IOException {
+               Writable matrixWritable = new MatrixWritable(m);
+               MatrixWritable matrixWritable2 = new MatrixWritable();
+               writeAndRead(matrixWritable, matrixWritable2);
+               Matrix m2 = matrixWritable2.get();
+               compareMatrices(m, m2);  // not sure this works?
+       }
+
+       private static void compareMatrices(Matrix m, Matrix m2) {
+               assertEquals(m.numRows(), m2.numRows());
+               assertEquals(m.numCols(), m2.numCols());
+               for(int r = 0; r < m.numRows(); r++) {
+                       for(int c = 0; c < m.numCols(); c++) {
+                               assertEquals(m.get(r, c), m2.get(r, c), 
EPSILON);
+                       }
+               }
+               Map<String,Integer> bindings = m.getRowLabelBindings();
+               Map<String, Integer> bindings2 = m2.getRowLabelBindings();
+               assertEquals(bindings == null, bindings2 == null);
+               if (bindings != null) {
+                       assertEquals(bindings.size(), m.numRows());
+                       assertEquals(bindings.size(), bindings2.size());
+                       for(Map.Entry<String,Integer> entry : 
bindings.entrySet()) {
+                               assertEquals(entry.getValue(), 
bindings2.get(entry.getKey()));
+                       }
+               }
+               bindings = m.getColumnLabelBindings();
+               bindings2 = m2.getColumnLabelBindings();
+               assertEquals(bindings == null, bindings2 == null);
+               if (bindings != null) {
+                       assertEquals(bindings.size(), bindings2.size());
+                       for(Map.Entry<String,Integer> entry : 
bindings.entrySet()) {
+                               assertEquals(entry.getValue(), 
bindings2.get(entry.getKey()));
+                       }
+               }
+       }
+
+       private static void writeAndRead(Writable toWrite, Writable toRead) 
throws IOException {
+               ByteArrayOutputStream baos = new ByteArrayOutputStream();
+               DataOutputStream dos = new DataOutputStream(baos);
+               try {
+                       toWrite.write(dos);
+               } finally {
+                       Closeables.closeQuietly(dos);
+               }
+
+               ByteArrayInputStream bais = new 
ByteArrayInputStream(baos.toByteArray());
+               DataInputStream dis = new DataInputStream(bais);
+               try {
+                       toRead.readFields(dis);
+               } finally {
+                       Closeables.closeQuietly(dis);
+               }
+       }
+
+
+}

Modified: 
mahout/trunk/core/src/test/java/org/apache/mahout/math/VectorWritableTest.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/VectorWritableTest.java?rev=1178324&r1=1178323&r2=1178324&view=diff
==============================================================================
--- 
mahout/trunk/core/src/test/java/org/apache/mahout/math/VectorWritableTest.java 
(original)
+++ 
mahout/trunk/core/src/test/java/org/apache/mahout/math/VectorWritableTest.java 
Mon Oct  3 07:51:03 2011
@@ -53,11 +53,27 @@ public final class VectorWritableTest ex
     doTestVectorWritableEquals(v);
   }
 
+  @Test
+  public void testNamedVectorWritable() throws Exception {
+    Vector v = new DenseVector(5);
+    v = new NamedVector(v, "Victor");
+    v.set(1, 3.0);
+    v.set(3, 5.0);
+    doTestVectorWritableEquals(v);
+  }
+
   private static void doTestVectorWritableEquals(Vector v) throws IOException {
     Writable vectorWritable = new VectorWritable(v);
     VectorWritable vectorWritable2 = new VectorWritable();
     writeAndRead(vectorWritable, vectorWritable2);
     Vector v2 = vectorWritable2.get();
+    if (v instanceof NamedVector) {
+       assertTrue(v2 instanceof NamedVector);
+       NamedVector nv = (NamedVector) v;
+       NamedVector nv2 = (NamedVector) v2;
+       assertEquals(nv.getName(), nv2.getName());
+       assertEquals("Victor", nv.getName());
+    }
     assertEquals(v, v2);
   }
 


Reply via email to