Author: robinanil
Date: Sun May 13 22:55:08 2012
New Revision: 1338002
URL: http://svn.apache.org/viewvc?rev=1338002&view=rev
Log:
MAHOUT-1014 TrainNewsGroups for naive bayes which encodes vector exactly like
sgd
Added:
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/naivebayes/
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/naivebayes/TestNewsGroups.java
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/naivebayes/TrainNewsGroups.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java
mahout/trunk/examples/bin/classify-20newsgroups.sh
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java?rev=1338002&r1=1338001&r2=1338002&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
Sun May 13 22:55:08 2012
@@ -36,5 +36,4 @@ public class ComplementaryNaiveBayesClas
return Math.log(numerator / denominator);
}
-
}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java?rev=1338002&r1=1338001&r2=1338002&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
Sun May 13 22:55:08 2012
@@ -34,5 +34,4 @@ public class StandardNaiveBayesClassifie
return -Math.log(numerator / denominator);
}
-
}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java?rev=1338002&r1=1338001&r2=1338002&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java
Sun May 13 22:55:08 2012
@@ -1,18 +1,16 @@
/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Licensed to the Apache Software Foundation (ASF) under one or more
contributor license
+ * agreements. See the NOTICE file distributed with this work for additional
information regarding
+ * copyright ownership. The ASF licenses this file to You under the Apache
License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance with the
License. You may obtain a
+ * copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
distributed under the License
+ * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express
+ * or implied. See the License for the specific language governing permissions
and limitations under
+ * the License.
*/
package org.apache.mahout.classifier.naivebayes.training;
@@ -20,30 +18,16 @@ package org.apache.mahout.classifier.nai
import java.io.IOException;
import org.apache.hadoop.io.IntWritable;
-import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Mapper;
-import org.apache.mahout.classifier.naivebayes.BayesUtils;
+import org.apache.mahout.math.MultiLabelVectorWritable;
import org.apache.mahout.math.VectorWritable;
-import org.apache.mahout.math.map.OpenObjectIntHashMap;
-
-public class IndexInstancesMapper extends Mapper<Text, VectorWritable,
IntWritable, VectorWritable> {
-
- public enum Counter { SKIPPED_INSTANCES }
-
- private OpenObjectIntHashMap<String> labelIndex;
-
- @Override
- protected void setup(Context ctx) throws IOException, InterruptedException {
- labelIndex = BayesUtils.readIndexFromCache(ctx.getConfiguration());
- }
+public class IndexInstancesMapper
+ extends Mapper<IntWritable, MultiLabelVectorWritable, IntWritable,
VectorWritable> {
@Override
- protected void map(Text labelText, VectorWritable instance, Context ctx)
throws IOException, InterruptedException {
- String label = labelText.toString();
- if (labelIndex.containsKey(label)) {
- ctx.write(new IntWritable(labelIndex.get(label)), instance);
- } else {
- ctx.getCounter(Counter.SKIPPED_INSTANCES).increment(1);
- }
+ protected void map(IntWritable key, MultiLabelVectorWritable instance,
Context ctx)
+ throws IOException, InterruptedException {
+ VectorWritable vw = new VectorWritable(instance.getVector());
+ ctx.write(new IntWritable(instance.getLabels()[0]), vw);
}
}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java?rev=1338002&r1=1338001&r2=1338002&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java
Sun May 13 22:55:08 2012
@@ -17,9 +17,10 @@
package org.apache.mahout.classifier.naivebayes.training;
-import com.google.common.base.Splitter;
+import java.util.List;
+import java.util.Map;
+
import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
@@ -31,16 +32,9 @@ import org.apache.mahout.classifier.naiv
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
-import org.apache.mahout.common.iterator.sequencefile.PathFilters;
-import org.apache.mahout.common.iterator.sequencefile.PathType;
-import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
import org.apache.mahout.common.mapreduce.VectorSumReducer;
import org.apache.mahout.math.VectorWritable;
-import java.io.IOException;
-import java.util.List;
-import java.util.Map;
-
/**
* This class trains a Naive Bayes Classifier (Parameters for both Naive Bayes
and Complementary Naive Bayes)
*/
@@ -60,15 +54,11 @@ public final class TrainNaiveBayesJob ex
@Override
public int run(String[] args) throws Exception {
-
addInputOption();
addOutputOption();
- addOption("labels", "l", "comma-separated list of labels to include in
training", false);
-
- addOption(buildOption("extractLabels", "el", "Extract the labels from the
input", false, false, ""));
+ addOption("labelSize", "ls", "Number of labels in the input data",
String.valueOf(2));
addOption("alphaI", "a", "smoothing parameter", String.valueOf(1.0f));
addOption(buildOption("trainComplementary", "c", "train complementary?",
false, false, String.valueOf(false)));
- addOption("labelIndex", "li", "The path to store the label index in",
false);
addOption(DefaultOptionCreator.overwriteOption().create());
Map<String, List<String>> parsedArgs = parseArguments(args);
if (parsedArgs == null) {
@@ -78,21 +68,12 @@ public final class TrainNaiveBayesJob ex
HadoopUtil.delete(getConf(), getOutputPath());
HadoopUtil.delete(getConf(), getTempPath());
}
- Path labPath;
- String labPathStr = getOption("labelIndex");
- if (labPathStr != null) {
- labPath = new Path(labPathStr);
- } else {
- labPath = getTempPath("labelIndex");
- }
- long labelSize = createLabelIndex(labPath);
+ int labelSize = Integer.parseInt(getOption("labelSize"));
float alphaI = Float.parseFloat(getOption("alphaI"));
boolean trainComplementary =
Boolean.parseBoolean(getOption("trainComplementary"));
-
HadoopUtil.setSerializations(getConf());
- HadoopUtil.cacheFiles(labPath, getConf());
-
+
//add up all the vectors with the same labels, while mapping the labels
into our index
Job indexInstances = prepareJob(getInputPath(),
getTempPath(SUMMED_OBSERVATIONS), SequenceFileInputFormat.class,
IndexInstancesMapper.class, IntWritable.class,
VectorWritable.class, VectorSumReducer.class, IntWritable.class,
@@ -132,18 +113,4 @@ public final class TrainNaiveBayesJob ex
return 0;
}
-
- private long createLabelIndex(Path labPath) throws IOException {
- long labelSize = 0;
- if (hasOption("labels")) {
- Iterable<String> labels = Splitter.on(",").split(getOption("labels"));
- labelSize = BayesUtils.writeLabelIndex(getConf(), labels, labPath);
- } else if (hasOption("extractLabels")) {
- SequenceFileDirIterable<Text, IntWritable> iterable =
- new SequenceFileDirIterable<Text, IntWritable>(getInputPath(),
PathType.LIST, PathFilters.logsCRCFilter(), getConf());
- labelSize = BayesUtils.writeLabelIndex(getConf(), labPath, iterable);
- }
- return labelSize;
- }
-
}
Modified:
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java?rev=1338002&r1=1338001&r2=1338002&view=diff
==============================================================================
---
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java
(original)
+++
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java
Sun May 13 22:55:08 2012
@@ -17,23 +17,24 @@
package org.apache.mahout.classifier.naivebayes;
-import com.google.common.io.Closeables;
+import java.io.File;
+
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile;
-import org.apache.hadoop.io.Text;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.classifier.naivebayes.training.TrainNaiveBayesJob;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.MultiLabelVectorWritable;
import org.apache.mahout.math.Vector;
-import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.hadoop.MathHelper;
import org.junit.Before;
import org.junit.Test;
-import java.io.File;
+import com.google.common.io.Closeables;
public class NaiveBayesTest extends MahoutTestCase {
@@ -42,8 +43,8 @@ public class NaiveBayesTest extends Maho
private File outputDir;
private File tempDir;
- static final Text LABEL_STOLEN = new Text("stolen");
- static final Text LABEL_NOT_STOLEN = new Text("not_stolen");
+ static final String LABEL_STOLEN = "stolen";
+ static final String LABEL_NOT_STOLEN = "not_stolen";
static final Vector.Element COLOR_RED = MathHelper.elem(0, 1);
static final Vector.Element COLOR_YELLOW = MathHelper.elem(1, 1);
@@ -66,19 +67,19 @@ public class NaiveBayesTest extends Maho
tempDir = getTestTempDir("tmp");
SequenceFile.Writer writer = new SequenceFile.Writer(FileSystem.get(conf),
conf,
- new Path(inputFile.getAbsolutePath()), Text.class,
VectorWritable.class);
+ new Path(inputFile.getAbsolutePath()), IntWritable.class,
MultiLabelVectorWritable.class);
try {
- writer.append(LABEL_STOLEN, trainingInstance(COLOR_RED,
TYPE_SPORTS, ORIGIN_DOMESTIC));
- writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_RED, TYPE_SPORTS,
ORIGIN_DOMESTIC));
- writer.append(LABEL_STOLEN, trainingInstance(COLOR_RED,
TYPE_SPORTS, ORIGIN_DOMESTIC));
- writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_YELLOW,
TYPE_SPORTS, ORIGIN_DOMESTIC));
- writer.append(LABEL_STOLEN, trainingInstance(COLOR_YELLOW,
TYPE_SPORTS, ORIGIN_IMPORTED));
- writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SUV,
ORIGIN_IMPORTED));
- writer.append(LABEL_STOLEN, trainingInstance(COLOR_YELLOW,
TYPE_SUV, ORIGIN_IMPORTED));
- writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SUV,
ORIGIN_DOMESTIC));
- writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_RED, TYPE_SUV,
ORIGIN_IMPORTED));
- writer.append(LABEL_STOLEN, trainingInstance(COLOR_RED,
TYPE_SPORTS, ORIGIN_IMPORTED));
+ writer.append(new IntWritable(0), trainingInstance(LABEL_STOLEN,
COLOR_RED, TYPE_SPORTS, ORIGIN_DOMESTIC));
+ writer.append(new IntWritable(0), trainingInstance(LABEL_NOT_STOLEN,
COLOR_RED, TYPE_SPORTS, ORIGIN_DOMESTIC));
+ writer.append(new IntWritable(0), trainingInstance(LABEL_STOLEN,
COLOR_RED, TYPE_SPORTS, ORIGIN_DOMESTIC));
+ writer.append(new IntWritable(0), trainingInstance(LABEL_NOT_STOLEN,
COLOR_YELLOW, TYPE_SPORTS, ORIGIN_DOMESTIC));
+ writer.append(new IntWritable(0), trainingInstance(LABEL_STOLEN,
COLOR_YELLOW, TYPE_SPORTS, ORIGIN_IMPORTED));
+ writer.append(new IntWritable(0), trainingInstance(LABEL_NOT_STOLEN,
COLOR_YELLOW, TYPE_SUV, ORIGIN_IMPORTED));
+ writer.append(new IntWritable(0), trainingInstance(LABEL_STOLEN,
COLOR_YELLOW, TYPE_SUV, ORIGIN_IMPORTED));
+ writer.append(new IntWritable(0), trainingInstance(LABEL_NOT_STOLEN,
COLOR_YELLOW, TYPE_SUV, ORIGIN_DOMESTIC));
+ writer.append(new IntWritable(0), trainingInstance(LABEL_NOT_STOLEN,
COLOR_RED, TYPE_SUV, ORIGIN_IMPORTED));
+ writer.append(new IntWritable(0), trainingInstance(LABEL_STOLEN,
COLOR_RED, TYPE_SPORTS, ORIGIN_IMPORTED));
} finally {
Closeables.closeQuietly(writer);
}
@@ -89,7 +90,7 @@ public class NaiveBayesTest extends Maho
TrainNaiveBayesJob trainNaiveBayes = new TrainNaiveBayesJob();
trainNaiveBayes.setConf(conf);
trainNaiveBayes.run(new String[] { "--input", inputFile.getAbsolutePath(),
"--output", outputDir.getAbsolutePath(),
- "--labels", "stolen,not_stolen", "--tempDir",
tempDir.getAbsolutePath() });
+ "--labelSize", "2", "--tempDir", tempDir.getAbsolutePath() });
NaiveBayesModel naiveBayesModel = NaiveBayesModel.materialize(new
Path(outputDir.getAbsolutePath()), conf);
@@ -97,7 +98,7 @@ public class NaiveBayesTest extends Maho
assertEquals(2, classifier.numCategories());
- Vector prediction = classifier.classify(trainingInstance(COLOR_RED,
TYPE_SUV, ORIGIN_DOMESTIC).get());
+ Vector prediction = classifier.classify(trainingInstance("", COLOR_RED,
TYPE_SUV, ORIGIN_DOMESTIC).getVector());
// should be classified as not stolen
assertTrue(prediction.get(0) < prediction.get(1));
@@ -108,7 +109,7 @@ public class NaiveBayesTest extends Maho
TrainNaiveBayesJob trainNaiveBayes = new TrainNaiveBayesJob();
trainNaiveBayes.setConf(conf);
trainNaiveBayes.run(new String[] { "--input", inputFile.getAbsolutePath(),
"--output", outputDir.getAbsolutePath(),
- "--labels", "stolen,not_stolen", "--trainComplementary",
+ "--labelSize", "2", "--trainComplementary",
"--tempDir", tempDir.getAbsolutePath() });
NaiveBayesModel naiveBayesModel = NaiveBayesModel.materialize(new
Path(outputDir.getAbsolutePath()), conf);
@@ -117,18 +118,18 @@ public class NaiveBayesTest extends Maho
assertEquals(2, classifier.numCategories());
- Vector prediction = classifier.classify(trainingInstance(COLOR_RED,
TYPE_SUV, ORIGIN_DOMESTIC).get());
+ Vector prediction = classifier.classify(trainingInstance("", COLOR_RED,
TYPE_SUV, ORIGIN_DOMESTIC).getVector());
// should be classified as not stolen
assertTrue(prediction.get(0) < prediction.get(1));
}
- static VectorWritable trainingInstance(Vector.Element... elems) {
+ static MultiLabelVectorWritable trainingInstance(String label,
Vector.Element... elems) {
DenseVector trainingInstance = new DenseVector(6);
for (Vector.Element elem : elems) {
trainingInstance.set(elem.index(), elem.get());
}
- return new VectorWritable(trainingInstance);
+ return new MultiLabelVectorWritable(trainingInstance, new int[]
{label.equals("stolen") ? 0 : 1});
}
Modified:
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java?rev=1338002&r1=1338001&r2=1338002&view=diff
==============================================================================
---
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java
(original)
+++
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java
Sun May 13 22:55:08 2012
@@ -18,22 +18,20 @@
package org.apache.mahout.classifier.naivebayes.training;
import org.apache.hadoop.io.IntWritable;
-import org.apache.hadoop.io.Text;
-import org.apache.hadoop.mapreduce.Counter;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.MultiLabelVectorWritable;
import org.apache.mahout.math.VectorWritable;
-import org.apache.mahout.math.map.OpenObjectIntHashMap;
import org.easymock.EasyMock;
import org.junit.Before;
import org.junit.Test;
+@SuppressWarnings("unchecked")
public class IndexInstancesMapperTest extends MahoutTestCase {
-
+ private static final DenseVector VECTOR = new DenseVector(new double[] { 1,
0, 1, 1, 0 });
private Mapper.Context ctx;
- private OpenObjectIntHashMap<String> labelIndex;
- private VectorWritable instance;
+ private MultiLabelVectorWritable instance;
@Override
@Before
@@ -41,45 +39,16 @@ public class IndexInstancesMapperTest ex
super.setUp();
ctx = EasyMock.createMock(Mapper.Context.class);
- instance = new VectorWritable(new DenseVector(new double[] { 1, 0, 1, 1, 0
}));
-
- labelIndex = new OpenObjectIntHashMap<String>();
- labelIndex.put("bird", 0);
- labelIndex.put("cat", 1);
+ instance = new MultiLabelVectorWritable(VECTOR,
+ new int[] {0});
}
-
-
+
@Test
public void index() throws Exception {
-
- ctx.write(new IntWritable(0), instance);
-
+ ctx.write(new IntWritable(0), new VectorWritable(VECTOR));
EasyMock.replay(ctx);
-
IndexInstancesMapper indexInstances = new IndexInstancesMapper();
- setField(indexInstances, "labelIndex", labelIndex);
-
- indexInstances.map(new Text("bird"), instance, ctx);
-
+ indexInstances.map(new IntWritable(-1), instance, ctx);
EasyMock.verify(ctx);
}
-
- @Test
- public void skip() throws Exception {
-
- Counter skippedInstances = EasyMock.createMock(Counter.class);
-
-
EasyMock.expect(ctx.getCounter(IndexInstancesMapper.Counter.SKIPPED_INSTANCES)).andReturn(skippedInstances);
- skippedInstances.increment(1);
-
- EasyMock.replay(ctx, skippedInstances);
-
- IndexInstancesMapper indexInstances = new IndexInstancesMapper();
- setField(indexInstances, "labelIndex", labelIndex);
-
- indexInstances.map(new Text("fish"), instance, ctx);
-
- EasyMock.verify(ctx, skippedInstances);
- }
-
}
Modified: mahout/trunk/examples/bin/classify-20newsgroups.sh
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/bin/classify-20newsgroups.sh?rev=1338002&r1=1338001&r2=1338002&view=diff
==============================================================================
--- mahout/trunk/examples/bin/classify-20newsgroups.sh (original)
+++ mahout/trunk/examples/bin/classify-20newsgroups.sh Sun May 13 22:55:08 2012
@@ -23,7 +23,7 @@
# examples/bin/build-20news.sh
if [ "$1" = "--help" ] || [ "$1" = "--?" ]; then
- echo "This script runs the SGD classifier over the classic 20 News Groups."
+ echo "This script runs SGD and Bayes classifiers over the classic 20 News
Groups."
exit
fi
@@ -34,13 +34,14 @@ fi
START_PATH=`pwd`
WORK_DIR=/tmp/mahout-work-${USER}
-algorithm=( sgd clean)
+algorithm=( naivebayes sgd clean)
if [ -n "$1" ]; then
choice=$1
else
echo "Please select a number to choose the corresponding task to run"
echo "1. ${algorithm[0]}"
- echo "2. ${algorithm[2]} -- cleans up the work area in $WORK_DIR"
+ echo "2. ${algorithm[1]}"
+ echo "3. ${algorithm[2]} -- cleans up the work area in $WORK_DIR"
read -p "Enter your choice : " choice
fi
@@ -67,7 +68,15 @@ cd ../..
set -e
-if [ "x$alg" == "xsgd" ]; then
+if [ "x$alg" == "xnaivebayes" ]; then
+ if [ ! -e "/tmp/news-group.model" ]; then
+ echo "Training on ${WORK_DIR}/20news-bydate/20news-bydate-train/"
+ ./bin/mahout org.apache.mahout.classifier.naivebayes.TrainNewsGroups
${WORK_DIR}/20news-bydate/20news-bydate-train/ 0 \
+ --input /tmp/news-group-train/ --output
${WORK_DIR}/news-group.naivebayes.model -ls 20 --tempDir ${WORK_DIR}/tmp/ -ow
+ fi
+ echo "Testing on ${WORK_DIR}/20news-bydate/20news-bydate-test/ with model:
/tmp/news-group.model"
+ # ./bin/mahout org.apache.mahout.classifier.sgd.TestNewsGroups --input
${WORK_DIR}/20news-bydate/20news-bydate-test/ --model /tmp/news-group.model
+elif [ "x$alg" == "xsgd" ]; then
if [ ! -e "/tmp/news-group.model" ]; then
echo "Training on ${WORK_DIR}/20news-bydate/20news-bydate-train/"
./bin/mahout org.apache.mahout.classifier.sgd.TrainNewsGroups
${WORK_DIR}/20news-bydate/20news-bydate-train/
Added:
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/naivebayes/TestNewsGroups.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/naivebayes/TestNewsGroups.java?rev=1338002&view=auto
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/naivebayes/TestNewsGroups.java
(added)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/naivebayes/TestNewsGroups.java
Sun May 13 22:55:08 2012
@@ -0,0 +1,136 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.PrintWriter;
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.cli2.util.HelpFormatter;
+import org.apache.mahout.classifier.ClassifierResult;
+import org.apache.mahout.classifier.NewsgroupHelper;
+import org.apache.mahout.classifier.ResultAnalyzer;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.vectorizer.encoders.Dictionary;
+
+import com.google.common.collect.HashMultiset;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Multiset;
+
+/**
+ * Run the 20 news groups test data through SGD, as trained by {@link
org.apache.mahout.classifier.sgd.TrainNewsGroups}.
+ */
+public final class TestNewsGroups {
+
+ private String inputFile;
+ private String modelFile;
+
+ private TestNewsGroups() {
+ }
+
+ public static void main(String[] args) throws IOException {
+ TestNewsGroups runner = new TestNewsGroups();
+ if (runner.parseArgs(args)) {
+ runner.run(new PrintWriter(System.out, true));
+ }
+ }
+
+ public void run(PrintWriter output) throws IOException {
+
+ File base = new File(inputFile);
+
+ Dictionary newsGroups = new Dictionary();
+ Multiset<String> overallCounts = HashMultiset.create();
+
+ List<File> files = Lists.newArrayList();
+ for (File newsgroup : base.listFiles()) {
+ if (newsgroup.isDirectory()) {
+ newsGroups.intern(newsgroup.getName());
+ files.addAll(Arrays.asList(newsgroup.listFiles()));
+ }
+ }
+ System.out.printf("%d test files\n", files.size());
+ ResultAnalyzer ra = new ResultAnalyzer(newsGroups.values(), "DEFAULT");
+ for (File file : files) {
+ String ng = file.getParentFile().getName();
+
+ int actual = newsGroups.intern(ng);
+ NewsgroupHelper helper = new NewsgroupHelper();
+ Vector input = helper.encodeFeatureVector(file, actual, 0,
overallCounts);//no leak type ensures this is a normal vector
+ Vector result = null;// classifier.classifyFull(input);
+ int cat = result.maxValueIndex();
+ double score = result.maxValue();
+ double ll = 0;// classifier.logLikelihood(actual, input);
+ ClassifierResult cr = new ClassifierResult(newsGroups.values().get(cat),
score, ll);
+ ra.addInstance(newsGroups.values().get(actual), cr);
+
+ }
+ output.printf("%s\n\n", ra.toString());
+ }
+
+ boolean parseArgs(String[] args) {
+ DefaultOptionBuilder builder = new DefaultOptionBuilder();
+
+ Option help = builder.withLongName("help").withDescription("print this
list").create();
+
+ ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+ Option inputFileOption = builder.withLongName("input")
+ .withRequired(true)
+
.withArgument(argumentBuilder.withName("input").withMaximum(1).create())
+ .withDescription("where to get training data")
+ .create();
+
+ Option modelFileOption = builder.withLongName("model")
+ .withRequired(true)
+
.withArgument(argumentBuilder.withName("model").withMaximum(1).create())
+ .withDescription("where to get a model")
+ .create();
+
+ Group normalArgs = new GroupBuilder()
+ .withOption(help)
+ .withOption(inputFileOption)
+ .withOption(modelFileOption)
+ .create();
+
+ Parser parser = new Parser();
+ parser.setHelpOption(help);
+ parser.setHelpTrigger("--help");
+ parser.setGroup(normalArgs);
+ parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
+ CommandLine cmdLine = parser.parseAndHelp(args);
+
+ if (cmdLine == null) {
+ return false;
+ }
+
+ inputFile = (String) cmdLine.getValue(inputFileOption);
+ modelFile = (String) cmdLine.getValue(modelFileOption);
+ return true;
+ }
+
+}
Added:
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/naivebayes/TrainNewsGroups.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/naivebayes/TrainNewsGroups.java?rev=1338002&view=auto
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/naivebayes/TrainNewsGroups.java
(added)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/naivebayes/TrainNewsGroups.java
Sun May 13 22:55:08 2012
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+import java.io.File;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.classifier.NewsgroupHelper;
+import org.apache.mahout.classifier.naivebayes.training.TrainNaiveBayesJob;
+import org.apache.mahout.math.MultiLabelVectorWritable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.vectorizer.encoders.Dictionary;
+
+import com.google.common.collect.HashMultiset;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Multiset;
+
+/**
+ * Reads and trains an naive bayes model on the 20 newsgroups data.
+ * The first command line argument gives the path of the directory holding the
training
+ * data. The optional second argument, leakType, defines which classes of
features to use.
+ * Importantly, leakType controls whether a synthetic date is injected into
the data as
+ * a target leak and if so, how.
+ * <p/>
+ * The value of leakType % 3 determines whether the target leak is injected
according to
+ * the following table:
+ * <p/>
+ * <table>
+ * <tr><td valign='top'>0</td><td>No leak injected</td></tr>
+ * <tr><td valign='top'>1</td><td>Synthetic date injected in MMM-yyyy format.
This will be a single token and
+ * is a perfect target leak since each newsgroup is given a different
month</td></tr>
+ * <tr><td valign='top'>2</td><td>Synthetic date injected in dd-MMM-yyyy
HH:mm:ss format. The day varies
+ * and thus there are more leak symbols that need to be learned. Ultimately
this is just
+ * as big a leak as case 1.</td></tr>
+ * </table>
+ * <p/>
+ * Leaktype also determines what other text will be indexed. If leakType is
greater
+ * than or equal to 6, then neither headers nor text body will be used for
features and the leak is the only
+ * source of data. If leakType is greater than or equal to 3, then subject
words will be used as features.
+ * If leakType is less than 3, then both subject and body text will be used as
features.
+ * <p/>
+ * A leakType of 0 gives no leak and all textual features.
+ * <p/>
+ * See the following table for a summary of commonly used values for leakType
+ * <p/>
+ * <table>
+ *
<tr><td><b>leakType</b></td><td><b>Leak?</b></td><td><b>Subject?</b></td><td><b>Body?</b></td></tr>
+ * <tr><td colspan=4><hr></td></tr>
+ * <tr><td>0</td><td>no</td><td>yes</td><td>yes</td></tr>
+ * <tr><td>1</td><td>mmm-yyyy</td><td>yes</td><td>yes</td></tr>
+ * <tr><td>2</td><td>dd-mmm-yyyy</td><td>yes</td><td>yes</td></tr>
+ * <tr><td colspan=4><hr></td></tr>
+ * <tr><td>3</td><td>no</td><td>yes</td><td>no</td></tr>
+ * <tr><td>4</td><td>mmm-yyyy</td><td>yes</td><td>no</td></tr>
+ * <tr><td>5</td><td>dd-mmm-yyyy</td><td>yes</td><td>no</td></tr>
+ * <tr><td colspan=4><hr></td></tr>
+ * <tr><td>6</td><td>no</td><td>no</td><td>no</td></tr>
+ * <tr><td>7</td><td>mmm-yyyy</td><td>no</td><td>no</td></tr>
+ * <tr><td>8</td><td>dd-mmm-yyyy</td><td>no</td><td>no</td></tr>
+ * <tr><td colspan=4><hr></td></tr>
+ * </table>
+ */
+public final class TrainNewsGroups {
+
+ private TrainNewsGroups() {}
+
+ public static void main(String[] args) throws Exception {
+ File base = new File(args[0]);
+
+ Multiset<String> overallCounts = HashMultiset.create();
+
+ int leakType = 0;
+ if (args.length > 1) {
+ leakType = Integer.parseInt(args[1]);
+ }
+
+ Dictionary newsGroups = new Dictionary();
+
+ NewsgroupHelper helper = new NewsgroupHelper();
+ helper.getEncoder().setProbes(2);
+
+ List<File> files = Lists.newArrayList();
+ for (File newsgroup : base.listFiles()) {
+ if (newsgroup.isDirectory()) {
+ newsGroups.intern(newsgroup.getName());
+ files.addAll(Arrays.asList(newsgroup.listFiles()));
+ }
+ }
+ Collections.sort(files); // required to get same labels for classes
+ System.out.printf("%d training files\n", files.size());
+
+ Configuration conf = new Configuration(true);
+ FileSystem fs = new Path("/tmp").getFileSystem(conf);
+ SequenceFile.Writer writer =
+ SequenceFile.createWriter(fs, conf, new
Path("/tmp/news-group-train/data"),
+ IntWritable.class, MultiLabelVectorWritable.class);
+ try {
+ for (File file : files) {
+ String ng = file.getParentFile().getName();
+ int actual = newsGroups.intern(ng);
+ Vector v = helper.encodeFeatureVector(file, actual, leakType,
overallCounts);
+ MultiLabelVectorWritable vw = new MultiLabelVectorWritable(v, new
int[] {actual});
+ writer.append(new IntWritable(0), vw);
+ }
+ } finally {
+ writer.close();
+ }
+
+ ToolRunner.run(new Configuration(), new TrainNaiveBayesJob(),
+ Arrays.copyOfRange(args, 2, args.length));
+ }
+}