Author: ssc
Date: Tue May  7 08:04:26 2013
New Revision: 1479793

URL: http://svn.apache.org/r1479793
Log:
MAHOUT-1205 ParallelALSFactorizationJob should leverage the distributed cache

Modified:
    
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ALS.java
    
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob.java
    
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SharingMapper.java
    
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveExplicitFeedbackMapper.java
    
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveImplicitFeedbackMapper.java
    
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJobTest.java

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ALS.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ALS.java?rev=1479793&r1=1479792&r2=1479793&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ALS.java 
(original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ALS.java 
Tue May  7 08:04:26 2013
@@ -17,10 +17,16 @@
 
 package org.apache.mahout.cf.taste.hadoop.als;
 
+import com.google.common.base.Preconditions;
 import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
 import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.LocalFileSystem;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
 import org.apache.mahout.common.Pair;
 import org.apache.mahout.common.iterator.sequencefile.PathFilters;
 import org.apache.mahout.common.iterator.sequencefile.PathType;
@@ -29,6 +35,7 @@ import org.apache.mahout.common.iterator
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
 import org.apache.mahout.math.als.AlternatingLeastSquaresSolver;
+import org.apache.mahout.math.function.Functions;
 import org.apache.mahout.math.map.OpenIntObjectHashMap;
 
 import java.io.IOException;
@@ -45,13 +52,60 @@ final class ALS {
     return iterator.hasNext() ? iterator.next().get() : null;
   }
 
-  public static OpenIntObjectHashMap<Vector> readMatrixByRows(Path dir, 
Configuration conf) {
-    OpenIntObjectHashMap<Vector> matrix = new OpenIntObjectHashMap<Vector>();
+  /**
+   * assumes that first entry always exists
+   *
+   * @param vectors
+   */
+  public static Vector sum(Iterator<VectorWritable> vectors) {
+    Vector sum = vectors.next().get();
+    while (vectors.hasNext()) {
+      sum.assign(vectors.next().get(), Functions.PLUS);
+    }
+    return sum;
+  }
+
+  public static OpenIntObjectHashMap<Vector> 
readMatrixByRowsFromDistributedCache(int numEntities,
+      Configuration conf) throws IOException {
+
+    IntWritable rowIndex = new IntWritable();
+    VectorWritable row = new VectorWritable();
+
+    LocalFileSystem localFs = FileSystem.getLocal(conf);
+    Path[] cacheFiles = DistributedCache.getLocalCacheFiles(conf);
 
+    OpenIntObjectHashMap<Vector> featureMatrix = numEntities > 0
+        ? new OpenIntObjectHashMap<Vector>(numEntities) : new 
OpenIntObjectHashMap<Vector>();
+
+    for (int n = 0; n < cacheFiles.length; n++) {
+      Path localCacheFile = localFs.makeQualified(cacheFiles[n]);
+
+      // fallback for local execution
+      if (!localFs.exists(localCacheFile)) {
+        localCacheFile = new 
Path(DistributedCache.getCacheFiles(conf)[n].getPath());
+      }
+
+      SequenceFile.Reader reader = null;
+      try {
+        reader = new SequenceFile.Reader(localFs, localCacheFile, conf);
+        while (reader.next(rowIndex, row)) {
+          featureMatrix.put(rowIndex.get(), row.get());
+        }
+      } finally {
+        Closeables.close(reader, true);
+      }
+    }
+
+    Preconditions.checkState(!featureMatrix.isEmpty(), "Feature matrix is 
empty");
+    return featureMatrix;
+  }
+
+  public static OpenIntObjectHashMap<Vector> readMatrixByRows(Path dir, 
Configuration conf) {
+    OpenIntObjectHashMap matrix = new OpenIntObjectHashMap<Vector>();
     for (Pair<IntWritable,VectorWritable> pair
         : new SequenceFileDirIterable<IntWritable,VectorWritable>(dir, 
PathType.LIST, PathFilters.partFilter(), conf)) {
       int rowIndex = pair.getFirst().get();
-      Vector row = pair.getSecond().get().clone();
+      Vector row = pair.getSecond().get();
       matrix.put(rowIndex, row);
     }
     return matrix;

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob.java?rev=1479793&r1=1479792&r2=1479793&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob.java
 Tue May  7 08:04:26 2013
@@ -19,14 +19,18 @@ package org.apache.mahout.cf.taste.hadoo
 
 import com.google.common.io.Closeables;
 import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FileStatus;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.IntWritable;
 import org.apache.hadoop.io.LongWritable;
 import org.apache.hadoop.io.SequenceFile;
 import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.WritableComparable;
 import org.apache.hadoop.mapreduce.Job;
 import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.Reducer;
 import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
 import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
 import org.apache.hadoop.mapreduce.lib.map.MultithreadedMapper;
@@ -37,15 +41,16 @@ import org.apache.mahout.cf.taste.impl.c
 import org.apache.mahout.cf.taste.impl.common.RunningAverage;
 import org.apache.mahout.common.AbstractJob;
 import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
 import org.apache.mahout.common.mapreduce.MergeVectorsCombiner;
 import org.apache.mahout.common.mapreduce.MergeVectorsReducer;
 import org.apache.mahout.common.mapreduce.TransposeMapper;
-import org.apache.mahout.common.mapreduce.VectorSumReducer;
 import org.apache.mahout.math.DenseVector;
 import org.apache.mahout.math.RandomAccessSparseVector;
 import org.apache.mahout.math.SequentialAccessSparseVector;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -83,7 +88,7 @@ public class ParallelALSFactorizationJob
   static final String NUM_FEATURES = 
ParallelALSFactorizationJob.class.getName() + ".numFeatures";
   static final String LAMBDA = ParallelALSFactorizationJob.class.getName() + 
".lambda";
   static final String ALPHA = ParallelALSFactorizationJob.class.getName() + 
".alpha";
-  static final String FEATURE_MATRIX = 
ParallelALSFactorizationJob.class.getName() + ".featureMatrix";
+  static final String NUM_ENTITIES = 
ParallelALSFactorizationJob.class.getName() + ".numEntities";
 
   private boolean implicitFeedback;
   private int numIterations;
@@ -92,6 +97,11 @@ public class ParallelALSFactorizationJob
   private double alpha;
   private int numThreadsPerSolver;
 
+  private int numItems;
+  private int numUsers;
+
+  enum Stats { NUM_USERS }
+
   public static void main(String[] args) throws Exception {
     ToolRunner.run(new ParallelALSFactorizationJob(), args);
   }
@@ -142,8 +152,8 @@ public class ParallelALSFactorizationJob
 
     /* create A */
     Job userRatings = prepareJob(pathToItemRatings(), pathToUserRatings(),
-        TransposeMapper.class, IntWritable.class, VectorWritable.class, 
MergeVectorsReducer.class, IntWritable.class,
-        VectorWritable.class);
+        TransposeMapper.class, IntWritable.class, VectorWritable.class, 
MergeUserVectorsReducer.class,
+        IntWritable.class, VectorWritable.class);
     userRatings.setCombinerClass(MergeVectorsCombiner.class);
     succeeded = userRatings.waitForCompletion(true);
     if (!succeeded) {
@@ -162,16 +172,23 @@ public class ParallelALSFactorizationJob
 
     Vector averageRatings = ALS.readFirstRow(getTempPath("averageRatings"), 
getConf());
 
+    numItems = averageRatings.getNumNondefaultElements();
+    numUsers = (int) 
userRatings.getCounters().findCounter(Stats.NUM_USERS).getValue();
+
+    log.info("Found {} users and {} items", numUsers, numItems);
+
     /* create an initial M */
     initializeM(averageRatings);
 
     for (int currentIteration = 0; currentIteration < numIterations; 
currentIteration++) {
       /* broadcast M, read A row-wise, recompute U row-wise */
       log.info("Recomputing U (iteration {}/{})", currentIteration, 
numIterations);
-      runSolver(pathToUserRatings(), pathToU(currentIteration), 
pathToM(currentIteration - 1), currentIteration, "U");
+      runSolver(pathToUserRatings(), pathToU(currentIteration), 
pathToM(currentIteration - 1), currentIteration, "U",
+                numItems);
       /* broadcast U, read A' row-wise, recompute M row-wise */
       log.info("Recomputing M (iteration {}/{})", currentIteration, 
numIterations);
-      runSolver(pathToItemRatings(), pathToM(currentIteration), 
pathToU(currentIteration), currentIteration, "M");
+      runSolver(pathToItemRatings(), pathToM(currentIteration), 
pathToU(currentIteration), currentIteration, "M",
+                numUsers);
     }
 
     return 0;
@@ -202,7 +219,49 @@ public class ParallelALSFactorizationJob
         writer.append(index, featureVector);
       }
     } finally {
-      Closeables.closeQuietly(writer);
+      Closeables.close(writer, true);
+    }
+  }
+
+  static class VectorSumCombiner
+      extends Reducer<WritableComparable<?>, VectorWritable, 
WritableComparable<?>, VectorWritable> {
+
+    private final VectorWritable result = new VectorWritable();
+
+    @Override
+    protected void reduce(WritableComparable<?> key, Iterable<VectorWritable> 
values, Context ctx)
+        throws IOException, InterruptedException {
+      result.set(ALS.sum(values.iterator()));
+      ctx.write(key, result);
+    }
+  }
+
+  static class VectorSumReducer
+      extends Reducer<WritableComparable<?>, VectorWritable, 
WritableComparable<?>, VectorWritable> {
+
+    private final VectorWritable result = new VectorWritable();
+
+    @Override
+    protected void reduce(WritableComparable<?> key, Iterable<VectorWritable> 
values, Context ctx)
+        throws IOException, InterruptedException {
+      Vector sum = ALS.sum(values.iterator());
+      result.set(new SequentialAccessSparseVector(sum));
+      ctx.write(key, result);
+    }
+  }
+
+  static class MergeUserVectorsReducer extends
+      
Reducer<WritableComparable<?>,VectorWritable,WritableComparable<?>,VectorWritable>
 {
+
+    private final VectorWritable result = new VectorWritable();
+
+    @Override
+    public void reduce(WritableComparable<?> key, Iterable<VectorWritable> 
vectors, Context ctx)
+        throws IOException, InterruptedException {
+      Vector merged = VectorWritable.merge(vectors.iterator()).get();
+      result.set(new SequentialAccessSparseVector(merged));
+      ctx.write(key, result);
+      ctx.getCounter(Stats.NUM_USERS).increment(1);
     }
   }
 
@@ -210,7 +269,7 @@ public class ParallelALSFactorizationJob
 
     private final IntWritable itemIDWritable = new IntWritable();
     private final VectorWritable ratingsWritable = new VectorWritable(true);
-    private final Vector ratings = new 
SequentialAccessSparseVector(Integer.MAX_VALUE, 1);
+    private final Vector ratings = new 
RandomAccessSparseVector(Integer.MAX_VALUE, 1);
 
     @Override
     protected void map(LongWritable offset, Text line, Context ctx) throws 
IOException, InterruptedException {
@@ -231,8 +290,11 @@ public class ParallelALSFactorizationJob
     }
   }
 
