Author: colen
Date: Thu Apr 10 23:25:11 2014
New Revision: 1586502
URL: http://svn.apache.org/r1586502
Log:
OPENNLP-81 Added doccat evaluator, with misclassified and fine grained reports.
Added:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/doccat/DoccatEvaluationErrorListener.java
(with props)
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/doccat/DoccatEvaluatorTool.java
(with props)
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/doccat/DoccatFineGrainedReportListener.java
(with props)
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/doccat/DoccatEvaluationMonitor.java
(with props)
Modified:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/CLI.java
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/EvaluationErrorPrinter.java
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/doccat/DocumentCategorizerEvaluator.java
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/util/eval/Evaluator.java
Modified:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/CLI.java
URL:
http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/CLI.java?rev=1586502&r1=1586501&r2=1586502&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/CLI.java
(original)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/CLI.java
Thu Apr 10 23:25:11 2014
@@ -32,6 +32,7 @@ import opennlp.tools.cmdline.chunker.Chu
import opennlp.tools.cmdline.chunker.ChunkerTrainerTool;
import opennlp.tools.cmdline.dictionary.DictionaryBuilderTool;
import opennlp.tools.cmdline.doccat.DoccatConverterTool;
+import opennlp.tools.cmdline.doccat.DoccatEvaluatorTool;
import opennlp.tools.cmdline.doccat.DoccatTool;
import opennlp.tools.cmdline.doccat.DoccatTrainerTool;
import opennlp.tools.cmdline.entitylinker.EntityLinkerTool;
@@ -80,6 +81,7 @@ public final class CLI {
// Document Categorizer
tools.add(new DoccatTool());
tools.add(new DoccatTrainerTool());
+ tools.add(new DoccatEvaluatorTool());
tools.add(new DoccatConverterTool());
// Dictionary Builder
Modified:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/EvaluationErrorPrinter.java
URL:
http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/EvaluationErrorPrinter.java?rev=1586502&r1=1586501&r2=1586502&view=diff
==============================================================================
---
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/EvaluationErrorPrinter.java
(original)
+++
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/EvaluationErrorPrinter.java
Thu Apr 10 23:25:11 2014
@@ -104,6 +104,12 @@ public abstract class EvaluationErrorPri
}
}
+ // for others
+ protected void printError(T referenceSample, T predictedSample) {
+ printSamples(referenceSample, predictedSample);
+ printStream.println();
+ }
+
/**
* Auxiliary method to print tag errors
*
Added:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/doccat/DoccatEvaluationErrorListener.java
URL:
http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/doccat/DoccatEvaluationErrorListener.java?rev=1586502&view=auto
==============================================================================
---
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/doccat/DoccatEvaluationErrorListener.java
(added)
+++
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/doccat/DoccatEvaluationErrorListener.java
Thu Apr 10 23:25:11 2014
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.tools.cmdline.doccat;
+
+import java.io.OutputStream;
+
+import opennlp.tools.cmdline.EvaluationErrorPrinter;
+import opennlp.tools.doccat.DoccatEvaluationMonitor;
+import opennlp.tools.doccat.DocumentSample;
+import opennlp.tools.util.eval.EvaluationMonitor;
+
+/**
+ * A default implementation of {@link EvaluationMonitor} that prints to an
+ * output stream.
+ *
+ */
+public class DoccatEvaluationErrorListener extends
+ EvaluationErrorPrinter<DocumentSample> implements DoccatEvaluationMonitor {
+
+ /**
+ * Creates a listener that will print to System.err
+ */
+ public DoccatEvaluationErrorListener() {
+ super(System.err);
+ }
+
+ /**
+ * Creates a listener that will print to a given {@link OutputStream}
+ */
+ public DoccatEvaluationErrorListener(OutputStream outputStream) {
+ super(outputStream);
+ }
+
+ @Override
+ public void missclassified(DocumentSample reference, DocumentSample
prediction) {
+ printError(reference, prediction);
+ }
+
+}
Propchange:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/doccat/DoccatEvaluationErrorListener.java
------------------------------------------------------------------------------
svn:mime-type = text/plain
Added:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/doccat/DoccatEvaluatorTool.java
URL:
http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/doccat/DoccatEvaluatorTool.java?rev=1586502&view=auto
==============================================================================
---
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/doccat/DoccatEvaluatorTool.java
(added)
+++
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/doccat/DoccatEvaluatorTool.java
Thu Apr 10 23:25:11 2014
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.tools.cmdline.doccat;
+
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.LinkedList;
+import java.util.List;
+
+import opennlp.tools.cmdline.AbstractEvaluatorTool;
+import opennlp.tools.cmdline.CmdLineUtil;
+import opennlp.tools.cmdline.PerformanceMonitor;
+import opennlp.tools.cmdline.TerminateToolException;
+import opennlp.tools.cmdline.ArgumentParser.OptionalParameter;
+import opennlp.tools.cmdline.ArgumentParser.ParameterDescription;
+import opennlp.tools.cmdline.doccat.DoccatEvaluatorTool.EvalToolParams;
+import opennlp.tools.cmdline.params.DetailedFMeasureEvaluatorParams;
+import opennlp.tools.cmdline.params.EvaluatorParams;
+import opennlp.tools.doccat.DoccatEvaluationMonitor;
+import opennlp.tools.doccat.DoccatModel;
+import opennlp.tools.doccat.DocumentCategorizerEvaluator;
+import opennlp.tools.doccat.DocumentCategorizerME;
+import opennlp.tools.doccat.DocumentSample;
+import opennlp.tools.util.ObjectStream;
+import opennlp.tools.util.eval.EvaluationMonitor;
+
+public final class DoccatEvaluatorTool extends
+ AbstractEvaluatorTool<DocumentSample, EvalToolParams> {
+
+ interface EvalToolParams extends EvaluatorParams,
+ DetailedFMeasureEvaluatorParams {
+ @ParameterDescription(valueName = "outputFile", description = "the path of
the fine-grained report file.")
+ @OptionalParameter
+ File getReportOutputFile();
+ }
+
+ public DoccatEvaluatorTool() {
+ super(DocumentSample.class, EvalToolParams.class);
+ }
+
+ public String getShortDescription() {
+ return "Measures the performance of the Doccat model with the reference
data";
+ }
+
+ public void run(String format, String[] args) {
+ super.run(format, args);
+
+ DoccatModel model = new DoccatModelLoader().load(params.getModel());
+
+ List<EvaluationMonitor<DocumentSample>> listeners = new
LinkedList<EvaluationMonitor<DocumentSample>>();
+ if (params.getMisclassified()) {
+ listeners.add(new DoccatEvaluationErrorListener());
+ }
+
+ DoccatFineGrainedReportListener reportListener = null;
+ File reportFile = params.getReportOutputFile();
+ OutputStream reportOutputStream = null;
+ if (reportFile != null) {
+ CmdLineUtil.checkOutputFile("Report Output File", reportFile);
+ try {
+ reportOutputStream = new FileOutputStream(reportFile);
+ reportListener = new
DoccatFineGrainedReportListener(reportOutputStream);
+ listeners.add(reportListener);
+ } catch (FileNotFoundException e) {
+ throw new TerminateToolException(-1,
+ "IO error while creating Doccat fine-grained report file: "
+ + e.getMessage());
+ }
+ }
+
+ DocumentCategorizerEvaluator evaluator = new DocumentCategorizerEvaluator(
+ new DocumentCategorizerME(model),
+ listeners.toArray(new DoccatEvaluationMonitor[listeners.size()]));
+
+ final PerformanceMonitor monitor = new PerformanceMonitor("doc");
+
+ ObjectStream<DocumentSample> measuredSampleStream = new
ObjectStream<DocumentSample>() {
+
+ public DocumentSample read() throws IOException {
+ monitor.incrementCounter();
+ return sampleStream.read();
+ }
+
+ public void reset() throws IOException {
+ sampleStream.reset();
+ }
+
+ public void close() throws IOException {
+ sampleStream.close();
+ }
+ };
+
+ monitor.startAndPrintThroughput();
+
+ try {
+ evaluator.evaluate(measuredSampleStream);
+ } catch (IOException e) {
+ System.err.println("failed");
+ throw new TerminateToolException(-1, "IO error while reading test data: "
+ + e.getMessage(), e);
+ } finally {
+ try {
+ measuredSampleStream.close();
+ } catch (IOException e) {
+ // sorry that this can fail
+ }
+ }
+
+ monitor.stopAndPrintFinalResult();
+
+ System.out.println();
+
+ System.out.println(evaluator);
+
+ if (reportListener != null) {
+ System.out.println("Writing fine-grained report to "
+ + params.getReportOutputFile().getAbsolutePath());
+ reportListener.writeReport();
+
+ try {
+ // TODO: is it a problem to close the stream now?
+ reportOutputStream.close();
+ } catch (IOException e) {
+ // nothing to do
+ }
+ }
+ }
+}
Propchange:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/doccat/DoccatEvaluatorTool.java
------------------------------------------------------------------------------
svn:mime-type = text/plain
Added:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/doccat/DoccatFineGrainedReportListener.java
URL:
http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/doccat/DoccatFineGrainedReportListener.java?rev=1586502&view=auto
==============================================================================
---
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/doccat/DoccatFineGrainedReportListener.java
(added)
+++
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/doccat/DoccatFineGrainedReportListener.java
Thu Apr 10 23:25:11 2014
@@ -0,0 +1,775 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.tools.cmdline.doccat;
+
+import java.io.OutputStream;
+import java.io.PrintStream;
+import java.text.MessageFormat;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.SortedSet;
+import java.util.TreeSet;
+
+import opennlp.tools.doccat.DoccatEvaluationMonitor;
+import opennlp.tools.doccat.DocumentSample;
+import opennlp.tools.util.Span;
+import opennlp.tools.util.eval.FMeasure;
+import opennlp.tools.util.eval.Mean;
+
+/**
+ * Generates a detailed report for the POS Tagger.
+ * <p>
+ * It is possible to use it from an API and access the statistics using the
+ * provided getters
+ *
+ */
+public class DoccatFineGrainedReportListener implements
DoccatEvaluationMonitor {
+
+ private final PrintStream printStream;
+ private final Stats stats = new Stats();
+
+ private static final char[] alpha = { 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v',
+ 'w', 'x', 'y', 'z' };
+
+ /**
+ * Creates a listener that will print to {@link System#err}
+ */
+ public DoccatFineGrainedReportListener() {
+ this(System.err);
+ }
+
+ /**
+ * Creates a listener that prints to a given {@link OutputStream}
+ */
+ public DoccatFineGrainedReportListener(OutputStream outputStream) {
+ this.printStream = new PrintStream(outputStream);
+ }
+
+ // methods inherited from EvaluationMonitor
+
+ public void missclassified(DocumentSample reference, DocumentSample
prediction) {
+ stats.add(reference, prediction);
+ }
+
+ public void correctlyClassified(DocumentSample reference,
+ DocumentSample prediction) {
+ stats.add(reference, prediction);
+ }
+
+ /**
+ * Writes the report to the {@link OutputStream}. Should be called only after
+ * the evaluation process
+ */
+ public void writeReport() {
+ printGeneralStatistics();
+ printTagsErrorRank();
+ printGeneralConfusionTable();
+ }
+
+ public long getNumberOfSentences() {
+ return stats.getNumberOfSentences();
+ }
+
+ public double getAverageSentenceSize() {
+ return stats.getAverageSentenceSize();
+ }
+
+ public int getMinSentenceSize() {
+ return stats.getMinSentenceSize();
+ }
+
+ public int getMaxSentenceSize() {
+ return stats.getMaxSentenceSize();
+ }
+
+ public int getNumberOfTags() {
+ return stats.getNumberOfTags();
+ }
+
+ public double getAccuracy() {
+ return stats.getAccuracy();
+ }
+
+ // token stats
+
+ public double getTokenAccuracy(String token) {
+ return stats.getTokenAccuracy(token);
+ }
+
+ public SortedSet<String> getTokensOrderedByFrequency() {
+ return stats.getTokensOrderedByFrequency();
+ }
+
+ public int getTokenFrequency(String token) {
+ return stats.getTokenFrequency(token);
+ }
+
+ public int getTokenErrors(String token) {
+ return stats.getTokenErrors(token);
+ }
+
+ public SortedSet<String> getTokensOrderedByNumberOfErrors() {
+ return stats.getTokensOrderedByNumberOfErrors();
+ }
+
+ public SortedSet<String> getTagsOrderedByErrors() {
+ return stats.getTagsOrderedByErrors();
+ }
+
+ public int getTagFrequency(String tag) {
+ return stats.getTagFrequency(tag);
+ }
+
+ public int getTagErrors(String tag) {
+ return stats.getTagErrors(tag);
+ }
+
+ public double getTagPrecision(String tag) {
+ return stats.getTagPrecision(tag);
+ }
+
+ public double getTagRecall(String tag) {
+ return stats.getTagRecall(tag);
+ }
+
+ public double getTagFMeasure(String tag) {
+ return stats.getTagFMeasure(tag);
+ }
+
+ public SortedSet<String> getConfusionMatrixTagset() {
+ return stats.getConfusionMatrixTagset();
+ }
+
+ public SortedSet<String> getConfusionMatrixTagset(String token) {
+ return stats.getConfusionMatrixTagset(token);
+ }
+
+ public double[][] getConfusionMatrix() {
+ return stats.getConfusionMatrix();
+ }
+
+ public double[][] getConfusionMatrix(String token) {
+ return stats.getConfusionMatrix(token);
+ }
+
+ private String matrixToString(SortedSet<String> tagset, double[][] data,
+ boolean filter) {
+ // we dont want to print trivial cases (acc=1)
+ int initialIndex = 0;
+ String[] tags = tagset.toArray(new String[tagset.size()]);
+ StringBuilder sb = new StringBuilder();
+ int minColumnSize = Integer.MIN_VALUE;
+ String[][] matrix = new String[data.length][data[0].length];
+ for (int i = 0; i < data.length; i++) {
+ int j = 0;
+ for (; j < data[i].length - 1; j++) {
+ matrix[i][j] = data[i][j] > 0 ? Integer.toString((int) data[i][j])
+ : ".";
+ if (minColumnSize < matrix[i][j].length()) {
+ minColumnSize = matrix[i][j].length();
+ }
+ }
+ matrix[i][j] = MessageFormat.format("{0,number,#.##%}", data[i][j]);
+ if (data[i][j] == 1 && filter) {
+ initialIndex = i + 1;
+ }
+ }
+
+ final String headerFormat = "%" + (minColumnSize + 2) + "s "; // | 1234567
|
+ final String cellFormat = "%" + (minColumnSize + 2) + "s "; // | 12345 |
+ final String diagFormat = " %" + (minColumnSize + 2) + "s";
+ for (int i = initialIndex; i < tagset.size(); i++) {
+ sb.append(String.format(headerFormat,
+ generateAlphaLabel(i - initialIndex).trim()));
+ }
+ sb.append("| Accuracy | <-- classified as\n");
+ for (int i = initialIndex; i < data.length; i++) {
+ int j = initialIndex;
+ for (; j < data[i].length - 1; j++) {
+ if (i == j) {
+ String val = "<" + matrix[i][j] + ">";
+ sb.append(String.format(diagFormat, val));
+ } else {
+ sb.append(String.format(cellFormat, matrix[i][j]));
+ }
+ }
+ sb.append(
+ String.format("| %-6s | %3s = ", matrix[i][j],
+ generateAlphaLabel(i - initialIndex))).append(tags[i]);
+ sb.append("\n");
+ }
+ return sb.toString();
+ }
+
+ private void printGeneralStatistics() {
+ printHeader("Evaluation summary");
+ printStream.append(
+ String.format("%21s: %6s", "Number of documents",
+ Long.toString(getNumberOfSentences()))).append("\n");
+ printStream.append(
+ String.format("%21s: %6s", "Min sentence size", getMinSentenceSize()))
+ .append("\n");
+ printStream.append(
+ String.format("%21s: %6s", "Max sentence size", getMaxSentenceSize()))
+ .append("\n");
+ printStream.append(
+ String.format("%21s: %6s", "Average sentence size",
+ MessageFormat.format("{0,number,#.##}", getAverageSentenceSize())))
+ .append("\n");
+ printStream.append(
+ String.format("%21s: %6s", "Categories count", getNumberOfTags()))
+ .append("\n");
+ printStream.append(
+ String.format("%21s: %6s", "Accuracy",
+ MessageFormat.format("{0,number,#.##%}", getAccuracy()))).append(
+ "\n");
+ }
+
+ private void printTagsErrorRank() {
+ printHeader("Detailed Accuracy By Tag");
+ SortedSet<String> tags = getTagsOrderedByErrors();
+ printStream.append("\n");
+
+ int maxTagSize = 3;
+
+ for (String t : tags) {
+ if (t.length() > maxTagSize) {
+ maxTagSize = t.length();
+ }
+ }
+
+ int tableSize = 65 + maxTagSize;
+
+ String headerFormat = "| %" + maxTagSize
+ + "s | %6s | %6s | %7s | %9s | %6s | %9s |\n";
+ String format = "| %" + maxTagSize
+ + "s | %6s | %6s | %-7s | %-9s | %-6s | %-9s |\n";
+
+ printLine(tableSize);
+ printStream.append(String.format(headerFormat, "Tag", "Errors", "Count",
+ "% Err", "Precision", "Recall", "F-Measure"));
+ printLine(tableSize);
+
+ Iterator<String> tagIterator = tags.iterator();
+ while (tagIterator.hasNext()) {
+ String tag = tagIterator.next();
+ int ocurrencies = getTagFrequency(tag);
+ int errors = getTagErrors(tag);
+ String rate = MessageFormat.format("{0,number,#.###}", (double) errors
+ / ocurrencies);
+
+ double p = getTagPrecision(tag);
+ double r = getTagRecall(tag);
+ double f = getTagFMeasure(tag);
+
+ printStream.append(String.format(format, tag, errors, ocurrencies, rate,
+ MessageFormat.format("{0,number,#.###}", p > 0 ? p : 0),
+ MessageFormat.format("{0,number,#.###}", r > 0 ? r : 0),
+ MessageFormat.format("{0,number,#.###}", f > 0 ? f : 0))
+
+ );
+ }
+ printLine(tableSize);
+ }
+
+ private void printGeneralConfusionTable() {
+ printHeader("Confusion matrix");
+
+ SortedSet<String> labels = getConfusionMatrixTagset();
+
+ double[][] confusionMatrix = getConfusionMatrix();
+
+ int line = 0;
+ for (String label : labels) {
+ if (confusionMatrix[line][confusionMatrix[0].length - 1] == 1) {
+ printStream.append(label).append(" (")
+ .append(Integer.toString((int) confusionMatrix[line][line]))
+ .append(") ");
+ }
+ line++;
+ }
+
+ printStream.append("\n\n");
+
+ printStream.append(matrixToString(labels, confusionMatrix, true));
+ }
+
+ /** Auxiliary method that prints a emphasised report header */
+ private void printHeader(String text) {
+ printStream.append("\n=== ").append(text).append(" ===\n");
+ }
+
+ /** Auxiliary method that prints a horizontal line of a given size */
+ private void printLine(int size) {
+ for (int i = 0; i < size; i++) {
+ printStream.append("-");
+ }
+ printStream.append("\n");
+ }
+
+ private static final String generateAlphaLabel(int index) {
+
+ char labelChars[] = new char[3];
+ int i;
+
+ for (i = 2; i >= 0; i--) {
+ labelChars[i] = alpha[index % alpha.length];
+ index = index / alpha.length - 1;
+ if (index < 0) {
+ break;
+ }
+ }
+
+ return new String(labelChars);
+ }
+
+ private class Stats {
+
+ // general statistics
+ private final Mean accuracy = new Mean();
+ private final Mean averageSentenceLength = new Mean();
+ private int minimalSentenceLength = Integer.MAX_VALUE;
+ private int maximumSentenceLength = Integer.MIN_VALUE;
+
+ // token statistics
+ private final Map<String, Mean> tokAccuracies = new HashMap<String,
Mean>();
+ private final Map<String, Counter> tokOcurrencies = new HashMap<String,
Counter>();
+ private final Map<String, Counter> tokErrors = new HashMap<String,
Counter>();
+
+ // tag statistics
+ private final Map<String, Counter> tagOcurrencies = new HashMap<String,
Counter>();
+ private final Map<String, Counter> tagErrors = new HashMap<String,
Counter>();
+ private final Map<String, FMeasure> tagFMeasure = new HashMap<String,
FMeasure>();
+
+ // represents a Confusion Matrix that aggregates all tokens
+ private final Map<String, ConfusionMatrixLine> generalConfusionMatrix =
new HashMap<String, ConfusionMatrixLine>();
+
+ // represents a set of Confusion Matrix for each token
+ private final Map<String, Map<String, ConfusionMatrixLine>>
tokenConfusionMatrix = new HashMap<String, Map<String, ConfusionMatrixLine>>();
+
+ public void add(DocumentSample reference, DocumentSample prediction) {
+ int length = reference.getText().length;
+ averageSentenceLength.add(length);
+
+ if (minimalSentenceLength > length) {
+ minimalSentenceLength = length;
+ }
+ if (maximumSentenceLength < length) {
+ maximumSentenceLength = length;
+ }
+
+ // String[] toks = reference.getSentence();
+ String[] refs = { reference.getCategory() };
+ String[] preds = { prediction.getCategory() };
+
+ updateTagFMeasure(refs, preds);
+
+ // for (int i = 0; i < toks.length; i++) {
+ add("xx", reference.getCategory(), prediction.getCategory());
+ // }
+ }
+
+ /**
+ * Includes a new evaluation data
+ *
+ * @param tok
+ * the evaluated token
+ * @param ref
+ * the reference pos tag
+ * @param pred
+ * the predicted pos tag
+ */
+ private void add(String tok, String ref, String pred) {
+ // token stats
+ if (!tokAccuracies.containsKey(tok)) {
+ tokAccuracies.put(tok, new Mean());
+ tokOcurrencies.put(tok, new Counter());
+ tokErrors.put(tok, new Counter());
+ }
+ tokOcurrencies.get(tok).increment();
+
+ // tag stats
+ if (!tagOcurrencies.containsKey(ref)) {
+ tagOcurrencies.put(ref, new Counter());
+ tagErrors.put(ref, new Counter());
+ }
+ tagOcurrencies.get(ref).increment();
+
+ // updates general, token and tag error stats
+ if (ref.equals(pred)) {
+ tokAccuracies.get(tok).add(1);
+ accuracy.add(1);
+ } else {
+ tokAccuracies.get(tok).add(0);
+ tokErrors.get(tok).increment();
+ tagErrors.get(ref).increment();
+ accuracy.add(0);
+ }
+
+ // populate confusion matrixes
+ if (!generalConfusionMatrix.containsKey(ref)) {
+ generalConfusionMatrix.put(ref, new ConfusionMatrixLine(ref));
+ }
+ generalConfusionMatrix.get(ref).increment(pred);
+
+ if (!tokenConfusionMatrix.containsKey(tok)) {
+ tokenConfusionMatrix.put(tok,
+ new HashMap<String, ConfusionMatrixLine>());
+ }
+ if (!tokenConfusionMatrix.get(tok).containsKey(ref)) {
+ tokenConfusionMatrix.get(tok).put(ref, new ConfusionMatrixLine(ref));
+ }
+ tokenConfusionMatrix.get(tok).get(ref).increment(pred);
+ }
+
+ private void updateTagFMeasure(String[] refs, String[] preds) {
+ // create a set with all tags
+ Set<String> tags = new HashSet<String>(Arrays.asList(refs));
+ tags.addAll(Arrays.asList(preds));
+
+ // create samples for each tag
+ for (String tag : tags) {
+ List<Span> reference = new ArrayList<Span>();
+ List<Span> prediction = new ArrayList<Span>();
+ for (int i = 0; i < refs.length; i++) {
+ if (refs[i].equals(tag)) {
+ reference.add(new Span(i, i + 1));
+ }
+ if (preds[i].equals(tag)) {
+ prediction.add(new Span(i, i + 1));
+ }
+ }
+ if (!this.tagFMeasure.containsKey(tag)) {
+ this.tagFMeasure.put(tag, new FMeasure());
+ }
+ // populate the fmeasure
+ this.tagFMeasure.get(tag).updateScores(
+ reference.toArray(new Span[reference.size()]),
+ prediction.toArray(new Span[prediction.size()]));
+ }
+ }
+
+ public double getAccuracy() {
+ return accuracy.mean();
+ }
+
+ public int getNumberOfTags() {
+ return this.tagOcurrencies.keySet().size();
+ }
+
+ public long getNumberOfSentences() {
+ return this.averageSentenceLength.count();
+ }
+
+ public double getAverageSentenceSize() {
+ return this.averageSentenceLength.mean();
+ }
+
+ public int getMinSentenceSize() {
+ return this.minimalSentenceLength;
+ }
+
+ public int getMaxSentenceSize() {
+ return this.maximumSentenceLength;
+ }
+
+ public double getTokenAccuracy(String token) {
+ return tokAccuracies.get(token).mean();
+ }
+
+ public int getTokenErrors(String token) {
+ return tokErrors.get(token).value();
+ }
+
+ public int getTokenFrequency(String token) {
+ return tokOcurrencies.get(token).value();
+ }
+
+ public SortedSet<String> getTokensOrderedByFrequency() {
+ SortedSet<String> toks = new TreeSet<String>(new Comparator<String>() {
+ public int compare(String o1, String o2) {
+ if (o1.equals(o2)) {
+ return 0;
+ }
+ int e1 = 0, e2 = 0;
+ if (tokOcurrencies.containsKey(o1))
+ e1 = tokOcurrencies.get(o1).value();
+ if (tokOcurrencies.containsKey(o2))
+ e2 = tokOcurrencies.get(o2).value();
+ if (e1 == e2) {
+ return o1.compareTo(o2);
+ }
+ return e2 - e1;
+ }
+ });
+
+ toks.addAll(tokOcurrencies.keySet());
+
+ return Collections.unmodifiableSortedSet(toks);
+ }
+
+ public SortedSet<String> getTokensOrderedByNumberOfErrors() {
+ SortedSet<String> toks = new TreeSet<String>(new Comparator<String>() {
+ public int compare(String o1, String o2) {
+ if (o1.equals(o2)) {
+ return 0;
+ }
+ int e1 = 0, e2 = 0;
+ if (tokErrors.containsKey(o1))
+ e1 = tokErrors.get(o1).value();
+ if (tokErrors.containsKey(o2))
+ e2 = tokErrors.get(o2).value();
+ if (e1 == e2) {
+ return o1.compareTo(o2);
+ }
+ return e2 - e1;
+ }
+ });
+ toks.addAll(tokErrors.keySet());
+ return toks;
+ }
+
+ public int getTagFrequency(String tag) {
+ return tagOcurrencies.get(tag).value();
+ }
+
+ public int getTagErrors(String tag) {
+ return tagErrors.get(tag).value();
+ }
+
+ public double getTagFMeasure(String tag) {
+ return tagFMeasure.get(tag).getFMeasure();
+ }
+
+ public double getTagRecall(String tag) {
+ return tagFMeasure.get(tag).getRecallScore();
+ }
+
+ public double getTagPrecision(String tag) {
+ return tagFMeasure.get(tag).getPrecisionScore();
+ }
+
+ public SortedSet<String> getTagsOrderedByErrors() {
+ SortedSet<String> tags = new TreeSet<String>(new Comparator<String>() {
+ public int compare(String o1, String o2) {
+ if (o1.equals(o2)) {
+ return 0;
+ }
+ int e1 = 0, e2 = 0;
+ if (tagErrors.containsKey(o1))
+ e1 = tagErrors.get(o1).value();
+ if (tagErrors.containsKey(o2))
+ e2 = tagErrors.get(o2).value();
+ if (e1 == e2) {
+ return o1.compareTo(o2);
+ }
+ return e2 - e1;
+ }
+ });
+ tags.addAll(tagErrors.keySet());
+ return Collections.unmodifiableSortedSet(tags);
+ }
+
+ public SortedSet<String> getConfusionMatrixTagset() {
+ return getConfusionMatrixTagset(generalConfusionMatrix);
+ }
+
+ public double[][] getConfusionMatrix() {
+ return createConfusionMatrix(getConfusionMatrixTagset(),
+ generalConfusionMatrix);
+ }
+
+ public SortedSet<String> getConfusionMatrixTagset(String token) {
+ return getConfusionMatrixTagset(tokenConfusionMatrix.get(token));
+ }
+
+ public double[][] getConfusionMatrix(String token) {
+ return createConfusionMatrix(getConfusionMatrixTagset(token),
+ tokenConfusionMatrix.get(token));
+ }
+
+ /**
+ * Creates a matrix with N lines and N + 1 columns with the data from
+ * confusion matrix. The last column is the accuracy.
+ */
+ private double[][] createConfusionMatrix(SortedSet<String> tagset,
+ Map<String, ConfusionMatrixLine> data) {
+ int size = tagset.size();
+ double[][] matrix = new double[size][size + 1];
+ int line = 0;
+ for (String ref : tagset) {
+ int column = 0;
+ for (String pred : tagset) {
+ matrix[line][column] = (double) (data.get(ref) != null ? data
+ .get(ref).getValue(pred) : 0);
+ column++;
+ }
+ // set accuracy
+ matrix[line][column] = (double) (data.get(ref) != null ? data.get(ref)
+ .getAccuracy() : 0);
+ line++;
+ }
+
+ return matrix;
+ }
+
+ private SortedSet<String> getConfusionMatrixTagset(
+ Map<String, ConfusionMatrixLine> data) {
+ SortedSet<String> tags = new TreeSet<String>(new
CategoryComparator(data));
+ tags.addAll(data.keySet());
+ List<String> col = new LinkedList<String>();
+ for (String t : tags) {
+ col.addAll(data.get(t).line.keySet());
+ }
+ tags.addAll(col);
+ return Collections.unmodifiableSortedSet(tags);
+ }
+ }
+
+ /**
+ * A comparator that sorts the confusion matrix labels according to the
+ * accuracy of each line
+ */
+ private static class CategoryComparator implements Comparator<String> {
+
+ private Map<String, ConfusionMatrixLine> confusionMatrix;
+
+ public CategoryComparator(Map<String, ConfusionMatrixLine>
confusionMatrix) {
+ this.confusionMatrix = confusionMatrix;
+ }
+
+ public int compare(String o1, String o2) {
+ if (o1.equals(o2)) {
+ return 0;
+ }
+ ConfusionMatrixLine t1 = confusionMatrix.get(o1);
+ ConfusionMatrixLine t2 = confusionMatrix.get(o2);
+ if (t1 == null || t2 == null) {
+ if (t1 == null) {
+ return 1;
+ } else if (t2 == null) {
+ return -1;
+ }
+ return 0;
+ }
+ double r1 = t1.getAccuracy();
+ double r2 = t2.getAccuracy();
+ if (r1 == r2) {
+ return o1.compareTo(o2);
+ }
+ if (r2 > r1) {
+ return 1;
+ }
+ return -1;
+ }
+
+ }
+
+ /**
+ * Represents a line in the confusion table.
+ */
+ private static class ConfusionMatrixLine {
+
+ private Map<String, Counter> line = new HashMap<String, Counter>();
+ private String ref;
+ private int total = 0;
+ private int correct = 0;
+ private double acc = -1;
+
+ /**
+ * Creates a new {@link ConfusionMatrixLine}
+ *
+ * @param ref
+ * the reference column
+ */
+ public ConfusionMatrixLine(String ref) {
+ this.ref = ref;
+ }
+
+ /**
+ * Increments the counter for the given column and updates the statistics.
+ *
+ * @param column
+ * the column to be incremented
+ */
+ public void increment(String column) {
+ total++;
+ if (column.equals(ref))
+ correct++;
+ if (!line.containsKey(column)) {
+ line.put(column, new Counter());
+ }
+ line.get(column).increment();
+ }
+
+ /**
+ * Gets the calculated accuracy of this element
+ *
+ * @return the accuracy
+ */
+ public double getAccuracy() {
+ // we save the accuracy because it is frequently used by the comparator
+ if (acc == -1) {
+ if (total == 0)
+ acc = 0;
+ acc = (double) correct / (double) total;
+ }
+ return acc;
+ }
+
+ /**
+ * Gets the value given a column
+ *
+ * @param column
+ * the column
+ * @return the counter value
+ */
+ public int getValue(String column) {
+ Counter c = line.get(column);
+ if (c == null)
+ return 0;
+ return c.value();
+ }
+ }
+
+ /**
+ * Implements a simple counter
+ */
+ private static class Counter {
+ private int c = 0;
+
+ public void increment() {
+ c++;
+ }
+
+ public int value() {
+ return c;
+ }
+ }
+
+}
Propchange:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/cmdline/doccat/DoccatFineGrainedReportListener.java
------------------------------------------------------------------------------
svn:mime-type = text/plain
Added:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/doccat/DoccatEvaluationMonitor.java
URL:
http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/doccat/DoccatEvaluationMonitor.java?rev=1586502&view=auto
==============================================================================
---
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/doccat/DoccatEvaluationMonitor.java
(added)
+++
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/doccat/DoccatEvaluationMonitor.java
Thu Apr 10 23:25:11 2014
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.tools.doccat;
+
+import opennlp.tools.util.eval.EvaluationMonitor;
+
+public interface DoccatEvaluationMonitor extends
+ EvaluationMonitor<DocumentSample> {
+
+}
Propchange:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/doccat/DoccatEvaluationMonitor.java
------------------------------------------------------------------------------
svn:mime-type = text/plain
Modified:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/doccat/DocumentCategorizerEvaluator.java
URL:
http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/doccat/DocumentCategorizerEvaluator.java?rev=1586502&r1=1586501&r2=1586502&view=diff
==============================================================================
---
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/doccat/DocumentCategorizerEvaluator.java
(original)
+++
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/doccat/DocumentCategorizerEvaluator.java
Thu Apr 10 23:25:11 2014
@@ -18,10 +18,8 @@
package opennlp.tools.doccat;
-import java.util.Iterator;
-
-import opennlp.tools.postag.POSSample;
import opennlp.tools.tokenize.TokenSample;
+import opennlp.tools.util.eval.Evaluator;
import opennlp.tools.util.eval.Mean;
/**
@@ -32,7 +30,7 @@ import opennlp.tools.util.eval.Mean;
* @see DocumentCategorizer
* @see DocumentSample
*/
-public class DocumentCategorizerEvaluator {
+public class DocumentCategorizerEvaluator extends Evaluator<DocumentSample>{
private DocumentCategorizer categorizer;
@@ -43,7 +41,9 @@ public class DocumentCategorizerEvaluato
*
* @param categorizer
*/
- public DocumentCategorizerEvaluator(DocumentCategorizer categorizer) {
+ public DocumentCategorizerEvaluator(DocumentCategorizer categorizer,
+ DoccatEvaluationMonitor ... listeners) {
+ super(listeners);
this.categorizer = categorizer;
}
@@ -56,7 +56,7 @@ public class DocumentCategorizerEvaluato
*
* @param sample the reference {@link TokenSample}.
*/
- public void evaluteSample(DocumentSample sample) {
+ public DocumentSample processSample(DocumentSample sample) {
String document[] = sample.getText();
@@ -70,21 +70,8 @@ public class DocumentCategorizerEvaluato
else {
accuracy.add(0);
}
- }
- /**
- * Reads all {@link DocumentSample} objects from the stream
- * and evaluates each {@link DocumentSample} object with
- * {@link #evaluteSample(DocumentSample)} method.
- *
- * @param samples the stream of reference {@link POSSample} which
- * should be evaluated.
- */
- public void evaluate(Iterator<DocumentSample> samples) {
-
- while (samples.hasNext()) {
- evaluteSample(samples.next());
- }
+ return new DocumentSample(cat, sample.getText());
}
/**
Modified:
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/util/eval/Evaluator.java
URL:
http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/util/eval/Evaluator.java?rev=1586502&r1=1586501&r2=1586502&view=diff
==============================================================================
---
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/util/eval/Evaluator.java
(original)
+++
opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/util/eval/Evaluator.java
Thu Apr 10 23:25:11 2014
@@ -59,10 +59,7 @@ public abstract class Evaluator<T> {
*
* @return the predicted sample
*/
- protected T processSample(T reference) {
- // should be overridden by subclass... in the future we will make it
abstract.
- return null;
- }
+ protected abstract T processSample(T reference);
/**
* Evaluates the given reference object. The default implementation calls