Author: robinanil
Date: Sun May 13 22:55:08 2012
New Revision: 1338002

URL: http://svn.apache.org/viewvc?rev=1338002&view=rev
Log:
MAHOUT-1014 TrainNewsGroups for naive bayes which encodes vector exactly like 
sgd

Added:
    mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/naivebayes/
    
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/naivebayes/TestNewsGroups.java
    
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/naivebayes/TrainNewsGroups.java
Modified:
    
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
    
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
    
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java
    
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java
    
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java
    
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java
    mahout/trunk/examples/bin/classify-20newsgroups.sh

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java?rev=1338002&r1=1338001&r2=1338002&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
 Sun May 13 22:55:08 2012
@@ -36,5 +36,4 @@ public class ComplementaryNaiveBayesClas
 
     return Math.log(numerator / denominator);
   }
-
 }

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java?rev=1338002&r1=1338001&r2=1338002&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
 Sun May 13 22:55:08 2012
@@ -34,5 +34,4 @@ public class StandardNaiveBayesClassifie
 
     return -Math.log(numerator / denominator);
   }
-  
 }

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java?rev=1338002&r1=1338001&r2=1338002&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java
 Sun May 13 22:55:08 2012
@@ -1,18 +1,16 @@
 /**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Licensed to the Apache Software Foundation (ASF) under one or more 
contributor license
+ * agreements. See the NOTICE file distributed with this work for additional 
information regarding
+ * copyright ownership. The ASF licenses this file to You under the Apache 
License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance with the 
License. You may obtain a
+ * copy of the License at
+ * 
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing, software 
distributed under the License
+ * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 
KIND, either express
+ * or implied. See the License for the specific language governing permissions 
and limitations under
+ * the License.
  */
 
 package org.apache.mahout.classifier.naivebayes.training;
@@ -20,30 +18,16 @@ package org.apache.mahout.classifier.nai
 import java.io.IOException;
 
 import org.apache.hadoop.io.IntWritable;
-import org.apache.hadoop.io.Text;
 import org.apache.hadoop.mapreduce.Mapper;
-import org.apache.mahout.classifier.naivebayes.BayesUtils;
+import org.apache.mahout.math.MultiLabelVectorWritable;
 import org.apache.mahout.math.VectorWritable;
