Author: ogrisel
Date: Fri Jan 13 18:58:24 2012
New Revision: 1231246
URL: http://svn.apache.org/viewvc?rev=1231246&view=rev
Log:
STANBOL-197: WIP: TDD for cross validation-based evaluation
Modified:
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/ClassificationReport.java
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/TopicClassifier.java
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TopicEngineTest.java
Modified:
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java
URL:
http://svn.apache.org/viewvc/incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java?rev=1231246&r1=1231245&r2=1231246&view=diff
==============================================================================
---
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java
(original)
+++
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java
Fri Jan 13 18:58:24 2012
@@ -677,8 +677,12 @@ public class TopicClassificationEngine e
@Override
public void setCrossValidationInfo(int foldIndex, int foldCount) {
- // TODO Auto-generated method stub
-
+ if (foldIndex > foldCount - 1) {
+ throw new IllegalArgumentException(String.format(
+ "foldIndex=%d should be smaller than foldCount=%d - 1",
foldIndex, foldCount));
+ }
+ cvFoldIndex = foldIndex;
+ cvFoldCount = foldCount;
}
@Override
@@ -693,13 +697,36 @@ public class TopicClassificationEngine e
}
- public void updatePerformanceEstimates(boolean incremental) throws
ClassifierException, TrainingSetException {
-
+ public int updatePerformanceEstimates(boolean incremental) throws
ClassifierException,
+
TrainingSetException {
+ int updatedTopics = 0;
+ // TODO
+ return updatedTopics;
}
@Override
public ClassificationReport getPerformanceEstimates(String topic) throws
ClassifierException {
- // TODO Auto-generated method stub
- return null;
+
+ SolrServer solrServer = getActiveSolrServer();
+ SolrQuery query = new SolrQuery(entryIdField + ":" + METADATA_ENTRY +
" AND " + topicUriField + ":"
+ + ClientUtils.escapeQueryChars(topic));
+ try {
+ SolrDocumentList results = solrServer.query(query).getResults();
+ if (results.isEmpty()) {
+ throw new ClassifierException(String.format("%s is not a
registered topic", topic));
+ }
+ SolrDocument metadata = results.get(0);
+ float precision = (Float) metadata.getFirstValue(precisionField);
+ float recall = (Float) metadata.getFirstValue(recallField);
+ float f1 = (Float) metadata.getFirstValue(f1Field);
+ // int positiveSupport = (Integer) metadata.getFirstValue(po);
+ // int negativeSupport = 0;
+ Date evaluationDate = (Date)
metadata.getFirstValue(modelEvaluationDateField);
+ boolean uptodate = evaluationDate != null;
+ return new ClassificationReport(precision, recall, f1, 0, 0,
uptodate, evaluationDate);
+ } catch (SolrServerException e) {
+ throw new ClassifierException(String.format("Error fetching the
performance report for topic "
+ + topic));
+ }
}
}
Modified:
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/ClassificationReport.java
URL:
http://svn.apache.org/viewvc/incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/ClassificationReport.java?rev=1231246&r1=1231245&r2=1231246&view=diff
==============================================================================
---
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/ClassificationReport.java
(original)
+++
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/ClassificationReport.java
Fri Jan 13 18:58:24 2012
@@ -17,6 +17,7 @@
package org.apache.stanbol.enhancer.topic;
import java.util.ArrayList;
+import java.util.Date;
import java.util.List;
/**
@@ -65,6 +66,10 @@ public class ClassificationReport {
*/
public final int negativeSupport;
+ public final boolean uptodate;
+
+ public final Date evaluationDate;
+
public final List<String> falsePositiveExampleIds = new
ArrayList<String>();
public final List<String> falseNegativeExampleIds = new
ArrayList<String>();
@@ -73,12 +78,16 @@ public class ClassificationReport {
float recall,
float f1,
int positiveSupport,
- int negativeSupport) {
+ int negativeSupport,
+ boolean uptodate,
+ Date evaluationDate) {
this.precision = precision;
this.recall = recall;
this.f1 = f1;
this.positiveSupport = positiveSupport;
this.negativeSupport = negativeSupport;
+ this.uptodate = uptodate;
+ this.evaluationDate = evaluationDate;
}
}
Modified:
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/TopicClassifier.java
URL:
http://svn.apache.org/viewvc/incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/TopicClassifier.java?rev=1231246&r1=1231245&r2=1231246&view=diff
==============================================================================
---
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/TopicClassifier.java
(original)
+++
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/TopicClassifier.java
Fri Jan 13 18:58:24 2012
@@ -112,9 +112,11 @@ public interface TopicClassifier {
/**
* Perform k-fold cross validation of the model to compute estimates of
the precision, recall and f1
* score.
+ *
+ * @return number of updated topics
*/
- public void updatePerformanceEstimates(boolean incremental) throws
ClassifierException,
-
TrainingSetException;
+ public int updatePerformanceEstimates(boolean incremental) throws
ClassifierException,
+
TrainingSetException;
/**
* Tell the classifier which slice of data to keep aside while training
for model evaluation using k-folds
Modified:
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TopicEngineTest.java
URL:
http://svn.apache.org/viewvc/incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TopicEngineTest.java?rev=1231246&r1=1231245&r2=1231246&view=diff
==============================================================================
---
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TopicEngineTest.java
(original)
+++
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TopicEngineTest.java
Fri Jan 13 18:58:24 2012
@@ -17,6 +17,7 @@
package org.apache.stanbol.enhancer.engine.topic;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
@@ -28,17 +29,24 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.Hashtable;
import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.TreeMap;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
+import org.apache.commons.lang.StringUtils;
import org.apache.solr.client.solrj.SolrQuery;
import org.apache.solr.client.solrj.SolrServerException;
import org.apache.solr.client.solrj.embedded.EmbeddedSolrServer;
import org.apache.solr.client.solrj.response.QueryResponse;
import org.apache.solr.common.params.CommonParams;
import org.apache.stanbol.commons.solr.utils.StreamQueryRequest;
+import org.apache.stanbol.enhancer.topic.ClassificationReport;
+import org.apache.stanbol.enhancer.topic.ClassifierException;
import org.apache.stanbol.enhancer.topic.SolrTrainingSet;
import org.apache.stanbol.enhancer.topic.TopicSuggestion;
+import org.apache.stanbol.enhancer.topic.TrainingSetException;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@@ -321,12 +329,145 @@ public class TopicEngineTest extends Bas
classifier.addTopic(law, null);
assertEquals(1, classifier.updateModel(true));
assertEquals(0, classifier.updateModel(true));
-
+
// registering new subtopics invalidate the models of the parent as
well
classifier.addTopic("urn:topics/sportsmafia", Arrays.asList(football,
business));
assertEquals(3, classifier.updateModel(true));
assertEquals(0, classifier.updateModel(true));
-
+
+ }
+
+ //@Test
+ public void testCrossValidation() throws Exception {
+ // seed a pseudo random number generator for reproducible tests
+ Random rng = new Random(0);
+ ClassificationReport performanceEstimates;
+
+ // build an artificial data set used for training models and evaluation
+ int numberOfTopics = 10;
+ int numberOfDocuments = 100;
+ int vocabSizeMin = 10;
+ int vocabSizeMax = 25; // we are using the alphabet as base terms
+ initArtificialTrainingSet(numberOfTopics, numberOfDocuments,
vocabSizeMin, vocabSizeMax, rng);
+
+ // by default the reports are not computed
+ performanceEstimates = classifier.getPerformanceEstimates("urn:t/001");
+ assertFalse(performanceEstimates.uptodate);
+ performanceEstimates = classifier.getPerformanceEstimates("urn:t/002");
+ assertFalse(performanceEstimates.uptodate);
+ performanceEstimates = classifier.getPerformanceEstimates("urn:t/003");
+ assertFalse(performanceEstimates.uptodate);
+
+ try {
+ classifier.getPerformanceEstimates("urn:doesnotexist");
+ fail("Should have raised a ClassifierException");
+ } catch (ClassifierException e) {
+ // expected
+ }
+
+ // let's evaluate the first topic manually
+ assertEquals(numberOfTopics,
classifier.updatePerformanceEstimates(true));
+ performanceEstimates = classifier.getPerformanceEstimates("urn:t/001");
+ assertTrue(performanceEstimates.uptodate);
+ assertGreater(performanceEstimates.precision, 0.8f);
+ assertGreater(performanceEstimates.recall, 0.8f);
+ assertGreater(performanceEstimates.f1, 0.8f);
+ assertGreater(performanceEstimates.positiveSupport, 10);
+ assertGreater(performanceEstimates.negativeSupport, 90);
+ assertNotNull(performanceEstimates.evaluationDate);
+
+ performanceEstimates = classifier.getPerformanceEstimates("urn:t/002");
+ assertTrue(performanceEstimates.uptodate);
+ assertGreater(performanceEstimates.precision, 0.8f);
+ assertGreater(performanceEstimates.recall, 0.8f);
+ assertGreater(performanceEstimates.f1, 0.8f);
+ assertGreater(performanceEstimates.positiveSupport, 10);
+ assertGreater(performanceEstimates.negativeSupport, 90);
+ assertNotNull(performanceEstimates.evaluationDate);
+
+ performanceEstimates = classifier.getPerformanceEstimates("urn:t/003");
+ assertTrue(performanceEstimates.uptodate);
+ assertGreater(performanceEstimates.precision, 0.8f);
+ assertGreater(performanceEstimates.recall, 0.8f);
+ assertGreater(performanceEstimates.f1, 0.8f);
+ assertGreater(performanceEstimates.positiveSupport, 10);
+ assertGreater(performanceEstimates.negativeSupport, 90);
+ assertNotNull(performanceEstimates.evaluationDate);
+
+ // TODO: test model invalidation by registering a sub topic manually
+ }
+
+ protected void assertGreater(float large, float small) {
+ if (small > large) {
+ throw new AssertionError(String.format("Expected %f to be greater
than %f.", large, small));
+ }
+ }
+
+ protected void initArtificialTrainingSet(int numberOfTopics,
+ int numberOfDocuments,
+ int vocabSizeMin,
+ int vocabSizeMax,
+ Random rng) throws
ClassifierException, TrainingSetException {
+ // define some artificial topics and register them to the classifier
+ char[] alphabet = "abcdefghijklmnopqrstuvwxyz".toCharArray();
+ String[] topics = new String[numberOfTopics];
+ Map<String,String[]> vocabularies = new TreeMap<String,String[]>();
+ for (int i = 0; i < numberOfTopics; i++) {
+ String topic = String.format("urn:t/%03d", i);
+ topics[i] = topic;
+ classifier.addTopic(topic, null);
+ int vocSize = rng.nextInt(vocabSizeMax + 1 - vocabSizeMin) +
vocabSizeMin;
+ String[] terms = new String[vocSize];
+
+ for (int j = 0; j < vocSize; j++) {
+ // define some artificial vocabulary for each topic to
automatically generate random examples
+ // with some topic structure
+ // if i = 1, will generate: ["a1", "b1", "c1", ...]
+ terms[j] = alphabet[j] + String.valueOf(i);
+ }
+ vocabularies.put(topic, terms);
+ }
+ classifier.setTrainingSet(trainingSet);
+
+ // build a random data where each example has a couple of dominating
topics and other words
+ for (int i = 0; i < numberOfDocuments; i++) {
+ List<String> documentTerms = new ArrayList<String>();
+
+ // add terms from some non-dominant topics that are used as
classification target
+ int numberOfDominantTopics = rng.nextInt(4) + 1; // between 1 and
3 topics
+ List<String> documentTopics = new ArrayList<String>();
+ for (int j = 0; j < numberOfDominantTopics; j++) {
+ String topic = randomTopicAndTerms(topics, vocabularies,
documentTerms, 50, 100, rng);
+ documentTopics.add(topic);
+ }
+ // add terms from some non-dominant topics
+ for (int j = 0; j < 10; j++) {
+ String topic = randomTopicAndTerms(topics, vocabularies,
documentTerms, 1, 10, rng);
+ documentTopics.add(topic);
+ }
+ // add some non discriminative terms not linked to any topic
+ for (int k = 0; k < 100; k++) {
+
documentTerms.add(String.valueOf(alphabet[rng.nextInt(alphabet.length)]));
+ }
+ // register the generated example in the training set
+ trainingSet.registerExample(String.format("example_%03d", i),
+ StringUtils.join(documentTerms, " "), Arrays.asList(topics));
+ }
+ }
+
+ protected String randomTopicAndTerms(String[] topics,
+ Map<String,String[]> vocabularies,
+ List<String> documentTerms,
+ int min,
+ int max,
+ Random rng) {
+ String topic = topics[rng.nextInt(topics.length)];
+ String[] terms = vocabularies.get(topic);
+ int numberOfDominantTopicTerm = rng.nextInt(max + 1 - min) + min;
+ for (int k = 0; k < numberOfDominantTopicTerm; k++) {
+ documentTerms.add(terms[rng.nextInt(terms.length)]);
+ }
+ return topic;
}
protected Hashtable<String,Object> getDefaultClassifierConfigParams() {