Author: ssc
Date: Tue Aug 23 09:13:03 2011
New Revision: 1160591
URL: http://svn.apache.org/viewvc?rev=1160591&view=rev
Log:
MAHOUT-777 Improve TransposeJob to use a Combiner
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/TransposeJob.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/TransposeJob.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/TransposeJob.java?rev=1160591&r1=1160590&r2=1160591&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/TransposeJob.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/TransposeJob.java
Tue Aug 23 09:13:03 2011
@@ -21,6 +21,7 @@ import org.apache.hadoop.conf.Configurat
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapred.FileInputFormat;
import org.apache.hadoop.mapred.FileOutputFormat;
import org.apache.hadoop.mapred.JobConf;
@@ -42,10 +43,8 @@ import java.io.IOException;
import java.util.Iterator;
import java.util.Map;
-/**
- * TODO: rewrite to use helpful combiner.
- */
public class TransposeJob extends AbstractJob {
+
public static final String NUM_ROWS_KEY = "SparseRowMatrix.numRows";
public static void main(String[] args) throws Exception {
@@ -59,16 +58,13 @@ public class TransposeJob extends Abstra
addOption("numCols", "nc", "Number of columns of the input matrix");
Map<String,String> parsedArgs = parseArguments(strings);
if (parsedArgs == null) {
- // FIXME
- return 0;
+ return -1;
}
- Path inputPath = getInputPath();
- Path outputTmpPath = new Path(parsedArgs.get("--tempDir"));
int numRows = Integer.parseInt(parsedArgs.get("--numRows"));
int numCols = Integer.parseInt(parsedArgs.get("--numCols"));
- DistributedRowMatrix matrix = new DistributedRowMatrix(inputPath,
outputTmpPath, numRows, numCols);
+ DistributedRowMatrix matrix = new DistributedRowMatrix(getInputPath(),
getTempPath(), numRows, numCols);
matrix.setConf(new Configuration(getConf()));
matrix.transpose();
@@ -96,9 +92,10 @@ public class TransposeJob extends Abstra
conf.setInputFormat(SequenceFileInputFormat.class);
FileOutputFormat.setOutputPath(conf, matrixOutputPath);
conf.setMapperClass(TransposeMapper.class);
- conf.setReducerClass(TransposeReducer.class);
conf.setMapOutputKeyClass(IntWritable.class);
-
conf.setMapOutputValueClass(DistributedRowMatrix.MatrixEntryWritable.class);
+ conf.setMapOutputValueClass(VectorWritable.class);
+ conf.setCombinerClass(MergeVectorsCombiner.class);
+ conf.setReducerClass(MergeVectorsReducer.class);
conf.setOutputFormat(SequenceFileOutputFormat.class);
conf.setOutputKeyClass(IntWritable.class);
conf.setOutputValueClass(VectorWritable.class);
@@ -106,49 +103,62 @@ public class TransposeJob extends Abstra
}
public static class TransposeMapper extends MapReduceBase
- implements
Mapper<IntWritable,VectorWritable,IntWritable,DistributedRowMatrix.MatrixEntryWritable>
{
+ implements Mapper<IntWritable,VectorWritable,IntWritable,VectorWritable>
{
+
+ private int newNumCols;
@Override
- public void map(IntWritable r,
- VectorWritable v,
- OutputCollector<IntWritable,
DistributedRowMatrix.MatrixEntryWritable> out,
- Reporter reporter) throws IOException {
- DistributedRowMatrix.MatrixEntryWritable entry = new
DistributedRowMatrix.MatrixEntryWritable();
- Iterator<Vector.Element> it = v.get().iterateNonZero();
+ public void configure(JobConf conf) {
+ newNumCols = conf.getInt(NUM_ROWS_KEY, Integer.MAX_VALUE);
+ }
+
+ @Override
+ public void map(IntWritable r, VectorWritable v,
OutputCollector<IntWritable, VectorWritable> out,
+ Reporter reporter) throws IOException {
int row = r.get();
- entry.setCol(row);
- entry.setRow(-1); // output "row" is captured in the key
+ Iterator<Vector.Element> it = v.get().iterateNonZero();
while (it.hasNext()) {
Vector.Element e = it.next();
+ RandomAccessSparseVector tmp = new
RandomAccessSparseVector(newNumCols, 1);
+ tmp.setQuick(row, e.get());
r.set(e.index());
- entry.setVal(e.get());
- out.collect(r, entry);
+ out.collect(r, new VectorWritable(tmp));
}
}
}
- public static class TransposeReducer extends MapReduceBase
- implements
Reducer<IntWritable,DistributedRowMatrix.MatrixEntryWritable,IntWritable,VectorWritable>
{
+ static Vector merge(Iterator<VectorWritable> vectors) {
+ Vector accumulator = vectors.next().get();
+ while (vectors.hasNext()) {
+ VectorWritable v = vectors.next();
+ if (v != null) {
+ Iterator<Vector.Element> nonZeroElements = v.get().iterateNonZero();
+ while (nonZeroElements.hasNext()) {
+ Vector.Element nonZeroElement = nonZeroElements.next();
+ accumulator.setQuick(nonZeroElement.index(), nonZeroElement.get());
+ }
+ }
+ }
+ return accumulator;
+ }
- private int newNumCols;
+ public static class MergeVectorsCombiner extends MapReduceBase
+ implements
Reducer<WritableComparable<?>,VectorWritable,WritableComparable<?>,VectorWritable>
{
@Override
- public void configure(JobConf conf) {
- newNumCols = conf.getInt(NUM_ROWS_KEY, Integer.MAX_VALUE);
+ public void reduce(WritableComparable<?> key, Iterator<VectorWritable>
vectors,
+ OutputCollector<WritableComparable<?>, VectorWritable> out, Reporter
reporter) throws IOException {
+ out.collect(key, new VectorWritable(merge(vectors)));
}
+ }
+
+ public static class MergeVectorsReducer extends MapReduceBase
+ implements
Reducer<WritableComparable<?>,VectorWritable,WritableComparable<?>,VectorWritable>
{
@Override
- public void reduce(IntWritable outRow,
- Iterator<DistributedRowMatrix.MatrixEntryWritable> it,
- OutputCollector<IntWritable, VectorWritable> out,
- Reporter reporter) throws IOException {
- RandomAccessSparseVector tmp = new RandomAccessSparseVector(newNumCols,
100);
- while (it.hasNext()) {
- DistributedRowMatrix.MatrixEntryWritable e = it.next();
- tmp.setQuick(e.getCol(), e.getVal());
- }
- SequentialAccessSparseVector outVector = new
SequentialAccessSparseVector(tmp);
- out.collect(outRow, new VectorWritable(outVector));
+ public void reduce(WritableComparable<?> key, Iterator<VectorWritable>
vectors,
+ OutputCollector<WritableComparable<?>, VectorWritable> out, Reporter
reporter) throws IOException {
+ out.collect(key, new VectorWritable(new
SequentialAccessSparseVector(merge(vectors))));
}
}
}