-import org.apache.mahout.math.map.OpenObjectIntHashMap;
-
-public class IndexInstancesMapper extends Mapper<Text, VectorWritable, 
IntWritable, VectorWritable> {
-
-  public enum Counter { SKIPPED_INSTANCES }
-
-  private OpenObjectIntHashMap<String> labelIndex;
-
-  @Override
-  protected void setup(Context ctx) throws IOException, InterruptedException {
-    labelIndex = BayesUtils.readIndexFromCache(ctx.getConfiguration());
-  }
 
+public class IndexInstancesMapper
+    extends Mapper<IntWritable, MultiLabelVectorWritable, IntWritable, 
VectorWritable> {
   @Override
-  protected void map(Text labelText, VectorWritable instance, Context ctx) 
throws IOException, InterruptedException {
-    String label = labelText.toString();
-    if (labelIndex.containsKey(label)) {
-      ctx.write(new IntWritable(labelIndex.get(label)), instance);
-    } else {
-      ctx.getCounter(Counter.SKIPPED_INSTANCES).increment(1);
-    }
+  protected void map(IntWritable key, MultiLabelVectorWritable instance, 
Context ctx)
+      throws IOException, InterruptedException {
+    VectorWritable vw = new VectorWritable(instance.getVector());
+    ctx.write(new IntWritable(instance.getLabels()[0]), vw);
   }
 }

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java?rev=1338002&r1=1338001&r2=1338002&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java
 Sun May 13 22:55:08 2012
@@ -17,9 +17,10 @@
 
 package org.apache.mahout.classifier.naivebayes.training;
 
-import com.google.common.base.Splitter;
+import java.util.List;
+import java.util.Map;
+
 import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.IntWritable;
 import org.apache.hadoop.io.Text;
 import org.apache.hadoop.mapreduce.Job;
@@ -31,16 +32,9 @@ import org.apache.mahout.classifier.naiv
 import org.apache.mahout.common.AbstractJob;
 import org.apache.mahout.common.HadoopUtil;
 import org.apache.mahout.common.commandline.DefaultOptionCreator;
-import org.apache.mahout.common.iterator.sequencefile.PathFilters;
-import org.apache.mahout.common.iterator.sequencefile.PathType;
-import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
 import org.apache.mahout.common.mapreduce.VectorSumReducer;
 import org.apache.mahout.math.VectorWritable;
 
-import java.io.IOException;
-import java.util.List;
-import java.util.Map;
-
 /**
  * This class trains a Naive Bayes Classifier (Parameters for both Naive Bayes 
and Complementary Naive Bayes)
  */
@@ -60,15 +54,11 @@ public final class TrainNaiveBayesJob ex
 
   @Override
   public int run(String[] args) throws Exception {
-
     addInputOption();
     addOutputOption();
-    addOption("labels", "l", "comma-separated list of labels to include in 
training", false);
-
-    addOption(buildOption("extractLabels", "el", "Extract the labels from the 
input", false, false, ""));
+    addOption("labelSize", "ls", "Number of labels in the input data", 
String.valueOf(2));
     addOption("alphaI", "a", "smoothing parameter", String.valueOf(1.0f));
     addOption(buildOption("trainComplementary", "c", "train complementary?", 
false, false, String.valueOf(false)));
-    addOption("labelIndex", "li", "The path to store the label index in", 
false);
     addOption(DefaultOptionCreator.overwriteOption().create());
     Map<String, List<String>> parsedArgs = parseArguments(args);
     if (parsedArgs == null) {
@@ -78,21 +68,12 @@ public final class TrainNaiveBayesJob ex
       HadoopUtil.delete(getConf(), getOutputPath());
       HadoopUtil.delete(getConf(), getTempPath());
     }
-    Path labPath;
-    String labPathStr = getOption("labelIndex");
-    if (labPathStr != null) {
-      labPath = new Path(labPathStr);
-    } else {
-      labPath = getTempPath("labelIndex");
-    }
-    long labelSize = createLabelIndex(labPath);
+    int labelSize = Integer.parseInt(getOption("labelSize"));
     float alphaI = Float.parseFloat(getOption("alphaI"));
     boolean trainComplementary = 
Boolean.parseBoolean(getOption("trainComplementary"));
 
-
     HadoopUtil.setSerializations(getConf());
-    HadoopUtil.cacheFiles(labPath, getConf());
-
+    
     //add up all the vectors with the same labels, while mapping the labels 
into our index
     Job indexInstances = prepareJob(getInputPath(), 
getTempPath(SUMMED_OBSERVATIONS), SequenceFileInputFormat.class,
             IndexInstancesMapper.class, IntWritable.class, 
VectorWritable.class, VectorSumReducer.class, IntWritable.class,
@@ -132,18 +113,4 @@ public final class TrainNaiveBayesJob ex
 
     return 0;
   }
-
-  private long createLabelIndex(Path labPath) throws IOException {
-    long labelSize = 0;
-    if (hasOption("labels")) {
-      Iterable<String> labels = Splitter.on(",").split(getOption("labels"));
-      labelSize = BayesUtils.writeLabelIndex(getConf(), labels, labPath);
-    } else if (hasOption("extractLabels")) {
-      SequenceFileDirIterable<Text, IntWritable> iterable =
-              new SequenceFileDirIterable<Text, IntWritable>(getInputPath(), 
PathType.LIST, PathFilters.logsCRCFilter(), getConf());
-      labelSize = BayesUtils.writeLabelIndex(getConf(), labPath, iterable);
-    }
-    return labelSize;
-  }
-
 }

Modified: 
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java?rev=1338002&r1=1338001&r2=1338002&view=diff
==============================================================================
--- 
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java
 (original)
+++ 
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java
 Sun May 13 22:55:08 2012
@@ -17,23 +17,24 @@
 
 package org.apache.mahout.classifier.naivebayes;
 
-import com.google.common.io.Closeables;
+import java.io.File;
+
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
 import org.apache.hadoop.io.SequenceFile;
-import org.apache.hadoop.io.Text;
 import org.apache.mahout.classifier.AbstractVectorClassifier;
 import org.apache.mahout.classifier.naivebayes.training.TrainNaiveBayesJob;
 import org.apache.mahout.common.MahoutTestCase;
 import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.MultiLabelVectorWritable;
 import org.apache.mahout.math.Vector;
-import org.apache.mahout.math.VectorWritable;
 import org.apache.mahout.math.hadoop.MathHelper;
 import org.junit.Before;
 import org.junit.Test;
 
-import java.io.File;
+import com.google.common.io.Closeables;
 
 public class NaiveBayesTest extends MahoutTestCase {
 
@@ -42,8 +43,8 @@ public class NaiveBayesTest extends Maho
   private File outputDir;
   private File tempDir;
 
-  static final Text LABEL_STOLEN = new Text("stolen");
-  static final Text LABEL_NOT_STOLEN = new Text("not_stolen");
+  static final String LABEL_STOLEN = "stolen";
+  static final String LABEL_NOT_STOLEN = "not_stolen";
 
   static final Vector.Element COLOR_RED = MathHelper.elem(0, 1);
   static final Vector.Element COLOR_YELLOW = MathHelper.elem(1, 1);
@@ -66,19 +67,19 @@ public class NaiveBayesTest extends Maho
     tempDir = getTestTempDir("tmp");
 
     SequenceFile.Writer writer = new SequenceFile.Writer(FileSystem.get(conf), 
conf,
-        new Path(inputFile.getAbsolutePath()), Text.class, 
VectorWritable.class);
+        new Path(inputFile.getAbsolutePath()), IntWritable.class, 
MultiLabelVectorWritable.class);
 
     try {
-      writer.append(LABEL_STOLEN,      trainingInstance(COLOR_RED, 
TYPE_SPORTS, ORIGIN_DOMESTIC));
-      writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_RED, TYPE_SPORTS, 
ORIGIN_DOMESTIC));
-      writer.append(LABEL_STOLEN,      trainingInstance(COLOR_RED, 
TYPE_SPORTS, ORIGIN_DOMESTIC));
-      writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_YELLOW, 
TYPE_SPORTS, ORIGIN_DOMESTIC));
-      writer.append(LABEL_STOLEN,      trainingInstance(COLOR_YELLOW, 
TYPE_SPORTS, ORIGIN_IMPORTED));
-      writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SUV, 
ORIGIN_IMPORTED));
-      writer.append(LABEL_STOLEN,      trainingInstance(COLOR_YELLOW, 
TYPE_SUV, ORIGIN_IMPORTED));
-      writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SUV, 
ORIGIN_DOMESTIC));
-      writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_RED, TYPE_SUV, 
ORIGIN_IMPORTED));
-      writer.append(LABEL_STOLEN,      trainingInstance(COLOR_RED, 
TYPE_SPORTS, ORIGIN_IMPORTED));
+      writer.append(new IntWritable(0), trainingInstance(LABEL_STOLEN, 
COLOR_RED, TYPE_SPORTS, ORIGIN_DOMESTIC));
+      writer.append(new IntWritable(0), trainingInstance(LABEL_NOT_STOLEN, 
COLOR_RED, TYPE_SPORTS, ORIGIN_DOMESTIC));
+      writer.append(new IntWritable(0), trainingInstance(LABEL_STOLEN, 
COLOR_RED, TYPE_SPORTS, ORIGIN_DOMESTIC));
+      writer.append(new IntWritable(0), trainingInstance(LABEL_NOT_STOLEN, 
COLOR_YELLOW, TYPE_SPORTS, ORIGIN_DOMESTIC));
+      writer.append(new IntWritable(0), trainingInstance(LABEL_STOLEN, 
COLOR_YELLOW, TYPE_SPORTS, ORIGIN_IMPORTED));
+      writer.append(new IntWritable(0), trainingInstance(LABEL_NOT_STOLEN, 
COLOR_YELLOW, TYPE_SUV, ORIGIN_IMPORTED));
+      writer.append(new IntWritable(0), trainingInstance(LABEL_STOLEN, 
COLOR_YELLOW, TYPE_SUV, ORIGIN_IMPORTED));
+      writer.append(new IntWritable(0), trainingInstance(LABEL_NOT_STOLEN, 
COLOR_YELLOW, TYPE_SUV, ORIGIN_DOMESTIC));
+      writer.append(new IntWritable(0), trainingInstance(LABEL_NOT_STOLEN, 
COLOR_RED, TYPE_SUV, ORIGIN_IMPORTED));
+      writer.append(new IntWritable(0), trainingInstance(LABEL_STOLEN, 
COLOR_RED, TYPE_SPORTS, ORIGIN_IMPORTED));
     } finally {
       Closeables.closeQuietly(writer);
     }
