Author: ogrisel
Date: Fri Jan 20 17:09:18 2012
New Revision: 1234006

URL: http://svn.apache.org/viewvc?rev=1234006&view=rev
Log:
STANBOL-197: refactored training set API to make it possible to access of the 
fields of the examples

Added:
    
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/training/
    
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/training/Example.java
Modified:
    
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java
    
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/SolrTrainingSet.java
    
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/TrainingSet.java
    
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TrainingSetTest.java

Modified: 
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java
URL: 
http://svn.apache.org/viewvc/incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java?rev=1234006&r1=1234005&r2=1234006&view=diff
==============================================================================
--- 
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java
 (original)
+++ 
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java
 Fri Jan 20 17:09:18 2012
@@ -79,6 +79,7 @@ import org.apache.stanbol.enhancer.topic
 import org.apache.stanbol.enhancer.topic.TrainingSet;
 import org.apache.stanbol.enhancer.topic.TrainingSetException;
 import org.apache.stanbol.enhancer.topic.UTCTimeStamper;
+import org.apache.stanbol.enhancer.topic.training.Example;
 import org.osgi.framework.InvalidSyntaxException;
 import org.osgi.service.cm.ConfigurationException;
 import org.osgi.service.component.ComponentContext;
@@ -338,6 +339,10 @@ public class TopicClassificationEngine e
         return acceptedLanguages;
     }
 
+    public List<TopicSuggestion> suggestTopics(Collection<Object> contents) 
throws ClassifierException {
+        return suggestTopics(StringUtils.join(contents, "\n\n"));
+    }
+
     public List<TopicSuggestion> suggestTopics(String text) throws 
ClassifierException {
         List<TopicSuggestion> suggestedTopics = new 
ArrayList<TopicSuggestion>(MAX_SUGGESTIONS * 3);
         SolrServer solrServer = getActiveSolrServer();
@@ -687,12 +692,12 @@ public class TopicClassificationEngine e
                                Collection<Object> broaderTopics) throws 
TrainingSetException,
                                                                 
