Author: robinanil
Date: Sat Apr 20 17:14:16 2013
New Revision: 1470196
URL: http://svn.apache.org/r1470196
Log:
Minor cleanups in naivebayes, Adds Norm2 benchmark
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java
mahout/trunk/integration/src/main/java/org/apache/mahout/benchmark/DotBenchmark.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java?rev=1470196&r1=1470195&r2=1470196&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java
Sat Apr 20 17:14:16 2013
@@ -17,10 +17,12 @@
package org.apache.mahout.classifier.naivebayes;
-import com.google.common.base.Preconditions;
-import com.google.common.collect.Maps;
-import com.google.common.collect.Sets;
-import com.google.common.io.Closeables;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.regex.Pattern;
+
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
@@ -42,11 +44,10 @@ import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.map.OpenObjectIntHashMap;
-import java.io.IOException;
-import java.util.Collection;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.regex.Pattern;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
+import com.google.common.io.Closeables;
public final class BayesUtils {
@@ -87,7 +88,7 @@ public final class BayesUtils {
if
(entry.getFirst().toString().equals(TrainNaiveBayesJob.LABEL_THETA_NORMALIZER))
{
perlabelThetaNormalizer = entry.getSecond().get();
}
- }
+ }
Preconditions.checkNotNull(perlabelThetaNormalizer);
*/
@@ -126,7 +127,7 @@ public final class BayesUtils {
}
}
} finally {
- Closeables.closeQuietly(writer);
+ Closeables.close(writer, true);
}
return i;
}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java?rev=1470196&r1=1470195&r2=1470196&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java
Sat Apr 20 17:14:16 2013
@@ -17,7 +17,10 @@
package org.apache.mahout.classifier.naivebayes.training;
-import com.google.common.base.Splitter;
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
@@ -38,9 +41,7 @@ import org.apache.mahout.common.iterator
import org.apache.mahout.common.mapreduce.VectorSumReducer;
import org.apache.mahout.math.VectorWritable;
-import java.io.IOException;
-import java.util.List;
-import java.util.Map;
+import com.google.common.base.Splitter;
/**
* This class trains a Naive Bayes Classifier (Parameters for both Naive Bayes
and Complementary Naive Bayes)
@@ -131,10 +132,10 @@ public final class TrainNaiveBayesJob ex
if (!succeeded) {
return -1;
}
-
+
//put the per label and per feature vectors into the cache
HadoopUtil.cacheFiles(getTempPath(WEIGHTS), getConf());
-
+
//calculate the Thetas, write out to LABEL_THETA_NORMALIZER vectors --
// TODO: add reference here to the part of the Rennie paper that discusses
this
Job thetaSummer = prepareJob(getTempPath(SUMMED_OBSERVATIONS),
@@ -155,7 +156,7 @@ public final class TrainNaiveBayesJob ex
if (!succeeded) {
return -1;
}*/
-
+
//validate our model and then write it out to the official output
getConf().setFloat(ThetaMapper.ALPHA_I, alphaI);
NaiveBayesModel naiveBayesModel =
BayesUtils.readModelFromDir(getTempPath(), getConf());
@@ -180,5 +181,4 @@ public final class TrainNaiveBayesJob ex
}
return labelSize;
}
-
}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java?rev=1470196&r1=1470195&r2=1470196&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java
Sat Apr 20 17:14:16 2013
@@ -19,15 +19,17 @@ package org.apache.mahout.classifier.nai
import java.io.IOException;
-import com.google.common.base.Preconditions;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.function.Functions;
+import com.google.common.base.Preconditions;
+
public class WeightsMapper extends Mapper<IntWritable, VectorWritable, Text,
VectorWritable> {
static final String NUM_LABELS = WeightsMapper.class.getName() +
".numLabels";
@@ -40,7 +42,7 @@ public class WeightsMapper extends Mappe
super.setup(ctx);
int numLabels = Integer.parseInt(ctx.getConfiguration().get(NUM_LABELS));
Preconditions.checkArgument(numLabels > 0);
- weightsPerLabel = new RandomAccessSparseVector(numLabels);
+ weightsPerLabel = new DenseVector(numLabels);
}
@Override
Modified:
mahout/trunk/integration/src/main/java/org/apache/mahout/benchmark/DotBenchmark.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/integration/src/main/java/org/apache/mahout/benchmark/DotBenchmark.java?rev=1470196&r1=1470195&r2=1470196&view=diff
==============================================================================
---
mahout/trunk/integration/src/main/java/org/apache/mahout/benchmark/DotBenchmark.java
(original)
+++
mahout/trunk/integration/src/main/java/org/apache/mahout/benchmark/DotBenchmark.java
Sat Apr 20 17:14:16 2013
@@ -10,12 +10,15 @@ import static org.apache.mahout.benchmar
import static org.apache.mahout.benchmark.VectorBenchmarks.SEQ_FN_RAND;
import static org.apache.mahout.benchmark.VectorBenchmarks.SEQ_SPARSE_VECTOR;
+import java.io.IOException;
+
import org.apache.mahout.benchmark.BenchmarkRunner.BenchmarkFn;
import org.apache.mahout.benchmark.BenchmarkRunner.BenchmarkFnD;
public class DotBenchmark {
private static final String DOT_PRODUCT = "DotProduct";
private static final String NORM1 = "Norm1";
+ private static final String NORM2 = "Norm2";
private static final String LOG_NORMALIZE = "LogNormalize";
private final VectorBenchmarks mark;
@@ -26,6 +29,7 @@ public class DotBenchmark {
public void benchmark() {
benchmarkDot();
benchmarkNorm1();
+ benchmarkNorm2();
benchmarkLogNormalize();
}
@@ -75,6 +79,29 @@ public class DotBenchmark {
}), NORM1, SEQ_SPARSE_VECTOR);
}
+ private void benchmarkNorm2() {
+ mark.printStats(mark.getRunner().benchmarkD(new BenchmarkFnD() {
+ @Override
+ public Double apply(Integer i) {
+ return mark.vectors[0][mark.vIndex(i)].norm(2);
+ }
+ }), NORM2, DENSE_VECTOR);
+
+ mark.printStats(mark.getRunner().benchmarkD(new BenchmarkFnD() {
+ @Override
+ public Double apply(Integer i) {
+ return mark.vectors[1][mark.vIndex(i)].norm(2);
+ }
+ }), NORM2, RAND_SPARSE_VECTOR);
+
+ mark.printStats(mark.getRunner().benchmarkD(new BenchmarkFnD() {
+ @Override
+ public Double apply(Integer i) {
+ return mark.vectors[2][mark.vIndex(i)].norm(2);
+ }
+ }), NORM2, SEQ_SPARSE_VECTOR);
+ }
+
private void benchmarkDot() {
mark.printStats(mark.getRunner().benchmarkD(new BenchmarkFnD() {
@Override
@@ -139,4 +166,11 @@ public class DotBenchmark {
}
}), DOT_PRODUCT, SEQ_FN_RAND);
}
+
+ public static void main(String[] args) throws IOException {
+ VectorBenchmarks mark = new VectorBenchmarks(1000000, 100, 1000, 10, 1);
+ mark.createData();
+ new DotBenchmark(mark).benchmarkNorm2();
+ System.out.println(mark);
+ }
}