@@ -89,7 +90,7 @@ public class NaiveBayesTest extends Maho
     TrainNaiveBayesJob trainNaiveBayes = new TrainNaiveBayesJob();
     trainNaiveBayes.setConf(conf);
     trainNaiveBayes.run(new String[] { "--input", inputFile.getAbsolutePath(), 
"--output", outputDir.getAbsolutePath(),
-        "--labels", "stolen,not_stolen", "--tempDir", 
tempDir.getAbsolutePath() });
+        "--labelSize", "2", "--tempDir", tempDir.getAbsolutePath() });
 
     NaiveBayesModel naiveBayesModel = NaiveBayesModel.materialize(new 
Path(outputDir.getAbsolutePath()), conf);
 
@@ -97,7 +98,7 @@ public class NaiveBayesTest extends Maho
 
     assertEquals(2, classifier.numCategories());
 
-    Vector prediction = classifier.classify(trainingInstance(COLOR_RED, 
TYPE_SUV, ORIGIN_DOMESTIC).get());
+    Vector prediction = classifier.classify(trainingInstance("", COLOR_RED, 
TYPE_SUV, ORIGIN_DOMESTIC).getVector());
 
     // should be classified as not stolen
     assertTrue(prediction.get(0) < prediction.get(1));
@@ -108,7 +109,7 @@ public class NaiveBayesTest extends Maho
     TrainNaiveBayesJob trainNaiveBayes = new TrainNaiveBayesJob();
     trainNaiveBayes.setConf(conf);
     trainNaiveBayes.run(new String[] { "--input", inputFile.getAbsolutePath(), 
"--output", outputDir.getAbsolutePath(),
-        "--labels", "stolen,not_stolen", "--trainComplementary",
+        "--labelSize", "2", "--trainComplementary",
         "--tempDir", tempDir.getAbsolutePath() });
 
     NaiveBayesModel naiveBayesModel = NaiveBayesModel.materialize(new 
Path(outputDir.getAbsolutePath()), conf);
@@ -117,18 +118,18 @@ public class NaiveBayesTest extends Maho
 
     assertEquals(2, classifier.numCategories());
 
