Or did we forget to add MatrixDumper? On Nov 4, 2011, at 5:52 PM, Grant Ingersoll wrote:
> Is that supposed be ConfusionMatrixDumper in the driver.classes.props? > > On Nov 4, 2011, at 7:20 AM, [email protected] wrote: > >> Author: srowen >> Date: Fri Nov 4 11:20:03 2011 >> New Revision: 1197510 >> >> URL: http://svn.apache.org/viewvc?rev=1197510&view=rev >> Log: >> MAHOUT-838 Add confusion matrix dumper >> >> Added: >> mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ >> >> mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java >> Modified: >> >> mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java >> >> mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java >> >> mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesClassifierDriver.java >> mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java >> >> mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java >> mahout/trunk/src/conf/driver.classes.props >> >> Modified: >> mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java >> URL: >> http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java?rev=1197510&r1=1197509&r2=1197510&view=diff >> ============================================================================== >> --- >> mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java >> (original) >> +++ >> mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java >> Fri Nov 4 11:20:03 2011 >> @@ -19,35 +19,40 @@ package org.apache.mahout.classifier; >> >> import java.util.Collection; >> import java.util.Collections; >> -import java.util.LinkedHashMap; >> import java.util.Map; >> >> -import com.google.common.collect.Maps; >> import org.apache.commons.lang.StringUtils; >> -import org.apache.mahout.math.CardinalityException; >> import org.apache.mahout.math.DenseMatrix; >> import org.apache.mahout.math.Matrix; >> >> import com.google.common.base.Preconditions; >> +import com.google.common.collect.Maps; >> >> /** >> * The ConfusionMatrix Class stores the result of Classification of a Test >> Dataset. >> * >> + * The fact of whether there is a default is not stored. A row of zeros is >> the only indicator that there is no default. >> + * >> * See http://en.wikipedia.org/wiki/Confusion_matrix for background >> */ >> public class ConfusionMatrix { >> - >> - private final Map<String,Integer> labelMap = new >> LinkedHashMap<String,Integer>(); >> + private final Map<String,Integer> labelMap = Maps.newLinkedHashMap(); >> private final int[][] confusionMatrix; >> private String defaultLabel = "unknown"; >> >> public ConfusionMatrix(Collection<String> labels, String defaultLabel) { >> confusionMatrix = new int[labels.size() + 1][labels.size() + 1]; >> this.defaultLabel = defaultLabel; >> + int i = 0; >> for (String label : labels) { >> - labelMap.put(label, labelMap.size()); >> + labelMap.put(label, i++); >> } >> - labelMap.put(defaultLabel, labelMap.size()); >> + labelMap.put(defaultLabel, i); >> + } >> + >> + public ConfusionMatrix(Matrix m) { >> + confusionMatrix = new int[m.numRows()][m.numRows()]; >> + setMatrix(m); >> } >> >> public int[][] getConfusionMatrix() { >> @@ -76,7 +81,7 @@ public class ConfusionMatrix { >> return confusionMatrix[labelId][labelId]; >> } >> >> - public double getTotal(String label) { >> + public int getTotal(String label) { >> int labelId = labelMap.get(label); >> int labelTotal = 0; >> for (int i = 0; i < labelMap.size(); i++) { >> @@ -94,25 +99,25 @@ public class ConfusionMatrix { >> } >> >> public int getCount(String correctLabel, String classifiedLabel) { >> - Preconditions.checkArgument(labelMap.containsKey(correctLabel), >> - "Label not found: " + correctLabel); >> - Preconditions.checkArgument(labelMap.containsKey(classifiedLabel), >> - "Label not found: " + classifiedLabel); >> + Preconditions.checkArgument(labelMap.containsKey(correctLabel), "Label >> not found: " + correctLabel); >> + Preconditions.checkArgument(labelMap.containsKey(classifiedLabel), >> "Label not found: " + classifiedLabel); >> int correctId = labelMap.get(correctLabel); >> int classifiedId = labelMap.get(classifiedLabel); >> return confusionMatrix[correctId][classifiedId]; >> } >> >> public void putCount(String correctLabel, String classifiedLabel, int >> count) { >> - Preconditions.checkArgument(labelMap.containsKey(correctLabel), >> - "Label not found: " + correctLabel); >> - Preconditions.checkArgument(labelMap.containsKey(classifiedLabel), >> - "Label not found: " + classifiedLabel); >> + Preconditions.checkArgument(labelMap.containsKey(correctLabel), "Label >> not found: " + correctLabel); >> + Preconditions.checkArgument(labelMap.containsKey(classifiedLabel), >> "Label not found: " + classifiedLabel); >> int correctId = labelMap.get(correctLabel); >> int classifiedId = labelMap.get(classifiedLabel); >> confusionMatrix[correctId][classifiedId] = count; >> } >> >> + public String getDefaultLabel() { >> + return defaultLabel; >> + } >> + >> public void incrementCount(String correctLabel, String classifiedLabel, >> int count) { >> putCount(correctLabel, classifiedLabel, count + getCount(correctLabel, >> classifiedLabel)); >> } >> @@ -132,45 +137,69 @@ public class ConfusionMatrix { >> } >> >> public Matrix getMatrix() { >> - int length = confusionMatrix.length; >> - Matrix m = new DenseMatrix(length, length); >> - for (int r = 0; r < length; r++) { >> - for (int c = 0; c < length; c++) { >> - m.set(r, c, confusionMatrix[r][c]); >> - } >> - } >> - Map<String,Integer> labels = Maps.newHashMap(); >> - for (Map.Entry<String, Integer> entry : labelMap.entrySet()) { >> - labels.put(entry.getKey(), entry.getValue()); >> - } >> - m.setRowLabelBindings(labels); >> - m.setColumnLabelBindings(labels); >> - return m; >> + int length = confusionMatrix.length; >> + Matrix m = new DenseMatrix(length, length); >> + for (int r = 0; r < length; r++) { >> + for (int c = 0; c < length; c++) { >> + m.set(r, c, confusionMatrix[r][c]); >> + } >> + } >> + Map<String,Integer> labels = Maps.newHashMap(); >> + for(Map.Entry<String, Integer> entry : labelMap.entrySet()) { >> + labels.put(entry.getKey(), entry.getValue()); >> + } >> + m.setRowLabelBindings(labels); >> + m.setColumnLabelBindings(labels); >> + return m; >> } >> - >> + >> public void setMatrix(Matrix m) { >> - int length = confusionMatrix.length; >> - if (m.numRows() != m.numCols()) { >> - throw new CardinalityException(m.numRows(), m.numCols()); >> - } >> - if (m.numRows() != length) { >> - throw new CardinalityException(m.numRows(), length); >> - } >> - for (int r = 0; r < length; r++) { >> - for (int c = 0; c < length; c++) { >> - confusionMatrix[r][c] = (int) Math.round(m.get(r, c)); >> - } >> - } >> - Map<String,Integer> labels = m.getRowLabelBindings(); >> - if (labels == null) { >> + int length = confusionMatrix.length; >> + if (m.numRows() != m.numCols()) { >> + throw new IllegalArgumentException( >> + String.format("ConfusionMatrix: matrix({},{}) must be square", >> m.numRows(), m.numCols())); >> + } >> + for (int r = 0; r < length; r++) { >> + for (int c = 0; c < length; c++) { >> + confusionMatrix[r][c] = (int) Math.round(m.get(r, c)); >> + } >> + } >> + Map<String,Integer> labels = m.getRowLabelBindings(); >> + if (labels == null) { >> labels = m.getColumnLabelBindings(); >> } >> - labelMap.clear(); >> - if (labels != null) { >> - labelMap.putAll(labels); >> - } >> + if (labels != null) { >> + String[] sorted = sortLabels(labels); >> + verifyLabels(length, sorted); >> + labelMap.clear(); >> + for(int i = 0; i < length; i++) { >> + labelMap.put(sorted[i], i); >> + } >> + } >> + } >> + >> + private static String[] sortLabels(Map<String,Integer> labels) { >> + String[] sorted = new String[labels.keySet().size()]; >> + for(String label: labels.keySet()) { >> + Integer index = labels.get(label); >> + sorted[index] = label; >> + } >> + return sorted; >> + } >> + >> + private void verifyLabels(int length, String[] sorted) { >> + Preconditions.checkArgument(sorted.length == length, "One label, one >> row"); >> + for(int i = 0; i < length; i++) { >> + if (sorted[i] == null) { >> + Preconditions.checkArgument(false, "One label, one row"); >> + } >> + } >> } >> >> + /** >> + * This is overloaded. toString() is not a formatted report you print for >> a manager :) >> + * Assume that if there are no default assignments, the default feature >> was not used >> + */ >> @Override >> public String toString() { >> StringBuilder returnString = new StringBuilder(200); >> @@ -178,26 +207,37 @@ public class ConfusionMatrix { >> returnString.append("Confusion Matrix\n"); >> >> returnString.append("-------------------------------------------------------").append('\n'); >> >> + int unclassified = getTotal(defaultLabel); >> for (Map.Entry<String,Integer> entry : this.labelMap.entrySet()) { >> + if (entry.getKey().equals(defaultLabel) && unclassified == 0) { >> + continue; >> + } >> + >> >> returnString.append(StringUtils.rightPad(getSmallLabel(entry.getValue()), >> 5)).append('\t'); >> } >> >> returnString.append("<--Classified as").append('\n'); >> - >> for (Map.Entry<String,Integer> entry : this.labelMap.entrySet()) { >> + if (entry.getKey().equals(defaultLabel) && unclassified == 0) { >> + continue; >> + } >> String correctLabel = entry.getKey(); >> int labelTotal = 0; >> for (String classifiedLabel : this.labelMap.keySet()) { >> + if (classifiedLabel.equals(defaultLabel) && unclassified == 0) { >> + continue; >> + } >> returnString.append( >> - StringUtils.rightPad(Integer.toString(getCount(correctLabel, >> classifiedLabel)), 5)).append('\t'); >> + StringUtils.rightPad(Integer.toString(getCount(correctLabel, >> classifiedLabel)), 5)).append('\t'); >> labelTotal += getCount(correctLabel, classifiedLabel); >> } >> returnString.append(" | >> ").append(StringUtils.rightPad(String.valueOf(labelTotal), 6)).append('\t') >> - .append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 5)) >> - .append(" = ").append(correctLabel).append('\n'); >> + .append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 5)) >> + .append(" = ").append(correctLabel).append('\n'); >> + } >> + if (unclassified > 0) { >> + returnString.append("Default Category: >> ").append(defaultLabel).append(": ").append(unclassified).append('\n'); >> } >> - returnString.append("Default Category: >> ").append(defaultLabel).append(": ").append( >> - labelMap.get(defaultLabel)).append('\n'); >> returnString.append('\n'); >> return returnString.toString(); >> } >> @@ -212,5 +252,5 @@ public class ConfusionMatrix { >> } while (val > 0); >> return returnString.toString(); >> } >> - >> + >> } >> >> Modified: >> mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java >> URL: >> http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java?rev=1197510&r1=1197509&r2=1197510&view=diff >> ============================================================================== >> --- >> mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java >> (original) >> +++ >> mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java >> Fri Nov 4 11:20:03 2011 >> @@ -33,6 +33,7 @@ import org.apache.commons.cli2.builder.D >> import org.apache.commons.cli2.builder.GroupBuilder; >> import org.apache.commons.cli2.commandline.Parser; >> import org.apache.mahout.classifier.ClassifierResult; >> +import org.apache.mahout.classifier.ConfusionMatrix; >> import org.apache.mahout.classifier.ResultAnalyzer; >> import >> org.apache.mahout.classifier.bayes.mapreduce.bayes.BayesClassifierDriver; >> import org.apache.mahout.common.CommandLineUtil; >> @@ -104,9 +105,14 @@ public final class TestClassifier { >> "Method of Classification: sequential|mapreduce. Default Value: >> sequential").withShortName("method") >> .create(); >> >> + Option confusionMatrixOpt = >> obuilder.withLongName("confusionMatrix").withRequired(false).withArgument( >> + >> abuilder.withName("confusionMatrix").withMinimum(1).withMaximum(1).create()).withDescription( >> + "Export ConfusionMatrix as >> SequenceFile").withShortName("cm").create(); >> + >> Group group = >> gbuilder.withName("Options").withOption(defaultCatOpt).withOption(dirOpt).withOption( >> >> encodingOpt).withOption(gramSizeOpt).withOption(pathOpt).withOption(typeOpt).withOption(dataSourceOpt) >> - >> .withOption(helpOpt).withOption(methodOpt).withOption(verboseOutputOpt).withOption(alphaOpt).create(); >> + >> .withOption(helpOpt).withOption(methodOpt).withOption(verboseOutputOpt).withOption(alphaOpt) >> + .withOption(confusionMatrixOpt).create(); >> >> try { >> Parser parser = new Parser(); >> @@ -163,6 +169,11 @@ public final class TestClassifier { >> classificationMethod = (String) cmdLine.getValue(methodOpt); >> } >> >> + String confusionMatrixFile = null; >> + if (cmdLine.hasOption(confusionMatrixOpt)) { >> + confusionMatrixFile = (String) cmdLine.getValue(confusionMatrixOpt); >> + } >> + >> params.setGramSize(gramSize); >> params.set("verbose", Boolean.toString(verbose)); >> params.setBasePath(modelBasePath); >> @@ -172,6 +183,7 @@ public final class TestClassifier { >> params.set("encoding", encoding); >> params.set("alpha_i", alphaI); >> params.set("testDirPath", testDirPath); >> + params.set("confusionMatrix", confusionMatrixFile); >> >> if ("sequential".equalsIgnoreCase(classificationMethod)) { >> classifySequential(params); >> @@ -253,12 +265,12 @@ public final class TestClassifier { >> } >> lineNum++; >> } >> - /* >> - * log.info("{}\t{}\t{}/{}", new Object[] {correctLabel, >> - * resultAnalyzer.getConfusionMatrix().getAccuracy(correctLabel), >> - * resultAnalyzer.getConfusionMatrix().getCorrect(correctLabel), >> - * resultAnalyzer.getConfusionMatrix().getTotal(correctLabel)}); >> - */ >> + ConfusionMatrix matrix = resultAnalyzer.getConfusionMatrix(); >> + log.info("{}", matrix); >> + BayesClassifierDriver.confusionMatrixSeqFileExport(params, matrix); >> + >> + log.info("ConfusionMatrix: {}", matrix.toString()); >> + >> log.info("Classified instances from {}", file.getName()); >> if (verbose) { >> log.info("Performance stats {}", operationStats.toString()); >> >> Modified: >> mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesClassifierDriver.java >> URL: >> http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesClassifierDriver.java?rev=1197510&r1=1197509&r2=1197510&view=diff >> ============================================================================== >> --- >> mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesClassifierDriver.java >> (original) >> +++ >> mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesClassifierDriver.java >> Fri Nov 4 11:20:03 2011 >> @@ -21,10 +21,15 @@ import java.io.IOException; >> import java.util.Map; >> >> import com.google.common.collect.Maps; >> +import com.google.common.io.Closeables; >> + >> import org.apache.hadoop.conf.Configurable; >> import org.apache.hadoop.conf.Configuration; >> +import org.apache.hadoop.fs.FileSystem; >> import org.apache.hadoop.fs.Path; >> import org.apache.hadoop.io.DoubleWritable; >> +import org.apache.hadoop.io.SequenceFile; >> +import org.apache.hadoop.io.Text; >> import org.apache.hadoop.mapred.FileInputFormat; >> import org.apache.hadoop.mapred.FileOutputFormat; >> import org.apache.hadoop.mapred.JobClient; >> @@ -38,6 +43,7 @@ import org.apache.mahout.common.Paramete >> import org.apache.mahout.common.StringTuple; >> import org.apache.mahout.common.iterator.sequencefile.PathType; >> import >> org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable; >> +import org.apache.mahout.math.MatrixWritable; >> import org.slf4j.Logger; >> import org.slf4j.LoggerFactory; >> >> @@ -83,6 +89,10 @@ public final class BayesClassifierDriver >> Path outputFiles = new Path(outPath, "part*"); >> ConfusionMatrix matrix = readResult(outputFiles, conf, params); >> log.info("{}", matrix); >> + if (params.get("confusionMatrix") != null) { >> + confusionMatrixSeqFileExport(params, matrix); >> + } >> + >> } >> >> public static ConfusionMatrix readResult(Path pathPattern, Configuration >> conf, Parameters params) { >> @@ -117,6 +127,24 @@ public final class BayesClassifierDriver >> } >> } >> return matrix; >> + } >> >> + public static void confusionMatrixSeqFileExport(Parameters params, >> ConfusionMatrix matrix) throws IOException { >> + if (params.get("confusionMatrix") != null) { >> + Configuration conf = new Configuration(); >> + FileSystem fs = FileSystem.get(conf); >> + SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, >> + new Path(params.get("confusionMatrix")), Text.class, >> MatrixWritable.class); >> + String name = params.get("confusionMatrix"); >> + // embed file name as sequence key- useful for tuning classifiers >> + name = name.substring(name.lastIndexOf('/') + 1, name.length()); >> + Text key = new Text(name); >> + MatrixWritable mw = new MatrixWritable(matrix.getMatrix()); >> + try { >> + writer.append(key, mw); >> + } finally { >> + Closeables.closeQuietly(writer); >> + } >> + } >> } >> } >> >> Modified: >> mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java >> URL: >> http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java?rev=1197510&r1=1197509&r2=1197510&view=diff >> ============================================================================== >> --- >> mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java >> (original) >> +++ >> mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java >> Fri Nov 4 11:20:03 2011 >> @@ -123,15 +123,15 @@ public class MatrixWritable implements W >> } >> >> if (hasLabels) { >> - Map<String,Integer> columnLabelBindings = Maps.newHashMap(); >> - Map<String,Integer> rowLabelBindings = Maps.newHashMap(); >> - readLabels(in, columnLabelBindings, rowLabelBindings); >> - if (!columnLabelBindings.isEmpty()) { >> - r.setColumnLabelBindings(columnLabelBindings); >> - } >> - if (!rowLabelBindings.isEmpty()) { >> - r.setRowLabelBindings(rowLabelBindings); >> - } >> + Map<String,Integer> columnLabelBindings = Maps.newHashMap(); >> + Map<String,Integer> rowLabelBindings = Maps.newHashMap(); >> + readLabels(in, columnLabelBindings, rowLabelBindings); >> + if (!columnLabelBindings.isEmpty()) { >> + r.setColumnLabelBindings(columnLabelBindings); >> + } >> + if (!rowLabelBindings.isEmpty()) { >> + r.setRowLabelBindings(rowLabelBindings); >> + } >> } >> >> return r; >> @@ -159,7 +159,7 @@ public class MatrixWritable implements W >> VectorWritable.writeVector(out, matrix.viewRow(i), false); >> } >> if ((flags & FLAG_LABELS) != 0) { >> - writeLabelBindings(out, matrix.getColumnLabelBindings(), >> matrix.getRowLabelBindings()); >> + writeLabelBindings(out, matrix.getColumnLabelBindings(), >> matrix.getRowLabelBindings()); >> } >> } >> } >> >> Modified: >> mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java >> URL: >> http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java?rev=1197510&r1=1197509&r2=1197510&view=diff >> ============================================================================== >> --- >> mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java >> (original) >> +++ >> mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java >> Fri Nov 4 11:20:03 2011 >> @@ -17,6 +17,7 @@ >> >> package org.apache.mahout.math; >> >> +import com.google.common.collect.Maps; >> import com.google.common.io.Closeables; >> import org.apache.hadoop.io.Writable; >> import org.junit.Test; >> @@ -26,7 +27,6 @@ import java.io.ByteArrayOutputStream; >> import java.io.DataInputStream; >> import java.io.DataOutputStream; >> import java.io.IOException; >> -import java.util.HashMap; >> import java.util.Map; >> >> public final class MatrixWritableTest extends MahoutTestCase { >> @@ -36,13 +36,14 @@ public final class MatrixWritableTest ex >> Matrix m = new SparseMatrix(5, 5); >> m.set(1, 2, 3.0); >> m.set(3, 4, 5.0); >> - Map<String, Integer> bindings = new HashMap<String, Integer>(); >> + Map<String, Integer> bindings = Maps.newHashMap(); >> bindings.put("A", 0); >> bindings.put("B", 1); >> bindings.put("C", 2); >> bindings.put("D", 3); >> bindings.put("default", 4); >> m.setRowLabelBindings(bindings); >> + m.setColumnLabelBindings(bindings); >> doTestMatrixWritableEquals(m); >> } >> >> @@ -51,12 +52,13 @@ public final class MatrixWritableTest ex >> Matrix m = new DenseMatrix(5,5); >> m.set(1, 2, 3.0); >> m.set(3, 4, 5.0); >> - Map<String, Integer> bindings = new HashMap<String, Integer>(); >> + Map<String, Integer> bindings = Maps.newHashMap(); >> bindings.put("A", 0); >> bindings.put("B", 1); >> bindings.put("C", 2); >> bindings.put("D", 3); >> bindings.put("default", 4); >> + m.setRowLabelBindings(bindings); >> m.setColumnLabelBindings(bindings); >> doTestMatrixWritableEquals(m); >> } >> @@ -66,7 +68,9 @@ public final class MatrixWritableTest ex >> MatrixWritable matrixWritable2 = new MatrixWritable(); >> writeAndRead(matrixWritable, matrixWritable2); >> Matrix m2 = matrixWritable2.get(); >> - compareMatrices(m, m2); // not sure this works? >> + compareMatrices(m, m2); >> + doCheckBindings(m2.getRowLabelBindings()); >> + doCheckBindings(m2.getColumnLabelBindings()); >> } >> >> private static void compareMatrices(Matrix m, Matrix m2) { >> @@ -98,6 +102,14 @@ public final class MatrixWritableTest ex >> } >> } >> >> + private static void doCheckBindings(Map<String,Integer> labels) { >> + assertTrue("Missing label", labels.keySet().contains("A")); >> + assertTrue("Missing label", labels.keySet().contains("B")); >> + assertTrue("Missing label", labels.keySet().contains("C")); >> + assertTrue("Missing label", labels.keySet().contains("D")); >> + assertTrue("Missing label", labels.keySet().contains("default")); >> + } >> + >> private static void writeAndRead(Writable toWrite, Writable toRead) >> throws IOException { >> ByteArrayOutputStream baos = new ByteArrayOutputStream(); >> DataOutputStream dos = new DataOutputStream(baos); >> >> Added: >> mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java >> URL: >> http://svn.apache.org/viewvc/mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java?rev=1197510&view=auto >> ============================================================================== >> --- >> mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java >> (added) >> +++ >> mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java >> Fri Nov 4 11:20:03 2011 >> @@ -0,0 +1,423 @@ >> +/* >> + * Licensed to the Apache Software Foundation (ASF) under one or more >> + * contributor license agreements. See the NOTICE file distributed with >> + * this work for additional information regarding copyright ownership. >> + * The ASF licenses this file to You under the Apache License, Version 2.0 >> + * (the "License"); you may not use this file except in compliance with >> + * the License. You may obtain a copy of the License at >> + * >> + * http://www.apache.org/licenses/LICENSE-2.0 >> + * >> + * Unless required by applicable law or agreed to in writing, software >> + * distributed under the License is distributed on an "AS IS" BASIS, >> + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. >> + * See the License for the specific language governing permissions and >> + * limitations under the License. >> + */ >> + >> +package org.apache.mahout.classifier; >> + >> +import java.io.File; >> +import java.io.FileOutputStream; >> +import java.io.IOException; >> +import java.io.OutputStream; >> +import java.io.PrintStream; >> +import java.util.Iterator; >> +import java.util.List; >> +import java.util.Map; >> + >> +import org.apache.hadoop.conf.Configuration; >> +import org.apache.hadoop.fs.FileSystem; >> +import org.apache.hadoop.fs.Path; >> +import org.apache.hadoop.io.SequenceFile; >> +import org.apache.hadoop.io.Text; >> +import org.apache.hadoop.util.ToolRunner; >> +import org.apache.mahout.common.AbstractJob; >> +import org.apache.mahout.common.commandline.DefaultOptionCreator; >> +import org.apache.mahout.math.Matrix; >> +import org.apache.mahout.math.MatrixWritable; >> + >> +import com.google.common.collect.Lists; >> + >> +/** >> + * Export a ConfusionMatrix in various text formats: >> + * ToString version >> + * Grayscale HTML table >> + * Summary HTML table >> + * Table of counts >> + * all with optional HTML wrappers >> + * >> + * Input format: Hadoop SequenceFile with Text key and MatrixWritable >> value, 1 pair >> + * >> + * Intended to consume ConfusionMatrix SequenceFile output by Bayes >> + * TestClassifier class >> + */ >> +public final class ConfusionMatrixDumper extends AbstractJob { >> + >> + // HTML wrapper - default CSS >> + private static final String HEADER = "<html>" + >> + "<head>\n" + >> + "<title>TITLE</title>\n" + >> + "</head>" + >> + "<body>\n" + >> + "<style type='text/css'> \n" + >> + "table\n" + >> + "{\n" + >> + "border:3px solid black; text-align:left;\n" + >> + "}\n" + >> + "th.normalHeader\n" + >> + "{\n" + >> + "border:1px solid >> black;border-collapse:collapse;text-align:center;background-color:white\n" + >> + "}\n" + >> + "th.tallHeader\n" + >> + "{\n" + >> + "border:1px solid >> black;border-collapse:collapse;text-align:center;background-color:white; >> height:6em\n" + >> + "}\n" + >> + "tr.label\n" + >> + "{\n" + >> + "border:1px solid >> black;border-collapse:collapse;text-align:center;background-color:white\n" + >> + "}\n" + >> + "tr.row\n" + >> + "{\n" + >> + "border:1px solid gray;text-align:center;background-color:snow\n" + >> + "}\n" + >> + "td\n" + >> + "{\n" + >> + "min-width:2em\n" + >> + "}\n" + >> + "td.cell\n" + >> + "{\n" + >> + "border:1px solid black;text-align:right;background-color:snow\n" + >> + "}\n" + >> + "td.empty\n" + >> + "{\n" + >> + "border:0px;text-align:right;background-color:snow\n" + >> + "}\n" + >> + "td.white\n" + >> + "{\n" + >> + "border:0px solid black;text-align:right;background-color:white\n" + >> + "}\n" + >> + "td.black\n" + >> + "{\n" + >> + "border:0px solid red;text-align:right;background-color:black\n" + >> + "}\n" + >> + "td.gray1\n" + >> + "{\n" + >> + "border:0px solid green;text-align:right; >> background-color:LightGray\n" + >> + "}\n" + >> + "td.gray2\n" + >> + "{\n" + >> + "border:0px solid blue;text-align:right;background-color:gray\n" + >> + "}\n" + >> + "td.gray3\n" + >> + "{\n" + >> + "border:0px solid red;text-align:right;background-color:DarkGray\n" + >> + "}\n" + >> + "th" + >> + "{\n" + >> + " text-align: center;\n" + >> + " vertical-align: bottom;\n" + >> + " padding-bottom: 3px;\n" + >> + " padding-left: 5px;\n" + >> + " padding-right: 5px;\n" + >> + "}\n" + >> + " .verticalText\n" + >> + " {\n" + >> + " text-align: center;\n" + >> + " vertical-align: middle;\n" + >> + " width: 20px;\n" + >> + " margin: 0px;\n" + >> + " padding: 0px;\n" + >> + " padding-left: 3px;\n" + >> + " padding-right: 3px;\n" + >> + " padding-top: 10px;\n" + >> + " white-space: nowrap;\n" + >> + " -webkit-transform: rotate(-90deg); \n" + >> + " -moz-transform: rotate(-90deg); \n" + >> + " };\n" + >> + "</style>\n"; >> + private static final String FOOTER = "</html></body>"; >> + >> + // CSS style names. >> + private static final String CSS_TABLE = "table"; >> + private static final String CSS_LABEL = "label"; >> + private static final String CSS_TALL_HEADER = "tall"; >> + private static final String CSS_VERTICAL = "verticalText"; >> + private static final String CSS_CELL = "cell"; >> + private static final String CSS_EMPTY = "empty"; >> + private static final String[] CSS_GRAY_CELLS = {"white", "gray1", >> "gray2", "gray3", "black"}; >> + >> + private ConfusionMatrixDumper() {} >> + >> + public static void main(String[] args) throws Exception { >> + ToolRunner.run(new ConfusionMatrixDumper(), args); >> + } >> + >> + @Override >> + public int run(String[] args) throws IOException { >> + addInputOption(); >> + addOption("output", "o", "Output path", null); // AbstractJob output >> feature requires param >> + addOption(DefaultOptionCreator.overwriteOption().create()); >> + addFlag("html", null, "Create complete HTML page"); >> + addFlag("text", null, "Dump simple text"); >> + Map<String, String> parsedArgs = parseArguments(args); >> + if (parsedArgs == null) { >> + return -1; >> + } >> + >> + Path inputPath = getInputPath(); >> + String outputFile = parsedArgs.containsKey("--output") ? >> parsedArgs.get("--output") : null; >> + boolean text = parsedArgs.containsKey("--text"); >> + boolean wrapHtml = parsedArgs.containsKey("--html"); >> + PrintStream out = getPrintStream(outputFile); >> + if (text) { >> + exportText(inputPath, out); >> + } else { >> + exportTable(inputPath, out, wrapHtml); >> + } >> + out.flush(); >> + if (out != System.out) { >> + out.close(); >> + } >> + return 0; >> + } >> + >> + private static void exportText(Path inputPath, PrintStream out) throws >> IOException { >> + MatrixWritable mw = new MatrixWritable(); >> + Text key = new Text(); >> + readSeqFile(inputPath, key, mw); >> + Matrix m = mw.get(); >> + ConfusionMatrix cm = new ConfusionMatrix(m); >> + out.println(cm.toString()); >> + } >> + >> + private static void exportTable(Path inputPath, PrintStream out, boolean >> wrapHtml) throws IOException { >> + MatrixWritable mw = new MatrixWritable(); >> + Text key = new Text(); >> + readSeqFile(inputPath, key, mw); >> + String fileName = inputPath.getName(); >> + fileName = fileName.substring(fileName.lastIndexOf('/') + 1, >> fileName.length()); >> + Matrix m = mw.get(); >> + ConfusionMatrix cm = new ConfusionMatrix(m); >> + if (wrapHtml) { >> + printHeader(out, fileName); >> + } >> + out.println("<p/>"); >> + printSummaryTable(cm, out); >> + out.println("<p/>"); >> + printGrayTable(cm, out); >> + out.println("<p/>"); >> + printCountsTable(cm, out); >> + out.println("<p/>"); >> + printTextInBox(cm, out); >> + out.println("<p/>"); >> + if (wrapHtml) { >> + printFooter(out); >> + } >> + } >> + >> + private static List<String> stripDefault(ConfusionMatrix cm) { >> + List<String> stripped = Lists.newArrayList(cm.getLabels().iterator()); >> + String defaultLabel = cm.getDefaultLabel(); >> + int unclassified = cm.getTotal(defaultLabel); >> + if (unclassified > 0) { >> + return stripped; >> + } >> + stripped.remove(defaultLabel); >> + return stripped; >> + } >> + >> + // TODO: test - this should work with HDFS files >> + private static void readSeqFile(Path path, Text key, MatrixWritable m) >> throws IOException { >> + Configuration conf = new Configuration(); >> + FileSystem fs = FileSystem.get(conf); >> + SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, conf); >> + reader.next(key, m); >> + } >> + >> + // TODO: test - this might not work with HDFS files? >> + // after all, it does no seeks >> + private static PrintStream getPrintStream(String outputFilename) throws >> IOException { >> + if (outputFilename != null) { >> + File outputFile = new File(outputFilename); >> + if (outputFile.exists()) { >> + outputFile.delete(); >> + } >> + outputFile.createNewFile(); >> + OutputStream os = new FileOutputStream(outputFile); >> + return new PrintStream(os); >> + } else { >> + return System.out; >> + } >> + } >> + >> + private static int getLabelTotal(ConfusionMatrix cm, String rowLabel) { >> + Iterator<String> iter = cm.getLabels().iterator(); >> + int count = 0; >> + while(iter.hasNext()) { >> + count += cm.getCount(rowLabel, iter.next()); >> + } >> + return count; >> + } >> + >> + // HTML generator code >> + >> + private static void printTextInBox(ConfusionMatrix cm, PrintStream out) { >> + out.println("<div style='width:90%;overflow:scroll;'>"); >> + out.println("<pre>"); >> + out.println(cm.toString()); >> + out.println("</pre>"); >> + out.println("</div>"); >> + } >> + >> + public static void printSummaryTable(ConfusionMatrix cm, PrintStream out) >> { >> + format("<table class='%s'>\n", out, CSS_TABLE); >> + format("<tr class='%s'>", out, CSS_LABEL); >> + out.println("<td>Label</td><td>Total</td><td>Correct</td><td>%</td>"); >> + out.println("</tr>"); >> + List<String> labels = stripDefault(cm); >> + for(String label: labels) { >> + printSummaryRow(cm, out, label); >> + } >> + out.println("</table>"); >> + } >> + >> + private static void printSummaryRow(ConfusionMatrix cm, PrintStream out, >> String label) { >> + format("<tr class='%s'>", out, CSS_CELL); >> + int correct = cm.getCorrect(label); >> + double accuracy = cm.getAccuracy(label); >> + int count = getCount(cm, label); >> + format("<td class='%s'>%s</td><td>%d</td><td>%d</td><td>%d</td>", >> + out, CSS_CELL, label, count, correct, (int) >> Math.round(accuracy)); >> + out.println("</tr>"); >> + } >> + >> + private static int getCount(ConfusionMatrix cm, String label) { >> + int count = 0; >> + for (String s : cm.getLabels()) { >> + count += cm.getCount(label, s); >> + } >> + return count; >> + } >> + >> + public static void printGrayTable(ConfusionMatrix cm, PrintStream out) { >> + format("<table class='%s'>\n", out, CSS_TABLE); >> + printCountsHeader(cm, out, true); >> + printGrayRows(cm, out); >> + out.println("</table>"); >> + } >> + >> + /** >> + * Print each value in a four-value grayscale based on count/max. >> + * Gives a mostly white matrix with grays in misclassified, and black in >> diagonal. >> + * TODO: Using the sqrt(count/max) as the rating is more stringent >> + */ >> + private static void printGrayRows(ConfusionMatrix cm, PrintStream out) { >> + List<String> labels = stripDefault(cm); >> + for (String label: labels) { >> + printGrayRow(cm, out, labels, label); >> + } >> + } >> + >> + private static void printGrayRow(ConfusionMatrix cm, PrintStream out, >> Iterable<String> labels, String rowLabel) { >> + format("<tr class='%s'>", out, CSS_LABEL); >> + format("<td>%s</td>", out, rowLabel); >> + int total = getLabelTotal(cm, rowLabel); >> + for (String columnLabel: labels) { >> + printGrayCell(cm, out, total, rowLabel, columnLabel); >> + } >> + out.println("</tr>"); >> + } >> + >> + // assign white/light/medium/dark to 0,1/4,1/2,3/4 of total number of >> inputs >> + // assign black to count = total, meaning complete success >> + // alternative rating is to use sqrt(total) instead of total - this is >> more drastic >> + private static void printGrayCell(ConfusionMatrix cm, >> + PrintStream out, >> + int total, >> + String rowLabel, >> + String columnLabel) { >> + >> + int count = cm.getCount(rowLabel, columnLabel); >> + if (count == 0) { >> + out.format("<td class='%s'/>", CSS_EMPTY); >> + } else { >> + // 0 is white, full is black, everything else gray >> + int rating = (int) ((count/ (double) total) * 4); >> + String css = CSS_GRAY_CELLS[rating]; >> + format("<td class='%s' title='%s'>%s</td>", out, css, columnLabel, >> count); >> + } >> + } >> + >> + public static void printCountsTable(ConfusionMatrix cm, PrintStream out) { >> + int length = cm.getLabels().size(); >> + format("<table class='%s'>\n", out, CSS_TABLE); >> + printCountsHeader(cm, out, false); >> + printCountsRows(cm, out); >> + out.println("</table>"); >> + } >> + >> + private static void printCountsRows(ConfusionMatrix cm, PrintStream out) { >> + List<String> labels = stripDefault(cm); >> + for(String label: labels) { >> + printCountsRow(cm, out, labels, label); >> + } >> + } >> + >> + private static void printCountsRow(ConfusionMatrix cm, PrintStream out, >> Iterable<String> labels, String rowLabel) { >> + out.println("<tr>"); >> + format("<td class='%s'>%s</td>", out, CSS_LABEL, rowLabel); >> + for(String columnLabel: labels) { >> + printCountsCell(cm, out, rowLabel, columnLabel); >> + } >> + out.println("</tr>"); >> + } >> + >> + private static void printCountsCell(ConfusionMatrix cm, PrintStream out, >> String rowLabel, String columnLabel) { >> + int count = cm.getCount(rowLabel, columnLabel); >> + String s = count == 0 ? "" : Integer.toString(count); >> + format("<td class='%s' title='%s'>%s</td>", out, CSS_CELL, columnLabel, >> s); >> + } >> + >> + private static void printCountsHeader(ConfusionMatrix cm, PrintStream >> out, boolean vertical) { >> + List<String> labels = stripDefault(cm); >> + int longest = getLongestHeader(labels); >> + if (vertical) { >> + // do vertical - rotation is a bitch >> + out.format("<tr class='%s' style='height:%dem'><th> </th>\n", >> CSS_TALL_HEADER, longest/2); >> + for(String label: labels) { >> + out.format("<th><div class='%s'>%s</div></th>", CSS_VERTICAL, >> label); >> + } >> + out.println("</tr>"); >> + } else { >> + // header - empty cell in upper left >> + out.format("<tr class='%s'><td class='%s'></td>\n", CSS_TABLE, >> CSS_LABEL); >> + for(String label: labels) { >> + out.format("<td>%s</td>", label); >> + } >> + out.format("</tr>"); >> + } >> + } >> + >> + private static int getLongestHeader(Iterable<String> labels) { >> + int max = 0; >> + for (String label: labels) { >> + max = Math.max(label.length(), max); >> + } >> + return max; >> + } >> + >> + private static void format(String format, PrintStream out, Object ... >> args) { >> + String format2 = String.format(format, args); >> + out.println(format2); >> + } >> + >> + public static void printHeader(PrintStream out, CharSequence title) { >> + out.println(HEADER.replace("TITLE", title)); >> + } >> + >> + public static void printFooter(PrintStream out) { >> + out.println(FOOTER); >> + } >> + >> +} >> >> Modified: mahout/trunk/src/conf/driver.classes.props >> URL: >> http://svn.apache.org/viewvc/mahout/trunk/src/conf/driver.classes.props?rev=1197510&r1=1197509&r2=1197510&view=diff >> ============================================================================== >> --- mahout/trunk/src/conf/driver.classes.props (original) >> +++ mahout/trunk/src/conf/driver.classes.props Fri Nov 4 11:20:03 2011 >> @@ -46,4 +46,6 @@ org.apache.mahout.classifier.sequencelea >> org.apache.mahout.classifier.sequencelearning.hmm.RandomSequenceGenerator = >> hmmpredict : Generate random sequence of observations by given HMM >> org.apache.mahout.utils.SplitInput = split : Split Input data into test and >> train sets >> org.apache.mahout.classifier.naivebayes.training.TrainNaiveBayesJob = >> trainnb : Train the Vector-based Bayes classifier >> -org.apache.mahout.classifier.naivebayes.test.TestNaiveBayesDriver = testnb >> : Test the Vector-based Bayes classifier >> \ No newline at end of file >> +org.apache.mahout.classifier.naivebayes.test.TestNaiveBayesDriver = testnb >> : Test the Vector-based Bayes classifier >> +org.apache.mahout.classifier.ConfusionMatrixDumper = cmdump : Dump >> confusion matrix in HTML or text formats >> +org.apache.mahout.utils.MatrixDumper = matrixdump : Dump matrix in CSV >> format >> \ No newline at end of file >> >> > > -------------------------------------------- > Grant Ingersoll > http://www.lucidimagination.com > > > -------------------------------------------- Grant Ingersoll http://www.lucidimagination.com
