Author: robinanil
Date: Thu Feb 11 16:32:44 2010
New Revision: 909063

URL: http://svn.apache.org/viewvc?rev=909063&view=rev
Log:
Bayes Classifier some classes modified to use math collections 

Modified:
    
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/BayesAlgorithm.java
    
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/CBayesAlgorithm.java
    
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/datastore/InMemoryBayesDatastore.java
    
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesThetaNormalizerMapper.java
    
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/cbayes/CBayesThetaNormalizerMapper.java
    
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesFeatureMapper.java
    
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesTfIdfMapper.java
    
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/BayesClassifierTest.java
    
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/CBayesClassifierTest.java

Modified: 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/BayesAlgorithm.java
URL: 
http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/BayesAlgorithm.java?rev=909063&r1=909062&r2=909063&view=diff
==============================================================================
--- 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/BayesAlgorithm.java
 (original)
+++ 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/BayesAlgorithm.java
 Thu Feb 11 16:32:44 2010
@@ -20,19 +20,21 @@
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
-import java.util.HashMap;
 import java.util.List;
-import java.util.Map;
 import java.util.PriorityQueue;
 
+import org.apache.commons.lang.mutable.MutableDouble;
 import org.apache.mahout.classifier.ClassifierResult;
 import org.apache.mahout.classifier.bayes.common.ByScoreLabelResultComparator;
 import org.apache.mahout.classifier.bayes.exceptions.InvalidDatastoreException;
 import org.apache.mahout.classifier.bayes.interfaces.Algorithm;
 import org.apache.mahout.classifier.bayes.interfaces.Datastore;
+import org.apache.mahout.math.function.ObjectIntProcedure;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+
 /**
  * Class implementing the Naive Bayes Classifier Algorithm
- *
+ * 
  */
 public class BayesAlgorithm implements Algorithm {
   
@@ -106,25 +108,33 @@
   }
   
   @Override