-    Vector prediction = classifier.classify(trainingInstance(COLOR_RED, 
TYPE_SUV, ORIGIN_DOMESTIC).get());
+    Vector prediction = classifier.classify(trainingInstance("", COLOR_RED, 
TYPE_SUV, ORIGIN_DOMESTIC).getVector());
 
     // should be classified as not stolen
     assertTrue(prediction.get(0) < prediction.get(1));
   }
 
-  static VectorWritable trainingInstance(Vector.Element... elems) {
+  static MultiLabelVectorWritable trainingInstance(String label, 
Vector.Element... elems) {
     DenseVector trainingInstance = new DenseVector(6);
     for (Vector.Element elem : elems) {
       trainingInstance.set(elem.index(), elem.get());
     }
-    return new VectorWritable(trainingInstance);
+    return new MultiLabelVectorWritable(trainingInstance, new int[] 
{label.equals("stolen") ? 0 : 1});
   }
 
 

Modified: 
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java?rev=1338002&r1=1338001&r2=1338002&view=diff
==============================================================================
--- 
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java
 (original)
+++ 
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java
 Sun May 13 22:55:08 2012
@@ -18,22 +18,20 @@
 package org.apache.mahout.classifier.naivebayes.training;
 
 import org.apache.hadoop.io.IntWritable;
-import org.apache.hadoop.io.Text;
-import org.apache.hadoop.mapreduce.Counter;
 import org.apache.hadoop.mapreduce.Mapper;
 import org.apache.mahout.common.MahoutTestCase;
 import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.MultiLabelVectorWritable;
 import org.apache.mahout.math.VectorWritable;
-import org.apache.mahout.math.map.OpenObjectIntHashMap;
 import org.easymock.EasyMock;
 import org.junit.Before;
 import org.junit.Test;
 
+@SuppressWarnings("unchecked")
 public class IndexInstancesMapperTest extends MahoutTestCase {
-
+  private static final DenseVector VECTOR = new DenseVector(new double[] { 1, 
0, 1, 1, 0 });
   private Mapper.Context ctx;
-  private OpenObjectIntHashMap<String> labelIndex;
-  private VectorWritable instance;
+  private MultiLabelVectorWritable instance;
 
   @Override
   @Before
@@ -41,45 +39,16 @@ public class IndexInstancesMapperTest ex
     super.setUp();
 
     ctx = EasyMock.createMock(Mapper.Context.class);
-    instance = new VectorWritable(new DenseVector(new double[] { 1, 0, 1, 1, 0 
}));
-
-    labelIndex = new OpenObjectIntHashMap<String>();
-    labelIndex.put("bird", 0);
-    labelIndex.put("cat", 1);
+    instance = new MultiLabelVectorWritable(VECTOR,
+      new int[] {0});
   }
-
-
+  
   @Test
   public void index() throws Exception {
-
-    ctx.write(new IntWritable(0), instance);
-
+    ctx.write(new IntWritable(0), new VectorWritable(VECTOR));
     EasyMock.replay(ctx);
-
     IndexInstancesMapper indexInstances = new IndexInstancesMapper();
-    setField(indexInstances, "labelIndex", labelIndex);
-
-    indexInstances.map(new Text("bird"), instance, ctx);
-
+    indexInstances.map(new IntWritable(-1), instance, ctx);
     EasyMock.verify(ctx);
   }
-
-  @Test
-  public void skip() throws Exception {
-
-    Counter skippedInstances = EasyMock.createMock(Counter.class);
-
-    
EasyMock.expect(ctx.getCounter(IndexInstancesMapper.Counter.SKIPPED_INSTANCES)).andReturn(skippedInstances);
-    skippedInstances.increment(1);
-
-    EasyMock.replay(ctx, skippedInstances);
-
-    IndexInstancesMapper indexInstances = new IndexInstancesMapper();
-    setField(indexInstances, "labelIndex", labelIndex);
-
-    indexInstances.map(new Text("fish"), instance, ctx);
-
-    EasyMock.verify(ctx, skippedInstances);
-  }
-
 }

