Author: ogrisel
Date: Tue Jan 17 18:54:33 2012
New Revision: 1232533
URL: http://svn.apache.org/viewvc?rev=1232533&view=rev
Log:
STANBOL-197: compute precision, recall and f1 score + averaging accross CV folds
Modified:
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TopicEngineTest.java
Modified:
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java
URL:
http://svn.apache.org/viewvc/incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java?rev=1232533&r1=1232532&r2=1232533&view=diff
==============================================================================
---
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java
(original)
+++
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java
Tue Jan 17 18:54:33 2012
@@ -26,6 +26,7 @@ import java.util.Collection;
import java.util.Collections;
import java.util.Date;
import java.util.Dictionary;
+import java.util.HashMap;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.LinkedHashSet;
@@ -838,11 +839,17 @@ public class TopicClassificationEngine e
@Override
public int process(List<SolrDocument> batch) throws
TrainingSetException, ClassifierException {
+ int offset;
for (SolrDocument topicMetadata : batch) {
String topic =
topicMetadata.getFirstValue(topicUriField).toString();
List<String> impactedTopics = new ArrayList<String>();
- int offset = 0;
+
Batch<String> examples = Batch.emtpyBatch(String.class);
+
+ List<String> falseNegativeExamples = new
ArrayList<String>();
+ int truePositives = 0;
+ int falseNegatives = 0;
+ offset = 0;
do {
examples =
trainingSet.getPositiveExamples(impactedTopics, examples.nextOffset);
for (String example : examples.items) {
@@ -853,16 +860,58 @@ public class TopicClassificationEngine e
}
offset++;
if
(classifier.suggestTopics(example).contains(topic)) {
- // count positive success
+ truePositives++;
} else {
- // collect false negatives
+ falseNegatives++;
+ // falseNegativeExamples.add(exampleId);
}
}
} while (examples.hasMore); // TODO: put a bound on the
number of examples
- // TODO: handle false positives with negative examples here
+ List<String> falsePositiveExamples = new
ArrayList<String>();
+ int trueNegatives = 0;
+ int falsePositives = 0;
+ offset = 0;
+ do {
+ examples =
trainingSet.getNegativeExamples(impactedTopics, examples.nextOffset);
+ for (String example : examples.items) {
+ if (!(offset % foldCount == foldIndex)) {
+ // TODO: change the dataset API to include
exampleId
+ // this example is not part of the test fold,
skip it
+ offset++;
+ continue;
+ }
+ offset++;
+ if
(classifier.suggestTopics(example).contains(topic)) {
+ falsePositives++;
+ // TODO: change the dataset API to include
exampleId
+ // falsePositiveExamples.add(exampleId);
+ } else {
+ trueNegatives++;
+ }
+ }
+ } while (examples.hasMore); // TODO: put a bound on the
number of examples
- // TODO: store performance statistics for current model in
the original classifier
+ // compute precision, recall and f1 score for the current
test fold and topic
+ float precision = 0;
+ if (truePositives != 0 || falsePositives != 0) {
+ precision = truePositives / (float) (truePositives +
falsePositives);
+ }
+ float recall = 0;
+ if (trueNegatives != 0 || falseNegatives != 0) {
+ recall = trueNegatives / (float) (trueNegatives +
falseNegatives);
+ }
+ float f1 = 0;
+ if (precision != 0 || recall != 0) {
+ f1 = 2 * precision * recall / (precision + recall);
+ }
+ updatePerformanceMetadata(topic, precision, recall, f1,
falsePositiveExamples,
+ falseNegativeExamples);
+ }
+ try {
+ getActiveSolrServer().commit();
+ } catch (Exception e) {
+ throw new ClassifierException(e);
}
return batch.size();
}
@@ -874,37 +923,122 @@ public class TopicClassificationEngine e
cvFoldIndex + 1, cvFoldCount, engineId, (stop - start) / 1000.0,
averageF1));
}
+ /**
+ * Update the performance statistics in a metadata entry of a topic. It
ist the responsibility of the
+ * caller to commit.
+ */
+ protected void updatePerformanceMetadata(String topicId,
+ float precision,
+ float recall,
+ float f1,
+ List<String>
falsePositiveExamples,
+ List<String>
falseNegativeExamples) throws ClassifierException {
+ SolrServer solrServer = getActiveSolrServer();
+ try {
+ SolrQuery query = new SolrQuery(entryTypeField + ":" +
METADATA_ENTRY + " AND " + topicUriField
+ + ":" +
ClientUtils.escapeQueryChars(topicId));
+ for (SolrDocument result : solrServer.query(query).getResults()) {
+ // there should be only one (or none: tolerated)
+ // fetch any old values to update (all metadata fields are
assumed to be stored)s
+ Map<String,Collection<Object>> fieldValues = new
HashMap<String,Collection<Object>>();
+ for (String fieldName : result.getFieldNames()) {
+ fieldValues.put(fieldName,
result.getFieldValues(fieldName));
+ }
+ addToList(fieldValues, precisionField, precision);
+ addToList(fieldValues, recallField, recall);
+ addToList(fieldValues, f1Field, f1);
+ // TODO: handle supports too...
+ // addToList(fieldValues, falsePositivesField,
falsePositiveExamples);
+ // addToList(fieldValues, falseNegativesField,
falseNegativeExamples);
+ SolrInputDocument newEntry = new SolrInputDocument();
+ for (Map.Entry<String,Collection<Object>> entry :
fieldValues.entrySet()) {
+ newEntry.addField(entry.getKey(), entry.getValue());
+ }
+ newEntry.setField(modelEvaluationDateField,
UTCTimeStamper.nowUtcDate());
+ solrServer.add(newEntry);
+ }
+ } catch (Exception e) {
+ String msg = String.format(
+ "Error updating performance metadata for topic '%s' on Solr
Core '%s'", topicId, solrCoreId);
+ throw new ClassifierException(msg, e);
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ protected void addToList(Map<String,Collection<Object>> fieldValues,
String fieldName, Object value) {
+ Collection<Object> values = new ArrayList<Object>();
+ if (fieldValues.get(fieldName) != null) {
+ values.addAll(fieldValues.get(fieldName));
+ }
+ if (value instanceof Collection) {
+ values.addAll((Collection<Object>) value);
+ } else {
+ values.add(value);
+ }
+ fieldValues.put(fieldName, values);
+ }
+
@Override
- public ClassificationReport getPerformanceEstimates(String topic) throws
ClassifierException {
+ public ClassificationReport getPerformanceEstimates(String topicId) throws
ClassifierException {
SolrServer solrServer = getActiveSolrServer();
- SolrQuery query = new SolrQuery(entryIdField + ":" + METADATA_ENTRY +
" AND " + topicUriField + ":"
- + ClientUtils.escapeQueryChars(topic));
+ SolrQuery query = new SolrQuery(entryTypeField + ":" + METADATA_ENTRY
+ " AND " + topicUriField + ":"
+ +
ClientUtils.escapeQueryChars(topicId));
try {
SolrDocumentList results = solrServer.query(query).getResults();
if (results.isEmpty()) {
- throw new ClassifierException(String.format("%s is not a
registered topic", topic));
+ throw new ClassifierException(String.format("'%s' is not a
registered topic", topicId));
}
SolrDocument metadata = results.get(0);
- float precision = (Float) metadata.getFirstValue(precisionField);
- float recall = (Float) metadata.getFirstValue(recallField);
- float f1 = (Float) metadata.getFirstValue(f1Field);
- int positiveSupport = (Integer)
metadata.getFirstValue(positiveSupportField);
- int negativeSupport = (Integer)
metadata.getFirstValue(negativeSupportField);
+ Float precision = computeMeanValue(metadata, precisionField);
+ Float recall = computeMeanValue(metadata, recallField);
+ Float f1 = computeMeanValue(metadata, f1Field);
+ int positiveSupport = computeSumValue(metadata,
positiveSupportField);
+ int negativeSupport = computeSumValue(metadata,
negativeSupportField);
Date evaluationDate = (Date)
metadata.getFirstValue(modelEvaluationDateField);
boolean uptodate = evaluationDate != null;
ClassificationReport report = new ClassificationReport(precision,
recall, f1, positiveSupport,
negativeSupport, uptodate, evaluationDate);
- for (Object falsePositiveId :
metadata.getFieldValues(FALSE_POSITIVES_FIELD)) {
+ if (metadata.getFieldValues(falsePositivesField) == null) {
+ metadata.setField(falsePositivesField, new
ArrayList<Object>());
+ }
+ for (Object falsePositiveId :
metadata.getFieldValues(falsePositivesField)) {
report.falsePositiveExampleIds.add(falsePositiveId.toString());
}
- for (Object falseNegativeId :
metadata.getFieldValues(FALSE_NEGATIVES_FIELD)) {
+ if (metadata.getFieldValues(falseNegativesField) == null) {
+ metadata.setField(falseNegativesField, new
ArrayList<Object>());
+ }
+ for (Object falseNegativeId :
metadata.getFieldValues(falseNegativesField)) {
report.falseNegativeExampleIds.add(falseNegativeId.toString());
}
return report;
} catch (SolrServerException e) {
throw new ClassifierException(String.format("Error fetching the
performance report for topic "
- + topic));
+ + topicId));
+ }
+ }
+
+ protected Float computeMeanValue(SolrDocument metadata, String fielName) {
+ Float mean = 0f;
+ Collection<Object> values = metadata.getFieldValues(fielName);
+ if (values == null || values.isEmpty()) {
+ return mean;
+ }
+ for (Object v : values) {
+ mean += (Float) v / values.size();
+ }
+ return mean;
+ }
+
+ protected Integer computeSumValue(SolrDocument metadata, String fielName) {
+ Integer sum = 0;
+ Collection<Object> values = metadata.getFieldValues(fielName);
+ if (values == null || values.isEmpty()) {
+ return sum;
+ }
+ for (Object v : values) {
+ sum += (Integer) v;
}
+ return sum;
}
}
Modified:
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TopicEngineTest.java
URL:
http://svn.apache.org/viewvc/incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TopicEngineTest.java?rev=1232533&r1=1232532&r2=1232533&view=diff
==============================================================================
---
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TopicEngineTest.java
(original)
+++
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TopicEngineTest.java
Tue Jan 17 18:54:33 2012
@@ -337,7 +337,46 @@ public class TopicEngineTest extends Bas
}
- //@Test
+ @Test
+ public void testUpdatePerformanceEstimates() throws Exception {
+ ClassificationReport performanceEstimates;
+ // no registered topic
+ try {
+ classifier.getPerformanceEstimates("urn:t/001");
+ fail("Should have raised ClassifierException");
+ } catch (ClassifierException e) {
+ // expected
+ }
+
+ // register some topics
+ classifier.addTopic("urn:t/001", null);
+ classifier.addTopic("urn:t/002", Arrays.asList("urn:t/001"));
+ performanceEstimates = classifier.getPerformanceEstimates("urn:t/002");
+ assertFalse(performanceEstimates.uptodate);
+
+ // update the performance metadata manually
+ classifier.updatePerformanceMetadata("urn:t/002", 0.76f, 0.60f, 0.67f,
Arrays.asList("ex14", "ex78"),
+ Arrays.asList("ex34", "ex23", "ex89"));
+ classifier.getActiveSolrServer().commit();
+ performanceEstimates = classifier.getPerformanceEstimates("urn:t/002");
+ assertTrue(performanceEstimates.uptodate);
+ assertEquals(Float.valueOf(0.76f),
Float.valueOf(performanceEstimates.precision));
+ assertEquals(Float.valueOf(0.60f),
Float.valueOf(performanceEstimates.recall));
+ assertEquals(Float.valueOf(0.67f),
Float.valueOf(performanceEstimates.f1));
+
assertTrue(classifier.getBroaderTopics("urn:t/002").contains("urn:t/001"));
+
+ // accumulate other folds statistics and compute means of statistics
+ classifier.updatePerformanceMetadata("urn:t/002", 0.79f, 0.63f, 0.72f,
Arrays.asList("ex1", "ex5"),
+ Arrays.asList("ex3", "ex10", "ex11"));
+ classifier.getActiveSolrServer().commit();
+ performanceEstimates = classifier.getPerformanceEstimates("urn:t/002");
+ assertTrue(performanceEstimates.uptodate);
+ assertEquals(Float.valueOf(0.775f),
Float.valueOf(performanceEstimates.precision));
+ assertEquals(Float.valueOf(0.615f),
Float.valueOf(performanceEstimates.recall));
+ assertEquals(Float.valueOf(0.69500005f),
Float.valueOf(performanceEstimates.f1));
+ }
+
+ // @Test
public void testCrossValidation() throws Exception {
// seed a pseudo random number generator for reproducible tests
Random rng = new Random(0);