Author: srowen Date: Thu Feb 9 13:57:13 2012 New Revision: 1242323 URL: http://svn.apache.org/viewvc?rev=1242323&view=rev Log: MAHOUT-784 refine and make consistent the confusion matrix output
Modified: mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java Modified: mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java URL: http://svn.apache.org/viewvc/mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java?rev=1242323&r1=1242322&r2=1242323&view=diff ============================================================================== --- mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java (original) +++ mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java Thu Feb 9 13:57:13 2012 @@ -40,105 +40,84 @@ import org.apache.mahout.math.MatrixWrit import com.google.common.collect.Lists; /** - * Export a ConfusionMatrix in various text formats: - * ToString version - * Grayscale HTML table - * Summary HTML table - * Table of counts - * all with optional HTML wrappers + * Export a ConfusionMatrix in various text formats: ToString version Grayscale HTML table Summary HTML table + * Table of counts all with optional HTML wrappers * * Input format: Hadoop SequenceFile with Text key and MatrixWritable value, 1 pair * - * Intended to consume ConfusionMatrix SequenceFile output by Bayes - * TestClassifier class + * Intended to consume ConfusionMatrix SequenceFile output by Bayes TestClassifier class */ public final class ConfusionMatrixDumper extends AbstractJob { - + // HTML wrapper - default CSS - private static final String HEADER = "<html>" + - "<head>\n" + - "<title>TITLE</title>\n" + - "</head>" + - "<body>\n" + - "<style type='text/css'> \n" + - "table\n" + - "{\n" + - "border:3px solid black; text-align:left;\n" + - "}\n" + - "th.normalHeader\n" + - "{\n" + - "border:1px solid black;border-collapse:collapse;text-align:center;background-color:white\n" + - "}\n" + - "th.tallHeader\n" + - "{\n" + - "border:1px solid black;border-collapse:collapse;text-align:center;background-color:white; height:6em\n" + - "}\n" + - "tr.label\n" + - "{\n" + - "border:1px solid black;border-collapse:collapse;text-align:center;background-color:white\n" + - "}\n" + - "tr.row\n" + - "{\n" + - "border:1px solid gray;text-align:center;background-color:snow\n" + - "}\n" + - "td\n" + - "{\n" + - "min-width:2em\n" + - "}\n" + - "td.cell\n" + - "{\n" + - "border:1px solid black;text-align:right;background-color:snow\n" + - "}\n" + - "td.empty\n" + - "{\n" + - "border:0px;text-align:right;background-color:snow\n" + - "}\n" + - "td.white\n" + - "{\n" + - "border:0px solid black;text-align:right;background-color:white\n" + - "}\n" + - "td.black\n" + - "{\n" + - "border:0px solid red;text-align:right;background-color:black\n" + - "}\n" + - "td.gray1\n" + - "{\n" + - "border:0px solid green;text-align:right; background-color:LightGray\n" + - "}\n" + - "td.gray2\n" + - "{\n" + - "border:0px solid blue;text-align:right;background-color:gray\n" + - "}\n" + - "td.gray3\n" + - "{\n" + - "border:0px solid red;text-align:right;background-color:DarkGray\n" + - "}\n" + - "th" + - "{\n" + - " text-align: center;\n" + - " vertical-align: bottom;\n" + - " padding-bottom: 3px;\n" + - " padding-left: 5px;\n" + - " padding-right: 5px;\n" + - "}\n" + - " .verticalText\n" + - " {\n" + - " text-align: center;\n" + - " vertical-align: middle;\n" + - " width: 20px;\n" + - " margin: 0px;\n" + - " padding: 0px;\n" + - " padding-left: 3px;\n" + - " padding-right: 3px;\n" + - " padding-top: 10px;\n" + - " white-space: nowrap;\n" + - " -webkit-transform: rotate(-90deg); \n" + - " -moz-transform: rotate(-90deg); \n" + - " };\n" + - "</style>\n"; + private static final String HEADER = "<html>" + + "<head>\n" + + "<title>TITLE</title>\n" + + "</head>" + + "<body>\n" + + "<style type='text/css'> \n" + + "table\n" + + "{\n" + + "border:3px solid black; text-align:left;\n" + + "}\n" + + "th.normalHeader\n" + + "{\n" + + "border:1px solid black;border-collapse:collapse;text-align:center;background-color:white\n" + + "}\n" + + "th.tallHeader\n" + + "{\n" + + "border:1px solid black;border-collapse:collapse;text-align:center;background-color:white; height:6em\n" + + "}\n" + + "tr.label\n" + + "{\n" + + "border:1px solid black;border-collapse:collapse;text-align:center;background-color:white\n" + + "}\n" + + "tr.row\n" + + "{\n" + + "border:1px solid gray;text-align:center;background-color:snow\n" + + "}\n" + + "td\n" + + "{\n" + + "min-width:2em\n" + + "}\n" + + "td.cell\n" + + "{\n" + + "border:1px solid black;text-align:right;background-color:snow\n" + + "}\n" + + "td.empty\n" + + "{\n" + + "border:0px;text-align:right;background-color:snow\n" + + "}\n" + + "td.white\n" + + "{\n" + + "border:0px solid black;text-align:right;background-color:white\n" + + "}\n" + + "td.black\n" + + "{\n" + + "border:0px solid red;text-align:right;background-color:black\n" + + "}\n" + + "td.gray1\n" + + "{\n" + + "border:0px solid green;text-align:right; background-color:LightGray\n" + + "}\n" + "td.gray2\n" + "{\n" + + "border:0px solid blue;text-align:right;background-color:gray\n" + + "}\n" + "td.gray3\n" + "{\n" + + "border:0px solid red;text-align:right;background-color:DarkGray\n" + + "}\n" + "th" + "{\n" + " text-align: center;\n" + + " vertical-align: bottom;\n" + + " padding-bottom: 3px;\n" + " padding-left: 5px;\n" + + " padding-right: 5px;\n" + "}\n" + " .verticalText\n" + + " {\n" + " text-align: center;\n" + + " vertical-align: middle;\n" + " width: 20px;\n" + + " margin: 0px;\n" + " padding: 0px;\n" + + " padding-left: 3px;\n" + " padding-right: 3px;\n" + + " padding-top: 10px;\n" + " white-space: nowrap;\n" + + " -webkit-transform: rotate(-90deg); \n" + + " -moz-transform: rotate(-90deg); \n" + " };\n" + + "</style>\n"; private static final String FOOTER = "</html></body>"; - // CSS style names. + // CSS style names. private static final String CSS_TABLE = "table"; private static final String CSS_LABEL = "label"; private static final String CSS_TALL_HEADER = "tall"; @@ -160,7 +139,7 @@ public final class ConfusionMatrixDumper addOption(DefaultOptionCreator.overwriteOption().create()); addFlag("html", null, "Create complete HTML page"); addFlag("text", null, "Dump simple text"); - Map<String, String> parsedArgs = parseArguments(args); + Map<String,String> parsedArgs = parseArguments(args); if (parsedArgs == null) { return -1; } @@ -183,11 +162,26 @@ public final class ConfusionMatrixDumper } private static void exportText(Path inputPath, PrintStream out) throws IOException { + String TAB_SEPARATOR = "|"; MatrixWritable mw = new MatrixWritable(); Text key = new Text(); readSeqFile(inputPath, key, mw); Matrix m = mw.get(); ConfusionMatrix cm = new ConfusionMatrix(m); + out.println(String.format("%-40s", "Label") + TAB_SEPARATOR + String.format("%-10s", "Total") + + TAB_SEPARATOR + String.format("%-10s", "Correct") + TAB_SEPARATOR + + String.format("%-6s", "%") + TAB_SEPARATOR); + out.println(String.format("%-70s", "-").replace(' ', '-')); + List<String> labels = stripDefault(cm); + for (String label : labels) { + int correct = cm.getCorrect(label); + double accuracy = cm.getAccuracy(label); + int count = getCount(cm, label); + out.println(String.format("%-40s", label) + TAB_SEPARATOR + String.format("%-10s", count) + + TAB_SEPARATOR + String.format("%-10s", correct) + TAB_SEPARATOR + + String.format("%-6s", (int) Math.round(accuracy)) + TAB_SEPARATOR); + } + out.println(String.format("%-70s", "-").replace(' ', '-')); out.println(cm.toString()); } @@ -217,7 +211,7 @@ public final class ConfusionMatrixDumper } private static List<String> stripDefault(ConfusionMatrix cm) { - List<String> stripped = Lists.newArrayList(cm.getLabels().iterator()); + List<String> stripped = Lists.newArrayList(cm.getLabels().iterator()); String defaultLabel = cm.getDefaultLabel(); int unclassified = cm.getTotal(defaultLabel); if (unclassified > 0) { @@ -236,11 +230,11 @@ public final class ConfusionMatrixDumper } // TODO: test - this might not work with HDFS files? - // after all, it does no seeks + // after all, it does no seeks private static PrintStream getPrintStream(String outputFilename) throws IOException { if (outputFilename != null) { File outputFile = new File(outputFilename); - if (outputFile.exists()) { + if (outputFile.exists()) { outputFile.delete(); } outputFile.createNewFile(); @@ -254,7 +248,7 @@ public final class ConfusionMatrixDumper private static int getLabelTotal(ConfusionMatrix cm, String rowLabel) { Iterator<String> iter = cm.getLabels().iterator(); int count = 0; - while(iter.hasNext()) { + while (iter.hasNext()) { count += cm.getCount(rowLabel, iter.next()); } return count; @@ -276,7 +270,7 @@ public final class ConfusionMatrixDumper out.println("<td>Label</td><td>Total</td><td>Correct</td><td>%</td>"); out.println("</tr>"); List<String> labels = stripDefault(cm); - for(String label: labels) { + for (String label : labels) { printSummaryRow(cm, out, label); } out.println("</table>"); @@ -287,8 +281,8 @@ public final class ConfusionMatrixDumper int correct = cm.getCorrect(label); double accuracy = cm.getAccuracy(label); int count = getCount(cm, label); - format("<td class='%s'>%s</td><td>%d</td><td>%d</td><td>%d</td>", - out, CSS_CELL, label, count, correct, (int) Math.round(accuracy)); + format("<td class='%s'>%s</td><td>%d</td><td>%d</td><td>%d</td>", out, CSS_CELL, label, count, correct, + (int) Math.round(accuracy)); out.println("</tr>"); } @@ -308,22 +302,24 @@ public final class ConfusionMatrixDumper } /** - * Print each value in a four-value grayscale based on count/max. - * Gives a mostly white matrix with grays in misclassified, and black in diagonal. - * TODO: Using the sqrt(count/max) as the rating is more stringent + * Print each value in a four-value grayscale based on count/max. Gives a mostly white matrix with grays in + * misclassified, and black in diagonal. TODO: Using the sqrt(count/max) as the rating is more stringent */ private static void printGrayRows(ConfusionMatrix cm, PrintStream out) { List<String> labels = stripDefault(cm); - for (String label: labels) { + for (String label : labels) { printGrayRow(cm, out, labels, label); } } - private static void printGrayRow(ConfusionMatrix cm, PrintStream out, Iterable<String> labels, String rowLabel) { + private static void printGrayRow(ConfusionMatrix cm, + PrintStream out, + Iterable<String> labels, + String rowLabel) { format("<tr class='%s'>", out, CSS_LABEL); format("<td>%s</td>", out, rowLabel); int total = getLabelTotal(cm, rowLabel); - for (String columnLabel: labels) { + for (String columnLabel : labels) { printGrayCell(cm, out, total, rowLabel, columnLabel); } out.println("</tr>"); @@ -331,7 +327,7 @@ public final class ConfusionMatrixDumper // assign white/light/medium/dark to 0,1/4,1/2,3/4 of total number of inputs // assign black to count = total, meaning complete success - // alternative rating is to use sqrt(total) instead of total - this is more drastic + // alternative rating is to use sqrt(total) instead of total - this is more drastic private static void printGrayCell(ConfusionMatrix cm, PrintStream out, int total, @@ -343,7 +339,7 @@ public final class ConfusionMatrixDumper out.format("<td class='%s'/>", CSS_EMPTY); } else { // 0 is white, full is black, everything else gray - int rating = (int) ((count/ (double) total) * 4); + int rating = (int) ((count / (double) total) * 4); String css = CSS_GRAY_CELLS[rating]; format("<td class='%s' title='%s'>%s</td>", out, css, columnLabel, count); } @@ -358,15 +354,18 @@ public final class ConfusionMatrixDumper private static void printCountsRows(ConfusionMatrix cm, PrintStream out) { List<String> labels = stripDefault(cm); - for(String label: labels) { + for (String label : labels) { printCountsRow(cm, out, labels, label); } } - private static void printCountsRow(ConfusionMatrix cm, PrintStream out, Iterable<String> labels, String rowLabel) { + private static void printCountsRow(ConfusionMatrix cm, + PrintStream out, + Iterable<String> labels, + String rowLabel) { out.println("<tr>"); format("<td class='%s'>%s</td>", out, CSS_LABEL, rowLabel); - for(String columnLabel: labels) { + for (String columnLabel : labels) { printCountsCell(cm, out, rowLabel, columnLabel); } out.println("</tr>"); @@ -383,15 +382,15 @@ public final class ConfusionMatrixDumper int longest = getLongestHeader(labels); if (vertical) { // do vertical - rotation is a bitch - out.format("<tr class='%s' style='height:%dem'><th> </th>\n", CSS_TALL_HEADER, longest/2); - for(String label: labels) { + out.format("<tr class='%s' style='height:%dem'><th> </th>\n", CSS_TALL_HEADER, longest / 2); + for (String label : labels) { out.format("<th><div class='%s'>%s</div></th>", CSS_VERTICAL, label); } out.println("</tr>"); } else { // header - empty cell in upper left out.format("<tr class='%s'><td class='%s'></td>\n", CSS_TABLE, CSS_LABEL); - for(String label: labels) { + for (String label : labels) { out.format("<td>%s</td>", label); } out.format("</tr>"); @@ -400,13 +399,13 @@ public final class ConfusionMatrixDumper private static int getLongestHeader(Iterable<String> labels) { int max = 0; - for (String label: labels) { + for (String label : labels) { max = Math.max(label.length(), max); } return max; } - private static void format(String format, PrintStream out, Object ... args) { + private static void format(String format, PrintStream out, Object... args) { String format2 = String.format(format, args); out.println(format2); }