Modified: mahout/trunk/examples/bin/classify-20newsgroups.sh
URL: 
http://svn.apache.org/viewvc/mahout/trunk/examples/bin/classify-20newsgroups.sh?rev=1338002&r1=1338001&r2=1338002&view=diff
==============================================================================
--- mahout/trunk/examples/bin/classify-20newsgroups.sh (original)
+++ mahout/trunk/examples/bin/classify-20newsgroups.sh Sun May 13 22:55:08 2012
@@ -23,7 +23,7 @@
 #  examples/bin/build-20news.sh
 
 if [ "$1" = "--help" ] || [ "$1" = "--?" ]; then
-  echo "This script runs the SGD classifier over the classic 20 News Groups."
+  echo "This script runs SGD and Bayes classifiers over the classic 20 News 
Groups."
   exit
 fi
 
@@ -34,13 +34,14 @@ fi
 START_PATH=`pwd`
 
 WORK_DIR=/tmp/mahout-work-${USER}
-algorithm=( sgd clean)
+algorithm=( naivebayes sgd clean)
 if [ -n "$1" ]; then
   choice=$1
 else
   echo "Please select a number to choose the corresponding task to run"
   echo "1. ${algorithm[0]}"
-  echo "2. ${algorithm[2]} -- cleans up the work area in $WORK_DIR"
+  echo "2. ${algorithm[1]}"
+  echo "3. ${algorithm[2]} -- cleans up the work area in $WORK_DIR"
   read -p "Enter your choice : " choice
 fi
 
