Or did we forget to add MatrixDumper?

On Nov 4, 2011, at 5:52 PM, Grant Ingersoll wrote:

> Is that supposed be ConfusionMatrixDumper in the driver.classes.props?
> 
> On Nov 4, 2011, at 7:20 AM, [email protected] wrote:
> 
>> Author: srowen
>> Date: Fri Nov  4 11:20:03 2011
>> New Revision: 1197510
>> 
>> URL: http://svn.apache.org/viewvc?rev=1197510&view=rev
>> Log:
>> MAHOUT-838 Add confusion matrix dumper
>> 
>> Added:
>>    mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/
>>    
>> mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java
>> Modified:
>>    
>> mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
>>    
>> mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java
>>    
>> mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesClassifierDriver.java
>>    mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java
>>    
>> mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java
>>    mahout/trunk/src/conf/driver.classes.props
>> 
>> Modified: 
>> mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
>> URL: 
>> http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java?rev=1197510&r1=1197509&r2=1197510&view=diff
>> ==============================================================================
>> --- 
>> mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
>>  (original)
>> +++ 
>> mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
>>  Fri Nov  4 11:20:03 2011
>> @@ -19,35 +19,40 @@ package org.apache.mahout.classifier;
>> 
>> import java.util.Collection;
>> import java.util.Collections;
>> -import java.util.LinkedHashMap;
>> import java.util.Map;
>> 
>> -import com.google.common.collect.Maps;
>> import org.apache.commons.lang.StringUtils;
>> -import org.apache.mahout.math.CardinalityException;
>> import org.apache.mahout.math.DenseMatrix;
>> import org.apache.mahout.math.Matrix;
>> 
>> import com.google.common.base.Preconditions;
>> +import com.google.common.collect.Maps;
>> 
>> /**
>>  * The ConfusionMatrix Class stores the result of Classification of a Test 
>> Dataset.
>>  * 
>> + * The fact of whether there is a default is not stored. A row of zeros is 
>> the only indicator that there is no default.
>> + * 
>>  * See http://en.wikipedia.org/wiki/Confusion_matrix for background
>>  */
>> public class ConfusionMatrix {
>> -
>> -  private final Map<String,Integer> labelMap = new 
>> LinkedHashMap<String,Integer>();
>> +  private final Map<String,Integer> labelMap = Maps.newLinkedHashMap();
>>   private final int[][] confusionMatrix;
>>   private String defaultLabel = "unknown";
>> 
>>   public ConfusionMatrix(Collection<String> labels, String defaultLabel) {
>>     confusionMatrix = new int[labels.size() + 1][labels.size() + 1];
>>     this.defaultLabel = defaultLabel;
>> +    int i = 0;
>>     for (String label : labels) {
>> -      labelMap.put(label, labelMap.size());
>> +      labelMap.put(label, i++);
>>     }
>> -    labelMap.put(defaultLabel, labelMap.size());
>> +    labelMap.put(defaultLabel, i);
>> +  }
>> +  
>> +  public ConfusionMatrix(Matrix m) {
>> +    confusionMatrix = new int[m.numRows()][m.numRows()];
>> +    setMatrix(m);
>>   }
>> 
>>   public int[][] getConfusionMatrix() {
>> @@ -76,7 +81,7 @@ public class ConfusionMatrix {
>>     return confusionMatrix[labelId][labelId];
>>   }
>> 
>> -  public double getTotal(String label) {
>> +  public int getTotal(String label) {
>>     int labelId = labelMap.get(label);
>>     int labelTotal = 0;
>>     for (int i = 0; i < labelMap.size(); i++) {
>> @@ -94,25 +99,25 @@ public class ConfusionMatrix {
>>   }
>> 
>>   public int getCount(String correctLabel, String classifiedLabel) {
>> -    Preconditions.checkArgument(labelMap.containsKey(correctLabel),
>> -                                "Label not found: " + correctLabel);
>> -    Preconditions.checkArgument(labelMap.containsKey(classifiedLabel),
>> -                                "Label not found: " + classifiedLabel);
>> +    Preconditions.checkArgument(labelMap.containsKey(correctLabel), "Label 
>> not found: " + correctLabel);
>> +    Preconditions.checkArgument(labelMap.containsKey(classifiedLabel), 
>> "Label not found: " + classifiedLabel);
>>     int correctId = labelMap.get(correctLabel);
>>     int classifiedId = labelMap.get(classifiedLabel);
>>     return confusionMatrix[correctId][classifiedId];
>>   }
>> 
>>   public void putCount(String correctLabel, String classifiedLabel, int 
>> count) {
>> -    Preconditions.checkArgument(labelMap.containsKey(correctLabel),
>> -                                "Label not found: " + correctLabel);
>> -    Preconditions.checkArgument(labelMap.containsKey(classifiedLabel),
>> -                                "Label not found: " + classifiedLabel);
>> +    Preconditions.checkArgument(labelMap.containsKey(correctLabel), "Label 
>> not found: " + correctLabel);
>> +    Preconditions.checkArgument(labelMap.containsKey(classifiedLabel), 
>> "Label not found: " + classifiedLabel);
>>     int correctId = labelMap.get(correctLabel);
>>     int classifiedId = labelMap.get(classifiedLabel);
>>     confusionMatrix[correctId][classifiedId] = count;
>>   }
>> 
>> +  public String getDefaultLabel() {
>> +    return defaultLabel;
>> +  }
>> +  
>>   public void incrementCount(String correctLabel, String classifiedLabel, 
>> int count) {
>>     putCount(correctLabel, classifiedLabel, count + getCount(correctLabel, 
>> classifiedLabel));
>>   }
>> @@ -132,45 +137,69 @@ public class ConfusionMatrix {
>>   }
>> 
>>   public Matrix getMatrix() {
>> -      int length = confusionMatrix.length;
>> -      Matrix m = new DenseMatrix(length, length);
>> -      for (int r = 0; r < length; r++) {
>> -              for (int c = 0; c < length; c++) {
>> -                      m.set(r, c, confusionMatrix[r][c]);
>> -              }
>> -      }
>> -      Map<String,Integer> labels = Maps.newHashMap();
>> -      for (Map.Entry<String, Integer> entry : labelMap.entrySet()) {
>> -              labels.put(entry.getKey(), entry.getValue());
>> -      }
>> -      m.setRowLabelBindings(labels);
>> -      m.setColumnLabelBindings(labels);
>> -      return m;
>> +    int length = confusionMatrix.length;
>> +    Matrix m = new DenseMatrix(length, length);
>> +    for (int r = 0; r < length; r++) {
>> +      for (int c = 0; c < length; c++) {
>> +        m.set(r, c, confusionMatrix[r][c]);
>> +      }
>> +    }
>> +    Map<String,Integer> labels = Maps.newHashMap();
>> +    for(Map.Entry<String, Integer> entry : labelMap.entrySet()) {
>> +      labels.put(entry.getKey(), entry.getValue());
>> +    }
>> +    m.setRowLabelBindings(labels);
>> +    m.setColumnLabelBindings(labels);
>> +    return m;
>>   }
>> -
>> +  
>>   public void setMatrix(Matrix m) {
>> -      int length = confusionMatrix.length;
>> -      if (m.numRows() != m.numCols()) {
>> -      throw new CardinalityException(m.numRows(), m.numCols());
>> -    }
>> -    if (m.numRows() != length) {
>> -      throw new CardinalityException(m.numRows(), length);
>> -    }
>> -      for (int r = 0; r < length; r++) {
>> -              for (int c = 0; c < length; c++) {
>> -                      confusionMatrix[r][c] = (int) Math.round(m.get(r, c));
>> -              }
>> -      }
>> -      Map<String,Integer> labels = m.getRowLabelBindings();
>> -      if (labels == null) {
>> +    int length = confusionMatrix.length;
>> +    if (m.numRows() != m.numCols()) {
>> +      throw new IllegalArgumentException(
>> +          String.format("ConfusionMatrix: matrix({},{}) must be square", 
>> m.numRows(), m.numCols()));
>> +    }
>> +    for (int r = 0; r < length; r++) {
>> +      for (int c = 0; c < length; c++) {
>> +        confusionMatrix[r][c] = (int) Math.round(m.get(r, c));
>> +      }
>> +    }
>> +    Map<String,Integer> labels = m.getRowLabelBindings();
>> +    if (labels == null) {
>>       labels = m.getColumnLabelBindings();
>>     }
>> -    labelMap.clear();    
>> -      if (labels != null) {
>> -      labelMap.putAll(labels);
>> -      }
>> +    if (labels != null) {
>> +      String[] sorted = sortLabels(labels);
>> +      verifyLabels(length, sorted);
>> +      labelMap.clear();
>> +      for(int i = 0; i < length; i++) {
>> +        labelMap.put(sorted[i], i);
>> +      }
>> +    }
>> +  }
>> +  
>> +  private static String[] sortLabels(Map<String,Integer> labels) {
>> +    String[] sorted = new String[labels.keySet().size()];
>> +    for(String label: labels.keySet()) {
>> +      Integer index = labels.get(label);
>> +      sorted[index] = label;
>> +    }
>> +    return sorted;
>> +  }
>> +  
>> +  private void verifyLabels(int length, String[] sorted) {
>> +    Preconditions.checkArgument(sorted.length == length, "One label, one 
>> row");
>> +    for(int i = 0; i < length; i++) {
>> +      if (sorted[i] == null) {
>> +        Preconditions.checkArgument(false, "One label, one row");
>> +      }
>> +    }
>>   }
>> 
>> +  /**
>> +   * This is overloaded. toString() is not a formatted report you print for 
>> a manager :)
>> +   * Assume that if there are no default assignments, the default feature 
>> was not used
>> +   */
>>   @Override
>>   public String toString() {
>>     StringBuilder returnString = new StringBuilder(200);
>> @@ -178,26 +207,37 @@ public class ConfusionMatrix {
>>     returnString.append("Confusion Matrix\n");
>>     
>> returnString.append("-------------------------------------------------------").append('\n');
>> 
>> +    int unclassified = getTotal(defaultLabel);
>>     for (Map.Entry<String,Integer> entry : this.labelMap.entrySet()) {
>> +      if (entry.getKey().equals(defaultLabel) && unclassified == 0) {
>> +        continue;
>> +      }
>> +      
>>       
>> returnString.append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 
>> 5)).append('\t');
>>     }
>> 
>>     returnString.append("<--Classified as").append('\n');
>> -    
>>     for (Map.Entry<String,Integer> entry : this.labelMap.entrySet()) {
>> +      if (entry.getKey().equals(defaultLabel) && unclassified == 0) {
>> +        continue;
>> +      }
>>       String correctLabel = entry.getKey();
>>       int labelTotal = 0;
>>       for (String classifiedLabel : this.labelMap.keySet()) {
>> +        if (classifiedLabel.equals(defaultLabel) && unclassified == 0) {
>> +          continue;
>> +        }
>>         returnString.append(
>> -          StringUtils.rightPad(Integer.toString(getCount(correctLabel, 
>> classifiedLabel)), 5)).append('\t');
>> +            StringUtils.rightPad(Integer.toString(getCount(correctLabel, 
>> classifiedLabel)), 5)).append('\t');
>>         labelTotal += getCount(correctLabel, classifiedLabel);
>>       }
>>       returnString.append(" |  
>> ").append(StringUtils.rightPad(String.valueOf(labelTotal), 6)).append('\t')
>> -          .append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 5))
>> -          .append(" = ").append(correctLabel).append('\n');
>> +      .append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 5))
>> +      .append(" = ").append(correctLabel).append('\n');
>> +    }
>> +    if (unclassified > 0) {
>> +      returnString.append("Default Category: 
>> ").append(defaultLabel).append(": ").append(unclassified).append('\n');
>>     }
>> -    returnString.append("Default Category: 
>> ").append(defaultLabel).append(": ").append(
>> -      labelMap.get(defaultLabel)).append('\n');
>>     returnString.append('\n');
>>     return returnString.toString();
>>   }
>> @@ -212,5 +252,5 @@ public class ConfusionMatrix {
>>     } while (val > 0);
>>     return returnString.toString();
>>   }
>> -  
>> +
>> }
>> 
>> Modified: 
>> mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java
>> URL: 
>> http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java?rev=1197510&r1=1197509&r2=1197510&view=diff
>> ==============================================================================
>> --- 
>> mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java
>>  (original)
>> +++ 
>> mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java
>>  Fri Nov  4 11:20:03 2011
>> @@ -33,6 +33,7 @@ import org.apache.commons.cli2.builder.D
>> import org.apache.commons.cli2.builder.GroupBuilder;
>> import org.apache.commons.cli2.commandline.Parser;
>> import org.apache.mahout.classifier.ClassifierResult;
>> +import org.apache.mahout.classifier.ConfusionMatrix;
>> import org.apache.mahout.classifier.ResultAnalyzer;
>> import 
>> org.apache.mahout.classifier.bayes.mapreduce.bayes.BayesClassifierDriver;
>> import org.apache.mahout.common.CommandLineUtil;
>> @@ -104,9 +105,14 @@ public final class TestClassifier {
>>       "Method of Classification: sequential|mapreduce. Default Value: 
>> sequential").withShortName("method")
>>         .create();
>> 
>> +    Option confusionMatrixOpt = 
>> obuilder.withLongName("confusionMatrix").withRequired(false).withArgument(
>> +        
>> abuilder.withName("confusionMatrix").withMinimum(1).withMaximum(1).create()).withDescription(
>> +        "Export ConfusionMatrix as 
>> SequenceFile").withShortName("cm").create();
>> +      
>>     Group group = 
>> gbuilder.withName("Options").withOption(defaultCatOpt).withOption(dirOpt).withOption(
>>       
>> encodingOpt).withOption(gramSizeOpt).withOption(pathOpt).withOption(typeOpt).withOption(dataSourceOpt)
>> -        
>> .withOption(helpOpt).withOption(methodOpt).withOption(verboseOutputOpt).withOption(alphaOpt).create();
>> +        
>> .withOption(helpOpt).withOption(methodOpt).withOption(verboseOutputOpt).withOption(alphaOpt)
>> +        .withOption(confusionMatrixOpt).create();
>> 
>>     try {
>>       Parser parser = new Parser();
>> @@ -163,6 +169,11 @@ public final class TestClassifier {
>>         classificationMethod = (String) cmdLine.getValue(methodOpt);
>>       }
>> 
>> +      String confusionMatrixFile = null;
>> +      if (cmdLine.hasOption(confusionMatrixOpt)) {
>> +        confusionMatrixFile = (String) cmdLine.getValue(confusionMatrixOpt);
>> +      }
>> +      
>>       params.setGramSize(gramSize);
>>       params.set("verbose", Boolean.toString(verbose));
>>       params.setBasePath(modelBasePath);
>> @@ -172,6 +183,7 @@ public final class TestClassifier {
>>       params.set("encoding", encoding);
>>       params.set("alpha_i", alphaI);
>>       params.set("testDirPath", testDirPath);
>> +      params.set("confusionMatrix", confusionMatrixFile);
>> 
>>       if ("sequential".equalsIgnoreCase(classificationMethod)) {
>>         classifySequential(params);
>> @@ -253,12 +265,12 @@ public final class TestClassifier {
>>           }
>>           lineNum++;
>>         }
>> -        /*
>> -         * log.info("{}\t{}\t{}/{}", new Object[] {correctLabel,
>> -         * resultAnalyzer.getConfusionMatrix().getAccuracy(correctLabel),
>> -         * resultAnalyzer.getConfusionMatrix().getCorrect(correctLabel),
>> -         * resultAnalyzer.getConfusionMatrix().getTotal(correctLabel)});
>> -         */
>> +        ConfusionMatrix matrix = resultAnalyzer.getConfusionMatrix();
>> +        log.info("{}", matrix);
>> +        BayesClassifierDriver.confusionMatrixSeqFileExport(params, matrix);
>> +
>> +        log.info("ConfusionMatrix: {}", matrix.toString());
>> +           
>>         log.info("Classified instances from {}", file.getName());
>>         if (verbose) {
>>           log.info("Performance stats {}", operationStats.toString());
>> 
>> Modified: 
>> mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesClassifierDriver.java
>> URL: 
>> http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesClassifierDriver.java?rev=1197510&r1=1197509&r2=1197510&view=diff
>> ==============================================================================
>> --- 
>> mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesClassifierDriver.java
>>  (original)
>> +++ 
>> mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesClassifierDriver.java
>>  Fri Nov  4 11:20:03 2011
>> @@ -21,10 +21,15 @@ import java.io.IOException;
>> import java.util.Map;
>> 
>> import com.google.common.collect.Maps;
>> +import com.google.common.io.Closeables;
>> +
>> import org.apache.hadoop.conf.Configurable;
>> import org.apache.hadoop.conf.Configuration;
>> +import org.apache.hadoop.fs.FileSystem;
>> import org.apache.hadoop.fs.Path;
>> import org.apache.hadoop.io.DoubleWritable;
>> +import org.apache.hadoop.io.SequenceFile;
>> +import org.apache.hadoop.io.Text;
>> import org.apache.hadoop.mapred.FileInputFormat;
>> import org.apache.hadoop.mapred.FileOutputFormat;
>> import org.apache.hadoop.mapred.JobClient;
>> @@ -38,6 +43,7 @@ import org.apache.mahout.common.Paramete
>> import org.apache.mahout.common.StringTuple;
>> import org.apache.mahout.common.iterator.sequencefile.PathType;
>> import 
>> org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
>> +import org.apache.mahout.math.MatrixWritable;
>> import org.slf4j.Logger;
>> import org.slf4j.LoggerFactory;
>> 
>> @@ -83,6 +89,10 @@ public final class BayesClassifierDriver
>>     Path outputFiles = new Path(outPath, "part*");
>>     ConfusionMatrix matrix = readResult(outputFiles, conf, params);
>>     log.info("{}", matrix);
>> +    if (params.get("confusionMatrix") != null) {
>> +      confusionMatrixSeqFileExport(params, matrix);
>> +    }
>> +    
>>   }
>> 
>>   public static ConfusionMatrix readResult(Path pathPattern, Configuration 
>> conf, Parameters params) {
>> @@ -117,6 +127,24 @@ public final class BayesClassifierDriver
>>       }
>>     }
>>     return matrix;
>> +  }
>> 
>> +  public static void confusionMatrixSeqFileExport(Parameters params, 
>> ConfusionMatrix matrix) throws IOException {
>> +    if (params.get("confusionMatrix") != null) {
>> +      Configuration conf = new Configuration();
>> +      FileSystem fs = FileSystem.get(conf);
>> +      SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf,
>> +          new Path(params.get("confusionMatrix")), Text.class, 
>> MatrixWritable.class);
>> +      String name = params.get("confusionMatrix");
>> +      // embed file name as sequence key- useful for tuning classifiers
>> +      name = name.substring(name.lastIndexOf('/') + 1, name.length());
>> +      Text key = new Text(name);
>> +      MatrixWritable mw = new MatrixWritable(matrix.getMatrix());
>> +      try {
>> +        writer.append(key, mw);
>> +      } finally {
>> +        Closeables.closeQuietly(writer);
>> +      }
>> +    }
>>   }
>> }
>> 
>> Modified: 
>> mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java
>> URL: 
>> http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java?rev=1197510&r1=1197509&r2=1197510&view=diff
>> ==============================================================================
>> --- 
>> mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java 
>> (original)
>> +++ 
>> mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java 
>> Fri Nov  4 11:20:03 2011
>> @@ -123,15 +123,15 @@ public class MatrixWritable implements W
>>     }
>> 
>>     if (hasLabels) {
>> -            Map<String,Integer> columnLabelBindings = Maps.newHashMap();
>> -            Map<String,Integer> rowLabelBindings = Maps.newHashMap();
>> -            readLabels(in, columnLabelBindings, rowLabelBindings);
>> -            if (!columnLabelBindings.isEmpty()) {
>> -                    r.setColumnLabelBindings(columnLabelBindings);
>> -            }
>> -            if (!rowLabelBindings.isEmpty()) {
>> -                    r.setRowLabelBindings(rowLabelBindings);
>> -            }
>> +      Map<String,Integer> columnLabelBindings = Maps.newHashMap();
>> +      Map<String,Integer> rowLabelBindings = Maps.newHashMap();
>> +      readLabels(in, columnLabelBindings, rowLabelBindings);
>> +      if (!columnLabelBindings.isEmpty()) {
>> +        r.setColumnLabelBindings(columnLabelBindings);
>> +      }
>> +      if (!rowLabelBindings.isEmpty()) {
>> +        r.setRowLabelBindings(rowLabelBindings);
>> +      }
>>     }
>> 
>>     return r;
>> @@ -159,7 +159,7 @@ public class MatrixWritable implements W
>>       VectorWritable.writeVector(out, matrix.viewRow(i), false);
>>     }
>>     if ((flags & FLAG_LABELS) != 0) {
>> -            writeLabelBindings(out, matrix.getColumnLabelBindings(), 
>> matrix.getRowLabelBindings());
>> +      writeLabelBindings(out, matrix.getColumnLabelBindings(), 
>> matrix.getRowLabelBindings());
>>     }
>>   }
>> }
>> 
>> Modified: 
>> mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java
>> URL: 
>> http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java?rev=1197510&r1=1197509&r2=1197510&view=diff
>> ==============================================================================
>> --- 
>> mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java
>>  (original)
>> +++ 
>> mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java
>>  Fri Nov  4 11:20:03 2011
>> @@ -17,6 +17,7 @@
>> 
>> package org.apache.mahout.math;
>> 
>> +import com.google.common.collect.Maps;
>> import com.google.common.io.Closeables;
>> import org.apache.hadoop.io.Writable;
>> import org.junit.Test;
>> @@ -26,7 +27,6 @@ import java.io.ByteArrayOutputStream;
>> import java.io.DataInputStream;
>> import java.io.DataOutputStream;
>> import java.io.IOException;
>> -import java.util.HashMap;
>> import java.util.Map;
>> 
>> public final class MatrixWritableTest extends MahoutTestCase {
>> @@ -36,13 +36,14 @@ public final class MatrixWritableTest ex
>>              Matrix m = new SparseMatrix(5, 5);
>>              m.set(1, 2, 3.0);
>>              m.set(3, 4, 5.0);
>> -            Map<String, Integer> bindings = new HashMap<String, Integer>();
>> +            Map<String, Integer> bindings = Maps.newHashMap();
>>              bindings.put("A", 0);
>>              bindings.put("B", 1);
>>              bindings.put("C", 2);
>>              bindings.put("D", 3);
>>              bindings.put("default", 4);
>>              m.setRowLabelBindings(bindings);
>> +    m.setColumnLabelBindings(bindings);
>>              doTestMatrixWritableEquals(m);
>>      }
>> 
>> @@ -51,12 +52,13 @@ public final class MatrixWritableTest ex
>>              Matrix m = new DenseMatrix(5,5);
>>              m.set(1, 2, 3.0);
>>              m.set(3, 4, 5.0);
>> -            Map<String, Integer> bindings = new HashMap<String, Integer>();
>> +            Map<String, Integer> bindings = Maps.newHashMap();
>>              bindings.put("A", 0);
>>              bindings.put("B", 1);
>>              bindings.put("C", 2);
>>              bindings.put("D", 3);
>>              bindings.put("default", 4);
>> +    m.setRowLabelBindings(bindings);
>>              m.setColumnLabelBindings(bindings);
>>              doTestMatrixWritableEquals(m);
>>      }
>> @@ -66,7 +68,9 @@ public final class MatrixWritableTest ex
>>              MatrixWritable matrixWritable2 = new MatrixWritable();
>>              writeAndRead(matrixWritable, matrixWritable2);
>>              Matrix m2 = matrixWritable2.get();
>> -            compareMatrices(m, m2);  // not sure this works?
>> +            compareMatrices(m, m2); 
>> +    doCheckBindings(m2.getRowLabelBindings());
>> +    doCheckBindings(m2.getColumnLabelBindings());    
>>      }
>> 
>>      private static void compareMatrices(Matrix m, Matrix m2) {
>> @@ -98,6 +102,14 @@ public final class MatrixWritableTest ex
>>              }
>>      }
>> 
>> +  private static void doCheckBindings(Map<String,Integer> labels) {
>> +    assertTrue("Missing label", labels.keySet().contains("A"));
>> +    assertTrue("Missing label", labels.keySet().contains("B"));
>> +    assertTrue("Missing label", labels.keySet().contains("C"));
>> +    assertTrue("Missing label", labels.keySet().contains("D"));
>> +    assertTrue("Missing label", labels.keySet().contains("default"));
>> +  }
>> +
>>      private static void writeAndRead(Writable toWrite, Writable toRead) 
>> throws IOException {
>>              ByteArrayOutputStream baos = new ByteArrayOutputStream();
>>              DataOutputStream dos = new DataOutputStream(baos);
>> 
>> Added: 
>> mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java
>> URL: 
>> http://svn.apache.org/viewvc/mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java?rev=1197510&view=auto
>> ==============================================================================
>> --- 
>> mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java
>>  (added)
>> +++ 
>> mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java
>>  Fri Nov  4 11:20:03 2011
>> @@ -0,0 +1,423 @@
>> +/*
>> + * Licensed to the Apache Software Foundation (ASF) under one or more
>> + * contributor license agreements.  See the NOTICE file distributed with
>> + * this work for additional information regarding copyright ownership.
>> + * The ASF licenses this file to You under the Apache License, Version 2.0
>> + * (the "License"); you may not use this file except in compliance with
>> + * the License.  You may obtain a copy of the License at
>> + *
>> + *     http://www.apache.org/licenses/LICENSE-2.0
>> + *
>> + * Unless required by applicable law or agreed to in writing, software
>> + * distributed under the License is distributed on an "AS IS" BASIS,
>> + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
>> + * See the License for the specific language governing permissions and
>> + * limitations under the License.
>> + */
>> +
>> +package org.apache.mahout.classifier;
>> +
>> +import java.io.File;
>> +import java.io.FileOutputStream;
>> +import java.io.IOException;
>> +import java.io.OutputStream;
>> +import java.io.PrintStream;
>> +import java.util.Iterator;
>> +import java.util.List;
>> +import java.util.Map;
>> +
>> +import org.apache.hadoop.conf.Configuration;
>> +import org.apache.hadoop.fs.FileSystem;
>> +import org.apache.hadoop.fs.Path;
>> +import org.apache.hadoop.io.SequenceFile;
>> +import org.apache.hadoop.io.Text;
>> +import org.apache.hadoop.util.ToolRunner;
>> +import org.apache.mahout.common.AbstractJob;
>> +import org.apache.mahout.common.commandline.DefaultOptionCreator;
>> +import org.apache.mahout.math.Matrix;
>> +import org.apache.mahout.math.MatrixWritable;
>> +
>> +import com.google.common.collect.Lists;
>> +
>> +/**
>> + * Export a ConfusionMatrix in various text formats: 
>> + *   ToString version
>> + *   Grayscale HTML table
>> + *   Summary HTML table 
>> + *   Table of counts
>> + *   all with optional HTML wrappers
>> + * 
>> + * Input format: Hadoop SequenceFile with Text key and MatrixWritable 
>> value, 1 pair
>> + * 
>> + * Intended to consume ConfusionMatrix SequenceFile output by Bayes
>> + * TestClassifier class
>> + */
>> +public final class ConfusionMatrixDumper extends AbstractJob {
>> +
>> +  // HTML wrapper - default CSS
>> +  private static final String HEADER = "<html>" +
>> +      "<head>\n" +
>> +      "<title>TITLE</title>\n" +
>> +      "</head>" +
>> +      "<body>\n" +
>> +      "<style type='text/css'> \n" +
>> +      "table\n" +
>> +      "{\n" +
>> +      "border:3px solid black; text-align:left;\n" +
>> +      "}\n" +
>> +      "th.normalHeader\n" +
>> +      "{\n" +
>> +      "border:1px solid 
>> black;border-collapse:collapse;text-align:center;background-color:white\n" +
>> +      "}\n" +
>> +      "th.tallHeader\n" +
>> +      "{\n" +
>> +      "border:1px solid 
>> black;border-collapse:collapse;text-align:center;background-color:white; 
>> height:6em\n" +
>> +      "}\n" +
>> +      "tr.label\n" +
>> +      "{\n" +
>> +      "border:1px solid 
>> black;border-collapse:collapse;text-align:center;background-color:white\n" +
>> +      "}\n" +
>> +      "tr.row\n" +
>> +      "{\n" +
>> +      "border:1px solid gray;text-align:center;background-color:snow\n" +
>> +      "}\n" +
>> +      "td\n" +
>> +      "{\n" +
>> +      "min-width:2em\n" +
>> +      "}\n" +
>> +      "td.cell\n" +
>> +      "{\n" +
>> +      "border:1px solid black;text-align:right;background-color:snow\n" +
>> +      "}\n" +
>> +      "td.empty\n" +
>> +      "{\n" +
>> +      "border:0px;text-align:right;background-color:snow\n" +
>> +      "}\n" +
>> +      "td.white\n" +
>> +      "{\n" +
>> +      "border:0px solid black;text-align:right;background-color:white\n" +
>> +      "}\n" +
>> +      "td.black\n" +
>> +      "{\n" +
>> +      "border:0px solid red;text-align:right;background-color:black\n" +
>> +      "}\n" +
>> +      "td.gray1\n" +
>> +      "{\n" +
>> +      "border:0px solid green;text-align:right; 
>> background-color:LightGray\n" +
>> +      "}\n" +
>> +      "td.gray2\n" +
>> +      "{\n" +
>> +      "border:0px solid blue;text-align:right;background-color:gray\n" +
>> +      "}\n" +
>> +      "td.gray3\n" +
>> +      "{\n" +
>> +      "border:0px solid red;text-align:right;background-color:DarkGray\n" +
>> +      "}\n" +
>> +      "th" +
>> +      "{\n" +
>> +      "        text-align: center;\n" +
>> +      "        vertical-align: bottom;\n" +
>> +      "        padding-bottom: 3px;\n" +
>> +      "        padding-left: 5px;\n" +
>> +      "        padding-right: 5px;\n" +
>> +      "}\n" +
>> +      "     .verticalText\n" +
>> +      "      {\n" +
>> +      "        text-align: center;\n" +
>> +      "        vertical-align: middle;\n" +
>> +      "        width: 20px;\n" +
>> +      "        margin: 0px;\n" +
>> +      "        padding: 0px;\n" +
>> +      "        padding-left: 3px;\n" +
>> +      "        padding-right: 3px;\n" +
>> +      "        padding-top: 10px;\n" +
>> +      "        white-space: nowrap;\n" +
>> +      "        -webkit-transform: rotate(-90deg); \n" +
>> +      "        -moz-transform: rotate(-90deg);         \n" +
>> +      "      };\n" +
>> +      "</style>\n";
>> +  private static final String FOOTER = "</html></body>";
>> +  
>> +  // CSS style names. 
>> +  private static final String CSS_TABLE = "table";
>> +  private static final String CSS_LABEL = "label";
>> +  private static final String CSS_TALL_HEADER = "tall";
>> +  private static final String CSS_VERTICAL = "verticalText";
>> +  private static final String CSS_CELL = "cell";
>> +  private static final String CSS_EMPTY = "empty";
>> +  private static final String[] CSS_GRAY_CELLS = {"white", "gray1", 
>> "gray2", "gray3", "black"};
>> +  
>> +  private ConfusionMatrixDumper() {}
>> +  
>> +  public static void main(String[] args) throws Exception {
>> +    ToolRunner.run(new ConfusionMatrixDumper(), args);
>> +  }
>> +  
>> +  @Override
>> +  public int run(String[] args) throws IOException {
>> +    addInputOption();
>> +    addOption("output", "o", "Output path", null); // AbstractJob output 
>> feature requires param
>> +    addOption(DefaultOptionCreator.overwriteOption().create());
>> +    addFlag("html", null, "Create complete HTML page");
>> +    addFlag("text", null, "Dump simple text");
>> +    Map<String, String> parsedArgs = parseArguments(args);
>> +    if (parsedArgs == null) {
>> +      return -1;
>> +    }
>> +    
>> +    Path inputPath = getInputPath();
>> +    String outputFile = parsedArgs.containsKey("--output") ? 
>> parsedArgs.get("--output") : null;
>> +    boolean text = parsedArgs.containsKey("--text");
>> +    boolean wrapHtml = parsedArgs.containsKey("--html");
>> +    PrintStream out = getPrintStream(outputFile);
>> +    if (text) {
>> +      exportText(inputPath, out);
>> +    } else {
>> +      exportTable(inputPath, out, wrapHtml);
>> +    }
>> +    out.flush();
>> +    if (out != System.out) {
>> +      out.close();
>> +    }
>> +    return 0;
>> +  }
>> +  
>> +  private static void exportText(Path inputPath, PrintStream out) throws 
>> IOException {
>> +    MatrixWritable mw = new MatrixWritable();
>> +    Text key = new Text();
>> +    readSeqFile(inputPath, key, mw);
>> +    Matrix m = mw.get();
>> +    ConfusionMatrix cm = new ConfusionMatrix(m);
>> +    out.println(cm.toString());
>> +  }
>> +  
>> +  private static void exportTable(Path inputPath, PrintStream out, boolean 
>> wrapHtml) throws IOException {
>> +    MatrixWritable mw = new MatrixWritable();
>> +    Text key = new Text();
>> +    readSeqFile(inputPath, key, mw);
>> +    String fileName = inputPath.getName();
>> +    fileName = fileName.substring(fileName.lastIndexOf('/') + 1, 
>> fileName.length());
>> +    Matrix m = mw.get();
>> +    ConfusionMatrix cm = new ConfusionMatrix(m);
>> +    if (wrapHtml) {
>> +      printHeader(out, fileName);
>> +    }
>> +    out.println("<p/>");
>> +    printSummaryTable(cm, out);
>> +    out.println("<p/>");
>> +    printGrayTable(cm, out);
>> +    out.println("<p/>");
>> +    printCountsTable(cm, out);
>> +    out.println("<p/>");
>> +    printTextInBox(cm, out);
>> +    out.println("<p/>");
>> +    if (wrapHtml) {
>> +      printFooter(out);
>> +    }
>> +  }
>> +  
>> +  private static List<String> stripDefault(ConfusionMatrix cm) {
>> +    List<String> stripped = Lists.newArrayList(cm.getLabels().iterator()); 
>> +    String defaultLabel = cm.getDefaultLabel();
>> +    int unclassified = cm.getTotal(defaultLabel);
>> +    if (unclassified > 0) {
>> +      return stripped;
>> +    }
>> +    stripped.remove(defaultLabel);
>> +    return stripped;
>> +  }
>> +  
>> +  // TODO: test - this should work with HDFS files
>> +  private static void readSeqFile(Path path, Text key, MatrixWritable m) 
>> throws IOException {
>> +    Configuration conf = new Configuration();
>> +    FileSystem fs = FileSystem.get(conf);
>> +    SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, conf);
>> +    reader.next(key, m);
>> +  }
>> +  
>> +  // TODO: test - this might not work with HDFS files?
>> +  //     after all, it does no seeks
>> +  private static PrintStream getPrintStream(String outputFilename) throws 
>> IOException {
>> +    if (outputFilename != null) {
>> +      File outputFile = new File(outputFilename);
>> +      if (outputFile.exists()) { 
>> +        outputFile.delete();
>> +      }
>> +      outputFile.createNewFile();
>> +      OutputStream os = new FileOutputStream(outputFile);
>> +      return new PrintStream(os);
>> +    } else {
>> +      return System.out;
>> +    }
>> +  }
>> +  
>> +  private static int getLabelTotal(ConfusionMatrix cm, String rowLabel) {
>> +    Iterator<String> iter = cm.getLabels().iterator();
>> +    int count = 0;
>> +    while(iter.hasNext()) {
>> +      count += cm.getCount(rowLabel, iter.next());
>> +    }
>> +    return count;
>> +  }
>> +  
>> +  // HTML generator code
>> +  
>> +  private static void printTextInBox(ConfusionMatrix cm, PrintStream out) {
>> +    out.println("<div style='width:90%;overflow:scroll;'>");
>> +    out.println("<pre>");
>> +    out.println(cm.toString());
>> +    out.println("</pre>");
>> +    out.println("</div>");
>> +  }
>> +  
>> +  public static void printSummaryTable(ConfusionMatrix cm, PrintStream out) 
>> {
>> +    format("<table class='%s'>\n", out, CSS_TABLE);
>> +    format("<tr class='%s'>", out, CSS_LABEL);
>> +    out.println("<td>Label</td><td>Total</td><td>Correct</td><td>%</td>");
>> +    out.println("</tr>");
>> +    List<String> labels = stripDefault(cm);
>> +    for(String label: labels) {
>> +      printSummaryRow(cm, out, label);
>> +    }
>> +    out.println("</table>");
>> +  }
>> +  
>> +  private static void printSummaryRow(ConfusionMatrix cm, PrintStream out, 
>> String label) {
>> +    format("<tr class='%s'>", out, CSS_CELL);
>> +    int correct = cm.getCorrect(label);
>> +    double accuracy = cm.getAccuracy(label);
>> +    int count = getCount(cm, label);
>> +    format("<td class='%s'>%s</td><td>%d</td><td>%d</td><td>%d</td>",
>> +           out, CSS_CELL, label, count, correct, (int) 
>> Math.round(accuracy));
>> +    out.println("</tr>");
>> +  }
>> +  
>> +  private static int getCount(ConfusionMatrix cm, String label) {
>> +    int count = 0;
>> +    for (String s : cm.getLabels()) {
>> +      count += cm.getCount(label, s);
>> +    }
>> +    return count;
>> +  }
>> +  
>> +  public static void printGrayTable(ConfusionMatrix cm, PrintStream out) {
>> +    format("<table class='%s'>\n", out, CSS_TABLE);
>> +    printCountsHeader(cm, out, true);
>> +    printGrayRows(cm, out);
>> +    out.println("</table>");
>> +  }
>> +  
>> +  /**
>> +   * Print each value in a four-value grayscale based on count/max. 
>> +   * Gives a mostly white matrix with grays in misclassified, and black in 
>> diagonal.
>> +   * TODO: Using the sqrt(count/max) as the rating is more stringent 
>> +   */
>> +  private static void printGrayRows(ConfusionMatrix cm, PrintStream out) {
>> +    List<String> labels = stripDefault(cm);
>> +    for (String label: labels) {
>> +      printGrayRow(cm, out, labels, label);
>> +    }
>> +  }
>> +  
>> +  private static void printGrayRow(ConfusionMatrix cm, PrintStream out, 
>> Iterable<String> labels, String rowLabel) {
>> +    format("<tr class='%s'>", out, CSS_LABEL);
>> +    format("<td>%s</td>", out, rowLabel);
>> +    int total = getLabelTotal(cm, rowLabel);
>> +    for (String columnLabel: labels) {
>> +      printGrayCell(cm, out, total, rowLabel, columnLabel);
>> +    }
>> +    out.println("</tr>");
>> +  }
>> +  
>> +  // assign white/light/medium/dark to 0,1/4,1/2,3/4 of total number of 
>> inputs
>> +  // assign black to count = total, meaning complete success
>> +  // alternative rating is to use sqrt(total) instead of total - this is 
>> more drastic 
>> +  private static void printGrayCell(ConfusionMatrix cm,
>> +                                    PrintStream out,
>> +                                    int total,
>> +                                    String rowLabel,
>> +                                    String columnLabel) {
>> +    
>> +    int count = cm.getCount(rowLabel, columnLabel);
>> +    if (count == 0) {
>> +      out.format("<td class='%s'/>", CSS_EMPTY);
>> +    } else {
>> +      // 0 is white, full is black, everything else gray
>> +      int rating = (int) ((count/ (double) total) * 4);
>> +      String css = CSS_GRAY_CELLS[rating];
>> +      format("<td class='%s' title='%s'>%s</td>", out, css, columnLabel, 
>> count);
>> +    }
>> +  }
>> +  
>> +  public static void printCountsTable(ConfusionMatrix cm, PrintStream out) {
>> +    int length = cm.getLabels().size();
>> +    format("<table class='%s'>\n", out, CSS_TABLE);
>> +    printCountsHeader(cm, out, false);
>> +    printCountsRows(cm, out);
>> +    out.println("</table>");
>> +  }
>> +  
>> +  private static void printCountsRows(ConfusionMatrix cm, PrintStream out) {
>> +    List<String> labels = stripDefault(cm);
>> +    for(String label: labels) {
>> +      printCountsRow(cm, out, labels, label);
>> +    }
>> +  }
>> +  
>> +  private static void printCountsRow(ConfusionMatrix cm, PrintStream out, 
>> Iterable<String> labels, String rowLabel) {
>> +    out.println("<tr>");
>> +    format("<td class='%s'>%s</td>", out, CSS_LABEL, rowLabel);
>> +    for(String columnLabel: labels) {
>> +      printCountsCell(cm, out, rowLabel, columnLabel);
>> +    }
>> +    out.println("</tr>");
>> +  }
>> +  
>> +  private static void printCountsCell(ConfusionMatrix cm, PrintStream out, 
>> String rowLabel, String columnLabel) {
>> +    int count = cm.getCount(rowLabel, columnLabel);
>> +    String s = count == 0 ? "" : Integer.toString(count);
>> +    format("<td class='%s' title='%s'>%s</td>", out, CSS_CELL, columnLabel, 
>> s);
>> +  }
>> +  
>> +  private static void printCountsHeader(ConfusionMatrix cm, PrintStream 
>> out, boolean vertical) {
>> +    List<String> labels = stripDefault(cm);
>> +    int longest = getLongestHeader(labels);
>> +    if (vertical) {
>> +      // do vertical - rotation is a bitch
>> +      out.format("<tr class='%s' style='height:%dem'><th>&nbsp;</th>\n", 
>> CSS_TALL_HEADER, longest/2);
>> +      for(String label: labels) {
>> +        out.format("<th><div class='%s'>%s</div></th>", CSS_VERTICAL, 
>> label);
>> +      }
>> +      out.println("</tr>");
>> +    } else {
>> +      // header - empty cell in upper left
>> +      out.format("<tr class='%s'><td class='%s'></td>\n", CSS_TABLE, 
>> CSS_LABEL);
>> +      for(String label: labels) {
>> +        out.format("<td>%s</td>", label);
>> +      }
>> +      out.format("</tr>");
>> +    }
>> +  }
>> +  
>> +  private static int getLongestHeader(Iterable<String> labels) {
>> +    int max = 0;
>> +    for (String label: labels) {
>> +      max = Math.max(label.length(), max);
>> +    }
>> +    return max;
>> +  }
>> +  
>> +  private static void format(String format, PrintStream out, Object ... 
>> args) {
>> +    String format2 = String.format(format, args);
>> +    out.println(format2);
>> +  }
>> +  
>> +  public static void printHeader(PrintStream out, CharSequence title) {
>> +    out.println(HEADER.replace("TITLE", title));
>> +  }
>> +  
>> +  public static void printFooter(PrintStream out) {
>> +    out.println(FOOTER);
>> +  }
>> +  
>> +}
>> 
>> Modified: mahout/trunk/src/conf/driver.classes.props
>> URL: 
>> http://svn.apache.org/viewvc/mahout/trunk/src/conf/driver.classes.props?rev=1197510&r1=1197509&r2=1197510&view=diff
>> ==============================================================================
>> --- mahout/trunk/src/conf/driver.classes.props (original)
>> +++ mahout/trunk/src/conf/driver.classes.props Fri Nov  4 11:20:03 2011
>> @@ -46,4 +46,6 @@ org.apache.mahout.classifier.sequencelea
>> org.apache.mahout.classifier.sequencelearning.hmm.RandomSequenceGenerator = 
>> hmmpredict : Generate random sequence of observations by given HMM
>> org.apache.mahout.utils.SplitInput = split : Split Input data into test and 
>> train sets
>> org.apache.mahout.classifier.naivebayes.training.TrainNaiveBayesJob = 
>> trainnb : Train the Vector-based Bayes classifier
>> -org.apache.mahout.classifier.naivebayes.test.TestNaiveBayesDriver = testnb 
>> : Test the Vector-based Bayes classifier
>> \ No newline at end of file
>> +org.apache.mahout.classifier.naivebayes.test.TestNaiveBayesDriver = testnb 
>> : Test the Vector-based Bayes classifier
>> +org.apache.mahout.classifier.ConfusionMatrixDumper = cmdump : Dump 
>> confusion matrix in HTML or text formats
>> +org.apache.mahout.utils.MatrixDumper = matrixdump : Dump matrix in CSV 
>> format
>> \ No newline at end of file
>> 
>> 
> 
> --------------------------------------------
> Grant Ingersoll
> http://www.lucidimagination.com
> 
> 
> 

--------------------------------------------
Grant Ingersoll
http://www.lucidimagination.com



Reply via email to