Author: gsingers
Date: Wed Nov 2 01:11:00 2011
New Revision: 1196420
URL: http://svn.apache.org/viewvc?rev=1196420&view=rev
Log:
MAHOUT-857: add in LogLikelihood and OnlineSummarizer to ResultAnalyzer
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ClassifierResult.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ClassifierResult.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ClassifierResult.java?rev=1196420&r1=1196419&r2=1196420&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ClassifierResult.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ClassifierResult.java
Wed Nov 2 01:11:00 2011
@@ -34,6 +34,7 @@ public class ClassifierResult {
private String label;
private double score;
+ private double logLikelihood = Double.MAX_VALUE;
public ClassifierResult() { }
@@ -45,7 +46,21 @@ public class ClassifierResult {
public ClassifierResult(String label) {
this.label = label;
}
-
+
+ public ClassifierResult(String label, double score, double logLikelihood) {
+ this.label = label;
+ this.score = score;
+ this.logLikelihood = logLikelihood;
+ }
+
+ public double getLogLikelihood() {
+ return logLikelihood;
+ }
+
+ public void setLogLikelihood(double logLikelihood) {
+ this.logLikelihood = logLikelihood;
+ }
+
public String getLabel() {
return label;
}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java?rev=1196420&r1=1196419&r2=1196420&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java
Wed Nov 2 01:11:00 2011
@@ -22,6 +22,7 @@ import java.text.NumberFormat;
import java.util.Collection;
import org.apache.commons.lang.StringUtils;
+import org.apache.mahout.math.stats.OnlineSummarizer;
/**
* ResultAnalyzer captures the classification statistics and displays in a
tabular manner
@@ -29,6 +30,8 @@ import org.apache.commons.lang.StringUti
public class ResultAnalyzer {
private final ConfusionMatrix confusionMatrix;
+ private OnlineSummarizer summarizer;
+ boolean hasLL = false;
/*
* === Summary ===
@@ -43,6 +46,7 @@ public class ResultAnalyzer {
public ResultAnalyzer(Collection<String> labelSet, String defaultLabel) {
confusionMatrix = new ConfusionMatrix(labelSet, defaultLabel);
+ summarizer = new OnlineSummarizer();
}
public ConfusionMatrix getConfusionMatrix() {
@@ -65,6 +69,10 @@ public class ResultAnalyzer {
incorrectlyClassified++;
}
confusionMatrix.addInstance(correctLabel, classifiedResult);
+ if (classifiedResult.getLogLikelihood() != Double.MAX_VALUE){
+ summarizer.add(classifiedResult.getLogLikelihood());
+ hasLL = true;
+ }
return result;
}
@@ -91,7 +99,12 @@ public class ResultAnalyzer {
returnString.append('\n');
returnString.append(confusionMatrix);
-
+ if (hasLL) {
+ returnString.append("\n\n");
+ returnString.append("Avg. Log-likelihood:
").append(summarizer.getMean()).append(" 25%-ile:
").append(summarizer.getQuartile(1))
+ .append(" 75%-ile: ").append(summarizer.getQuartile(2));
+ }
+
return returnString.toString();
}
}
Modified:
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java?rev=1196420&r1=1196419&r2=1196420&view=diff
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java
(original)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java
Wed Nov 2 01:11:00 2011
@@ -86,7 +86,8 @@ public class TestNewsGroups {
Vector result = classifier.classifyFull(input);
int cat = result.maxValueIndex();
double score = result.maxValue();
- ClassifierResult cr = new ClassifierResult(newsGroups.values().get(cat),
score);
+ double ll = classifier.logLikelihood(actual, input);
+ ClassifierResult cr = new ClassifierResult(newsGroups.values().get(cat),
score, ll);
ra.addInstance(newsGroups.values().get(actual), cr);
}
Modified:
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java?rev=1196420&r1=1196419&r2=1196420&view=diff
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
(original)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
Wed Nov 2 01:11:00 2011
@@ -238,6 +238,8 @@ public final class TrainNewsGroups {
List<String> ngNames = Lists.newArrayList(newsGroups.values());
List<ModelDissector.Weight> weights = md.summary(100);
+ System.out.println("============");
+ System.out.println("Model Dissection");
for (ModelDissector.Weight w : weights) {
System.out.printf("%s\t%.1f\t%s\t%.1f\t%s\t%.1f\t%s\n",
w.getFeature(), w.getWeight(),
ngNames.get(w.getMaxImpact() + 1),