@@ -67,7 +68,15 @@ cd ../..
 
 set -e
 
-if [ "x$alg" == "xsgd" ]; then
+if [ "x$alg" == "xnaivebayes" ]; then
+  if [ ! -e "/tmp/news-group.model" ]; then
+    echo "Training on ${WORK_DIR}/20news-bydate/20news-bydate-train/"
+    ./bin/mahout org.apache.mahout.classifier.naivebayes.TrainNewsGroups 
${WORK_DIR}/20news-bydate/20news-bydate-train/ 0 \
+       --input /tmp/news-group-train/ --output 
${WORK_DIR}/news-group.naivebayes.model -ls 20 --tempDir ${WORK_DIR}/tmp/ -ow
+  fi
+  echo "Testing on ${WORK_DIR}/20news-bydate/20news-bydate-test/ with model: 
/tmp/news-group.model"
+  # ./bin/mahout org.apache.mahout.classifier.sgd.TestNewsGroups --input 
${WORK_DIR}/20news-bydate/20news-bydate-test/ --model /tmp/news-group.model
+elif [ "x$alg" == "xsgd" ]; then
   if [ ! -e "/tmp/news-group.model" ]; then
     echo "Training on ${WORK_DIR}/20news-bydate/20news-bydate-train/"
     ./bin/mahout org.apache.mahout.classifier.sgd.TrainNewsGroups 
${WORK_DIR}/20news-bydate/20news-bydate-train/

Added: 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/naivebayes/TestNewsGroups.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/naivebayes/TestNewsGroups.java?rev=1338002&view=auto
==============================================================================
--- 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/naivebayes/TestNewsGroups.java
 (added)
+++ 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/naivebayes/TestNewsGroups.java
 Sun May 13 22:55:08 2012
