Author: gsingers
Date: Wed Nov  2 01:11:00 2011
New Revision: 1196420

URL: http://svn.apache.org/viewvc?rev=1196420&view=rev
Log:
MAHOUT-857: add in LogLikelihood and OnlineSummarizer to ResultAnalyzer

Modified:
    
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ClassifierResult.java
    
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java
    
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java
    
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ClassifierResult.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ClassifierResult.java?rev=1196420&r1=1196419&r2=1196420&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ClassifierResult.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ClassifierResult.java
 Wed Nov  2 01:11:00 2011
@@ -34,6 +34,7 @@ public class ClassifierResult {
 
   private String label;
   private double score;
+  private double logLikelihood = Double.MAX_VALUE;
   
   public ClassifierResult() { }
   
@@ -45,7 +46,21 @@ public class ClassifierResult {
   public ClassifierResult(String label) {
     this.label = label;
   }
-  
+
+  public ClassifierResult(String label, double score, double logLikelihood) {
+    this.label = label;
+    this.score = score;
+    this.logLikelihood = logLikelihood;
+  }
+
+  public double getLogLikelihood() {
+    return logLikelihood;
+  }
+
+  public void setLogLikelihood(double logLikelihood) {
+    this.logLikelihood = logLikelihood;
+  }
+
   public String getLabel() {
     return label;
   }

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java?rev=1196420&r1=1196419&r2=1196420&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java
 Wed Nov  2 01:11:00 2011
@@ -22,6 +22,7 @@ import java.text.NumberFormat;
 import java.util.Collection;
 
 import org.apache.commons.lang.StringUtils;
+import org.apache.mahout.math.stats.OnlineSummarizer;
 
 /**
  * ResultAnalyzer captures the classification statistics and displays in a 
tabular manner
@@ -29,6 +30,8 @@ import org.apache.commons.lang.StringUti
 public class ResultAnalyzer {
   
   private final ConfusionMatrix confusionMatrix;
+  private OnlineSummarizer summarizer;
+  boolean hasLL = false;
   
   /*
    * === Summary ===
@@ -43,6 +46,7 @@ public class ResultAnalyzer {
   
   public ResultAnalyzer(Collection<String> labelSet, String defaultLabel) {
     confusionMatrix = new ConfusionMatrix(labelSet, defaultLabel);
+    summarizer = new OnlineSummarizer();
   }
   
   public ConfusionMatrix getConfusionMatrix() {
@@ -65,6 +69,10 @@ public class ResultAnalyzer {
       incorrectlyClassified++;
     }
     confusionMatrix.addInstance(correctLabel, classifiedResult);
+    if (classifiedResult.getLogLikelihood() != Double.MAX_VALUE){
+      summarizer.add(classifiedResult.getLogLikelihood());
+      hasLL = true;
+    }
     return result;
   }
   
@@ -91,7 +99,12 @@ public class ResultAnalyzer {
     returnString.append('\n');
     
     returnString.append(confusionMatrix);
-    
+    if (hasLL) {
+      returnString.append("\n\n");
+      returnString.append("Avg. Log-likelihood: 
").append(summarizer.getMean()).append(" 25%-ile: 
").append(summarizer.getQuartile(1))
+              .append(" 75%-ile: ").append(summarizer.getQuartile(2));
+    }
+
     return returnString.toString();
   }
 }

Modified: 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java?rev=1196420&r1=1196419&r2=1196420&view=diff
==============================================================================
--- 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java
 (original)
+++ 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java
 Wed Nov  2 01:11:00 2011
@@ -86,7 +86,8 @@ public class TestNewsGroups {
       Vector result = classifier.classifyFull(input);
       int cat = result.maxValueIndex();
       double score = result.maxValue();
-      ClassifierResult cr = new ClassifierResult(newsGroups.values().get(cat), 
score);
+      double ll = classifier.logLikelihood(actual, input);
+      ClassifierResult cr = new ClassifierResult(newsGroups.values().get(cat), 
score, ll);
       ra.addInstance(newsGroups.values().get(actual), cr);
 
     }

Modified: 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java?rev=1196420&r1=1196419&r2=1196420&view=diff
==============================================================================
--- 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
 (original)
+++ 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
 Wed Nov  2 01:11:00 2011
@@ -238,6 +238,8 @@ public final class TrainNewsGroups {
 
     List<String> ngNames = Lists.newArrayList(newsGroups.values());
     List<ModelDissector.Weight> weights = md.summary(100);
+    System.out.println("============");
+    System.out.println("Model Dissection");
     for (ModelDissector.Weight w : weights) {
       System.out.printf("%s\t%.1f\t%s\t%.1f\t%s\t%.1f\t%s\n",
                         w.getFeature(), w.getWeight(), 
ngNames.get(w.getMaxImpact() + 1),


Reply via email to