-  public double documentWeight(Datastore datastore,
-                               String label,
+  public double documentWeight(final Datastore datastore,
+                               final String label,
                                String[] document) throws 
InvalidDatastoreException {
-    Map<String,int[]> wordList = new HashMap<String,int[]>(1000);
+    OpenObjectIntHashMap<String> wordList = new OpenObjectIntHashMap<String>(
+        document.length / 2);
     for (String word : document) {
-      int[] count = wordList.get(word);
-      if (count == null) {
-        count = new int[] {0};
-        wordList.put(word, count);
+      if (wordList.containsKey(word) == false) {
+        wordList.put(word, 1);
+      } else {
+        wordList.put(word, wordList.get(word) + 1);
       }
-      count[0]++;
-    }
-    double result = 0.0;
-    for (Map.Entry<String,int[]> entry : wordList.entrySet()) {
-      String word = entry.getKey();
-      int count = entry.getValue()[0];
-      result += count * featureWeight(datastore, label, word);
     }
-    return result;
+    final MutableDouble result = new MutableDouble(0.0d);
+    
+    wordList.forEachPair(new ObjectIntProcedure<String>() {
+      
+      @Override
+      public boolean apply(String word, int frequency) {
+        try {
+          result.add(frequency * featureWeight(datastore, label, word));
+        } catch (InvalidDatastoreException e) {
+          throw new RuntimeException(e);
+        }
+        return true;
+      }
+    });
+    return result.doubleValue();
   }
   
   @Override

Modified: 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/CBayesAlgorithm.java
URL: 
http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/CBayesAlgorithm.java?rev=909063&r1=909062&r2=909063&view=diff
==============================================================================
--- 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/CBayesAlgorithm.java
 (original)
+++ 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/CBayesAlgorithm.java
 Thu Feb 11 16:32:44 2010
@@ -20,19 +20,21 @@
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
-import java.util.HashMap;
 import java.util.List;
-import java.util.Map;
 import java.util.PriorityQueue;
 
+import org.apache.commons.lang.mutable.MutableDouble;
 import org.apache.mahout.classifier.ClassifierResult;
 import org.apache.mahout.classifier.bayes.common.ByScoreLabelResultComparator;
 import org.apache.mahout.classifier.bayes.exceptions.InvalidDatastoreException;
 import org.apache.mahout.classifier.bayes.interfaces.Algorithm;
 import org.apache.mahout.classifier.bayes.interfaces.Datastore;
+import org.apache.mahout.math.function.ObjectIntProcedure;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+
 /**
  * Class implementing the Complementary Naive Bayes Classifier Algorithm
- *
+ * 
  */
 public class CBayesAlgorithm implements Algorithm {
   
@@ -116,25 +118,33 @@
   }
   
   @Override
-  public double documentWeight(Datastore datastore,
-                               String label,
-                               String[] document) throws 
InvalidDatastoreException {
-    Map<String,int[]> wordList = new HashMap<String,int[]>(1000);
+  public double documentWeight(final Datastore datastore,
+                               final String label,
+                               final String[] document) throws 
InvalidDatastoreException {
+    OpenObjectIntHashMap<String> wordList = new OpenObjectIntHashMap<String>(
+        document.length / 2);
     for (String word : document) {
-      int[] count = wordList.get(word);
-      if (count == null) {
-        count = new int[] {0};
-        wordList.put(word, count);
+      if (wordList.containsKey(word) == false) {
+        wordList.put(word, 1);
+      } else {
+        wordList.put(word, wordList.get(word) + 1);
       }
-      count[0]++;
-    }
-    double result = 0.0;
-    for (Map.Entry<String,int[]> entry : wordList.entrySet()) {
-      String word = entry.getKey();
-      int count = entry.getValue()[0];
-      result += count * featureWeight(datastore, label, word);
     }
-    return result;
+    final MutableDouble result = new MutableDouble(0.0d);
+    
+    wordList.forEachPair(new ObjectIntProcedure<String>() {
+      
+      @Override
+      public boolean apply(String word, int frequency) {
+        try {
+          result.add(frequency * featureWeight(datastore, label, word));
+        } catch (InvalidDatastoreException e) {
+          throw new RuntimeException(e);
+        }
+        return true;
+      }
+    });
+    return result.doubleValue();
   }
   
   @Override

Modified: 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/datastore/InMemoryBayesDatastore.java
URL: 
http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/datastore/InMemoryBayesDatastore.java?rev=909063&r1=909062&r2=909063&view=diff
==============================================================================
--- 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/datastore/InMemoryBayesDatastore.java
 (original)
+++ 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/datastore/InMemoryBayesDatastore.java
 Thu Feb 11 16:32:44 2010
@@ -19,8 +19,6 @@
 
 import java.io.IOException;
 import java.util.Collection;
-import java.util.HashMap;
-import java.util.Map;
 
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.FileSystem;
@@ -29,45 +27,53 @@
 import org.apache.mahout.classifier.bayes.interfaces.Datastore;
 import org.apache.mahout.classifier.bayes.io.SequenceFileModelReader;
 import org.apache.mahout.common.Parameters;
+import org.apache.mahout.math.SparseMatrix;
+import org.apache.mahout.math.map.OpenIntDoubleHashMap;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
+
 /**
  * Class implementing the Datastore for Algorithms to read In-Memory model
  * 
  */
 public class InMemoryBayesDatastore implements Datastore {
-
-  private static final Logger log = 
LoggerFactory.getLogger(InMemoryBayesDatastore.class);
-
-  private final Map<String,Map<String,Map<String,Double>>> matrices 
-    = new HashMap<String,Map<String,Map<String,Double>>>();
   
-  private final Map<String,Map<String,Double>> vectors = new 
HashMap<String,Map<String,Double>>();
+  private static final Logger log = LoggerFactory
+      .getLogger(InMemoryBayesDatastore.class);
+  
+  private final OpenObjectIntHashMap<String> featureDictionary = new 
OpenObjectIntHashMap<String>();
+  
+  private final OpenObjectIntHashMap<String> labelDictionary = new 
OpenObjectIntHashMap<String>();
+  
+  private final OpenIntDoubleHashMap sigma_j = new OpenIntDoubleHashMap();
+  
+  private final OpenIntDoubleHashMap sigma_k = new OpenIntDoubleHashMap();
+  
+  private final OpenIntDoubleHashMap thetaNormalizerPerLabel = new 
OpenIntDoubleHashMap();
+  
+  private double sigma_jSigma_k = 1.0;
+  
+  private final SparseMatrix weightMatrix = new SparseMatrix(new int[] {1,0});
   
   private final Parameters params;
   
   private double thetaNormalizer = 1.0;
-
+  
   private double alphaI = 1.0;
-
+  
   public InMemoryBayesDatastore(Parameters params) {
-
-    matrices.put("weight", new HashMap<String, Map<String, Double>>());
-    vectors.put("sumWeight", new HashMap<String, Double>());
-    matrices.put("weight", new HashMap<String, Map<String, Double>>());
-    vectors.put("labelWeight", new HashMap<String, Double>());
-    vectors.put("thetaNormalizer", new HashMap<String, Double>());
     String basePath = params.get("basePath");
     this.params = params;
     params.set("sigma_j", basePath + "/trainer-weights/Sigma_j/part-*");
     params.set("sigma_k", basePath + "/trainer-weights/Sigma_k/part-*");
     params.set("sigma_kSigma_j", basePath
-        + "/trainer-weights/Sigma_kSigma_j/part-*");
+                                 + "/trainer-weights/Sigma_kSigma_j/part-*");
     params.set("thetaNormalizer", basePath + 
"/trainer-thetaNormalizer/part-*");
     params.set("weight", basePath + "/trainer-tfIdf/trainer-tfIdf/part-*");
     alphaI = Double.valueOf(params.get("alpha_i", "1.0"));
   }
-
+  
   @Override
   public void initialize() throws InvalidDatastoreException {
     Configuration conf = new Configuration();
@@ -78,129 +84,95 @@
     } catch (IOException e) {
       throw new InvalidDatastoreException(e.getMessage());
     }
-    updateVocabCount();
-    Collection<String> labels = getKeys("thetaNormalizer");
-    for (String label : labels) {
-      thetaNormalizer = Math.max(thetaNormalizer, Math.abs(vectorGetCell(
-          "thetaNormalizer", label)));
-    }
-    for (String label : labels) {
-      log.info("{} {} {} {}", new Object[] {label,
-                                            vectorGetCell("thetaNormalizer",
-                                              label),
+    for (String label : getKeys("")) {
+      log.info("{} {} {} {}", new Object[] {
+                                            label,
+                                            thetaNormalizerPerLabel
+                                                .get(getLabelID(label)),
                                             thetaNormalizer,
-                                            vectorGetCell("thetaNormalizer",
-                                              label) / thetaNormalizer});
+                                            thetaNormalizerPerLabel
+                                                .get(getLabelID(label))
+                                                / thetaNormalizer});
     }
   }
-
+  
   @Override
   public Collection<String> getKeys(String name) throws 
InvalidDatastoreException {
-    return vectors.get("labelWeight").keySet();
+    return labelDictionary.keys();
   }
   
   @Override
   public double getWeight(String matrixName, String row, String column) throws 
InvalidDatastoreException {
-    return matrixGetCell(matrixName, row, column);
+    if (matrixName.equals("weight")) {
+      if (column.equals("sigma_j")) {
+        return sigma_j.get(getFeatureID(row));
+      } else return weightMatrix.getQuick(getFeatureID(row), 
getLabelID(column));
+    } else throw new InvalidDatastoreException("Matrix not found: "
+                                               + matrixName);
   }
   
   @Override
   public double getWeight(String vectorName, String index) throws 
InvalidDatastoreException {
-    if (vectorName.equals("thetaNormalizer")) return vectorGetCell(vectorName,
-      index)
-                                                     / thetaNormalizer;
-    else if (vectorName.equals("params")) {
+    if (vectorName.equals("sumWeight")) {
+      if (index.equals("sigma_jSigma_k")) {
+        return sigma_jSigma_k;
+      } else if (index.equals("vocabCount")) {
+        return featureDictionary.size();
+      } else throw new InvalidDatastoreException();
+    } else if (vectorName.equals("thetaNormalizer")) {
+      return thetaNormalizerPerLabel.get(getLabelID(index)) / thetaNormalizer;
+    } else if (vectorName.equals("params")) {
       if (index.equals("alpha_i")) {
         return alphaI;
-      } else {
-        throw new InvalidDatastoreException();
-      }
+      } else throw new InvalidDatastoreException();
+    } else if (vectorName.equals("labelWeight")) {
+      return sigma_k.get(getLabelID(index));
+    } else throw new InvalidDatastoreException();
+  }
+  
+  private int getFeatureID(String feature) {
+    if (featureDictionary.containsKey(feature)) {
+      return featureDictionary.get(feature);
+    } else {
+      int id = featureDictionary.size();
+      featureDictionary.put(feature, id);
+      return id;
     }
-    return vectorGetCell(vectorName, index);
   }
   
-  private double matrixGetCell(String matrixName, String row, String col) 
throws InvalidDatastoreException {
-    Map<String,Map<String,Double>> matrix = matrices.get(matrixName);
-    if (matrix == null) {
-      throw new InvalidDatastoreException();
+  private int getLabelID(String label) {
+    if (labelDictionary.containsKey(label)) {
+      return labelDictionary.get(label);
+    } else {
+      int id = labelDictionary.size();
+      labelDictionary.put(label, id);
+      return id;
     }
-    Map<String,Double> rowVector = matrix.get(row);
-    if (rowVector == null) {
-      return 0.0;
-    }
-    return nullToZero(rowVector.get(col));
-  }
-  
-  private double vectorGetCell(String vectorName, String index) throws 
InvalidDatastoreException {
-    
-    Map<String,Double> vector = vectors.get(vectorName);
-    if (vector == null) {
-      throw new InvalidDatastoreException();
-    }
-    return nullToZero(vector.get(index));
-  }
-  
-  private void matrixPutCell(String matrixName,
-                             String row,
-                             String col,
-                             double weight) {
-    Map<String,Map<String,Double>> matrix = matrices.get(matrixName);
-    if (matrix == null) {
-      matrix = new HashMap<String,Map<String,Double>>();
-      matrices.put(matrixName, matrix);
-    }
-    Map<String,Double> rowVector = matrix.get(row);
-    if (rowVector == null) {
-      rowVector = new HashMap<String,Double>();
-      matrix.put(row, rowVector);
-    }
-    rowVector.put(col, weight);
-  }
-  
-  private void vectorPutCell(String vectorName, String index, double weight) {
-    
-    Map<String,Double> vector = vectors.get(vectorName);
-    if (vector == null) {
-      vector = new HashMap<String,Double>();
-      vectors.put(vectorName, vector);
-    }
-    vector.put(index, weight);
-  }
-  
-  private long sizeOfMatrix(String matrixName) {
-    Map<String,Map<String,Double>> matrix = matrices.get(matrixName);
-    if (matrix == null) {
-      return 0;
-    }
-    return matrix.size();
   }
   
   public void loadFeatureWeight(String feature, String label, double weight) {
-    matrixPutCell("weight", feature, label, weight);
+    int fid = getFeatureID(feature);
+    int lid = getLabelID(label);
+    weightMatrix.setQuick(fid, lid, weight);
   }
   
   public void setSumFeatureWeight(String feature, double weight) {
-    matrixPutCell("weight", feature, "sigma_j", weight);
+    int fid = getFeatureID(feature);
+    sigma_j.put(fid, weight);
   }
   
   public void setSumLabelWeight(String label, double weight) {
-    vectorPutCell("labelWeight", label, weight);
+    int lid = getLabelID(label);
+    sigma_k.put(lid, weight);
   }
   
   public void setThetaNormalizer(String label, double weight) {
-    vectorPutCell("thetaNormalizer", label, weight);
+    int lid = getLabelID(label);
+    thetaNormalizerPerLabel.put(lid, weight);
+    thetaNormalizer = Math.max(thetaNormalizer, Math.abs(weight));
   }
   
   public void setSigmaJSigmaK(double weight) {
-    vectorPutCell("sumWeight", "sigma_jSigma_k", weight);
+    this.sigma_jSigma_k = weight;
   }
-  
-  public void updateVocabCount() {
-    vectorPutCell("sumWeight", "vocabCount", sizeOfMatrix("weight"));
-  }
-  
-  private static double nullToZero(Double value) {
-    return value == null ? 0.0 : value;
-  }
-  
 }

Modified: 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesThetaNormalizerMapper.java
URL: 
http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesThetaNormalizerMapper.java?rev=909063&r1=909062&r2=909063&view=diff
==============================================================================
--- 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesThetaNormalizerMapper.java
 (original)
+++ 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesThetaNormalizerMapper.java
 Thu Feb 11 16:32:44 2010
@@ -32,11 +32,13 @@
 import org.apache.mahout.classifier.bayes.mapreduce.common.BayesConstants;
 import org.apache.mahout.common.Parameters;
 import org.apache.mahout.common.StringTuple;
+import org.apache.mahout.math.map.OpenObjectDoubleHashMap;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /**
- * Mapper for Calculating the ThetaNormalizer for a label in Naive Bayes 
Algorithm
+ * Mapper for Calculating the ThetaNormalizer for a label in Naive Bayes
+ * Algorithm
  * 
  */
 public class BayesThetaNormalizerMapper extends MapReduceBase implements
@@ -45,7 +47,7 @@
   private static final Logger log = LoggerFactory
       .getLogger(BayesThetaNormalizerMapper.class);
   
-  private Map<String,Double> labelWeightSum;
+  private OpenObjectDoubleHashMap<String> labelWeightSum = new 
OpenObjectDoubleHashMap<String>();
   private double sigmaJSigmaK;
   private double vocabCount;
   private double alphaI = 1.0;
@@ -79,33 +81,34 @@
   @Override
   public void configure(JobConf job) {
     try {
-      if (labelWeightSum == null) {
-        labelWeightSum = new HashMap<String,Double>();
-        
-        DefaultStringifier<Map<String,Double>> mapStringifier = new 
DefaultStringifier<Map<String,Double>>(
-            job, GenericsUtil.getClass(labelWeightSum));
-        
-        String labelWeightSumString = mapStringifier.toString(labelWeightSum);
-        labelWeightSumString = job.get("cnaivebayes.sigma_k",
-          labelWeightSumString);
-        labelWeightSum = mapStringifier.fromString(labelWeightSumString);
-        
-        DefaultStringifier<Double> stringifier = new 
DefaultStringifier<Double>(
-            job, GenericsUtil.getClass(sigmaJSigmaK));
-        String sigmaJSigmaKString = stringifier.toString(sigmaJSigmaK);
-        sigmaJSigmaKString = job.get("cnaivebayes.sigma_jSigma_k",
-          sigmaJSigmaKString);
-        sigmaJSigmaK = stringifier.fromString(sigmaJSigmaKString);
-        
-        String vocabCountString = stringifier.toString(vocabCount);
-        vocabCountString = job.get("cnaivebayes.vocabCount", vocabCountString);
-        vocabCount = stringifier.fromString(vocabCountString);
-        
-        Parameters params = Parameters.fromString(job.get("bayes.parameters",
-          ""));
-        alphaI = Double.valueOf(params.get("alpha_i", "1.0"));
-        
+      labelWeightSum.clear();
+      Map<String,Double> labelWeightSumTemp = new HashMap<String,Double>();
+      
+      DefaultStringifier<Map<String,Double>> mapStringifier = new 
DefaultStringifier<Map<String,Double>>(
+          job, GenericsUtil.getClass(labelWeightSumTemp));
+      
+      String labelWeightSumString = 
mapStringifier.toString(labelWeightSumTemp);
+      labelWeightSumString = job.get("cnaivebayes.sigma_k",
+        labelWeightSumString);
+      labelWeightSumTemp = mapStringifier.fromString(labelWeightSumString);
+      for (String key : labelWeightSumTemp.keySet()) {
+        this.labelWeightSum.put(key, labelWeightSumTemp.get(key));
       }
+      DefaultStringifier<Double> stringifier = new DefaultStringifier<Double>(
+          job, GenericsUtil.getClass(sigmaJSigmaK));
+      String sigmaJSigmaKString = stringifier.toString(sigmaJSigmaK);
+      sigmaJSigmaKString = job.get("cnaivebayes.sigma_jSigma_k",
+        sigmaJSigmaKString);
+      sigmaJSigmaK = stringifier.fromString(sigmaJSigmaKString);
+      
+      String vocabCountString = stringifier.toString(vocabCount);
+      vocabCountString = job.get("cnaivebayes.vocabCount", vocabCountString);
+      vocabCount = stringifier.fromString(vocabCountString);
+      
+      Parameters params = Parameters
+          .fromString(job.get("bayes.parameters", ""));
+      alphaI = Double.valueOf(params.get("alpha_i", "1.0"));
+      
     } catch (IOException ex) {
       log.warn(ex.toString(), ex);
     }

Modified: 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/cbayes/CBayesThetaNormalizerMapper.java
URL: 
http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/cbayes/CBayesThetaNormalizerMapper.java?rev=909063&r1=909062&r2=909063&view=diff
==============================================================================
--- 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/cbayes/CBayesThetaNormalizerMapper.java
 (original)
+++ 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/cbayes/CBayesThetaNormalizerMapper.java
 Thu Feb 11 16:32:44 2010
@@ -32,6 +32,8 @@
 import org.apache.mahout.classifier.bayes.mapreduce.common.BayesConstants;
 import org.apache.mahout.common.Parameters;
 import org.apache.mahout.common.StringTuple;
+import org.apache.mahout.math.function.ObjectDoubleProcedure;
+import org.apache.mahout.math.map.OpenObjectDoubleHashMap;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -46,7 +48,7 @@
   private static final Logger log = LoggerFactory
       .getLogger(CBayesThetaNormalizerMapper.class);
   
-  private Map<String,Double> labelWeightSum;
+  private OpenObjectDoubleHashMap<String> labelWeightSum = new 
OpenObjectDoubleHashMap<String>();
   private double sigmaJSigmaK;
   private double vocabCount;
   private double alphaI = 1.0;
@@ -60,30 +62,33 @@
    */
   @Override
   public void map(StringTuple key,
-                  DoubleWritable value,
-                  OutputCollector<StringTuple,DoubleWritable> output,
-                  Reporter reporter) throws IOException {
+                  final DoubleWritable value,
+                  final OutputCollector<StringTuple,DoubleWritable> output,
+                  final Reporter reporter) throws IOException {
     
     if (key.stringAt(0).equals(BayesConstants.FEATURE_SUM)) { // if it is from
       // the Sigma_j
       // folder
-      
-      for (Map.Entry<String,Double> stringDoubleEntry : labelWeightSum
-          .entrySet()) {
-        String label = stringDoubleEntry.getKey();
-        double weight = Math
-            .log((value.get() + alphaI)
-                 / (sigmaJSigmaK - stringDoubleEntry.getValue() + vocabCount));
-        
-        reporter.setStatus("Complementary Bayes Theta Normalizer Mapper: "
-                           + stringDoubleEntry + " => " + weight);
-        StringTuple normalizerTuple = new StringTuple(
-            BayesConstants.LABEL_THETA_NORMALIZER);
-        normalizerTuple.add(label);
-        output.collect(normalizerTuple, new DoubleWritable(weight)); // output
-        // Sigma_j
+      labelWeightSum.forEachPair(new ObjectDoubleProcedure<String>() {
         
-      }
+        @Override
+        public boolean apply(String label, double sigmaJ) {
+          double weight = Math.log((value.get() + alphaI)
+                                   / (sigmaJSigmaK - sigmaJ + vocabCount));
+          
+          reporter.setStatus("Complementary Bayes Theta Normalizer Mapper: "
+                             + label + " => " + weight);
+          StringTuple normalizerTuple = new StringTuple(
+              BayesConstants.LABEL_THETA_NORMALIZER);
+          normalizerTuple.add(label);
+          try {
+            output.collect(normalizerTuple, new DoubleWritable(weight));
+          } catch (IOException e) {
+           throw new RuntimeException(e);
+          } // output Sigma_j
+          return true;
+        }
+      });
       
     } else {
       String label = key.stringAt(1);
@@ -110,33 +115,35 @@
   @Override
   public void configure(JobConf job) {
     try {
-      if (labelWeightSum == null) {
-        labelWeightSum = new HashMap<String,Double>();
-        
-        DefaultStringifier<Map<String,Double>> mapStringifier = new 
DefaultStringifier<Map<String,Double>>(
-            job, GenericsUtil.getClass(labelWeightSum));
-        
-        String labelWeightSumString = mapStringifier.toString(labelWeightSum);
-        labelWeightSumString = job.get("cnaivebayes.sigma_k",
-          labelWeightSumString);
-        labelWeightSum = mapStringifier.fromString(labelWeightSumString);
-        
-        DefaultStringifier<Double> stringifier = new 
DefaultStringifier<Double>(
-            job, GenericsUtil.getClass(sigmaJSigmaK));
-        String sigmaJSigmaKString = stringifier.toString(sigmaJSigmaK);
-        sigmaJSigmaKString = job.get("cnaivebayes.sigma_jSigma_k",
-          sigmaJSigmaKString);
-        sigmaJSigmaK = stringifier.fromString(sigmaJSigmaKString);
-        
-        String vocabCountString = stringifier.toString(vocabCount);
-        vocabCountString = job.get("cnaivebayes.vocabCount", vocabCountString);
-        vocabCount = stringifier.fromString(vocabCountString);
-        
-        Parameters params = Parameters.fromString(job.get("bayes.parameters",
-          ""));
-        alphaI = Double.valueOf(params.get("alpha_i", "1.0"));
-        
+      labelWeightSum.clear();
+      Map<String,Double> labelWeightSumTemp = new HashMap<String,Double>();
+      
+      DefaultStringifier<Map<String,Double>> mapStringifier = new 
DefaultStringifier<Map<String,Double>>(
+          job, GenericsUtil.getClass(labelWeightSumTemp));
+      
+      String labelWeightSumString = 
mapStringifier.toString(labelWeightSumTemp);
+      labelWeightSumString = job.get("cnaivebayes.sigma_k",
+        labelWeightSumString);
+      labelWeightSumTemp = mapStringifier.fromString(labelWeightSumString);
+      for (String key : labelWeightSumTemp.keySet()) {
+        this.labelWeightSum.put(key, labelWeightSumTemp.get(key));
       }
+      
+      DefaultStringifier<Double> stringifier = new DefaultStringifier<Double>(
+          job, GenericsUtil.getClass(sigmaJSigmaK));
+      String sigmaJSigmaKString = stringifier.toString(sigmaJSigmaK);
+      sigmaJSigmaKString = job.get("cnaivebayes.sigma_jSigma_k",
+        sigmaJSigmaKString);
+      sigmaJSigmaK = stringifier.fromString(sigmaJSigmaKString);
+      
+      String vocabCountString = stringifier.toString(vocabCount);
+      vocabCountString = job.get("cnaivebayes.vocabCount", vocabCountString);
+      vocabCount = stringifier.fromString(vocabCountString);
+      
+      Parameters params = Parameters
+          .fromString(job.get("bayes.parameters", ""));
+      alphaI = Double.valueOf(params.get("alpha_i", "1.0"));
+      
     } catch (IOException ex) {
       log.warn(ex.toString(), ex);
     }

Modified: 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesFeatureMapper.java
URL: 
http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesFeatureMapper.java?rev=909063&r1=909062&r2=909063&view=diff
==============================================================================
--- 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesFeatureMapper.java
 (original)
+++ 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesFeatureMapper.java
 Thu Feb 11 16:32:44 2010
@@ -18,10 +18,11 @@
 package org.apache.mahout.classifier.bayes.mapreduce.common;
 
 import java.io.IOException;
-import java.util.HashMap;
+import java.util.Arrays;
+import java.util.Iterator;
 import java.util.List;
-import java.util.Map;
 
+import org.apache.commons.lang.mutable.MutableDouble;
 import org.apache.hadoop.io.DoubleWritable;
 import org.apache.hadoop.io.Text;
 import org.apache.hadoop.mapred.JobConf;
@@ -29,14 +30,20 @@
 import org.apache.hadoop.mapred.Mapper;
 import org.apache.hadoop.mapred.OutputCollector;
 import org.apache.hadoop.mapred.Reporter;
+import org.apache.lucene.analysis.TokenStream;
+import org.apache.lucene.analysis.shingle.ShingleFilter;
+import org.apache.lucene.analysis.tokenattributes.TermAttribute;
 import org.apache.mahout.common.Parameters;
 import org.apache.mahout.common.StringTuple;
-import org.apache.mahout.common.nlp.NGrams;
+import org.apache.mahout.math.function.ObjectIntProcedure;
+import org.apache.mahout.math.function.ObjectProcedure;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /**
- * Reads the input train set(preprocessed using the {...@link 
org.apache.mahout.classifier.BayesFileFormatter}).
+ * Reads the input train set(preprocessed using the
+ * {...@link org.apache.mahout.classifier.BayesFileFormatter}).
  */
 public class BayesFeatureMapper extends MapReduceBase implements
     Mapper<Text,Text,StringTuple,DoubleWritable> {
@@ -60,7 +67,8 @@
    * @param key
    *          The label
    * @param value
-   *          the features (all unique) associated w/ this label
+   *          the features (all unique) associated w/ this label in stringtuple
+   *          format
    * @param output
    *          The OutputCollector to write the results to
    * @param reporter
@@ -69,65 +77,91 @@
   @Override
   public void map(Text key,
                   Text value,
-                  OutputCollector<StringTuple,DoubleWritable> output,
+                  final OutputCollector<StringTuple,DoubleWritable> output,
                   Reporter reporter) throws IOException {
     // String line = value.toString();
-    String label = key.toString();
-    
-    Map<String,int[]> wordList = new HashMap<String,int[]>(1000);
-    
-    List<String> ngrams = new NGrams(value.toString(), gramSize)
-        .generateNGramsWithoutLabel();
-    
-    for (String ngram : ngrams) {
-      int[] count = wordList.get(ngram);
-      if (count == null) {
-        count = new int[1];
-        count[0] = 0;
-        wordList.put(ngram, count);
+    final String label = key.toString();
+    List<String> tokens = Arrays.asList(value.toString().split("[ ]+"));
+    OpenObjectIntHashMap<String> wordList = new OpenObjectIntHashMap<String>(
+        tokens.size() * gramSize);
+    
+    if (gramSize > 1) {
+      ShingleFilter sf = new ShingleFilter(new IteratorTokenStream(tokens
+          .iterator()), gramSize);
+      do {
+        String term = ((TermAttribute) sf.getAttribute(TermAttribute.class))
+            .term();
+        if (term.length() > 0) {
+          if (wordList.containsKey(term) == false) {
+            wordList.put(term, 1);
+          } else {
+            wordList.put(term, 1 + wordList.get(term));
+          }
+        }
+      } while (sf.incrementToken());
+    } else {
+      for (String term : tokens) {
+        if (wordList.containsKey(term) == false) {
+          wordList.put(term, 1);
+        } else {
+          wordList.put(term, 1 + wordList.get(term));
+        }
       }
-      count[0]++;
     }
-    double lengthNormalisation = 0.0;
-    for (int[] dKJ : wordList.values()) {
-      // key is label,word
-      double dkjValue = (double) dKJ[0];
-      lengthNormalisation += dkjValue * dkjValue;
-    }
-    lengthNormalisation = Math.sqrt(lengthNormalisation);
+    final MutableDouble lengthNormalisationMut = new MutableDouble(0);
+    wordList.forEachPair(new ObjectIntProcedure<String>() {
+      @Override
+      public boolean apply(String word, int dKJ) {
+        lengthNormalisationMut.add(dKJ * dKJ);
+        return true;
+      }
+    });
+    
+    final double lengthNormalisation = Math.sqrt(lengthNormalisationMut
+        .doubleValue());
     
     // Output Length Normalized + TF Transformed Frequency per Word per Class
     // Log(1 + D_ij)/SQRT( SIGMA(k, D_kj) )
-    for (Map.Entry<String,int[]> entry : wordList.entrySet()) {
-      // key is label,word
-      String token = entry.getKey();
-      StringTuple tuple = new StringTuple();
-      tuple.add(BayesConstants.WEIGHT);
-      tuple.add(label);
-      tuple.add(token);
-      DoubleWritable f = new DoubleWritable(Math.log(1.0 + entry.getValue()[0])
-                                            / lengthNormalisation);
-      output.collect(tuple, f);
-    }
+    wordList.forEachPair(new ObjectIntProcedure<String>() {
+      @Override
+      public boolean apply(String token, int dKJ) {
+        try {
+          StringTuple tuple = new StringTuple();
+          tuple.add(BayesConstants.WEIGHT);
+          tuple.add(label);
+          tuple.add(token);
+          DoubleWritable f = new DoubleWritable(Math.log(1.0 + dKJ)
+                                                / lengthNormalisation);
+          output.collect(tuple, f);
+        } catch (IOException e) {
+          throw new RuntimeException(e);
+        }
+        return true;
+      }
+    });
     reporter.setStatus("Bayes Feature Mapper: Document Label: " + label);
     
     // Output Document Frequency per Word per Class
-    
-    for (String token : wordList.keySet()) {
-      // key is label,word
-      
-      StringTuple dfTuple = new StringTuple();
-      dfTuple.add(BayesConstants.DOCUMENT_FREQUENCY);
-      dfTuple.add(label);
-      dfTuple.add(token);
-      output.collect(dfTuple, ONE);
-      
-      StringTuple tokenCountTuple = new StringTuple();
-      tokenCountTuple.add(BayesConstants.FEATURE_COUNT);
-      tokenCountTuple.add(token);
-      output.collect(tokenCountTuple, ONE);
-      
-    }
+    wordList.forEachKey(new ObjectProcedure<String>() {
+      @Override
+      public boolean apply(String token) {
+        try {
+          StringTuple dfTuple = new StringTuple();
+          dfTuple.add(BayesConstants.DOCUMENT_FREQUENCY);
+          dfTuple.add(label);
+          dfTuple.add(token);
+          output.collect(dfTuple, ONE);
+          
+          StringTuple tokenCountTuple = new StringTuple();
+          tokenCountTuple.add(BayesConstants.FEATURE_COUNT);
+          tokenCountTuple.add(token);
+          output.collect(tokenCountTuple, ONE);
+        } catch (IOException e) {
+          throw new RuntimeException(e);
+        }
+        return true;
+      }
+    });
     
     // output that we have seen the label to calculate the Count of Document 
per
     // class
@@ -149,4 +183,26 @@
       log.warn(ex.toString(), ex);
     }
   }
+  
+  /** Used to emit tokens from an input string array in the style of 
TokenStream */
+  public static class IteratorTokenStream extends TokenStream {
+    private final TermAttribute termAtt;
+    private final Iterator<String> iterator;
+    
+    public IteratorTokenStream(Iterator<String> iterator) {
+      this.iterator = iterator;
+      this.termAtt = (TermAttribute) addAttribute(TermAttribute.class);
+    }
+    
+    @Override
+    public boolean incrementToken() throws IOException {
+      if (iterator.hasNext()) {
+        clearAttributes();
+        termAtt.setTermBuffer(iterator.next());
+        return true;
+      } else {
+        return false;
+      }
+    }
+  }
 }

Modified: 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesTfIdfMapper.java
URL: 
http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesTfIdfMapper.java?rev=909063&r1=909062&r2=909063&view=diff
==============================================================================
--- 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesTfIdfMapper.java
 (original)
+++ 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesTfIdfMapper.java
 Thu Feb 11 16:32:44 2010
@@ -30,6 +30,7 @@
 import org.apache.hadoop.mapred.Reporter;
 import org.apache.hadoop.util.GenericsUtil;
 import org.apache.mahout.common.StringTuple;
+import org.apache.mahout.math.map.OpenObjectDoubleHashMap;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -48,7 +49,7 @@
   
   private static final DoubleWritable ONE = new DoubleWritable(1.0);
   
-  private Map<String,Double> labelDocumentCounts;
+  private OpenObjectDoubleHashMap<String> labelDocumentCounts = new 
OpenObjectDoubleHashMap<String>();
   
   /**
    * We need to calculate the Tf-Idf of each feature in each label
@@ -87,20 +88,22 @@
   @Override
   public void configure(JobConf job) {
     try {
-      if (labelDocumentCounts == null) {
-        labelDocumentCounts = new HashMap<String,Double>();
-        
-        DefaultStringifier<Map<String,Double>> mapStringifier = new 
DefaultStringifier<Map<String,Double>>(
-            job, GenericsUtil.getClass(labelDocumentCounts));
-        
-        String labelDocumentCountString = mapStringifier
-            .toString(labelDocumentCounts);
-        labelDocumentCountString = job.get("cnaivebayes.labelDocumentCounts",
-          labelDocumentCountString);
-        
-        labelDocumentCounts = mapStringifier
-            .fromString(labelDocumentCountString);
+      this.labelDocumentCounts.clear();
+      Map<String,Double> labelDocCountTemp = new HashMap<String,Double>();
+      
+      DefaultStringifier<Map<String,Double>> mapStringifier = new 
DefaultStringifier<Map<String,Double>>(
+          job, GenericsUtil.getClass(labelDocCountTemp));
+      
+      String labelDocumentCountString = mapStringifier
+          .toString(labelDocCountTemp);
+      labelDocumentCountString = job.get("cnaivebayes.labelDocumentCounts",
+        labelDocumentCountString);
+      
+      labelDocCountTemp = mapStringifier.fromString(labelDocumentCountString);
+      for (String key : labelDocCountTemp.keySet()) {
+        this.labelDocumentCounts.put(key, labelDocCountTemp.get(key));
       }
+      
     } catch (IOException ex) {
       log.warn(ex.toString(), ex);
     }

Modified: 
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/BayesClassifierTest.java
URL: 
http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/BayesClassifierTest.java?rev=909063&r1=909062&r2=909063&view=diff
==============================================================================
--- 
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/BayesClassifierTest.java
 (original)
+++ 
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/BayesClassifierTest.java
 Thu Feb 11 16:32:44 2010
@@ -69,7 +69,7 @@
     store.loadFeatureWeight("ee", "e", 100);
     store.loadFeatureWeight("aa", "e", 50);
     store.loadFeatureWeight("dd", "e", 50);
-    store.updateVocabCount();
+
   }
   
   public void test() throws InvalidDatastoreException {
@@ -79,7 +79,7 @@
     assertNotNull("category is null and it shouldn't be", result);
     assertEquals(result + " is not equal to e", "e", result.getLabel());
     
-    document = new String[] {"ff"};
+    document = new String[] {"dd"};
     result = classifier.classifyDocument(document, "unknown");
     assertNotNull("category is null and it shouldn't be", result);
     assertEquals(result + " is not equal to d", "d", result.getLabel());

Modified: 
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/CBayesClassifierTest.java
URL: 
http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/CBayesClassifierTest.java?rev=909063&r1=909062&r2=909063&view=diff
==============================================================================
--- 
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/CBayesClassifierTest.java
 (original)
+++ 
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/CBayesClassifierTest.java
 Thu Feb 11 16:32:44 2010
@@ -74,7 +74,7 @@
     store.loadFeatureWeight("ee", "e", 100);
     store.loadFeatureWeight("aa", "e", 50);
     store.loadFeatureWeight("dd", "e", 50);
-    store.updateVocabCount();
+    
   }
   
   public void test() throws InvalidDatastoreException {
@@ -84,7 +84,7 @@
     assertNotNull("category is null and it shouldn't be", result);
     assertEquals(result + " is not equal to e", "e", result.getLabel());
     
-    document = new String[] {"ff"};
+    document = new String[] {"dd"};
     result = classifier.classifyDocument(document, "unknown");
     assertNotNull("category is null and it shouldn't be", result);
     assertEquals(result + " is not equal to d", "d", result.getLabel());


Reply via email to