Author: robinanil Date: Thu Feb 11 16:32:44 2010 New Revision: 909063 URL: http://svn.apache.org/viewvc?rev=909063&view=rev Log: Bayes Classifier some classes modified to use math collections
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/BayesAlgorithm.java lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/CBayesAlgorithm.java lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/datastore/InMemoryBayesDatastore.java lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesThetaNormalizerMapper.java lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/cbayes/CBayesThetaNormalizerMapper.java lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesFeatureMapper.java lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesTfIdfMapper.java lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/BayesClassifierTest.java lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/CBayesClassifierTest.java Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/BayesAlgorithm.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/BayesAlgorithm.java?rev=909063&r1=909062&r2=909063&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/BayesAlgorithm.java (original) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/BayesAlgorithm.java Thu Feb 11 16:32:44 2010 @@ -20,19 +20,21 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; -import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.PriorityQueue; +import org.apache.commons.lang.mutable.MutableDouble; import org.apache.mahout.classifier.ClassifierResult; import org.apache.mahout.classifier.bayes.common.ByScoreLabelResultComparator; import org.apache.mahout.classifier.bayes.exceptions.InvalidDatastoreException; import org.apache.mahout.classifier.bayes.interfaces.Algorithm; import org.apache.mahout.classifier.bayes.interfaces.Datastore; +import org.apache.mahout.math.function.ObjectIntProcedure; +import org.apache.mahout.math.map.OpenObjectIntHashMap; + /** * Class implementing the Naive Bayes Classifier Algorithm - * + * */ public class BayesAlgorithm implements Algorithm { @@ -106,25 +108,33 @@ } @Override - public double documentWeight(Datastore datastore, - String label, + public double documentWeight(final Datastore datastore, + final String label, String[] document) throws InvalidDatastoreException { - Map<String,int[]> wordList = new HashMap<String,int[]>(1000); + OpenObjectIntHashMap<String> wordList = new OpenObjectIntHashMap<String>( + document.length / 2); for (String word : document) { - int[] count = wordList.get(word); - if (count == null) { - count = new int[] {0}; - wordList.put(word, count); + if (wordList.containsKey(word) == false) { + wordList.put(word, 1); + } else { + wordList.put(word, wordList.get(word) + 1); } - count[0]++; - } - double result = 0.0; - for (Map.Entry<String,int[]> entry : wordList.entrySet()) { - String word = entry.getKey(); - int count = entry.getValue()[0]; - result += count * featureWeight(datastore, label, word); } - return result; + final MutableDouble result = new MutableDouble(0.0d); + + wordList.forEachPair(new ObjectIntProcedure<String>() { + + @Override + public boolean apply(String word, int frequency) { + try { + result.add(frequency * featureWeight(datastore, label, word)); + } catch (InvalidDatastoreException e) { + throw new RuntimeException(e); + } + return true; + } + }); + return result.doubleValue(); } @Override Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/CBayesAlgorithm.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/CBayesAlgorithm.java?rev=909063&r1=909062&r2=909063&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/CBayesAlgorithm.java (original) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/CBayesAlgorithm.java Thu Feb 11 16:32:44 2010 @@ -20,19 +20,21 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; -import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.PriorityQueue; +import org.apache.commons.lang.mutable.MutableDouble; import org.apache.mahout.classifier.ClassifierResult; import org.apache.mahout.classifier.bayes.common.ByScoreLabelResultComparator; import org.apache.mahout.classifier.bayes.exceptions.InvalidDatastoreException; import org.apache.mahout.classifier.bayes.interfaces.Algorithm; import org.apache.mahout.classifier.bayes.interfaces.Datastore; +import org.apache.mahout.math.function.ObjectIntProcedure; +import org.apache.mahout.math.map.OpenObjectIntHashMap; + /** * Class implementing the Complementary Naive Bayes Classifier Algorithm - * + * */ public class CBayesAlgorithm implements Algorithm { @@ -116,25 +118,33 @@ } @Override - public double documentWeight(Datastore datastore, - String label, - String[] document) throws InvalidDatastoreException { - Map<String,int[]> wordList = new HashMap<String,int[]>(1000); + public double documentWeight(final Datastore datastore, + final String label, + final String[] document) throws InvalidDatastoreException { + OpenObjectIntHashMap<String> wordList = new OpenObjectIntHashMap<String>( + document.length / 2); for (String word : document) { - int[] count = wordList.get(word); - if (count == null) { - count = new int[] {0}; - wordList.put(word, count); + if (wordList.containsKey(word) == false) { + wordList.put(word, 1); + } else { + wordList.put(word, wordList.get(word) + 1); } - count[0]++; - } - double result = 0.0; - for (Map.Entry<String,int[]> entry : wordList.entrySet()) { - String word = entry.getKey(); - int count = entry.getValue()[0]; - result += count * featureWeight(datastore, label, word); } - return result; + final MutableDouble result = new MutableDouble(0.0d); + + wordList.forEachPair(new ObjectIntProcedure<String>() { + + @Override + public boolean apply(String word, int frequency) { + try { + result.add(frequency * featureWeight(datastore, label, word)); + } catch (InvalidDatastoreException e) { + throw new RuntimeException(e); + } + return true; + } + }); + return result.doubleValue(); } @Override Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/datastore/InMemoryBayesDatastore.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/datastore/InMemoryBayesDatastore.java?rev=909063&r1=909062&r2=909063&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/datastore/InMemoryBayesDatastore.java (original) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/datastore/InMemoryBayesDatastore.java Thu Feb 11 16:32:44 2010 @@ -19,8 +19,6 @@ import java.io.IOException; import java.util.Collection; -import java.util.HashMap; -import java.util.Map; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; @@ -29,45 +27,53 @@ import org.apache.mahout.classifier.bayes.interfaces.Datastore; import org.apache.mahout.classifier.bayes.io.SequenceFileModelReader; import org.apache.mahout.common.Parameters; +import org.apache.mahout.math.SparseMatrix; +import org.apache.mahout.math.map.OpenIntDoubleHashMap; +import org.apache.mahout.math.map.OpenObjectIntHashMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + /** * Class implementing the Datastore for Algorithms to read In-Memory model * */ public class InMemoryBayesDatastore implements Datastore { - - private static final Logger log = LoggerFactory.getLogger(InMemoryBayesDatastore.class); - - private final Map<String,Map<String,Map<String,Double>>> matrices - = new HashMap<String,Map<String,Map<String,Double>>>(); - private final Map<String,Map<String,Double>> vectors = new HashMap<String,Map<String,Double>>(); + private static final Logger log = LoggerFactory + .getLogger(InMemoryBayesDatastore.class); + + private final OpenObjectIntHashMap<String> featureDictionary = new OpenObjectIntHashMap<String>(); + + private final OpenObjectIntHashMap<String> labelDictionary = new OpenObjectIntHashMap<String>(); + + private final OpenIntDoubleHashMap sigma_j = new OpenIntDoubleHashMap(); + + private final OpenIntDoubleHashMap sigma_k = new OpenIntDoubleHashMap(); + + private final OpenIntDoubleHashMap thetaNormalizerPerLabel = new OpenIntDoubleHashMap(); + + private double sigma_jSigma_k = 1.0; + + private final SparseMatrix weightMatrix = new SparseMatrix(new int[] {1,0}); private final Parameters params; private double thetaNormalizer = 1.0; - + private double alphaI = 1.0; - + public InMemoryBayesDatastore(Parameters params) { - - matrices.put("weight", new HashMap<String, Map<String, Double>>()); - vectors.put("sumWeight", new HashMap<String, Double>()); - matrices.put("weight", new HashMap<String, Map<String, Double>>()); - vectors.put("labelWeight", new HashMap<String, Double>()); - vectors.put("thetaNormalizer", new HashMap<String, Double>()); String basePath = params.get("basePath"); this.params = params; params.set("sigma_j", basePath + "/trainer-weights/Sigma_j/part-*"); params.set("sigma_k", basePath + "/trainer-weights/Sigma_k/part-*"); params.set("sigma_kSigma_j", basePath - + "/trainer-weights/Sigma_kSigma_j/part-*"); + + "/trainer-weights/Sigma_kSigma_j/part-*"); params.set("thetaNormalizer", basePath + "/trainer-thetaNormalizer/part-*"); params.set("weight", basePath + "/trainer-tfIdf/trainer-tfIdf/part-*"); alphaI = Double.valueOf(params.get("alpha_i", "1.0")); } - + @Override public void initialize() throws InvalidDatastoreException { Configuration conf = new Configuration(); @@ -78,129 +84,95 @@ } catch (IOException e) { throw new InvalidDatastoreException(e.getMessage()); } - updateVocabCount(); - Collection<String> labels = getKeys("thetaNormalizer"); - for (String label : labels) { - thetaNormalizer = Math.max(thetaNormalizer, Math.abs(vectorGetCell( - "thetaNormalizer", label))); - } - for (String label : labels) { - log.info("{} {} {} {}", new Object[] {label, - vectorGetCell("thetaNormalizer", - label), + for (String label : getKeys("")) { + log.info("{} {} {} {}", new Object[] { + label, + thetaNormalizerPerLabel + .get(getLabelID(label)), thetaNormalizer, - vectorGetCell("thetaNormalizer", - label) / thetaNormalizer}); + thetaNormalizerPerLabel + .get(getLabelID(label)) + / thetaNormalizer}); } } - + @Override public Collection<String> getKeys(String name) throws InvalidDatastoreException { - return vectors.get("labelWeight").keySet(); + return labelDictionary.keys(); } @Override public double getWeight(String matrixName, String row, String column) throws InvalidDatastoreException { - return matrixGetCell(matrixName, row, column); + if (matrixName.equals("weight")) { + if (column.equals("sigma_j")) { + return sigma_j.get(getFeatureID(row)); + } else return weightMatrix.getQuick(getFeatureID(row), getLabelID(column)); + } else throw new InvalidDatastoreException("Matrix not found: " + + matrixName); } @Override public double getWeight(String vectorName, String index) throws InvalidDatastoreException { - if (vectorName.equals("thetaNormalizer")) return vectorGetCell(vectorName, - index) - / thetaNormalizer; - else if (vectorName.equals("params")) { + if (vectorName.equals("sumWeight")) { + if (index.equals("sigma_jSigma_k")) { + return sigma_jSigma_k; + } else if (index.equals("vocabCount")) { + return featureDictionary.size(); + } else throw new InvalidDatastoreException(); + } else if (vectorName.equals("thetaNormalizer")) { + return thetaNormalizerPerLabel.get(getLabelID(index)) / thetaNormalizer; + } else if (vectorName.equals("params")) { if (index.equals("alpha_i")) { return alphaI; - } else { - throw new InvalidDatastoreException(); - } + } else throw new InvalidDatastoreException(); + } else if (vectorName.equals("labelWeight")) { + return sigma_k.get(getLabelID(index)); + } else throw new InvalidDatastoreException(); + } + + private int getFeatureID(String feature) { + if (featureDictionary.containsKey(feature)) { + return featureDictionary.get(feature); + } else { + int id = featureDictionary.size(); + featureDictionary.put(feature, id); + return id; } - return vectorGetCell(vectorName, index); } - private double matrixGetCell(String matrixName, String row, String col) throws InvalidDatastoreException { - Map<String,Map<String,Double>> matrix = matrices.get(matrixName); - if (matrix == null) { - throw new InvalidDatastoreException(); + private int getLabelID(String label) { + if (labelDictionary.containsKey(label)) { + return labelDictionary.get(label); + } else { + int id = labelDictionary.size(); + labelDictionary.put(label, id); + return id; } - Map<String,Double> rowVector = matrix.get(row); - if (rowVector == null) { - return 0.0; - } - return nullToZero(rowVector.get(col)); - } - - private double vectorGetCell(String vectorName, String index) throws InvalidDatastoreException { - - Map<String,Double> vector = vectors.get(vectorName); - if (vector == null) { - throw new InvalidDatastoreException(); - } - return nullToZero(vector.get(index)); - } - - private void matrixPutCell(String matrixName, - String row, - String col, - double weight) { - Map<String,Map<String,Double>> matrix = matrices.get(matrixName); - if (matrix == null) { - matrix = new HashMap<String,Map<String,Double>>(); - matrices.put(matrixName, matrix); - } - Map<String,Double> rowVector = matrix.get(row); - if (rowVector == null) { - rowVector = new HashMap<String,Double>(); - matrix.put(row, rowVector); - } - rowVector.put(col, weight); - } - - private void vectorPutCell(String vectorName, String index, double weight) { - - Map<String,Double> vector = vectors.get(vectorName); - if (vector == null) { - vector = new HashMap<String,Double>(); - vectors.put(vectorName, vector); - } - vector.put(index, weight); - } - - private long sizeOfMatrix(String matrixName) { - Map<String,Map<String,Double>> matrix = matrices.get(matrixName); - if (matrix == null) { - return 0; - } - return matrix.size(); } public void loadFeatureWeight(String feature, String label, double weight) { - matrixPutCell("weight", feature, label, weight); + int fid = getFeatureID(feature); + int lid = getLabelID(label); + weightMatrix.setQuick(fid, lid, weight); } public void setSumFeatureWeight(String feature, double weight) { - matrixPutCell("weight", feature, "sigma_j", weight); + int fid = getFeatureID(feature); + sigma_j.put(fid, weight); } public void setSumLabelWeight(String label, double weight) { - vectorPutCell("labelWeight", label, weight); + int lid = getLabelID(label); + sigma_k.put(lid, weight); } public void setThetaNormalizer(String label, double weight) { - vectorPutCell("thetaNormalizer", label, weight); + int lid = getLabelID(label); + thetaNormalizerPerLabel.put(lid, weight); + thetaNormalizer = Math.max(thetaNormalizer, Math.abs(weight)); } public void setSigmaJSigmaK(double weight) { - vectorPutCell("sumWeight", "sigma_jSigma_k", weight); + this.sigma_jSigma_k = weight; } - - public void updateVocabCount() { - vectorPutCell("sumWeight", "vocabCount", sizeOfMatrix("weight")); - } - - private static double nullToZero(Double value) { - return value == null ? 0.0 : value; - } - } Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesThetaNormalizerMapper.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesThetaNormalizerMapper.java?rev=909063&r1=909062&r2=909063&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesThetaNormalizerMapper.java (original) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesThetaNormalizerMapper.java Thu Feb 11 16:32:44 2010 @@ -32,11 +32,13 @@ import org.apache.mahout.classifier.bayes.mapreduce.common.BayesConstants; import org.apache.mahout.common.Parameters; import org.apache.mahout.common.StringTuple; +import org.apache.mahout.math.map.OpenObjectDoubleHashMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** - * Mapper for Calculating the ThetaNormalizer for a label in Naive Bayes Algorithm + * Mapper for Calculating the ThetaNormalizer for a label in Naive Bayes + * Algorithm * */ public class BayesThetaNormalizerMapper extends MapReduceBase implements @@ -45,7 +47,7 @@ private static final Logger log = LoggerFactory .getLogger(BayesThetaNormalizerMapper.class); - private Map<String,Double> labelWeightSum; + private OpenObjectDoubleHashMap<String> labelWeightSum = new OpenObjectDoubleHashMap<String>(); private double sigmaJSigmaK; private double vocabCount; private double alphaI = 1.0; @@ -79,33 +81,34 @@ @Override public void configure(JobConf job) { try { - if (labelWeightSum == null) { - labelWeightSum = new HashMap<String,Double>(); - - DefaultStringifier<Map<String,Double>> mapStringifier = new DefaultStringifier<Map<String,Double>>( - job, GenericsUtil.getClass(labelWeightSum)); - - String labelWeightSumString = mapStringifier.toString(labelWeightSum); - labelWeightSumString = job.get("cnaivebayes.sigma_k", - labelWeightSumString); - labelWeightSum = mapStringifier.fromString(labelWeightSumString); - - DefaultStringifier<Double> stringifier = new DefaultStringifier<Double>( - job, GenericsUtil.getClass(sigmaJSigmaK)); - String sigmaJSigmaKString = stringifier.toString(sigmaJSigmaK); - sigmaJSigmaKString = job.get("cnaivebayes.sigma_jSigma_k", - sigmaJSigmaKString); - sigmaJSigmaK = stringifier.fromString(sigmaJSigmaKString); - - String vocabCountString = stringifier.toString(vocabCount); - vocabCountString = job.get("cnaivebayes.vocabCount", vocabCountString); - vocabCount = stringifier.fromString(vocabCountString); - - Parameters params = Parameters.fromString(job.get("bayes.parameters", - "")); - alphaI = Double.valueOf(params.get("alpha_i", "1.0")); - + labelWeightSum.clear(); + Map<String,Double> labelWeightSumTemp = new HashMap<String,Double>(); + + DefaultStringifier<Map<String,Double>> mapStringifier = new DefaultStringifier<Map<String,Double>>( + job, GenericsUtil.getClass(labelWeightSumTemp)); + + String labelWeightSumString = mapStringifier.toString(labelWeightSumTemp); + labelWeightSumString = job.get("cnaivebayes.sigma_k", + labelWeightSumString); + labelWeightSumTemp = mapStringifier.fromString(labelWeightSumString); + for (String key : labelWeightSumTemp.keySet()) { + this.labelWeightSum.put(key, labelWeightSumTemp.get(key)); } + DefaultStringifier<Double> stringifier = new DefaultStringifier<Double>( + job, GenericsUtil.getClass(sigmaJSigmaK)); + String sigmaJSigmaKString = stringifier.toString(sigmaJSigmaK); + sigmaJSigmaKString = job.get("cnaivebayes.sigma_jSigma_k", + sigmaJSigmaKString); + sigmaJSigmaK = stringifier.fromString(sigmaJSigmaKString); + + String vocabCountString = stringifier.toString(vocabCount); + vocabCountString = job.get("cnaivebayes.vocabCount", vocabCountString); + vocabCount = stringifier.fromString(vocabCountString); + + Parameters params = Parameters + .fromString(job.get("bayes.parameters", "")); + alphaI = Double.valueOf(params.get("alpha_i", "1.0")); + } catch (IOException ex) { log.warn(ex.toString(), ex); } Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/cbayes/CBayesThetaNormalizerMapper.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/cbayes/CBayesThetaNormalizerMapper.java?rev=909063&r1=909062&r2=909063&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/cbayes/CBayesThetaNormalizerMapper.java (original) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/cbayes/CBayesThetaNormalizerMapper.java Thu Feb 11 16:32:44 2010 @@ -32,6 +32,8 @@ import org.apache.mahout.classifier.bayes.mapreduce.common.BayesConstants; import org.apache.mahout.common.Parameters; import org.apache.mahout.common.StringTuple; +import org.apache.mahout.math.function.ObjectDoubleProcedure; +import org.apache.mahout.math.map.OpenObjectDoubleHashMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -46,7 +48,7 @@ private static final Logger log = LoggerFactory .getLogger(CBayesThetaNormalizerMapper.class); - private Map<String,Double> labelWeightSum; + private OpenObjectDoubleHashMap<String> labelWeightSum = new OpenObjectDoubleHashMap<String>(); private double sigmaJSigmaK; private double vocabCount; private double alphaI = 1.0; @@ -60,30 +62,33 @@ */ @Override public void map(StringTuple key, - DoubleWritable value, - OutputCollector<StringTuple,DoubleWritable> output, - Reporter reporter) throws IOException { + final DoubleWritable value, + final OutputCollector<StringTuple,DoubleWritable> output, + final Reporter reporter) throws IOException { if (key.stringAt(0).equals(BayesConstants.FEATURE_SUM)) { // if it is from // the Sigma_j // folder - - for (Map.Entry<String,Double> stringDoubleEntry : labelWeightSum - .entrySet()) { - String label = stringDoubleEntry.getKey(); - double weight = Math - .log((value.get() + alphaI) - / (sigmaJSigmaK - stringDoubleEntry.getValue() + vocabCount)); - - reporter.setStatus("Complementary Bayes Theta Normalizer Mapper: " - + stringDoubleEntry + " => " + weight); - StringTuple normalizerTuple = new StringTuple( - BayesConstants.LABEL_THETA_NORMALIZER); - normalizerTuple.add(label); - output.collect(normalizerTuple, new DoubleWritable(weight)); // output - // Sigma_j + labelWeightSum.forEachPair(new ObjectDoubleProcedure<String>() { - } + @Override + public boolean apply(String label, double sigmaJ) { + double weight = Math.log((value.get() + alphaI) + / (sigmaJSigmaK - sigmaJ + vocabCount)); + + reporter.setStatus("Complementary Bayes Theta Normalizer Mapper: " + + label + " => " + weight); + StringTuple normalizerTuple = new StringTuple( + BayesConstants.LABEL_THETA_NORMALIZER); + normalizerTuple.add(label); + try { + output.collect(normalizerTuple, new DoubleWritable(weight)); + } catch (IOException e) { + throw new RuntimeException(e); + } // output Sigma_j + return true; + } + }); } else { String label = key.stringAt(1); @@ -110,33 +115,35 @@ @Override public void configure(JobConf job) { try { - if (labelWeightSum == null) { - labelWeightSum = new HashMap<String,Double>(); - - DefaultStringifier<Map<String,Double>> mapStringifier = new DefaultStringifier<Map<String,Double>>( - job, GenericsUtil.getClass(labelWeightSum)); - - String labelWeightSumString = mapStringifier.toString(labelWeightSum); - labelWeightSumString = job.get("cnaivebayes.sigma_k", - labelWeightSumString); - labelWeightSum = mapStringifier.fromString(labelWeightSumString); - - DefaultStringifier<Double> stringifier = new DefaultStringifier<Double>( - job, GenericsUtil.getClass(sigmaJSigmaK)); - String sigmaJSigmaKString = stringifier.toString(sigmaJSigmaK); - sigmaJSigmaKString = job.get("cnaivebayes.sigma_jSigma_k", - sigmaJSigmaKString); - sigmaJSigmaK = stringifier.fromString(sigmaJSigmaKString); - - String vocabCountString = stringifier.toString(vocabCount); - vocabCountString = job.get("cnaivebayes.vocabCount", vocabCountString); - vocabCount = stringifier.fromString(vocabCountString); - - Parameters params = Parameters.fromString(job.get("bayes.parameters", - "")); - alphaI = Double.valueOf(params.get("alpha_i", "1.0")); - + labelWeightSum.clear(); + Map<String,Double> labelWeightSumTemp = new HashMap<String,Double>(); + + DefaultStringifier<Map<String,Double>> mapStringifier = new DefaultStringifier<Map<String,Double>>( + job, GenericsUtil.getClass(labelWeightSumTemp)); + + String labelWeightSumString = mapStringifier.toString(labelWeightSumTemp); + labelWeightSumString = job.get("cnaivebayes.sigma_k", + labelWeightSumString); + labelWeightSumTemp = mapStringifier.fromString(labelWeightSumString); + for (String key : labelWeightSumTemp.keySet()) { + this.labelWeightSum.put(key, labelWeightSumTemp.get(key)); } + + DefaultStringifier<Double> stringifier = new DefaultStringifier<Double>( + job, GenericsUtil.getClass(sigmaJSigmaK)); + String sigmaJSigmaKString = stringifier.toString(sigmaJSigmaK); + sigmaJSigmaKString = job.get("cnaivebayes.sigma_jSigma_k", + sigmaJSigmaKString); + sigmaJSigmaK = stringifier.fromString(sigmaJSigmaKString); + + String vocabCountString = stringifier.toString(vocabCount); + vocabCountString = job.get("cnaivebayes.vocabCount", vocabCountString); + vocabCount = stringifier.fromString(vocabCountString); + + Parameters params = Parameters + .fromString(job.get("bayes.parameters", "")); + alphaI = Double.valueOf(params.get("alpha_i", "1.0")); + } catch (IOException ex) { log.warn(ex.toString(), ex); } Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesFeatureMapper.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesFeatureMapper.java?rev=909063&r1=909062&r2=909063&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesFeatureMapper.java (original) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesFeatureMapper.java Thu Feb 11 16:32:44 2010 @@ -18,10 +18,11 @@ package org.apache.mahout.classifier.bayes.mapreduce.common; import java.io.IOException; -import java.util.HashMap; +import java.util.Arrays; +import java.util.Iterator; import java.util.List; -import java.util.Map; +import org.apache.commons.lang.mutable.MutableDouble; import org.apache.hadoop.io.DoubleWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapred.JobConf; @@ -29,14 +30,20 @@ import org.apache.hadoop.mapred.Mapper; import org.apache.hadoop.mapred.OutputCollector; import org.apache.hadoop.mapred.Reporter; +import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.shingle.ShingleFilter; +import org.apache.lucene.analysis.tokenattributes.TermAttribute; import org.apache.mahout.common.Parameters; import org.apache.mahout.common.StringTuple; -import org.apache.mahout.common.nlp.NGrams; +import org.apache.mahout.math.function.ObjectIntProcedure; +import org.apache.mahout.math.function.ObjectProcedure; +import org.apache.mahout.math.map.OpenObjectIntHashMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** - * Reads the input train set(preprocessed using the {...@link org.apache.mahout.classifier.BayesFileFormatter}). + * Reads the input train set(preprocessed using the + * {...@link org.apache.mahout.classifier.BayesFileFormatter}). */ public class BayesFeatureMapper extends MapReduceBase implements Mapper<Text,Text,StringTuple,DoubleWritable> { @@ -60,7 +67,8 @@ * @param key * The label * @param value - * the features (all unique) associated w/ this label + * the features (all unique) associated w/ this label in stringtuple + * format * @param output * The OutputCollector to write the results to * @param reporter @@ -69,65 +77,91 @@ @Override public void map(Text key, Text value, - OutputCollector<StringTuple,DoubleWritable> output, + final OutputCollector<StringTuple,DoubleWritable> output, Reporter reporter) throws IOException { // String line = value.toString(); - String label = key.toString(); - - Map<String,int[]> wordList = new HashMap<String,int[]>(1000); - - List<String> ngrams = new NGrams(value.toString(), gramSize) - .generateNGramsWithoutLabel(); - - for (String ngram : ngrams) { - int[] count = wordList.get(ngram); - if (count == null) { - count = new int[1]; - count[0] = 0; - wordList.put(ngram, count); + final String label = key.toString(); + List<String> tokens = Arrays.asList(value.toString().split("[ ]+")); + OpenObjectIntHashMap<String> wordList = new OpenObjectIntHashMap<String>( + tokens.size() * gramSize); + + if (gramSize > 1) { + ShingleFilter sf = new ShingleFilter(new IteratorTokenStream(tokens + .iterator()), gramSize); + do { + String term = ((TermAttribute) sf.getAttribute(TermAttribute.class)) + .term(); + if (term.length() > 0) { + if (wordList.containsKey(term) == false) { + wordList.put(term, 1); + } else { + wordList.put(term, 1 + wordList.get(term)); + } + } + } while (sf.incrementToken()); + } else { + for (String term : tokens) { + if (wordList.containsKey(term) == false) { + wordList.put(term, 1); + } else { + wordList.put(term, 1 + wordList.get(term)); + } } - count[0]++; } - double lengthNormalisation = 0.0; - for (int[] dKJ : wordList.values()) { - // key is label,word - double dkjValue = (double) dKJ[0]; - lengthNormalisation += dkjValue * dkjValue; - } - lengthNormalisation = Math.sqrt(lengthNormalisation); + final MutableDouble lengthNormalisationMut = new MutableDouble(0); + wordList.forEachPair(new ObjectIntProcedure<String>() { + @Override + public boolean apply(String word, int dKJ) { + lengthNormalisationMut.add(dKJ * dKJ); + return true; + } + }); + + final double lengthNormalisation = Math.sqrt(lengthNormalisationMut + .doubleValue()); // Output Length Normalized + TF Transformed Frequency per Word per Class // Log(1 + D_ij)/SQRT( SIGMA(k, D_kj) ) - for (Map.Entry<String,int[]> entry : wordList.entrySet()) { - // key is label,word - String token = entry.getKey(); - StringTuple tuple = new StringTuple(); - tuple.add(BayesConstants.WEIGHT); - tuple.add(label); - tuple.add(token); - DoubleWritable f = new DoubleWritable(Math.log(1.0 + entry.getValue()[0]) - / lengthNormalisation); - output.collect(tuple, f); - } + wordList.forEachPair(new ObjectIntProcedure<String>() { + @Override + public boolean apply(String token, int dKJ) { + try { + StringTuple tuple = new StringTuple(); + tuple.add(BayesConstants.WEIGHT); + tuple.add(label); + tuple.add(token); + DoubleWritable f = new DoubleWritable(Math.log(1.0 + dKJ) + / lengthNormalisation); + output.collect(tuple, f); + } catch (IOException e) { + throw new RuntimeException(e); + } + return true; + } + }); reporter.setStatus("Bayes Feature Mapper: Document Label: " + label); // Output Document Frequency per Word per Class - - for (String token : wordList.keySet()) { - // key is label,word - - StringTuple dfTuple = new StringTuple(); - dfTuple.add(BayesConstants.DOCUMENT_FREQUENCY); - dfTuple.add(label); - dfTuple.add(token); - output.collect(dfTuple, ONE); - - StringTuple tokenCountTuple = new StringTuple(); - tokenCountTuple.add(BayesConstants.FEATURE_COUNT); - tokenCountTuple.add(token); - output.collect(tokenCountTuple, ONE); - - } + wordList.forEachKey(new ObjectProcedure<String>() { + @Override + public boolean apply(String token) { + try { + StringTuple dfTuple = new StringTuple(); + dfTuple.add(BayesConstants.DOCUMENT_FREQUENCY); + dfTuple.add(label); + dfTuple.add(token); + output.collect(dfTuple, ONE); + + StringTuple tokenCountTuple = new StringTuple(); + tokenCountTuple.add(BayesConstants.FEATURE_COUNT); + tokenCountTuple.add(token); + output.collect(tokenCountTuple, ONE); + } catch (IOException e) { + throw new RuntimeException(e); + } + return true; + } + }); // output that we have seen the label to calculate the Count of Document per // class @@ -149,4 +183,26 @@ log.warn(ex.toString(), ex); } } + + /** Used to emit tokens from an input string array in the style of TokenStream */ + public static class IteratorTokenStream extends TokenStream { + private final TermAttribute termAtt; + private final Iterator<String> iterator; + + public IteratorTokenStream(Iterator<String> iterator) { + this.iterator = iterator; + this.termAtt = (TermAttribute) addAttribute(TermAttribute.class); + } + + @Override + public boolean incrementToken() throws IOException { + if (iterator.hasNext()) { + clearAttributes(); + termAtt.setTermBuffer(iterator.next()); + return true; + } else { + return false; + } + } + } } Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesTfIdfMapper.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesTfIdfMapper.java?rev=909063&r1=909062&r2=909063&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesTfIdfMapper.java (original) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesTfIdfMapper.java Thu Feb 11 16:32:44 2010 @@ -30,6 +30,7 @@ import org.apache.hadoop.mapred.Reporter; import org.apache.hadoop.util.GenericsUtil; import org.apache.mahout.common.StringTuple; +import org.apache.mahout.math.map.OpenObjectDoubleHashMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -48,7 +49,7 @@ private static final DoubleWritable ONE = new DoubleWritable(1.0); - private Map<String,Double> labelDocumentCounts; + private OpenObjectDoubleHashMap<String> labelDocumentCounts = new OpenObjectDoubleHashMap<String>(); /** * We need to calculate the Tf-Idf of each feature in each label @@ -87,20 +88,22 @@ @Override public void configure(JobConf job) { try { - if (labelDocumentCounts == null) { - labelDocumentCounts = new HashMap<String,Double>(); - - DefaultStringifier<Map<String,Double>> mapStringifier = new DefaultStringifier<Map<String,Double>>( - job, GenericsUtil.getClass(labelDocumentCounts)); - - String labelDocumentCountString = mapStringifier - .toString(labelDocumentCounts); - labelDocumentCountString = job.get("cnaivebayes.labelDocumentCounts", - labelDocumentCountString); - - labelDocumentCounts = mapStringifier - .fromString(labelDocumentCountString); + this.labelDocumentCounts.clear(); + Map<String,Double> labelDocCountTemp = new HashMap<String,Double>(); + + DefaultStringifier<Map<String,Double>> mapStringifier = new DefaultStringifier<Map<String,Double>>( + job, GenericsUtil.getClass(labelDocCountTemp)); + + String labelDocumentCountString = mapStringifier + .toString(labelDocCountTemp); + labelDocumentCountString = job.get("cnaivebayes.labelDocumentCounts", + labelDocumentCountString); + + labelDocCountTemp = mapStringifier.fromString(labelDocumentCountString); + for (String key : labelDocCountTemp.keySet()) { + this.labelDocumentCounts.put(key, labelDocCountTemp.get(key)); } + } catch (IOException ex) { log.warn(ex.toString(), ex); } Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/BayesClassifierTest.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/BayesClassifierTest.java?rev=909063&r1=909062&r2=909063&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/BayesClassifierTest.java (original) +++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/BayesClassifierTest.java Thu Feb 11 16:32:44 2010 @@ -69,7 +69,7 @@ store.loadFeatureWeight("ee", "e", 100); store.loadFeatureWeight("aa", "e", 50); store.loadFeatureWeight("dd", "e", 50); - store.updateVocabCount(); + } public void test() throws InvalidDatastoreException { @@ -79,7 +79,7 @@ assertNotNull("category is null and it shouldn't be", result); assertEquals(result + " is not equal to e", "e", result.getLabel()); - document = new String[] {"ff"}; + document = new String[] {"dd"}; result = classifier.classifyDocument(document, "unknown"); assertNotNull("category is null and it shouldn't be", result); assertEquals(result + " is not equal to d", "d", result.getLabel()); Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/CBayesClassifierTest.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/CBayesClassifierTest.java?rev=909063&r1=909062&r2=909063&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/CBayesClassifierTest.java (original) +++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/CBayesClassifierTest.java Thu Feb 11 16:32:44 2010 @@ -74,7 +74,7 @@ store.loadFeatureWeight("ee", "e", 100); store.loadFeatureWeight("aa", "e", 50); store.loadFeatureWeight("dd", "e", 50); - store.updateVocabCount(); + } public void test() throws InvalidDatastoreException { @@ -84,7 +84,7 @@ assertNotNull("category is null and it shouldn't be", result); assertEquals(result + " is not equal to e", "e", result.getLabel()); - document = new String[] {"ff"}; + document = new String[] {"dd"}; result = classifier.classifyDocument(document, "unknown"); assertNotNull("category is null and it shouldn't be", result); assertEquals(result + " is not equal to d", "d", result.getLabel());