Author: ssc
Date: Tue May 7 08:04:26 2013
New Revision: 1479793
URL: http://svn.apache.org/r1479793
Log:
MAHOUT-1205 ParallelALSFactorizationJob should leverage the distributed cache
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ALS.java
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob.java
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SharingMapper.java
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveExplicitFeedbackMapper.java
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveImplicitFeedbackMapper.java
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJobTest.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ALS.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ALS.java?rev=1479793&r1=1479792&r2=1479793&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ALS.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ALS.java
Tue May 7 08:04:26 2013
@@ -17,10 +17,16 @@
package org.apache.mahout.cf.taste.hadoop.als;
+import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.LocalFileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.common.iterator.sequencefile.PathType;
@@ -29,6 +35,7 @@ import org.apache.mahout.common.iterator
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.als.AlternatingLeastSquaresSolver;
+import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.map.OpenIntObjectHashMap;
import java.io.IOException;
@@ -45,13 +52,60 @@ final class ALS {
return iterator.hasNext() ? iterator.next().get() : null;
}
- public static OpenIntObjectHashMap<Vector> readMatrixByRows(Path dir,
Configuration conf) {
- OpenIntObjectHashMap<Vector> matrix = new OpenIntObjectHashMap<Vector>();
+ /**
+ * assumes that first entry always exists
+ *
+ * @param vectors
+ */
+ public static Vector sum(Iterator<VectorWritable> vectors) {
+ Vector sum = vectors.next().get();
+ while (vectors.hasNext()) {
+ sum.assign(vectors.next().get(), Functions.PLUS);
+ }
+ return sum;
+ }
+
+ public static OpenIntObjectHashMap<Vector>
readMatrixByRowsFromDistributedCache(int numEntities,
+ Configuration conf) throws IOException {
+
+ IntWritable rowIndex = new IntWritable();
+ VectorWritable row = new VectorWritable();
+
+ LocalFileSystem localFs = FileSystem.getLocal(conf);
+ Path[] cacheFiles = DistributedCache.getLocalCacheFiles(conf);
+ OpenIntObjectHashMap<Vector> featureMatrix = numEntities > 0
+ ? new OpenIntObjectHashMap<Vector>(numEntities) : new
OpenIntObjectHashMap<Vector>();
+
+ for (int n = 0; n < cacheFiles.length; n++) {
+ Path localCacheFile = localFs.makeQualified(cacheFiles[n]);
+
+ // fallback for local execution
+ if (!localFs.exists(localCacheFile)) {
+ localCacheFile = new
Path(DistributedCache.getCacheFiles(conf)[n].getPath());
+ }
+
+ SequenceFile.Reader reader = null;
+ try {
+ reader = new SequenceFile.Reader(localFs, localCacheFile, conf);
+ while (reader.next(rowIndex, row)) {
+ featureMatrix.put(rowIndex.get(), row.get());
+ }
+ } finally {
+ Closeables.close(reader, true);
+ }
+ }
+
+ Preconditions.checkState(!featureMatrix.isEmpty(), "Feature matrix is
empty");
+ return featureMatrix;
+ }
+
+ public static OpenIntObjectHashMap<Vector> readMatrixByRows(Path dir,
Configuration conf) {
+ OpenIntObjectHashMap matrix = new OpenIntObjectHashMap<Vector>();
for (Pair<IntWritable,VectorWritable> pair
: new SequenceFileDirIterable<IntWritable,VectorWritable>(dir,
PathType.LIST, PathFilters.partFilter(), conf)) {
int rowIndex = pair.getFirst().get();
- Vector row = pair.getSecond().get().clone();
+ Vector row = pair.getSecond().get();
matrix.put(rowIndex, row);
}
return matrix;
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob.java?rev=1479793&r1=1479792&r2=1479793&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob.java
Tue May 7 08:04:26 2013
@@ -19,14 +19,18 @@ package org.apache.mahout.cf.taste.hadoo
import com.google.common.io.Closeables;
import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
import org.apache.hadoop.mapreduce.lib.map.MultithreadedMapper;
@@ -37,15 +41,16 @@ import org.apache.mahout.cf.taste.impl.c
import org.apache.mahout.cf.taste.impl.common.RunningAverage;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.common.mapreduce.MergeVectorsCombiner;
import org.apache.mahout.common.mapreduce.MergeVectorsReducer;
import org.apache.mahout.common.mapreduce.TransposeMapper;
-import org.apache.mahout.common.mapreduce.VectorSumReducer;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -83,7 +88,7 @@ public class ParallelALSFactorizationJob
static final String NUM_FEATURES =
ParallelALSFactorizationJob.class.getName() + ".numFeatures";
static final String LAMBDA = ParallelALSFactorizationJob.class.getName() +
".lambda";
static final String ALPHA = ParallelALSFactorizationJob.class.getName() +
".alpha";
- static final String FEATURE_MATRIX =
ParallelALSFactorizationJob.class.getName() + ".featureMatrix";
+ static final String NUM_ENTITIES =
ParallelALSFactorizationJob.class.getName() + ".numEntities";
private boolean implicitFeedback;
private int numIterations;
@@ -92,6 +97,11 @@ public class ParallelALSFactorizationJob
private double alpha;
private int numThreadsPerSolver;
+ private int numItems;
+ private int numUsers;
+
+ enum Stats { NUM_USERS }
+
public static void main(String[] args) throws Exception {
ToolRunner.run(new ParallelALSFactorizationJob(), args);
}
@@ -142,8 +152,8 @@ public class ParallelALSFactorizationJob
/* create A */
Job userRatings = prepareJob(pathToItemRatings(), pathToUserRatings(),
- TransposeMapper.class, IntWritable.class, VectorWritable.class,
MergeVectorsReducer.class, IntWritable.class,
- VectorWritable.class);
+ TransposeMapper.class, IntWritable.class, VectorWritable.class,
MergeUserVectorsReducer.class,
+ IntWritable.class, VectorWritable.class);
userRatings.setCombinerClass(MergeVectorsCombiner.class);
succeeded = userRatings.waitForCompletion(true);
if (!succeeded) {
@@ -162,16 +172,23 @@ public class ParallelALSFactorizationJob
Vector averageRatings = ALS.readFirstRow(getTempPath("averageRatings"),
getConf());
+ numItems = averageRatings.getNumNondefaultElements();
+ numUsers = (int)
userRatings.getCounters().findCounter(Stats.NUM_USERS).getValue();
+
+ log.info("Found {} users and {} items", numUsers, numItems);
+
/* create an initial M */
initializeM(averageRatings);
for (int currentIteration = 0; currentIteration < numIterations;
currentIteration++) {
/* broadcast M, read A row-wise, recompute U row-wise */
log.info("Recomputing U (iteration {}/{})", currentIteration,
numIterations);
- runSolver(pathToUserRatings(), pathToU(currentIteration),
pathToM(currentIteration - 1), currentIteration, "U");
+ runSolver(pathToUserRatings(), pathToU(currentIteration),
pathToM(currentIteration - 1), currentIteration, "U",
+ numItems);
/* broadcast U, read A' row-wise, recompute M row-wise */
log.info("Recomputing M (iteration {}/{})", currentIteration,
numIterations);
- runSolver(pathToItemRatings(), pathToM(currentIteration),
pathToU(currentIteration), currentIteration, "M");
+ runSolver(pathToItemRatings(), pathToM(currentIteration),
pathToU(currentIteration), currentIteration, "M",
+ numUsers);
}
return 0;
@@ -202,7 +219,49 @@ public class ParallelALSFactorizationJob
writer.append(index, featureVector);
}
} finally {
- Closeables.closeQuietly(writer);
+ Closeables.close(writer, true);
+ }
+ }
+
+ static class VectorSumCombiner
+ extends Reducer<WritableComparable<?>, VectorWritable,
WritableComparable<?>, VectorWritable> {
+
+ private final VectorWritable result = new VectorWritable();
+
+ @Override
+ protected void reduce(WritableComparable<?> key, Iterable<VectorWritable>
values, Context ctx)
+ throws IOException, InterruptedException {
+ result.set(ALS.sum(values.iterator()));
+ ctx.write(key, result);
+ }
+ }
+
+ static class VectorSumReducer
+ extends Reducer<WritableComparable<?>, VectorWritable,
WritableComparable<?>, VectorWritable> {
+
+ private final VectorWritable result = new VectorWritable();
+
+ @Override
+ protected void reduce(WritableComparable<?> key, Iterable<VectorWritable>
values, Context ctx)
+ throws IOException, InterruptedException {
+ Vector sum = ALS.sum(values.iterator());
+ result.set(new SequentialAccessSparseVector(sum));
+ ctx.write(key, result);
+ }
+ }
+
+ static class MergeUserVectorsReducer extends
+
Reducer<WritableComparable<?>,VectorWritable,WritableComparable<?>,VectorWritable>
{
+
+ private final VectorWritable result = new VectorWritable();
+
+ @Override
+ public void reduce(WritableComparable<?> key, Iterable<VectorWritable>
vectors, Context ctx)
+ throws IOException, InterruptedException {
+ Vector merged = VectorWritable.merge(vectors.iterator()).get();
+ result.set(new SequentialAccessSparseVector(merged));
+ ctx.write(key, result);
+ ctx.getCounter(Stats.NUM_USERS).increment(1);
}
}
@@ -210,7 +269,7 @@ public class ParallelALSFactorizationJob
private final IntWritable itemIDWritable = new IntWritable();
private final VectorWritable ratingsWritable = new VectorWritable(true);
- private final Vector ratings = new
SequentialAccessSparseVector(Integer.MAX_VALUE, 1);
+ private final Vector ratings = new
RandomAccessSparseVector(Integer.MAX_VALUE, 1);
@Override
protected void map(LongWritable offset, Text line, Context ctx) throws
IOException, InterruptedException {
@@ -231,8 +290,11 @@ public class ParallelALSFactorizationJob
}
}
- private void runSolver(Path ratings, Path output, Path pathToUorI, int
currentIteration, String matrixName)
- throws ClassNotFoundException, IOException, InterruptedException {
+ private void runSolver(Path ratings, Path output, Path pathToUorM, int
currentIteration, String matrixName,
+ int numEntities) throws ClassNotFoundException,
IOException, InterruptedException {
+
+ // necessary for local execution in the same JVM only
+ SharingMapper.reset();
int iterationNumber = currentIteration + 1;
Class<? extends
Mapper<IntWritable,VectorWritable,IntWritable,VectorWritable>>
solverMapperClassInternal;
@@ -241,11 +303,11 @@ public class ParallelALSFactorizationJob
if (implicitFeedback) {
solverMapperClassInternal = SolveImplicitFeedbackMapper.class;
name = "Recompute " + matrixName + ", iteration (" + iterationNumber +
"/" + numIterations + "), "
- + "(" + numThreadsPerSolver + " threads, implicit feedback)";
+ + "(" + numThreadsPerSolver + " threads, " + numFeatures +"
features, implicit feedback)";
} else {
solverMapperClassInternal = SolveExplicitFeedbackMapper.class;
name = "Recompute " + matrixName + ", iteration (" + iterationNumber +
"/" + numIterations + "), "
- + "(" + numThreadsPerSolver + " threads, explicit feedback)";
+ + "(" + numThreadsPerSolver + " threads, " + numFeatures + "
features, explicit feedback)";
}
Job solverForUorI = prepareJob(ratings, output,
SequenceFileInputFormat.class, MultithreadedSharingMapper.class,
@@ -254,7 +316,16 @@ public class ParallelALSFactorizationJob
solverConf.set(LAMBDA, String.valueOf(lambda));
solverConf.set(ALPHA, String.valueOf(alpha));
solverConf.setInt(NUM_FEATURES, numFeatures);
- solverConf.set(FEATURE_MATRIX, pathToUorI.toString());
+ solverConf.set(NUM_ENTITIES, String.valueOf(numEntities));
+
+ FileSystem fs = FileSystem.get(pathToUorM.toUri(), solverConf);
+ FileStatus[] parts = fs.listStatus(pathToUorM, PathFilters.partFilter());
+ for (FileStatus part : parts) {
+ if (log.isDebugEnabled()) {
+ log.debug("Adding {} to distributed cache", part.getPath().toString());
+ }
+ DistributedCache.addCacheFile(part.getPath().toUri(), solverConf);
+ }
MultithreadedMapper.setMapperClass(solverForUorI,
solverMapperClassInternal);
MultithreadedMapper.setNumberOfThreads(solverForUorI, numThreadsPerSolver);
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SharingMapper.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SharingMapper.java?rev=1479793&r1=1479792&r2=1479793&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SharingMapper.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SharingMapper.java
Tue May 7 08:04:26 2013
@@ -19,6 +19,8 @@ package org.apache.mahout.cf.taste.hadoo
import org.apache.hadoop.mapreduce.Mapper;
+import java.io.IOException;
+
/**
* Mapper class to be used by {@link MultithreadedSharingMapper}. Offers
"global" before() and after() methods
* that will typically be used to set up static variables.
@@ -32,20 +34,26 @@ import org.apache.hadoop.mapreduce.Mappe
*/
public abstract class SharingMapper<K1,V1,K2,V2,S> extends Mapper<K1,V1,K2,V2>
{
- private static Object sharedInstance;
+ private static Object SHARED_INSTANCE;
/**
* Called before the multithreaded execution
*
* @param context mapper's context
*/
- abstract S createSharedInstance(Context context);
+ abstract S createSharedInstance(Context context) throws IOException;
- final void setupSharedInstance(Context context) {
- sharedInstance = createSharedInstance(context);
+ final void setupSharedInstance(Context context) throws IOException {
+ if (SHARED_INSTANCE == null) {
+ SHARED_INSTANCE = createSharedInstance(context);
+ }
}
final S getSharedInstance() {
- return (S) sharedInstance;
+ return (S) SHARED_INSTANCE;
+ }
+
+ final static void reset() {
+ SHARED_INSTANCE = null;
}
}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveExplicitFeedbackMapper.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveExplicitFeedbackMapper.java?rev=1479793&r1=1479792&r2=1479793&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveExplicitFeedbackMapper.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveExplicitFeedbackMapper.java
Tue May 7 08:04:26 2013
@@ -18,6 +18,7 @@
package org.apache.mahout.cf.taste.hadoop.als;
import com.google.common.base.Preconditions;
+import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapreduce.Mapper;
@@ -37,9 +38,10 @@ public class SolveExplicitFeedbackMapper
private final VectorWritable uiOrmj = new VectorWritable();
@Override
- OpenIntObjectHashMap<Vector> createSharedInstance(Context ctx) {
- Path UOrIPath = new
Path(ctx.getConfiguration().get(ParallelALSFactorizationJob.FEATURE_MATRIX));
- return ALS.readMatrixByRows(UOrIPath, ctx.getConfiguration());
+ OpenIntObjectHashMap<Vector> createSharedInstance(Context ctx) throws
IOException {
+ Configuration conf = ctx.getConfiguration();
+ int numEntities =
Integer.parseInt(conf.get(ParallelALSFactorizationJob.NUM_ENTITIES));
+ return ALS.readMatrixByRowsFromDistributedCache(numEntities, conf);
}
@Override
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveImplicitFeedbackMapper.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveImplicitFeedbackMapper.java?rev=1479793&r1=1479792&r2=1479793&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveImplicitFeedbackMapper.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveImplicitFeedbackMapper.java
Tue May 7 08:04:26 2013
@@ -34,18 +34,18 @@ public class SolveImplicitFeedbackMapper
private final VectorWritable uiOrmj = new VectorWritable();
@Override
- ImplicitFeedbackAlternatingLeastSquaresSolver createSharedInstance(Context
ctx) {
+ ImplicitFeedbackAlternatingLeastSquaresSolver createSharedInstance(Context
ctx) throws IOException {
Configuration conf = ctx.getConfiguration();
double lambda =
Double.parseDouble(conf.get(ParallelALSFactorizationJob.LAMBDA));
double alpha =
Double.parseDouble(conf.get(ParallelALSFactorizationJob.ALPHA));
int numFeatures = conf.getInt(ParallelALSFactorizationJob.NUM_FEATURES,
-1);
- Path YPath = new
Path(conf.get(ParallelALSFactorizationJob.FEATURE_MATRIX));
+ int numEntities =
Integer.parseInt(conf.get(ParallelALSFactorizationJob.NUM_ENTITIES));
Preconditions.checkArgument(numFeatures > 0, "numFeatures was not set
correctly!");
return new ImplicitFeedbackAlternatingLeastSquaresSolver(numFeatures,
lambda, alpha,
- ALS.readMatrixByRows(YPath, conf));
+ ALS.readMatrixByRowsFromDistributedCache(numEntities, conf));
}
@Override
Modified:
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJobTest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJobTest.java?rev=1479793&r1=1479792&r2=1479793&view=diff
==============================================================================
---
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJobTest.java
(original)
+++
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJobTest.java
Tue May 7 08:04:26 2013
@@ -55,6 +55,8 @@ public class ParallelALSFactorizationJob
tmpDir = getTestTempDir("tmp");
conf = new Configuration();
+ // reset as we run all tests in the same JVM
+ SharingMapper.reset();
}
@Test