@@ -0,0 +1,136 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.PrintWriter;
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.cli2.util.HelpFormatter;
+import org.apache.mahout.classifier.ClassifierResult;
+import org.apache.mahout.classifier.NewsgroupHelper;
+import org.apache.mahout.classifier.ResultAnalyzer;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.vectorizer.encoders.Dictionary;
+
+import com.google.common.collect.HashMultiset;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Multiset;
+
+/**
+ * Run the 20 news groups test data through SGD, as trained by {@link 
org.apache.mahout.classifier.sgd.TrainNewsGroups}.
+ */
+public final class TestNewsGroups {
+
+  private String inputFile;
+  private String modelFile;
+
+  private TestNewsGroups() {
+  }
+
+  public static void main(String[] args) throws IOException {
+    TestNewsGroups runner = new TestNewsGroups();
+    if (runner.parseArgs(args)) {
+      runner.run(new PrintWriter(System.out, true));
+    }
+  }
+
+  public void run(PrintWriter output) throws IOException {
+
+    File base = new File(inputFile);
+ 
+    Dictionary newsGroups = new Dictionary();
+    Multiset<String> overallCounts = HashMultiset.create();
+
+    List<File> files = Lists.newArrayList();
+    for (File newsgroup : base.listFiles()) {
+      if (newsgroup.isDirectory()) {
+        newsGroups.intern(newsgroup.getName());
+        files.addAll(Arrays.asList(newsgroup.listFiles()));
+      }
+    }
+    System.out.printf("%d test files\n", files.size());
+    ResultAnalyzer ra = new ResultAnalyzer(newsGroups.values(), "DEFAULT");
+    for (File file : files) {
+      String ng = file.getParentFile().getName();
+
+      int actual = newsGroups.intern(ng);
+      NewsgroupHelper helper = new NewsgroupHelper();
+      Vector input = helper.encodeFeatureVector(file, actual, 0, 
overallCounts);//no leak type ensures this is a normal vector
+      Vector result = null;// classifier.classifyFull(input);
+      int cat = result.maxValueIndex();
+      double score = result.maxValue();
+      double ll = 0;// classifier.logLikelihood(actual, input);
+      ClassifierResult cr = new ClassifierResult(newsGroups.values().get(cat), 
score, ll);
+      ra.addInstance(newsGroups.values().get(actual), cr);
+
+    }
+    output.printf("%s\n\n", ra.toString());
+  }
+
+  boolean parseArgs(String[] args) {
+    DefaultOptionBuilder builder = new DefaultOptionBuilder();
+
+    Option help = builder.withLongName("help").withDescription("print this 
list").create();
+
+    ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+    Option inputFileOption = builder.withLongName("input")
+            .withRequired(true)
+            
.withArgument(argumentBuilder.withName("input").withMaximum(1).create())
+            .withDescription("where to get training data")
+            .create();
+
+    Option modelFileOption = builder.withLongName("model")
+            .withRequired(true)
+            
.withArgument(argumentBuilder.withName("model").withMaximum(1).create())
+            .withDescription("where to get a model")
+            .create();
+
+    Group normalArgs = new GroupBuilder()
+            .withOption(help)
+            .withOption(inputFileOption)
+            .withOption(modelFileOption)
+            .create();
+
+    Parser parser = new Parser();
+    parser.setHelpOption(help);
+    parser.setHelpTrigger("--help");
+    parser.setGroup(normalArgs);
+    parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
+    CommandLine cmdLine = parser.parseAndHelp(args);
+
+    if (cmdLine == null) {
+      return false;
+    }
+
+    inputFile = (String) cmdLine.getValue(inputFileOption);
+    modelFile = (String) cmdLine.getValue(modelFileOption);
+    return true;
+  }
+
+}

Added: 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/naivebayes/TrainNewsGroups.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/naivebayes/TrainNewsGroups.java?rev=1338002&view=auto
==============================================================================
--- 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/naivebayes/TrainNewsGroups.java
 (added)
+++ 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/naivebayes/TrainNewsGroups.java
 Sun May 13 22:55:08 2012
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+import java.io.File;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.classifier.NewsgroupHelper;
+import org.apache.mahout.classifier.naivebayes.training.TrainNaiveBayesJob;
+import org.apache.mahout.math.MultiLabelVectorWritable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.vectorizer.encoders.Dictionary;
+
+import com.google.common.collect.HashMultiset;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Multiset;
+
+/**
+ * Reads and trains an naive bayes model on the 20 newsgroups data.
+ * The first command line argument gives the path of the directory holding the 
training
+ * data.  The optional second argument, leakType, defines which classes of 
features to use.
+ * Importantly, leakType controls whether a synthetic date is injected into 
the data as
+ * a target leak and if so, how.
+ * <p/>
+ * The value of leakType % 3 determines whether the target leak is injected 
according to
+ * the following table:
+ * <p/>
+ * <table>
+ * <tr><td valign='top'>0</td><td>No leak injected</td></tr>
+ * <tr><td valign='top'>1</td><td>Synthetic date injected in MMM-yyyy format. 
This will be a single token and
+ * is a perfect target leak since each newsgroup is given a different 
month</td></tr>
+ * <tr><td valign='top'>2</td><td>Synthetic date injected in dd-MMM-yyyy 
HH:mm:ss format.  The day varies
+ * and thus there are more leak symbols that need to be learned.  Ultimately 
this is just
+ * as big a leak as case 1.</td></tr>
+ * </table>
+ * <p/>
+ * Leaktype also determines what other text will be indexed.  If leakType is 
greater
+ * than or equal to 6, then neither headers nor text body will be used for 
features and the leak is the only
+ * source of data.  If leakType is greater than or equal to 3, then subject 
words will be used as features.
+ * If leakType is less than 3, then both subject and body text will be used as 
features.
+ * <p/>
+ * A leakType of 0 gives no leak and all textual features.
+ * <p/>
+ * See the following table for a summary of commonly used values for leakType
+ * <p/>
+ * <table>
+ * 
<tr><td><b>leakType</b></td><td><b>Leak?</b></td><td><b>Subject?</b></td><td><b>Body?</b></td></tr>
+ * <tr><td colspan=4><hr></td></tr>
+ * <tr><td>0</td><td>no</td><td>yes</td><td>yes</td></tr>
+ * <tr><td>1</td><td>mmm-yyyy</td><td>yes</td><td>yes</td></tr>
+ * <tr><td>2</td><td>dd-mmm-yyyy</td><td>yes</td><td>yes</td></tr>
+ * <tr><td colspan=4><hr></td></tr>
+ * <tr><td>3</td><td>no</td><td>yes</td><td>no</td></tr>
+ * <tr><td>4</td><td>mmm-yyyy</td><td>yes</td><td>no</td></tr>
+ * <tr><td>5</td><td>dd-mmm-yyyy</td><td>yes</td><td>no</td></tr>
+ * <tr><td colspan=4><hr></td></tr>
+ * <tr><td>6</td><td>no</td><td>no</td><td>no</td></tr>
+ * <tr><td>7</td><td>mmm-yyyy</td><td>no</td><td>no</td></tr>
+ * <tr><td>8</td><td>dd-mmm-yyyy</td><td>no</td><td>no</td></tr>
+ * <tr><td colspan=4><hr></td></tr>
+ * </table>
+ */
+public final class TrainNewsGroups {
+
+  private TrainNewsGroups() {}
+
+  public static void main(String[] args) throws Exception {
+    File base = new File(args[0]);
+
+    Multiset<String> overallCounts = HashMultiset.create();
+
+    int leakType = 0;
+    if (args.length > 1) {
+      leakType = Integer.parseInt(args[1]);
+    }
+
+    Dictionary newsGroups = new Dictionary();
+
+    NewsgroupHelper helper = new NewsgroupHelper();
+    helper.getEncoder().setProbes(2);
+
+    List<File> files = Lists.newArrayList();
+    for (File newsgroup : base.listFiles()) {
+      if (newsgroup.isDirectory()) {
+        newsGroups.intern(newsgroup.getName());
+        files.addAll(Arrays.asList(newsgroup.listFiles()));
+      }
+    }
+    Collections.sort(files); // required to get same labels for classes
+    System.out.printf("%d training files\n", files.size());
+
+    Configuration conf = new Configuration(true);
+    FileSystem fs = new Path("/tmp").getFileSystem(conf);
+    SequenceFile.Writer writer =
+        SequenceFile.createWriter(fs, conf, new 
Path("/tmp/news-group-train/data"),
+            IntWritable.class, MultiLabelVectorWritable.class);
+    try {
+      for (File file : files) {
+        String ng = file.getParentFile().getName();
+        int actual = newsGroups.intern(ng);
+        Vector v = helper.encodeFeatureVector(file, actual, leakType, 
overallCounts);
+        MultiLabelVectorWritable vw = new MultiLabelVectorWritable(v, new 
int[] {actual});
+        writer.append(new IntWritable(0), vw);
+      }
+    } finally {
+      writer.close();
+    }
+
+    ToolRunner.run(new Configuration(), new TrainNaiveBayesJob(),
+        Arrays.copyOfRange(args, 2, args.length));
+  }
+}


Reply via email to