-  private void runSolver(Path ratings, Path output, Path pathToUorI, int 
currentIteration, String matrixName)
-    throws ClassNotFoundException, IOException, InterruptedException {
+  private void runSolver(Path ratings, Path output, Path pathToUorM, int 
currentIteration, String matrixName,
+                         int numEntities) throws ClassNotFoundException, 
IOException, InterruptedException {
+
+    // necessary for local execution in the same JVM only
+    SharingMapper.reset();
 
     int iterationNumber = currentIteration + 1;
     Class<? extends 
Mapper<IntWritable,VectorWritable,IntWritable,VectorWritable>> 
solverMapperClassInternal;
@@ -241,11 +303,11 @@ public class ParallelALSFactorizationJob
     if (implicitFeedback) {
       solverMapperClassInternal = SolveImplicitFeedbackMapper.class;
       name = "Recompute " + matrixName + ", iteration (" + iterationNumber + 
"/" + numIterations + "), "
-          + "(" + numThreadsPerSolver + " threads, implicit feedback)";
+          + "(" + numThreadsPerSolver + " threads, " + numFeatures +" 
features, implicit feedback)";
     } else {
       solverMapperClassInternal = SolveExplicitFeedbackMapper.class;
       name = "Recompute " + matrixName + ", iteration (" + iterationNumber + 
"/" + numIterations + "), "
-          + "(" + numThreadsPerSolver + " threads, explicit feedback)";
+          + "(" + numThreadsPerSolver + " threads, " + numFeatures + " 
features, explicit feedback)";
     }
 
     Job solverForUorI = prepareJob(ratings, output, 
SequenceFileInputFormat.class, MultithreadedSharingMapper.class,
@@ -254,7 +316,16 @@ public class ParallelALSFactorizationJob
     solverConf.set(LAMBDA, String.valueOf(lambda));
     solverConf.set(ALPHA, String.valueOf(alpha));
     solverConf.setInt(NUM_FEATURES, numFeatures);
-    solverConf.set(FEATURE_MATRIX, pathToUorI.toString());
+    solverConf.set(NUM_ENTITIES, String.valueOf(numEntities));
+
+    FileSystem fs = FileSystem.get(pathToUorM.toUri(), solverConf);
+    FileStatus[] parts = fs.listStatus(pathToUorM, PathFilters.partFilter());
+    for (FileStatus part : parts) {
+      if (log.isDebugEnabled()) {
+        log.debug("Adding {} to distributed cache", part.getPath().toString());
+      }
+      DistributedCache.addCacheFile(part.getPath().toUri(), solverConf);
+    }
 
     MultithreadedMapper.setMapperClass(solverForUorI, 
solverMapperClassInternal);
     MultithreadedMapper.setNumberOfThreads(solverForUorI, numThreadsPerSolver);

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SharingMapper.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SharingMapper.java?rev=1479793&r1=1479792&r2=1479793&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SharingMapper.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SharingMapper.java
 Tue May  7 08:04:26 2013
@@ -19,6 +19,8 @@ package org.apache.mahout.cf.taste.hadoo
 
 import org.apache.hadoop.mapreduce.Mapper;
 
+import java.io.IOException;
+
 /**
  * Mapper class to be used by {@link MultithreadedSharingMapper}. Offers 
"global" before() and after() methods
  * that will typically be used to set up static variables.
@@ -32,20 +34,26 @@ import org.apache.hadoop.mapreduce.Mappe
  */
 public abstract class SharingMapper<K1,V1,K2,V2,S> extends Mapper<K1,V1,K2,V2> 
{
 
-  private static Object sharedInstance;
+  private static Object SHARED_INSTANCE;
 
   /**
    * Called before the multithreaded execution
    *
    * @param context mapper's context
    */
-  abstract S createSharedInstance(Context context);
+  abstract S createSharedInstance(Context context) throws IOException;
 
-  final void setupSharedInstance(Context context) {
-    sharedInstance = createSharedInstance(context);
+  final void setupSharedInstance(Context context) throws IOException {
+    if (SHARED_INSTANCE == null) {
+      SHARED_INSTANCE = createSharedInstance(context);
+    }
   }
 
   final S getSharedInstance() {
-    return (S) sharedInstance;
+    return (S) SHARED_INSTANCE;
+  }
+
+  final static void reset() {
+    SHARED_INSTANCE = null;
   }
 }

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveExplicitFeedbackMapper.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveExplicitFeedbackMapper.java?rev=1479793&r1=1479792&r2=1479793&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveExplicitFeedbackMapper.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveExplicitFeedbackMapper.java
 Tue May  7 08:04:26 2013
@@ -18,6 +18,7 @@
 package org.apache.mahout.cf.taste.hadoop.als;
 
 import com.google.common.base.Preconditions;
+import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.IntWritable;
 import org.apache.hadoop.mapreduce.Mapper;
@@ -37,9 +38,10 @@ public class SolveExplicitFeedbackMapper
   private final VectorWritable uiOrmj = new VectorWritable();
 
   @Override
-  OpenIntObjectHashMap<Vector> createSharedInstance(Context ctx) {
-    Path UOrIPath = new 
Path(ctx.getConfiguration().get(ParallelALSFactorizationJob.FEATURE_MATRIX));
-    return ALS.readMatrixByRows(UOrIPath, ctx.getConfiguration());
+  OpenIntObjectHashMap<Vector> createSharedInstance(Context ctx) throws 
IOException {
+    Configuration conf = ctx.getConfiguration();
+    int numEntities = 
Integer.parseInt(conf.get(ParallelALSFactorizationJob.NUM_ENTITIES));
+    return ALS.readMatrixByRowsFromDistributedCache(numEntities, conf);
   }
 
   @Override

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveImplicitFeedbackMapper.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveImplicitFeedbackMapper.java?rev=1479793&r1=1479792&r2=1479793&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveImplicitFeedbackMapper.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveImplicitFeedbackMapper.java
 Tue May  7 08:04:26 2013
@@ -34,18 +34,18 @@ public class SolveImplicitFeedbackMapper
   private final VectorWritable uiOrmj = new VectorWritable();
 
   @Override
-  ImplicitFeedbackAlternatingLeastSquaresSolver createSharedInstance(Context 
ctx) {
+  ImplicitFeedbackAlternatingLeastSquaresSolver createSharedInstance(Context 
ctx) throws IOException {
     Configuration conf = ctx.getConfiguration();
 
     double lambda = 
Double.parseDouble(conf.get(ParallelALSFactorizationJob.LAMBDA));
     double alpha = 
Double.parseDouble(conf.get(ParallelALSFactorizationJob.ALPHA));
     int numFeatures = conf.getInt(ParallelALSFactorizationJob.NUM_FEATURES, 
-1);
-    Path YPath = new 
Path(conf.get(ParallelALSFactorizationJob.FEATURE_MATRIX));
+    int numEntities = 
Integer.parseInt(conf.get(ParallelALSFactorizationJob.NUM_ENTITIES));
 
     Preconditions.checkArgument(numFeatures > 0, "numFeatures was not set 
correctly!");
 
     return new ImplicitFeedbackAlternatingLeastSquaresSolver(numFeatures, 
lambda, alpha,
-        ALS.readMatrixByRows(YPath, conf));
+        ALS.readMatrixByRowsFromDistributedCache(numEntities, conf));
   }
 
   @Override

Modified: 
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJobTest.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJobTest.java?rev=1479793&r1=1479792&r2=1479793&view=diff
==============================================================================
--- 
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJobTest.java
 (original)
+++ 
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJobTest.java
 Tue May  7 08:04:26 2013
@@ -55,6 +55,8 @@ public class ParallelALSFactorizationJob
     tmpDir = getTestTempDir("tmp");
 
     conf = new Configuration();
+    // reset as we run all tests in the same JVM
+    SharingMapper.reset();
   }
 
   @Test


Reply via email to