Author: ogrisel
Date: Fri Jan 20 17:09:18 2012
New Revision: 1234006
URL: http://svn.apache.org/viewvc?rev=1234006&view=rev
Log:
STANBOL-197: refactored training set API to make it possible to access of the
fields of the examples
Added:
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/training/
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/training/Example.java
Modified:
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/SolrTrainingSet.java
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/TrainingSet.java
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TrainingSetTest.java
Modified:
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java
URL:
http://svn.apache.org/viewvc/incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java?rev=1234006&r1=1234005&r2=1234006&view=diff
==============================================================================
---
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java
(original)
+++
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java
Fri Jan 20 17:09:18 2012
@@ -79,6 +79,7 @@ import org.apache.stanbol.enhancer.topic
import org.apache.stanbol.enhancer.topic.TrainingSet;
import org.apache.stanbol.enhancer.topic.TrainingSetException;
import org.apache.stanbol.enhancer.topic.UTCTimeStamper;
+import org.apache.stanbol.enhancer.topic.training.Example;
import org.osgi.framework.InvalidSyntaxException;
import org.osgi.service.cm.ConfigurationException;
import org.osgi.service.component.ComponentContext;
@@ -338,6 +339,10 @@ public class TopicClassificationEngine e
return acceptedLanguages;
}
+ public List<TopicSuggestion> suggestTopics(Collection<Object> contents)
throws ClassifierException {
+ return suggestTopics(StringUtils.join(contents, "\n\n"));
+ }
+
public List<TopicSuggestion> suggestTopics(String text) throws
ClassifierException {
List<TopicSuggestion> suggestedTopics = new
ArrayList<TopicSuggestion>(MAX_SUGGESTIONS * 3);
SolrServer solrServer = getActiveSolrServer();
@@ -687,12 +692,12 @@ public class TopicClassificationEngine e
Collection<Object> broaderTopics) throws
TrainingSetException,
ClassifierException {
long start = System.currentTimeMillis();
- Batch<String> examples = Batch.emtpyBatch(String.class);
+ Batch<Example> examples = Batch.emtpyBatch(Example.class);
StringBuffer sb = new StringBuffer();
int offset = 0;
do {
examples = trainingSet.getPositiveExamples(impactedTopics,
examples.nextOffset);
- for (String example : examples.items) {
+ for (Example example : examples.items) {
if ((cvFoldCount != 0) && (offset % cvFoldCount ==
cvFoldIndex)) {
// we are performing a cross validation session and this
example belong to the test
// fold hence should be skipped
@@ -700,7 +705,7 @@ public class TopicClassificationEngine e
continue;
}
offset++;
- sb.append(example);
+ sb.append(StringUtils.join(example.contents, "\n\n"));
sb.append("\n\n");
}
} while (sb.length() < MAX_CHARS_PER_TOPIC && examples.hasMore);
@@ -874,10 +879,10 @@ public class TopicClassificationEngine e
int falseNegatives = 0;
int positiveSupport = 0;
offset = 0;
- Batch<String> examples = Batch.emtpyBatch(String.class);
+ Batch<Example> examples = Batch.emtpyBatch(Example.class);
do {
examples = trainingSet.getPositiveExamples(topics,
examples.nextOffset);
- for (String example : examples.items) {
+ for (Example example : examples.items) {
if (!(offset % foldCount == foldIndex)) {
// this example is not part of the test fold,
skip it
offset++;
@@ -885,7 +890,8 @@ public class TopicClassificationEngine e
}
positiveSupport++;
offset++;
- List<TopicSuggestion> suggestedTopics =
classifier.suggestTopics(example);
+ List<TopicSuggestion> suggestedTopics = classifier
+ .suggestTopics(example.contents);
boolean match = false;
for (TopicSuggestion suggestedTopic :
suggestedTopics) {
if (topic.equals(suggestedTopic.uri)) {
@@ -896,7 +902,7 @@ public class TopicClassificationEngine e
}
if (!match) {
falseNegatives++;
- // falseNegativeExamples.add(exampleId);
+ falseNegativeExamples.add(example.id);
}
}
} while (examples.hasMore); // TODO: put a bound on the
number of examples
@@ -905,10 +911,10 @@ public class TopicClassificationEngine e
int falsePositives = 0;
int negativeSupport = 0;
offset = 0;
- examples = Batch.emtpyBatch(String.class);
+ examples = Batch.emtpyBatch(Example.class);
do {
examples = trainingSet.getNegativeExamples(topics,
examples.nextOffset);
- for (String example : examples.items) {
+ for (Example example : examples.items) {
if (!(offset % foldCount == foldIndex)) {
// TODO: change the dataset API to include
exampleId
// this example is not part of the test fold,
skip it
@@ -917,11 +923,12 @@ public class TopicClassificationEngine e
}
negativeSupport++;
offset++;
- List<TopicSuggestion> suggestedTopics =
classifier.suggestTopics(example);
+ List<TopicSuggestion> suggestedTopics = classifier
+ .suggestTopics(example.contents);
for (TopicSuggestion suggestedTopic :
suggestedTopics) {
if (topic.equals(suggestedTopic.uri)) {
falsePositives++;
- // falsePositiveExamples.add(exampleId);
+ falsePositiveExamples.add(example.id);
break;
}
}
Modified:
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/SolrTrainingSet.java
URL:
http://svn.apache.org/viewvc/incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/SolrTrainingSet.java?rev=1234006&r1=1234005&r2=1234006&view=diff
==============================================================================
---
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/SolrTrainingSet.java
(original)
+++
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/SolrTrainingSet.java
Fri Jan 20 17:09:18 2012
@@ -38,6 +38,7 @@ import org.apache.solr.client.solrj.resp
import org.apache.solr.client.solrj.util.ClientUtils;
import org.apache.solr.common.SolrDocument;
import org.apache.solr.common.SolrInputDocument;
+import org.apache.stanbol.enhancer.topic.training.Example;
import org.osgi.framework.InvalidSyntaxException;
import org.osgi.service.cm.ConfigurationException;
import org.osgi.service.component.ComponentContext;
@@ -195,17 +196,17 @@ public class SolrTrainingSet extends Con
}
@Override
- public Batch<String> getPositiveExamples(List<String> topics, Object
offset) throws TrainingSetException {
+ public Batch<Example> getPositiveExamples(List<String> topics, Object
offset) throws TrainingSetException {
return getExamples(topics, offset, true);
}
@Override
- public Batch<String> getNegativeExamples(List<String> topics, Object
offset) throws TrainingSetException {
+ public Batch<Example> getNegativeExamples(List<String> topics, Object
offset) throws TrainingSetException {
return getExamples(topics, offset, false);
}
- protected Batch<String> getExamples(List<String> topics, Object offset,
boolean positive) throws TrainingSetException {
- List<String> items = new ArrayList<String>();
+ protected Batch<Example> getExamples(List<String> topics, Object offset,
boolean positive) throws TrainingSetException {
+ List<Example> items = new ArrayList<Example>();
SolrServer solrServer = getActiveSolrServer();
SolrQuery query = new SolrQuery();
List<String> parts = new ArrayList<String>();
@@ -244,13 +245,13 @@ public class SolrTrainingSet extends Con
nextExampleId =
result.getFirstValue(exampleIdField).toString();
} else {
count++;
+ String exampleId =
result.getFirstValue(exampleIdField).toString();
+ Collection<Object> labelValues =
result.getFieldValues(topicUrisField);
Collection<Object> textValues =
result.getFieldValues(exampleTextField);
if (textValues == null) {
continue;
}
- for (Object value : textValues) {
- items.add(value.toString());
- }
+ items.add(new Example(exampleId, labelValues, textValues));
}
}
} catch (SolrServerException e) {
@@ -259,7 +260,7 @@ public class SolrTrainingSet extends Con
StringUtils.join(topics, "', '"), solrCoreId);
throw new TrainingSetException(msg, e);
}
- return new Batch<String>(items, nextExampleId != null, nextExampleId);
+ return new Batch<Example>(items, nextExampleId != null, nextExampleId);
}
@Override
Modified:
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/TrainingSet.java
URL:
http://svn.apache.org/viewvc/incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/TrainingSet.java?rev=1234006&r1=1234005&r2=1234006&view=diff
==============================================================================
---
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/TrainingSet.java
(original)
+++
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/TrainingSet.java
Fri Jan 20 17:09:18 2012
@@ -19,6 +19,8 @@ package org.apache.stanbol.enhancer.topi
import java.util.Date;
import java.util.List;
+import org.apache.stanbol.enhancer.topic.training.Example;
+
/**
* Source of categorized text documents that can be used to build a the
statistical model of a
* TopicClassifier.
@@ -60,7 +62,7 @@ public interface TrainingSet {
* marker value to fetch the next batch. Pass null to fetch the
first batch.
* @return a batch of example suitable for training a classifier model for
the requested topics.
*/
- Batch<String> getPositiveExamples(List<String> topics, Object offset)
throws TrainingSetException;
+ Batch<Example> getPositiveExamples(List<String> topics, Object offset)
throws TrainingSetException;
/**
* Fetch examples representative of any document not specifically
classified in one of the passed topics.
@@ -76,7 +78,7 @@ public interface TrainingSet {
* @return a batch of examples suitable for training (negative-refinement)
a classifier model for the
* requested topics.
*/
- Batch<String> getNegativeExamples(List<String> topics, Object offset)
throws TrainingSetException;
+ Batch<Example> getNegativeExamples(List<String> topics, Object offset)
throws TrainingSetException;
/**
* Number of examples to fetch at once.
Added:
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/training/Example.java
URL:
http://svn.apache.org/viewvc/incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/training/Example.java?rev=1234006&view=auto
==============================================================================
---
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/training/Example.java
(added)
+++
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/training/Example.java
Fri Jan 20 17:09:18 2012
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.stanbol.enhancer.topic.training;
+
+import java.util.Collection;
+
+import org.apache.commons.lang.StringUtils;
+
+/**
+ * Data transfer object to pack the items of a multi-label text classification
training set.
+ */
+public class Example {
+
+ /**
+ * Unique id of the document
+ */
+ public final String id;
+
+ /**
+ * Identifier of the labels (categories, topics, tags...) of the document.
This is the target signal to
+ * predict.
+ *
+ * In practice this is expected to be a collection of String items but we
do not enforce the cast to avoid
+ * the GC overhead and be able to work with the native data-structures
returned by SolrJ.
+ */
+ public final Collection<Object> labels;
+
+ /**
+ * Text fields of the document (could be headers, paragraphs, text
extractions of PDF files...). Any
+ * markup is assumed to have been cleaned up in some preprocessing step.
+ *
+ * In practice this is expected to be a collection of String items but we
do not enforce the cast to avoid
+ * the GC overhead and be able to work with the native data-structures
returned by SolrJ.
+ */
+ public final Collection<Object> contents;
+
+ public Example(String id, Collection<Object> labelValues,
Collection<Object> textValues) {
+ this.id = id;
+ this.labels = labelValues;
+ this.contents = textValues;
+ }
+
+ /**
+ * @return concatenated version of the content fields.
+ */
+ public String getContentString() {
+ return StringUtils.join(contents, "\n\n");
+ }
+}
Modified:
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TrainingSetTest.java
URL:
http://svn.apache.org/viewvc/incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TrainingSetTest.java?rev=1234006&r1=1234005&r2=1234006&view=diff
==============================================================================
---
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TrainingSetTest.java
(original)
+++
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TrainingSetTest.java
Fri Jan 20 17:09:18 2012
@@ -41,6 +41,7 @@ import org.apache.stanbol.enhancer.topic
import org.apache.stanbol.enhancer.topic.SolrTrainingSet;
import org.apache.stanbol.enhancer.topic.TrainingSetException;
import org.apache.stanbol.enhancer.topic.UTCTimeStamper;
+import org.apache.stanbol.enhancer.topic.training.Example;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@@ -97,7 +98,7 @@ public class TrainingSetTest extends Emb
@Test
public void testEmptyTrainingSet() throws TrainingSetException {
- Batch<String> examples = trainingSet.getPositiveExamples(new
ArrayList<String>(), null);
+ Batch<Example> examples = trainingSet.getPositiveExamples(new
ArrayList<String>(), null);
assertEquals(examples.items.size(), 0);
assertFalse(examples.hasMore);
examples = trainingSet.getNegativeExamples(new ArrayList<String>(),
null);
@@ -120,19 +121,20 @@ public class TrainingSetTest extends Emb
trainingSet.registerExample("example2", "Text of example2.",
Arrays.asList(TOPIC_1, TOPIC_2));
trainingSet.registerExample("example3", "Text of example3.", new
ArrayList<String>());
- Batch<String> examples =
trainingSet.getPositiveExamples(Arrays.asList(TOPIC_2), null);
+ Batch<Example> examples =
trainingSet.getPositiveExamples(Arrays.asList(TOPIC_2), null);
assertEquals(1, examples.items.size());
- assertEquals(examples.items, Arrays.asList("Text of example2."));
+ assertEquals(examples.items.get(0).getContentString(), "Text of
example2.");
assertFalse(examples.hasMore);
examples = trainingSet.getPositiveExamples(Arrays.asList(TOPIC_1,
TOPIC_3), null);
assertEquals(2, examples.items.size());
- assertEquals(examples.items, Arrays.asList("Text of example1.", "Text
of example2."));
+ assertEquals(examples.items.get(0).getContentString(), "Text of
example1.");
+ assertEquals(examples.items.get(1).getContentString(), "Text of
example2.");
assertFalse(examples.hasMore);
examples = trainingSet.getNegativeExamples(Arrays.asList(TOPIC_1),
null);
assertEquals(1, examples.items.size());
- assertEquals(examples.items, Arrays.asList("Text of example3."));
+ assertEquals(examples.items.get(0).getContentString(), "Text of
example3.");
assertFalse(examples.hasMore);
// Test example update by adding topic3 to example1. The results of
the previous query should remain
@@ -140,14 +142,15 @@ public class TrainingSetTest extends Emb
trainingSet.registerExample("example1", "Text of example1.",
Arrays.asList(TOPIC_1, TOPIC_3));
examples = trainingSet.getPositiveExamples(Arrays.asList(TOPIC_1,
TOPIC_3), null);
assertEquals(2, examples.items.size());
- assertEquals(examples.items, Arrays.asList("Text of example1.", "Text
of example2."));
+ assertEquals(examples.items.get(0).getContentString(), "Text of
example1.");
+ assertEquals(examples.items.get(1).getContentString(), "Text of
example2.");
assertFalse(examples.hasMore);
// Test example removal
trainingSet.registerExample("example1", null, Arrays.asList(TOPIC_1,
TOPIC_3));
examples = trainingSet.getPositiveExamples(Arrays.asList(TOPIC_1,
TOPIC_3), null);
assertEquals(1, examples.items.size());
- assertEquals(examples.items, Arrays.asList("Text of example2."));
+ assertEquals(examples.items.get(0).getContentString(), "Text of
example2.");
assertFalse(examples.hasMore);
trainingSet.registerExample("example2", null, Arrays.asList(TOPIC_1,
TOPIC_3));
@@ -158,52 +161,76 @@ public class TrainingSetTest extends Emb
@Test
public void testBatchingPositiveExamples() throws ConfigurationException,
TrainingSetException {
+ Set<String> expectedCollectedIds = new HashSet<String>();
Set<String> expectedCollectedText = new HashSet<String>();
+ Set<String> collectedIds = new HashSet<String>();
Set<String> collectedText = new HashSet<String>();
for (int i = 0; i < 28; i++) {
+ String id = "example-" + i;
String text = "Text of example" + i + ".";
- trainingSet.registerExample("example-" + i, text,
Arrays.asList(TOPIC_1));
+ trainingSet.registerExample(id, text, Arrays.asList(TOPIC_1));
+ expectedCollectedIds.add(id);
expectedCollectedText.add(text);
}
trainingSet.setBatchSize(10);
- Batch<String> examples =
trainingSet.getPositiveExamples(Arrays.asList(TOPIC_1, TOPIC_2), null);
+ Batch<Example> examples =
trainingSet.getPositiveExamples(Arrays.asList(TOPIC_1, TOPIC_2), null);
assertEquals(10, examples.items.size());
- collectedText.addAll(examples.items);
+ for (Example example : examples.items) {
+ collectedIds.add(example.id);
+ collectedText.add(example.getContentString());
+ }
assertTrue(examples.hasMore);
examples = trainingSet.getPositiveExamples(Arrays.asList(TOPIC_1,
TOPIC_2), examples.nextOffset);
assertEquals(10, examples.items.size());
- collectedText.addAll(examples.items);
+ for (Example example : examples.items) {
+ collectedIds.add(example.id);
+ collectedText.add(example.getContentString());
+ }
assertTrue(examples.hasMore);
examples = trainingSet.getPositiveExamples(Arrays.asList(TOPIC_1,
TOPIC_2), examples.nextOffset);
assertEquals(8, examples.items.size());
- collectedText.addAll(examples.items);
+ for (Example example : examples.items) {
+ collectedIds.add(example.id);
+ collectedText.add(example.getContentString());
+ }
assertFalse(examples.hasMore);
+ assertEquals(expectedCollectedIds, collectedIds);
assertEquals(expectedCollectedText, collectedText);
}
@Test
public void testBatchingNegativeExamplesAndAutoId() throws
ConfigurationException, TrainingSetException {
+ Set<String> expectedCollectedIds = new HashSet<String>();
Set<String> expectedCollectedText = new HashSet<String>();
+ Set<String> collectedIds = new HashSet<String>();
Set<String> collectedText = new HashSet<String>();
for (int i = 0; i < 17; i++) {
String text = "Text of example" + i + ".";
- trainingSet.registerExample(null, text, Arrays.asList(TOPIC_1));
+ String id = trainingSet.registerExample(null, text,
Arrays.asList(TOPIC_1));
+ expectedCollectedIds.add(id);
expectedCollectedText.add(text);
}
trainingSet.setBatchSize(10);
- Batch<String> examples =
trainingSet.getNegativeExamples(Arrays.asList(TOPIC_2), null);
+ Batch<Example> examples =
trainingSet.getNegativeExamples(Arrays.asList(TOPIC_2), null);
assertEquals(10, examples.items.size());
- collectedText.addAll(examples.items);
+ for (Example example : examples.items) {
+ collectedIds.add(example.id);
+ collectedText.add(example.getContentString());
+ }
assertTrue(examples.hasMore);
examples = trainingSet.getNegativeExamples(Arrays.asList(TOPIC_2),
examples.nextOffset);
assertEquals(7, examples.items.size());
- collectedText.addAll(examples.items);
+ for (Example example : examples.items) {
+ collectedIds.add(example.id);
+ collectedText.add(example.getContentString());
+ }
assertFalse(examples.hasMore);
+ assertEquals(expectedCollectedIds, collectedIds);
assertEquals(expectedCollectedText, collectedText);
}