ClassifierException {
         long start = System.currentTimeMillis();
-        Batch<String> examples = Batch.emtpyBatch(String.class);
+        Batch<Example> examples = Batch.emtpyBatch(Example.class);
         StringBuffer sb = new StringBuffer();
         int offset = 0;
         do {
             examples = trainingSet.getPositiveExamples(impactedTopics, 
examples.nextOffset);
-            for (String example : examples.items) {
+            for (Example example : examples.items) {
                 if ((cvFoldCount != 0) && (offset % cvFoldCount == 
cvFoldIndex)) {
                     // we are performing a cross validation session and this 
example belong to the test
                     // fold hence should be skipped
@@ -700,7 +705,7 @@ public class TopicClassificationEngine e
                     continue;
                 }
                 offset++;
-                sb.append(example);
+                sb.append(StringUtils.join(example.contents, "\n\n"));
                 sb.append("\n\n");
             }
         } while (sb.length() < MAX_CHARS_PER_TOPIC && examples.hasMore);
@@ -874,10 +879,10 @@ public class TopicClassificationEngine e
                     int falseNegatives = 0;
                     int positiveSupport = 0;
                     offset = 0;
-                    Batch<String> examples = Batch.emtpyBatch(String.class);
+                    Batch<Example> examples = Batch.emtpyBatch(Example.class);
                     do {
                         examples = trainingSet.getPositiveExamples(topics, 
examples.nextOffset);
-                        for (String example : examples.items) {
+                        for (Example example : examples.items) {
                             if (!(offset % foldCount == foldIndex)) {
                                 // this example is not part of the test fold, 
skip it
                                 offset++;
@@ -885,7 +890,8 @@ public class TopicClassificationEngine e
                             }
                             positiveSupport++;
                             offset++;
-                            List<TopicSuggestion> suggestedTopics = 
classifier.suggestTopics(example);
+                            List<TopicSuggestion> suggestedTopics = classifier
+                                    .suggestTopics(example.contents);
                             boolean match = false;
                             for (TopicSuggestion suggestedTopic : 
suggestedTopics) {
                                 if (topic.equals(suggestedTopic.uri)) {
@@ -896,7 +902,7 @@ public class TopicClassificationEngine e
                             }
                             if (!match) {
                                 falseNegatives++;
-                                // falseNegativeExamples.add(exampleId);
+                                falseNegativeExamples.add(example.id);
                             }
                         }
                     } while (examples.hasMore); // TODO: put a bound on the 
number of examples
@@ -905,10 +911,10 @@ public class TopicClassificationEngine e
                     int falsePositives = 0;
                     int negativeSupport = 0;
                     offset = 0;
-                    examples = Batch.emtpyBatch(String.class);
+                    examples = Batch.emtpyBatch(Example.class);
                     do {
                         examples = trainingSet.getNegativeExamples(topics, 
examples.nextOffset);
-                        for (String example : examples.items) {
+                        for (Example example : examples.items) {
                             if (!(offset % foldCount == foldIndex)) {
                                 // TODO: change the dataset API to include 
exampleId
                                 // this example is not part of the test fold, 
skip it
@@ -917,11 +923,12 @@ public class TopicClassificationEngine e
                             }
                             negativeSupport++;
                             offset++;
-                            List<TopicSuggestion> suggestedTopics = 
classifier.suggestTopics(example);
+                            List<TopicSuggestion> suggestedTopics = classifier
+                                    .suggestTopics(example.contents);
                             for (TopicSuggestion suggestedTopic : 
suggestedTopics) {
                                 if (topic.equals(suggestedTopic.uri)) {
                                     falsePositives++;
-                                    // falsePositiveExamples.add(exampleId);
+                                    falsePositiveExamples.add(example.id);
                                     break;
                                 }
                             }

Modified: 
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/SolrTrainingSet.java
URL: 
http://svn.apache.org/viewvc/incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/SolrTrainingSet.java?rev=1234006&r1=1234005&r2=1234006&view=diff
==============================================================================
--- 
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/SolrTrainingSet.java
 (original)
+++ 
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/SolrTrainingSet.java
 Fri Jan 20 17:09:18 2012
@@ -38,6 +38,7 @@ import org.apache.solr.client.solrj.resp
 import org.apache.solr.client.solrj.util.ClientUtils;
 import org.apache.solr.common.SolrDocument;
 import org.apache.solr.common.SolrInputDocument;
+import org.apache.stanbol.enhancer.topic.training.Example;
 import org.osgi.framework.InvalidSyntaxException;
 import org.osgi.service.cm.ConfigurationException;
 import org.osgi.service.component.ComponentContext;
@@ -195,17 +196,17 @@ public class SolrTrainingSet extends Con
     }
 
     @Override
-    public Batch<String> getPositiveExamples(List<String> topics, Object 
offset) throws TrainingSetException {
+    public Batch<Example> getPositiveExamples(List<String> topics, Object 
offset) throws TrainingSetException {
         return getExamples(topics, offset, true);
     }
 
     @Override
-    public Batch<String> getNegativeExamples(List<String> topics, Object 
offset) throws TrainingSetException {
+    public Batch<Example> getNegativeExamples(List<String> topics, Object 
offset) throws TrainingSetException {
         return getExamples(topics, offset, false);
     }
 
-    protected Batch<String> getExamples(List<String> topics, Object offset, 
boolean positive) throws TrainingSetException {
-        List<String> items = new ArrayList<String>();
+    protected Batch<Example> getExamples(List<String> topics, Object offset, 
boolean positive) throws TrainingSetException {
+        List<Example> items = new ArrayList<Example>();
         SolrServer solrServer = getActiveSolrServer();
         SolrQuery query = new SolrQuery();
         List<String> parts = new ArrayList<String>();
@@ -244,13 +245,13 @@ public class SolrTrainingSet extends Con
                     nextExampleId = 
result.getFirstValue(exampleIdField).toString();
                 } else {
                     count++;
+                    String exampleId = 
result.getFirstValue(exampleIdField).toString();
+                    Collection<Object> labelValues = 
result.getFieldValues(topicUrisField);
                     Collection<Object> textValues = 
result.getFieldValues(exampleTextField);
                     if (textValues == null) {
                         continue;
                     }
-                    for (Object value : textValues) {
-                        items.add(value.toString());
-                    }
+                    items.add(new Example(exampleId, labelValues, textValues));
                 }
             }
         } catch (SolrServerException e) {
@@ -259,7 +260,7 @@ public class SolrTrainingSet extends Con
                 StringUtils.join(topics, "', '"), solrCoreId);
             throw new TrainingSetException(msg, e);
         }
-        return new Batch<String>(items, nextExampleId != null, nextExampleId);
+        return new Batch<Example>(items, nextExampleId != null, nextExampleId);
     }
 
     @Override

Modified: 
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/TrainingSet.java
URL: 
http://svn.apache.org/viewvc/incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/TrainingSet.java?rev=1234006&r1=1234005&r2=1234006&view=diff
==============================================================================
--- 
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/TrainingSet.java
 (original)
+++ 
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/TrainingSet.java
 Fri Jan 20 17:09:18 2012
@@ -19,6 +19,8 @@ package org.apache.stanbol.enhancer.topi
 import java.util.Date;
 import java.util.List;
 
+import org.apache.stanbol.enhancer.topic.training.Example;
+
 /**
  * Source of categorized text documents that can be used to build a the 
statistical model of a
  * TopicClassifier.
@@ -60,7 +62,7 @@ public interface TrainingSet {
      *            marker value to fetch the next batch. Pass null to fetch the 
first batch.
      * @return a batch of example suitable for training a classifier model for 
the requested topics.
      */
-    Batch<String> getPositiveExamples(List<String> topics, Object offset) 
throws TrainingSetException;
+    Batch<Example> getPositiveExamples(List<String> topics, Object offset) 
throws TrainingSetException;
 
     /**
      * Fetch examples representative of any document not specifically 
classified in one of the passed topics.
@@ -76,7 +78,7 @@ public interface TrainingSet {
      * @return a batch of examples suitable for training (negative-refinement) 
a classifier model for the
      *         requested topics.
      */
-    Batch<String> getNegativeExamples(List<String> topics, Object offset) 
throws TrainingSetException;
+    Batch<Example> getNegativeExamples(List<String> topics, Object offset) 
throws TrainingSetException;
 
     /**
      * Number of examples to fetch at once.

Added: 
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/training/Example.java
URL: 
http://svn.apache.org/viewvc/incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/training/Example.java?rev=1234006&view=auto
==============================================================================
--- 
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/training/Example.java
 (added)
+++ 
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/training/Example.java
 Fri Jan 20 17:09:18 2012
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.stanbol.enhancer.topic.training;
+
+import java.util.Collection;
+
+import org.apache.commons.lang.StringUtils;
+
+/**
+ * Data transfer object to pack the items of a multi-label text classification 
training set.
+ */
+public class Example {
+
+    /**
+     * Unique id of the document
+     */
+    public final String id;
+
+    /**
+     * Identifier of the labels (categories, topics, tags...) of the document. 
This is the target signal to
+     * predict.
+     * 
+     * In practice this is expected to be a collection of String items but we 
do not enforce the cast to avoid
+     * the GC overhead and be able to work with the native data-structures 
returned by SolrJ.
+     */
+    public final Collection<Object> labels;
+
+    /**
+     * Text fields of the document (could be headers, paragraphs, text 
extractions of PDF files...). Any
+     * markup is assumed to have been cleaned up in some preprocessing step.
+     * 
+     * In practice this is expected to be a collection of String items but we 
do not enforce the cast to avoid
+     * the GC overhead and be able to work with the native data-structures 
returned by SolrJ.
+     */
+    public final Collection<Object> contents;
+
+    public Example(String id, Collection<Object> labelValues, 
Collection<Object> textValues) {
+        this.id = id;
+        this.labels = labelValues;
+        this.contents = textValues;
+    }
+
+    /**
+     * @return concatenated version of the content fields.
+     */
+    public String getContentString() {
+        return StringUtils.join(contents, "\n\n");
+    }
+}

Modified: 
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TrainingSetTest.java
URL: 
http://svn.apache.org/viewvc/incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TrainingSetTest.java?rev=1234006&r1=1234005&r2=1234006&view=diff
==============================================================================
--- 
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TrainingSetTest.java
 (original)
+++ 
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TrainingSetTest.java
 Fri Jan 20 17:09:18 2012
@@ -41,6 +41,7 @@ import org.apache.stanbol.enhancer.topic
 import org.apache.stanbol.enhancer.topic.SolrTrainingSet;
 import org.apache.stanbol.enhancer.topic.TrainingSetException;
 import org.apache.stanbol.enhancer.topic.UTCTimeStamper;
+import org.apache.stanbol.enhancer.topic.training.Example;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
@@ -97,7 +98,7 @@ public class TrainingSetTest extends Emb
 
     @Test
     public void testEmptyTrainingSet() throws TrainingSetException {
-        Batch<String> examples = trainingSet.getPositiveExamples(new 
ArrayList<String>(), null);
+        Batch<Example> examples = trainingSet.getPositiveExamples(new 
ArrayList<String>(), null);
         assertEquals(examples.items.size(), 0);
         assertFalse(examples.hasMore);
         examples = trainingSet.getNegativeExamples(new ArrayList<String>(), 
null);
@@ -120,19 +121,20 @@ public class TrainingSetTest extends Emb
         trainingSet.registerExample("example2", "Text of example2.", 
Arrays.asList(TOPIC_1, TOPIC_2));
         trainingSet.registerExample("example3", "Text of example3.", new 
ArrayList<String>());
 
-        Batch<String> examples = 
trainingSet.getPositiveExamples(Arrays.asList(TOPIC_2), null);
+        Batch<Example> examples = 
trainingSet.getPositiveExamples(Arrays.asList(TOPIC_2), null);
         assertEquals(1, examples.items.size());
-        assertEquals(examples.items, Arrays.asList("Text of example2."));
+        assertEquals(examples.items.get(0).getContentString(), "Text of 
example2.");
         assertFalse(examples.hasMore);
 
         examples = trainingSet.getPositiveExamples(Arrays.asList(TOPIC_1, 
TOPIC_3), null);
         assertEquals(2, examples.items.size());
-        assertEquals(examples.items, Arrays.asList("Text of example1.", "Text 
of example2."));
+        assertEquals(examples.items.get(0).getContentString(), "Text of 
example1.");
+        assertEquals(examples.items.get(1).getContentString(), "Text of 
example2.");
         assertFalse(examples.hasMore);
 
         examples = trainingSet.getNegativeExamples(Arrays.asList(TOPIC_1), 
null);
         assertEquals(1, examples.items.size());
-        assertEquals(examples.items, Arrays.asList("Text of example3."));
+        assertEquals(examples.items.get(0).getContentString(), "Text of 
example3.");
         assertFalse(examples.hasMore);
 
         // Test example update by adding topic3 to example1. The results of 
the previous query should remain
@@ -140,14 +142,15 @@ public class TrainingSetTest extends Emb
         trainingSet.registerExample("example1", "Text of example1.", 
Arrays.asList(TOPIC_1, TOPIC_3));
         examples = trainingSet.getPositiveExamples(Arrays.asList(TOPIC_1, 
TOPIC_3), null);
         assertEquals(2, examples.items.size());
-        assertEquals(examples.items, Arrays.asList("Text of example1.", "Text 
of example2."));
+        assertEquals(examples.items.get(0).getContentString(), "Text of 
example1.");
+        assertEquals(examples.items.get(1).getContentString(), "Text of 
example2.");
         assertFalse(examples.hasMore);
 
         // Test example removal
         trainingSet.registerExample("example1", null, Arrays.asList(TOPIC_1, 
TOPIC_3));
         examples = trainingSet.getPositiveExamples(Arrays.asList(TOPIC_1, 
TOPIC_3), null);
         assertEquals(1, examples.items.size());
-        assertEquals(examples.items, Arrays.asList("Text of example2."));
+        assertEquals(examples.items.get(0).getContentString(), "Text of 
example2.");
         assertFalse(examples.hasMore);
 
         trainingSet.registerExample("example2", null, Arrays.asList(TOPIC_1, 
TOPIC_3));
@@ -158,52 +161,76 @@ public class TrainingSetTest extends Emb
 
     @Test
     public void testBatchingPositiveExamples() throws ConfigurationException, 
TrainingSetException {
+        Set<String> expectedCollectedIds = new HashSet<String>();
         Set<String> expectedCollectedText = new HashSet<String>();
+        Set<String> collectedIds = new HashSet<String>();
         Set<String> collectedText = new HashSet<String>();
         for (int i = 0; i < 28; i++) {
+            String id = "example-" + i;
             String text = "Text of example" + i + ".";
-            trainingSet.registerExample("example-" + i, text, 
Arrays.asList(TOPIC_1));
+            trainingSet.registerExample(id, text, Arrays.asList(TOPIC_1));
+            expectedCollectedIds.add(id);
             expectedCollectedText.add(text);
         }
         trainingSet.setBatchSize(10);
-        Batch<String> examples = 
trainingSet.getPositiveExamples(Arrays.asList(TOPIC_1, TOPIC_2), null);
+        Batch<Example> examples = 
trainingSet.getPositiveExamples(Arrays.asList(TOPIC_1, TOPIC_2), null);
         assertEquals(10, examples.items.size());
-        collectedText.addAll(examples.items);
+        for (Example example : examples.items) {
+            collectedIds.add(example.id);
+            collectedText.add(example.getContentString());
+        }
         assertTrue(examples.hasMore);
 
         examples = trainingSet.getPositiveExamples(Arrays.asList(TOPIC_1, 
TOPIC_2), examples.nextOffset);
         assertEquals(10, examples.items.size());
-        collectedText.addAll(examples.items);
+        for (Example example : examples.items) {
+            collectedIds.add(example.id);
+            collectedText.add(example.getContentString());
+        }
         assertTrue(examples.hasMore);
 
         examples = trainingSet.getPositiveExamples(Arrays.asList(TOPIC_1, 
TOPIC_2), examples.nextOffset);
         assertEquals(8, examples.items.size());
-        collectedText.addAll(examples.items);
+        for (Example example : examples.items) {
+            collectedIds.add(example.id);
+            collectedText.add(example.getContentString());
+        }
         assertFalse(examples.hasMore);
 
+        assertEquals(expectedCollectedIds, collectedIds);
         assertEquals(expectedCollectedText, collectedText);
     }
 
     @Test
     public void testBatchingNegativeExamplesAndAutoId() throws 
ConfigurationException, TrainingSetException {
+        Set<String> expectedCollectedIds = new HashSet<String>();
         Set<String> expectedCollectedText = new HashSet<String>();
+        Set<String> collectedIds = new HashSet<String>();
         Set<String> collectedText = new HashSet<String>();
         for (int i = 0; i < 17; i++) {
             String text = "Text of example" + i + ".";
-            trainingSet.registerExample(null, text, Arrays.asList(TOPIC_1));
+            String id = trainingSet.registerExample(null, text, 
Arrays.asList(TOPIC_1));
+            expectedCollectedIds.add(id);
             expectedCollectedText.add(text);
         }
         trainingSet.setBatchSize(10);
-        Batch<String> examples = 
trainingSet.getNegativeExamples(Arrays.asList(TOPIC_2), null);
+        Batch<Example> examples = 
trainingSet.getNegativeExamples(Arrays.asList(TOPIC_2), null);
         assertEquals(10, examples.items.size());
-        collectedText.addAll(examples.items);
+        for (Example example : examples.items) {
+            collectedIds.add(example.id);
+            collectedText.add(example.getContentString());
+        }
         assertTrue(examples.hasMore);
 
         examples = trainingSet.getNegativeExamples(Arrays.asList(TOPIC_2), 
examples.nextOffset);
         assertEquals(7, examples.items.size());
-        collectedText.addAll(examples.items);
+        for (Example example : examples.items) {
+            collectedIds.add(example.id);
+            collectedText.add(example.getContentString());
+        }
         assertFalse(examples.hasMore);
 
+        assertEquals(expectedCollectedIds, collectedIds);
         assertEquals(expectedCollectedText, collectedText);
     }
 


Reply via email to