Author: tdunning
Date: Wed Sep 15 06:19:40 2010
New Revision: 997194
URL: http://svn.apache.org/viewvc?rev=997194&view=rev
Log:
Added model dissector itself
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java?rev=997194&r1=997193&r2=997194&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java
Wed Sep 15 06:19:40 2010
@@ -17,15 +17,28 @@
package org.apache.mahout.classifier.sgd;
+import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
+import com.google.common.collect.Ordering;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.QRDecomposition;
import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.SparseMatrix;
import org.apache.mahout.math.SparseRowMatrix;
import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.function.UnaryFunction;
+import org.apache.mahout.math.matrix.GaussSeidel;
import org.apache.mahout.vectors.Dictionary;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
import java.util.Map;
+import java.util.PriorityQueue;
import java.util.Set;
/**
@@ -35,52 +48,79 @@ import java.util.Set;
* in the original space.
*/
public class ModelDissector {
- int records = 0;
- private Dictionary dict;
- private Matrix a;
- private Matrix b;
+ private Map<String,Vector> weightMap;
public ModelDissector(int n) {
- a = new SparseRowMatrix(new int[]{Integer.MAX_VALUE, Integer.MAX_VALUE},
true);
- b = new SparseRowMatrix(new int[]{Integer.MAX_VALUE, n});
-
- dict.intern("Intercept Value");
+ weightMap = Maps.newHashMap();
}
- public void addExample(Set<String> features, Vector score) {
- for (Vector.Element element : score) {
- b.set(records, element.index(), element.get());
+ public void update(Vector features, Map<String, Set<Integer>>
traceDictionary, AbstractVectorClassifier learner) {
+ features.assign(0);
+ final int numCategories = learner.numCategories();
+ for (String feature : traceDictionary.keySet()) {
+ weightMap = weightMap;
+ if (!weightMap.containsKey(feature)) {
+ for (Integer where : traceDictionary.get(feature)) {
+ features.set(where, 1);
+ }
+
+ Vector v = learner.classifyNoLink(features);
+ weightMap.put(feature, v);
+
+ for (Integer where : traceDictionary.get(feature)) {
+ features.set(where, 0);
+ }
+ }
}
- for (String feature : features) {
- int j = dict.intern(feature);
- a.set(records, j, 1);
+ }
+
+ public List<Weight> summary(int n) {
+ PriorityQueue<Weight> pq = new PriorityQueue<Weight>();
+ for (String s : weightMap.keySet()) {
+ pq.add(new Weight(s, weightMap.get(s)));
+ while (pq.size() > n) {
+ pq.poll();
+ }
}
- records++;
+ List<Weight> r = Lists.newArrayList(pq);
+ Collections.sort(r, Ordering.natural().reverse());
+ return r;
}
- public void addExample(Set<String> features, double score) {
- b.set(records, 0, score);
+ public static class Weight implements Comparable<Weight> {
+ private String feature;
+ private double value;
+ private int maxIndex;
+ private Vector weights;
+
+ public Weight(String feature, Vector weights) {
+ this.weights = weights;
+ this.feature = feature;
+ value = weights.norm(1);
+ maxIndex = weights.maxValueIndex();
+ }
- a.set(records, 0, 1);
- for (String feature : features) {
- int j = dict.intern(feature);
- a.set(records, j, 1);
+ @Override
+ public int compareTo(Weight other) {
+ int r = Double.compare(this.value, other.value);
+ if (r != 0) {
+ return r;
+ } else {
+ return feature.compareTo(other.feature);
+ }
+ }
+
+ public String getFeature() {
+ return feature;
+ }
+
+ public double getWeight() {
+ return value;
}
- records++;
- }
- public Matrix solve() {
- Matrix az = a.viewPart(new int[]{0, 0}, new int[]{records, dict.size()});
- Matrix bz = b.viewPart(new int[]{0, 0}, new int[]{records,
b.columnSize()});
- QRDecomposition qr = new QRDecomposition(az.transpose().times(az));
- Matrix x = qr.solve(bz);
- Map<String, Integer> labels = Maps.newHashMap();
- int i = 0;
- for (String s : dict.values()) {
- labels.put(s, i++);
+ public int getMaxImpact() {
+ return maxIndex;
}
- x.setRowLabelBindings(labels);
- return x;
}
}