Is that supposed be ConfusionMatrixDumper in the driver.classes.props? On Nov 4, 2011, at 7:20 AM, [email protected] wrote:
> Author: srowen > Date: Fri Nov 4 11:20:03 2011 > New Revision: 1197510 > > URL: http://svn.apache.org/viewvc?rev=1197510&view=rev > Log: > MAHOUT-838 Add confusion matrix dumper > > Added: > mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ > > mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java > Modified: > > mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java > > mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java > > mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesClassifierDriver.java > mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java > > mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java > mahout/trunk/src/conf/driver.classes.props > > Modified: > mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java > URL: > http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java?rev=1197510&r1=1197509&r2=1197510&view=diff > ============================================================================== > --- > mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java > (original) > +++ > mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java > Fri Nov 4 11:20:03 2011 > @@ -19,35 +19,40 @@ package org.apache.mahout.classifier; > > import java.util.Collection; > import java.util.Collections; > -import java.util.LinkedHashMap; > import java.util.Map; > > -import com.google.common.collect.Maps; > import org.apache.commons.lang.StringUtils; > -import org.apache.mahout.math.CardinalityException; > import org.apache.mahout.math.DenseMatrix; > import org.apache.mahout.math.Matrix; > > import com.google.common.base.Preconditions; > +import com.google.common.collect.Maps; > > /** > * The ConfusionMatrix Class stores the result of Classification of a Test > Dataset. > * > + * The fact of whether there is a default is not stored. A row of zeros is > the only indicator that there is no default. > + * > * See http://en.wikipedia.org/wiki/Confusion_matrix for background > */ > public class ConfusionMatrix { > - > - private final Map<String,Integer> labelMap = new > LinkedHashMap<String,Integer>(); > + private final Map<String,Integer> labelMap = Maps.newLinkedHashMap(); > private final int[][] confusionMatrix; > private String defaultLabel = "unknown"; > > public ConfusionMatrix(Collection<String> labels, String defaultLabel) { > confusionMatrix = new int[labels.size() + 1][labels.size() + 1]; > this.defaultLabel = defaultLabel; > + int i = 0; > for (String label : labels) { > - labelMap.put(label, labelMap.size()); > + labelMap.put(label, i++); > } > - labelMap.put(defaultLabel, labelMap.size()); > + labelMap.put(defaultLabel, i); > + } > + > + public ConfusionMatrix(Matrix m) { > + confusionMatrix = new int[m.numRows()][m.numRows()]; > + setMatrix(m); > } > > public int[][] getConfusionMatrix() { > @@ -76,7 +81,7 @@ public class ConfusionMatrix { > return confusionMatrix[labelId][labelId]; > } > > - public double getTotal(String label) { > + public int getTotal(String label) { > int labelId = labelMap.get(label); > int labelTotal = 0; > for (int i = 0; i < labelMap.size(); i++) { > @@ -94,25 +99,25 @@ public class ConfusionMatrix { > } > > public int getCount(String correctLabel, String classifiedLabel) { > - Preconditions.checkArgument(labelMap.containsKey(correctLabel), > - "Label not found: " + correctLabel); > - Preconditions.checkArgument(labelMap.containsKey(classifiedLabel), > - "Label not found: " + classifiedLabel); > + Preconditions.checkArgument(labelMap.containsKey(correctLabel), "Label > not found: " + correctLabel); > + Preconditions.checkArgument(labelMap.containsKey(classifiedLabel), > "Label not found: " + classifiedLabel); > int correctId = labelMap.get(correctLabel); > int classifiedId = labelMap.get(classifiedLabel); > return confusionMatrix[correctId][classifiedId]; > } > > public void putCount(String correctLabel, String classifiedLabel, int > count) { > - Preconditions.checkArgument(labelMap.containsKey(correctLabel), > - "Label not found: " + correctLabel); > - Preconditions.checkArgument(labelMap.containsKey(classifiedLabel), > - "Label not found: " + classifiedLabel); > + Preconditions.checkArgument(labelMap.containsKey(correctLabel), "Label > not found: " + correctLabel); > + Preconditions.checkArgument(labelMap.containsKey(classifiedLabel), > "Label not found: " + classifiedLabel); > int correctId = labelMap.get(correctLabel); > int classifiedId = labelMap.get(classifiedLabel); > confusionMatrix[correctId][classifiedId] = count; > } > > + public String getDefaultLabel() { > + return defaultLabel; > + } > + > public void incrementCount(String correctLabel, String classifiedLabel, int > count) { > putCount(correctLabel, classifiedLabel, count + getCount(correctLabel, > classifiedLabel)); > } > @@ -132,45 +137,69 @@ public class ConfusionMatrix { > } > > public Matrix getMatrix() { > - int length = confusionMatrix.length; > - Matrix m = new DenseMatrix(length, length); > - for (int r = 0; r < length; r++) { > - for (int c = 0; c < length; c++) { > - m.set(r, c, confusionMatrix[r][c]); > - } > - } > - Map<String,Integer> labels = Maps.newHashMap(); > - for (Map.Entry<String, Integer> entry : labelMap.entrySet()) { > - labels.put(entry.getKey(), entry.getValue()); > - } > - m.setRowLabelBindings(labels); > - m.setColumnLabelBindings(labels); > - return m; > + int length = confusionMatrix.length; > + Matrix m = new DenseMatrix(length, length); > + for (int r = 0; r < length; r++) { > + for (int c = 0; c < length; c++) { > + m.set(r, c, confusionMatrix[r][c]); > + } > + } > + Map<String,Integer> labels = Maps.newHashMap(); > + for(Map.Entry<String, Integer> entry : labelMap.entrySet()) { > + labels.put(entry.getKey(), entry.getValue()); > + } > + m.setRowLabelBindings(labels); > + m.setColumnLabelBindings(labels); > + return m; > } > - > + > public void setMatrix(Matrix m) { > - int length = confusionMatrix.length; > - if (m.numRows() != m.numCols()) { > - throw new CardinalityException(m.numRows(), m.numCols()); > - } > - if (m.numRows() != length) { > - throw new CardinalityException(m.numRows(), length); > - } > - for (int r = 0; r < length; r++) { > - for (int c = 0; c < length; c++) { > - confusionMatrix[r][c] = (int) Math.round(m.get(r, c)); > - } > - } > - Map<String,Integer> labels = m.getRowLabelBindings(); > - if (labels == null) { > + int length = confusionMatrix.length; > + if (m.numRows() != m.numCols()) { > + throw new IllegalArgumentException( > + String.format("ConfusionMatrix: matrix({},{}) must be square", > m.numRows(), m.numCols())); > + } > + for (int r = 0; r < length; r++) { > + for (int c = 0; c < length; c++) { > + confusionMatrix[r][c] = (int) Math.round(m.get(r, c)); > + } > + } > + Map<String,Integer> labels = m.getRowLabelBindings(); > + if (labels == null) { > labels = m.getColumnLabelBindings(); > } > - labelMap.clear(); > - if (labels != null) { > - labelMap.putAll(labels); > - } > + if (labels != null) { > + String[] sorted = sortLabels(labels); > + verifyLabels(length, sorted); > + labelMap.clear(); > + for(int i = 0; i < length; i++) { > + labelMap.put(sorted[i], i); > + } > + } > + } > + > + private static String[] sortLabels(Map<String,Integer> labels) { > + String[] sorted = new String[labels.keySet().size()]; > + for(String label: labels.keySet()) { > + Integer index = labels.get(label); > + sorted[index] = label; > + } > + return sorted; > + } > + > + private void verifyLabels(int length, String[] sorted) { > + Preconditions.checkArgument(sorted.length == length, "One label, one > row"); > + for(int i = 0; i < length; i++) { > + if (sorted[i] == null) { > + Preconditions.checkArgument(false, "One label, one row"); > + } > + } > } > > + /** > + * This is overloaded. toString() is not a formatted report you print for > a manager :) > + * Assume that if there are no default assignments, the default feature > was not used > + */ > @Override > public String toString() { > StringBuilder returnString = new StringBuilder(200); > @@ -178,26 +207,37 @@ public class ConfusionMatrix { > returnString.append("Confusion Matrix\n"); > > returnString.append("-------------------------------------------------------").append('\n'); > > + int unclassified = getTotal(defaultLabel); > for (Map.Entry<String,Integer> entry : this.labelMap.entrySet()) { > + if (entry.getKey().equals(defaultLabel) && unclassified == 0) { > + continue; > + } > + > > returnString.append(StringUtils.rightPad(getSmallLabel(entry.getValue()), > 5)).append('\t'); > } > > returnString.append("<--Classified as").append('\n'); > - > for (Map.Entry<String,Integer> entry : this.labelMap.entrySet()) { > + if (entry.getKey().equals(defaultLabel) && unclassified == 0) { > + continue; > + } > String correctLabel = entry.getKey(); > int labelTotal = 0; > for (String classifiedLabel : this.labelMap.keySet()) { > + if (classifiedLabel.equals(defaultLabel) && unclassified == 0) { > + continue; > + } > returnString.append( > - StringUtils.rightPad(Integer.toString(getCount(correctLabel, > classifiedLabel)), 5)).append('\t'); > + StringUtils.rightPad(Integer.toString(getCount(correctLabel, > classifiedLabel)), 5)).append('\t'); > labelTotal += getCount(correctLabel, classifiedLabel); > } > returnString.append(" | > ").append(StringUtils.rightPad(String.valueOf(labelTotal), 6)).append('\t') > - .append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 5)) > - .append(" = ").append(correctLabel).append('\n'); > + .append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 5)) > + .append(" = ").append(correctLabel).append('\n'); > + } > + if (unclassified > 0) { > + returnString.append("Default Category: > ").append(defaultLabel).append(": ").append(unclassified).append('\n'); > } > - returnString.append("Default Category: ").append(defaultLabel).append(": > ").append( > - labelMap.get(defaultLabel)).append('\n'); > returnString.append('\n'); > return returnString.toString(); > } > @@ -212,5 +252,5 @@ public class ConfusionMatrix { > } while (val > 0); > return returnString.toString(); > } > - > + > } > > Modified: > mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java > URL: > http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java?rev=1197510&r1=1197509&r2=1197510&view=diff > ============================================================================== > --- > mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java > (original) > +++ > mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java > Fri Nov 4 11:20:03 2011 > @@ -33,6 +33,7 @@ import org.apache.commons.cli2.builder.D > import org.apache.commons.cli2.builder.GroupBuilder; > import org.apache.commons.cli2.commandline.Parser; > import org.apache.mahout.classifier.ClassifierResult; > +import org.apache.mahout.classifier.ConfusionMatrix; > import org.apache.mahout.classifier.ResultAnalyzer; > import > org.apache.mahout.classifier.bayes.mapreduce.bayes.BayesClassifierDriver; > import org.apache.mahout.common.CommandLineUtil; > @@ -104,9 +105,14 @@ public final class TestClassifier { > "Method of Classification: sequential|mapreduce. Default Value: > sequential").withShortName("method") > .create(); > > + Option confusionMatrixOpt = > obuilder.withLongName("confusionMatrix").withRequired(false).withArgument( > + > abuilder.withName("confusionMatrix").withMinimum(1).withMaximum(1).create()).withDescription( > + "Export ConfusionMatrix as > SequenceFile").withShortName("cm").create(); > + > Group group = > gbuilder.withName("Options").withOption(defaultCatOpt).withOption(dirOpt).withOption( > > encodingOpt).withOption(gramSizeOpt).withOption(pathOpt).withOption(typeOpt).withOption(dataSourceOpt) > - > .withOption(helpOpt).withOption(methodOpt).withOption(verboseOutputOpt).withOption(alphaOpt).create(); > + > .withOption(helpOpt).withOption(methodOpt).withOption(verboseOutputOpt).withOption(alphaOpt) > + .withOption(confusionMatrixOpt).create(); > > try { > Parser parser = new Parser(); > @@ -163,6 +169,11 @@ public final class TestClassifier { > classificationMethod = (String) cmdLine.getValue(methodOpt); > } > > + String confusionMatrixFile = null; > + if (cmdLine.hasOption(confusionMatrixOpt)) { > + confusionMatrixFile = (String) cmdLine.getValue(confusionMatrixOpt); > + } > + > params.setGramSize(gramSize); > params.set("verbose", Boolean.toString(verbose)); > params.setBasePath(modelBasePath); > @@ -172,6 +183,7 @@ public final class TestClassifier { > params.set("encoding", encoding); > params.set("alpha_i", alphaI); > params.set("testDirPath", testDirPath); > + params.set("confusionMatrix", confusionMatrixFile); > > if ("sequential".equalsIgnoreCase(classificationMethod)) { > classifySequential(params); > @@ -253,12 +265,12 @@ public final class TestClassifier { > } > lineNum++; > } > - /* > - * log.info("{}\t{}\t{}/{}", new Object[] {correctLabel, > - * resultAnalyzer.getConfusionMatrix().getAccuracy(correctLabel), > - * resultAnalyzer.getConfusionMatrix().getCorrect(correctLabel), > - * resultAnalyzer.getConfusionMatrix().getTotal(correctLabel)}); > - */ > + ConfusionMatrix matrix = resultAnalyzer.getConfusionMatrix(); > + log.info("{}", matrix); > + BayesClassifierDriver.confusionMatrixSeqFileExport(params, matrix); > + > + log.info("ConfusionMatrix: {}", matrix.toString()); > + > log.info("Classified instances from {}", file.getName()); > if (verbose) { > log.info("Performance stats {}", operationStats.toString()); > > Modified: > mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesClassifierDriver.java > URL: > http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesClassifierDriver.java?rev=1197510&r1=1197509&r2=1197510&view=diff > ============================================================================== > --- > mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesClassifierDriver.java > (original) > +++ > mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesClassifierDriver.java > Fri Nov 4 11:20:03 2011 > @@ -21,10 +21,15 @@ import java.io.IOException; > import java.util.Map; > > import com.google.common.collect.Maps; > +import com.google.common.io.Closeables; > + > import org.apache.hadoop.conf.Configurable; > import org.apache.hadoop.conf.Configuration; > +import org.apache.hadoop.fs.FileSystem; > import org.apache.hadoop.fs.Path; > import org.apache.hadoop.io.DoubleWritable; > +import org.apache.hadoop.io.SequenceFile; > +import org.apache.hadoop.io.Text; > import org.apache.hadoop.mapred.FileInputFormat; > import org.apache.hadoop.mapred.FileOutputFormat; > import org.apache.hadoop.mapred.JobClient; > @@ -38,6 +43,7 @@ import org.apache.mahout.common.Paramete > import org.apache.mahout.common.StringTuple; > import org.apache.mahout.common.iterator.sequencefile.PathType; > import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable; > +import org.apache.mahout.math.MatrixWritable; > import org.slf4j.Logger; > import org.slf4j.LoggerFactory; > > @@ -83,6 +89,10 @@ public final class BayesClassifierDriver > Path outputFiles = new Path(outPath, "part*"); > ConfusionMatrix matrix = readResult(outputFiles, conf, params); > log.info("{}", matrix); > + if (params.get("confusionMatrix") != null) { > + confusionMatrixSeqFileExport(params, matrix); > + } > + > } > > public static ConfusionMatrix readResult(Path pathPattern, Configuration > conf, Parameters params) { > @@ -117,6 +127,24 @@ public final class BayesClassifierDriver > } > } > return matrix; > + } > > + public static void confusionMatrixSeqFileExport(Parameters params, > ConfusionMatrix matrix) throws IOException { > + if (params.get("confusionMatrix") != null) { > + Configuration conf = new Configuration(); > + FileSystem fs = FileSystem.get(conf); > + SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, > + new Path(params.get("confusionMatrix")), Text.class, > MatrixWritable.class); > + String name = params.get("confusionMatrix"); > + // embed file name as sequence key- useful for tuning classifiers > + name = name.substring(name.lastIndexOf('/') + 1, name.length()); > + Text key = new Text(name); > + MatrixWritable mw = new MatrixWritable(matrix.getMatrix()); > + try { > + writer.append(key, mw); > + } finally { > + Closeables.closeQuietly(writer); > + } > + } > } > } > > Modified: > mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java > URL: > http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java?rev=1197510&r1=1197509&r2=1197510&view=diff > ============================================================================== > --- > mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java > (original) > +++ > mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java > Fri Nov 4 11:20:03 2011 > @@ -123,15 +123,15 @@ public class MatrixWritable implements W > } > > if (hasLabels) { > - Map<String,Integer> columnLabelBindings = Maps.newHashMap(); > - Map<String,Integer> rowLabelBindings = Maps.newHashMap(); > - readLabels(in, columnLabelBindings, rowLabelBindings); > - if (!columnLabelBindings.isEmpty()) { > - r.setColumnLabelBindings(columnLabelBindings); > - } > - if (!rowLabelBindings.isEmpty()) { > - r.setRowLabelBindings(rowLabelBindings); > - } > + Map<String,Integer> columnLabelBindings = Maps.newHashMap(); > + Map<String,Integer> rowLabelBindings = Maps.newHashMap(); > + readLabels(in, columnLabelBindings, rowLabelBindings); > + if (!columnLabelBindings.isEmpty()) { > + r.setColumnLabelBindings(columnLabelBindings); > + } > + if (!rowLabelBindings.isEmpty()) { > + r.setRowLabelBindings(rowLabelBindings); > + } > } > > return r; > @@ -159,7 +159,7 @@ public class MatrixWritable implements W > VectorWritable.writeVector(out, matrix.viewRow(i), false); > } > if ((flags & FLAG_LABELS) != 0) { > - writeLabelBindings(out, matrix.getColumnLabelBindings(), > matrix.getRowLabelBindings()); > + writeLabelBindings(out, matrix.getColumnLabelBindings(), > matrix.getRowLabelBindings()); > } > } > } > > Modified: > mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java > URL: > http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java?rev=1197510&r1=1197509&r2=1197510&view=diff > ============================================================================== > --- > mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java > (original) > +++ > mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java > Fri Nov 4 11:20:03 2011 > @@ -17,6 +17,7 @@ > > package org.apache.mahout.math; > > +import com.google.common.collect.Maps; > import com.google.common.io.Closeables; > import org.apache.hadoop.io.Writable; > import org.junit.Test; > @@ -26,7 +27,6 @@ import java.io.ByteArrayOutputStream; > import java.io.DataInputStream; > import java.io.DataOutputStream; > import java.io.IOException; > -import java.util.HashMap; > import java.util.Map; > > public final class MatrixWritableTest extends MahoutTestCase { > @@ -36,13 +36,14 @@ public final class MatrixWritableTest ex > Matrix m = new SparseMatrix(5, 5); > m.set(1, 2, 3.0); > m.set(3, 4, 5.0); > - Map<String, Integer> bindings = new HashMap<String, Integer>(); > + Map<String, Integer> bindings = Maps.newHashMap(); > bindings.put("A", 0); > bindings.put("B", 1); > bindings.put("C", 2); > bindings.put("D", 3); > bindings.put("default", 4); > m.setRowLabelBindings(bindings); > + m.setColumnLabelBindings(bindings); > doTestMatrixWritableEquals(m); > } > > @@ -51,12 +52,13 @@ public final class MatrixWritableTest ex > Matrix m = new DenseMatrix(5,5); > m.set(1, 2, 3.0); > m.set(3, 4, 5.0); > - Map<String, Integer> bindings = new HashMap<String, Integer>(); > + Map<String, Integer> bindings = Maps.newHashMap(); > bindings.put("A", 0); > bindings.put("B", 1); > bindings.put("C", 2); > bindings.put("D", 3); > bindings.put("default", 4); > + m.setRowLabelBindings(bindings); > m.setColumnLabelBindings(bindings); > doTestMatrixWritableEquals(m); > } > @@ -66,7 +68,9 @@ public final class MatrixWritableTest ex > MatrixWritable matrixWritable2 = new MatrixWritable(); > writeAndRead(matrixWritable, matrixWritable2); > Matrix m2 = matrixWritable2.get(); > - compareMatrices(m, m2); // not sure this works? > + compareMatrices(m, m2); > + doCheckBindings(m2.getRowLabelBindings()); > + doCheckBindings(m2.getColumnLabelBindings()); > } > > private static void compareMatrices(Matrix m, Matrix m2) { > @@ -98,6 +102,14 @@ public final class MatrixWritableTest ex > } > } > > + private static void doCheckBindings(Map<String,Integer> labels) { > + assertTrue("Missing label", labels.keySet().contains("A")); > + assertTrue("Missing label", labels.keySet().contains("B")); > + assertTrue("Missing label", labels.keySet().contains("C")); > + assertTrue("Missing label", labels.keySet().contains("D")); > + assertTrue("Missing label", labels.keySet().contains("default")); > + } > + > private static void writeAndRead(Writable toWrite, Writable toRead) > throws IOException { > ByteArrayOutputStream baos = new ByteArrayOutputStream(); > DataOutputStream dos = new DataOutputStream(baos); > > Added: > mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java > URL: > http://svn.apache.org/viewvc/mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java?rev=1197510&view=auto > ============================================================================== > --- > mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java > (added) > +++ > mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java > Fri Nov 4 11:20:03 2011 > @@ -0,0 +1,423 @@ > +/* > + * Licensed to the Apache Software Foundation (ASF) under one or more > + * contributor license agreements. See the NOTICE file distributed with > + * this work for additional information regarding copyright ownership. > + * The ASF licenses this file to You under the Apache License, Version 2.0 > + * (the "License"); you may not use this file except in compliance with > + * the License. You may obtain a copy of the License at > + * > + * http://www.apache.org/licenses/LICENSE-2.0 > + * > + * Unless required by applicable law or agreed to in writing, software > + * distributed under the License is distributed on an "AS IS" BASIS, > + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. > + * See the License for the specific language governing permissions and > + * limitations under the License. > + */ > + > +package org.apache.mahout.classifier; > + > +import java.io.File; > +import java.io.FileOutputStream; > +import java.io.IOException; > +import java.io.OutputStream; > +import java.io.PrintStream; > +import java.util.Iterator; > +import java.util.List; > +import java.util.Map; > + > +import org.apache.hadoop.conf.Configuration; > +import org.apache.hadoop.fs.FileSystem; > +import org.apache.hadoop.fs.Path; > +import org.apache.hadoop.io.SequenceFile; > +import org.apache.hadoop.io.Text; > +import org.apache.hadoop.util.ToolRunner; > +import org.apache.mahout.common.AbstractJob; > +import org.apache.mahout.common.commandline.DefaultOptionCreator; > +import org.apache.mahout.math.Matrix; > +import org.apache.mahout.math.MatrixWritable; > + > +import com.google.common.collect.Lists; > + > +/** > + * Export a ConfusionMatrix in various text formats: > + * ToString version > + * Grayscale HTML table > + * Summary HTML table > + * Table of counts > + * all with optional HTML wrappers > + * > + * Input format: Hadoop SequenceFile with Text key and MatrixWritable value, > 1 pair > + * > + * Intended to consume ConfusionMatrix SequenceFile output by Bayes > + * TestClassifier class > + */ > +public final class ConfusionMatrixDumper extends AbstractJob { > + > + // HTML wrapper - default CSS > + private static final String HEADER = "<html>" + > + "<head>\n" + > + "<title>TITLE</title>\n" + > + "</head>" + > + "<body>\n" + > + "<style type='text/css'> \n" + > + "table\n" + > + "{\n" + > + "border:3px solid black; text-align:left;\n" + > + "}\n" + > + "th.normalHeader\n" + > + "{\n" + > + "border:1px solid > black;border-collapse:collapse;text-align:center;background-color:white\n" + > + "}\n" + > + "th.tallHeader\n" + > + "{\n" + > + "border:1px solid > black;border-collapse:collapse;text-align:center;background-color:white; > height:6em\n" + > + "}\n" + > + "tr.label\n" + > + "{\n" + > + "border:1px solid > black;border-collapse:collapse;text-align:center;background-color:white\n" + > + "}\n" + > + "tr.row\n" + > + "{\n" + > + "border:1px solid gray;text-align:center;background-color:snow\n" + > + "}\n" + > + "td\n" + > + "{\n" + > + "min-width:2em\n" + > + "}\n" + > + "td.cell\n" + > + "{\n" + > + "border:1px solid black;text-align:right;background-color:snow\n" + > + "}\n" + > + "td.empty\n" + > + "{\n" + > + "border:0px;text-align:right;background-color:snow\n" + > + "}\n" + > + "td.white\n" + > + "{\n" + > + "border:0px solid black;text-align:right;background-color:white\n" + > + "}\n" + > + "td.black\n" + > + "{\n" + > + "border:0px solid red;text-align:right;background-color:black\n" + > + "}\n" + > + "td.gray1\n" + > + "{\n" + > + "border:0px solid green;text-align:right; > background-color:LightGray\n" + > + "}\n" + > + "td.gray2\n" + > + "{\n" + > + "border:0px solid blue;text-align:right;background-color:gray\n" + > + "}\n" + > + "td.gray3\n" + > + "{\n" + > + "border:0px solid red;text-align:right;background-color:DarkGray\n" + > + "}\n" + > + "th" + > + "{\n" + > + " text-align: center;\n" + > + " vertical-align: bottom;\n" + > + " padding-bottom: 3px;\n" + > + " padding-left: 5px;\n" + > + " padding-right: 5px;\n" + > + "}\n" + > + " .verticalText\n" + > + " {\n" + > + " text-align: center;\n" + > + " vertical-align: middle;\n" + > + " width: 20px;\n" + > + " margin: 0px;\n" + > + " padding: 0px;\n" + > + " padding-left: 3px;\n" + > + " padding-right: 3px;\n" + > + " padding-top: 10px;\n" + > + " white-space: nowrap;\n" + > + " -webkit-transform: rotate(-90deg); \n" + > + " -moz-transform: rotate(-90deg); \n" + > + " };\n" + > + "</style>\n"; > + private static final String FOOTER = "</html></body>"; > + > + // CSS style names. > + private static final String CSS_TABLE = "table"; > + private static final String CSS_LABEL = "label"; > + private static final String CSS_TALL_HEADER = "tall"; > + private static final String CSS_VERTICAL = "verticalText"; > + private static final String CSS_CELL = "cell"; > + private static final String CSS_EMPTY = "empty"; > + private static final String[] CSS_GRAY_CELLS = {"white", "gray1", "gray2", > "gray3", "black"}; > + > + private ConfusionMatrixDumper() {} > + > + public static void main(String[] args) throws Exception { > + ToolRunner.run(new ConfusionMatrixDumper(), args); > + } > + > + @Override > + public int run(String[] args) throws IOException { > + addInputOption(); > + addOption("output", "o", "Output path", null); // AbstractJob output > feature requires param > + addOption(DefaultOptionCreator.overwriteOption().create()); > + addFlag("html", null, "Create complete HTML page"); > + addFlag("text", null, "Dump simple text"); > + Map<String, String> parsedArgs = parseArguments(args); > + if (parsedArgs == null) { > + return -1; > + } > + > + Path inputPath = getInputPath(); > + String outputFile = parsedArgs.containsKey("--output") ? > parsedArgs.get("--output") : null; > + boolean text = parsedArgs.containsKey("--text"); > + boolean wrapHtml = parsedArgs.containsKey("--html"); > + PrintStream out = getPrintStream(outputFile); > + if (text) { > + exportText(inputPath, out); > + } else { > + exportTable(inputPath, out, wrapHtml); > + } > + out.flush(); > + if (out != System.out) { > + out.close(); > + } > + return 0; > + } > + > + private static void exportText(Path inputPath, PrintStream out) throws > IOException { > + MatrixWritable mw = new MatrixWritable(); > + Text key = new Text(); > + readSeqFile(inputPath, key, mw); > + Matrix m = mw.get(); > + ConfusionMatrix cm = new ConfusionMatrix(m); > + out.println(cm.toString()); > + } > + > + private static void exportTable(Path inputPath, PrintStream out, boolean > wrapHtml) throws IOException { > + MatrixWritable mw = new MatrixWritable(); > + Text key = new Text(); > + readSeqFile(inputPath, key, mw); > + String fileName = inputPath.getName(); > + fileName = fileName.substring(fileName.lastIndexOf('/') + 1, > fileName.length()); > + Matrix m = mw.get(); > + ConfusionMatrix cm = new ConfusionMatrix(m); > + if (wrapHtml) { > + printHeader(out, fileName); > + } > + out.println("<p/>"); > + printSummaryTable(cm, out); > + out.println("<p/>"); > + printGrayTable(cm, out); > + out.println("<p/>"); > + printCountsTable(cm, out); > + out.println("<p/>"); > + printTextInBox(cm, out); > + out.println("<p/>"); > + if (wrapHtml) { > + printFooter(out); > + } > + } > + > + private static List<String> stripDefault(ConfusionMatrix cm) { > + List<String> stripped = Lists.newArrayList(cm.getLabels().iterator()); > + String defaultLabel = cm.getDefaultLabel(); > + int unclassified = cm.getTotal(defaultLabel); > + if (unclassified > 0) { > + return stripped; > + } > + stripped.remove(defaultLabel); > + return stripped; > + } > + > + // TODO: test - this should work with HDFS files > + private static void readSeqFile(Path path, Text key, MatrixWritable m) > throws IOException { > + Configuration conf = new Configuration(); > + FileSystem fs = FileSystem.get(conf); > + SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, conf); > + reader.next(key, m); > + } > + > + // TODO: test - this might not work with HDFS files? > + // after all, it does no seeks > + private static PrintStream getPrintStream(String outputFilename) throws > IOException { > + if (outputFilename != null) { > + File outputFile = new File(outputFilename); > + if (outputFile.exists()) { > + outputFile.delete(); > + } > + outputFile.createNewFile(); > + OutputStream os = new FileOutputStream(outputFile); > + return new PrintStream(os); > + } else { > + return System.out; > + } > + } > + > + private static int getLabelTotal(ConfusionMatrix cm, String rowLabel) { > + Iterator<String> iter = cm.getLabels().iterator(); > + int count = 0; > + while(iter.hasNext()) { > + count += cm.getCount(rowLabel, iter.next()); > + } > + return count; > + } > + > + // HTML generator code > + > + private static void printTextInBox(ConfusionMatrix cm, PrintStream out) { > + out.println("<div style='width:90%;overflow:scroll;'>"); > + out.println("<pre>"); > + out.println(cm.toString()); > + out.println("</pre>"); > + out.println("</div>"); > + } > + > + public static void printSummaryTable(ConfusionMatrix cm, PrintStream out) { > + format("<table class='%s'>\n", out, CSS_TABLE); > + format("<tr class='%s'>", out, CSS_LABEL); > + out.println("<td>Label</td><td>Total</td><td>Correct</td><td>%</td>"); > + out.println("</tr>"); > + List<String> labels = stripDefault(cm); > + for(String label: labels) { > + printSummaryRow(cm, out, label); > + } > + out.println("</table>"); > + } > + > + private static void printSummaryRow(ConfusionMatrix cm, PrintStream out, > String label) { > + format("<tr class='%s'>", out, CSS_CELL); > + int correct = cm.getCorrect(label); > + double accuracy = cm.getAccuracy(label); > + int count = getCount(cm, label); > + format("<td class='%s'>%s</td><td>%d</td><td>%d</td><td>%d</td>", > + out, CSS_CELL, label, count, correct, (int) Math.round(accuracy)); > + out.println("</tr>"); > + } > + > + private static int getCount(ConfusionMatrix cm, String label) { > + int count = 0; > + for (String s : cm.getLabels()) { > + count += cm.getCount(label, s); > + } > + return count; > + } > + > + public static void printGrayTable(ConfusionMatrix cm, PrintStream out) { > + format("<table class='%s'>\n", out, CSS_TABLE); > + printCountsHeader(cm, out, true); > + printGrayRows(cm, out); > + out.println("</table>"); > + } > + > + /** > + * Print each value in a four-value grayscale based on count/max. > + * Gives a mostly white matrix with grays in misclassified, and black in > diagonal. > + * TODO: Using the sqrt(count/max) as the rating is more stringent > + */ > + private static void printGrayRows(ConfusionMatrix cm, PrintStream out) { > + List<String> labels = stripDefault(cm); > + for (String label: labels) { > + printGrayRow(cm, out, labels, label); > + } > + } > + > + private static void printGrayRow(ConfusionMatrix cm, PrintStream out, > Iterable<String> labels, String rowLabel) { > + format("<tr class='%s'>", out, CSS_LABEL); > + format("<td>%s</td>", out, rowLabel); > + int total = getLabelTotal(cm, rowLabel); > + for (String columnLabel: labels) { > + printGrayCell(cm, out, total, rowLabel, columnLabel); > + } > + out.println("</tr>"); > + } > + > + // assign white/light/medium/dark to 0,1/4,1/2,3/4 of total number of > inputs > + // assign black to count = total, meaning complete success > + // alternative rating is to use sqrt(total) instead of total - this is > more drastic > + private static void printGrayCell(ConfusionMatrix cm, > + PrintStream out, > + int total, > + String rowLabel, > + String columnLabel) { > + > + int count = cm.getCount(rowLabel, columnLabel); > + if (count == 0) { > + out.format("<td class='%s'/>", CSS_EMPTY); > + } else { > + // 0 is white, full is black, everything else gray > + int rating = (int) ((count/ (double) total) * 4); > + String css = CSS_GRAY_CELLS[rating]; > + format("<td class='%s' title='%s'>%s</td>", out, css, columnLabel, > count); > + } > + } > + > + public static void printCountsTable(ConfusionMatrix cm, PrintStream out) { > + int length = cm.getLabels().size(); > + format("<table class='%s'>\n", out, CSS_TABLE); > + printCountsHeader(cm, out, false); > + printCountsRows(cm, out); > + out.println("</table>"); > + } > + > + private static void printCountsRows(ConfusionMatrix cm, PrintStream out) { > + List<String> labels = stripDefault(cm); > + for(String label: labels) { > + printCountsRow(cm, out, labels, label); > + } > + } > + > + private static void printCountsRow(ConfusionMatrix cm, PrintStream out, > Iterable<String> labels, String rowLabel) { > + out.println("<tr>"); > + format("<td class='%s'>%s</td>", out, CSS_LABEL, rowLabel); > + for(String columnLabel: labels) { > + printCountsCell(cm, out, rowLabel, columnLabel); > + } > + out.println("</tr>"); > + } > + > + private static void printCountsCell(ConfusionMatrix cm, PrintStream out, > String rowLabel, String columnLabel) { > + int count = cm.getCount(rowLabel, columnLabel); > + String s = count == 0 ? "" : Integer.toString(count); > + format("<td class='%s' title='%s'>%s</td>", out, CSS_CELL, columnLabel, > s); > + } > + > + private static void printCountsHeader(ConfusionMatrix cm, PrintStream out, > boolean vertical) { > + List<String> labels = stripDefault(cm); > + int longest = getLongestHeader(labels); > + if (vertical) { > + // do vertical - rotation is a bitch > + out.format("<tr class='%s' style='height:%dem'><th> </th>\n", > CSS_TALL_HEADER, longest/2); > + for(String label: labels) { > + out.format("<th><div class='%s'>%s</div></th>", CSS_VERTICAL, label); > + } > + out.println("</tr>"); > + } else { > + // header - empty cell in upper left > + out.format("<tr class='%s'><td class='%s'></td>\n", CSS_TABLE, > CSS_LABEL); > + for(String label: labels) { > + out.format("<td>%s</td>", label); > + } > + out.format("</tr>"); > + } > + } > + > + private static int getLongestHeader(Iterable<String> labels) { > + int max = 0; > + for (String label: labels) { > + max = Math.max(label.length(), max); > + } > + return max; > + } > + > + private static void format(String format, PrintStream out, Object ... > args) { > + String format2 = String.format(format, args); > + out.println(format2); > + } > + > + public static void printHeader(PrintStream out, CharSequence title) { > + out.println(HEADER.replace("TITLE", title)); > + } > + > + public static void printFooter(PrintStream out) { > + out.println(FOOTER); > + } > + > +} > > Modified: mahout/trunk/src/conf/driver.classes.props > URL: > http://svn.apache.org/viewvc/mahout/trunk/src/conf/driver.classes.props?rev=1197510&r1=1197509&r2=1197510&view=diff > ============================================================================== > --- mahout/trunk/src/conf/driver.classes.props (original) > +++ mahout/trunk/src/conf/driver.classes.props Fri Nov 4 11:20:03 2011 > @@ -46,4 +46,6 @@ org.apache.mahout.classifier.sequencelea > org.apache.mahout.classifier.sequencelearning.hmm.RandomSequenceGenerator = > hmmpredict : Generate random sequence of observations by given HMM > org.apache.mahout.utils.SplitInput = split : Split Input data into test and > train sets > org.apache.mahout.classifier.naivebayes.training.TrainNaiveBayesJob = trainnb > : Train the Vector-based Bayes classifier > -org.apache.mahout.classifier.naivebayes.test.TestNaiveBayesDriver = testnb : > Test the Vector-based Bayes classifier > \ No newline at end of file > +org.apache.mahout.classifier.naivebayes.test.TestNaiveBayesDriver = testnb : > Test the Vector-based Bayes classifier > +org.apache.mahout.classifier.ConfusionMatrixDumper = cmdump : Dump confusion > matrix in HTML or text formats > +org.apache.mahout.utils.MatrixDumper = matrixdump : Dump matrix in CSV format > \ No newline at end of file > > -------------------------------------------- Grant Ingersoll http://www.lucidimagination.com
