Author: srowen
Date: Tue May 11 07:50:13 2010
New Revision: 943028

URL: http://svn.apache.org/viewvc?rev=943028&view=rev
Log:
Interim commit of more fixes and improvements

Modified:
    
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/AggregateAndRecommendReducer.java
    
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/PartialMultiplyReducer.java
    
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/UserVectorToCooccurrenceMapper.java

Modified: 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/AggregateAndRecommendReducer.java
URL: 
http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/AggregateAndRecommendReducer.java?rev=943028&r1=943027&r2=943028&view=diff
==============================================================================
--- 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/AggregateAndRecommendReducer.java
 (original)
+++ 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/AggregateAndRecommendReducer.java
 Tue May 11 07:50:13 2010
@@ -105,12 +105,11 @@ public final class AggregateAndRecommend
     while (recommendationVectorIterator.hasNext()) {
       Vector.Element element = recommendationVectorIterator.next();
       int index = element.index();
-      if (topItems.size() < recommendationsPerUser) {
-        long theItemID = indexItemIDMap.get(index);
-        topItems.add(new GenericRecommendedItem(theItemID, (float) 
element.get()));
-      } else if (element.get() > topItems.peek().getValue()) {
-        long theItemID = indexItemIDMap.get(index);
-        topItems.add(new GenericRecommendedItem(theItemID, (float) 
element.get()));
+      float value = (float) element.get();
+      if (topItems.size() < recommendationsPerUser && !Float.isNaN(value)) {
+        topItems.add(new GenericRecommendedItem(indexItemIDMap.get(index), 
value));
+      } else if (value > topItems.peek().getValue()) {
+        topItems.add(new GenericRecommendedItem(indexItemIDMap.get(index), 
value));
         topItems.poll();
       }
     }

Modified: 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/PartialMultiplyReducer.java
URL: 
http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/PartialMultiplyReducer.java?rev=943028&r1=943027&r2=943028&view=diff
==============================================================================
--- 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/PartialMultiplyReducer.java
 (original)
+++ 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/PartialMultiplyReducer.java
 Tue May 11 07:50:13 2010
@@ -19,6 +19,7 @@ package org.apache.mahout.cf.taste.hadoo
 
 import java.io.IOException;
 import java.util.Iterator;
+import java.util.PriorityQueue;
 
 import org.apache.hadoop.io.IntWritable;
 import org.apache.hadoop.io.VLongWritable;
@@ -26,76 +27,123 @@ import org.apache.hadoop.mapred.MapReduc
 import org.apache.hadoop.mapred.OutputCollector;
 import org.apache.hadoop.mapred.Reducer;
 import org.apache.hadoop.mapred.Reporter;
+import org.apache.mahout.math.RandomAccessSparseVector;
 import org.apache.mahout.math.VectorWritable;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.function.LongFloatProcedure;
+import org.apache.mahout.math.function.LongProcedure;
 import org.apache.mahout.math.map.OpenLongFloatHashMap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 public final class PartialMultiplyReducer extends MapReduceBase implements
     Reducer<IntWritable,VectorOrPrefWritable,VLongWritable,VectorWritable> {
 
+  private static final Logger log = 
LoggerFactory.getLogger(PartialMultiplyReducer.class);
+
+  private static final int MAX_PRODUCTS_PER_ITEM = 1000;
+
+  private enum Counters {
+    PRODUCTS_OUTPUT,
+    PRODUCTS_SKIPPED,
+  }
+
   @Override
   public void reduce(IntWritable key,
                      Iterator<VectorOrPrefWritable> values,
                      final OutputCollector<VLongWritable,VectorWritable> 
output,
-                     Reporter reporter) throws IOException {
+                     final Reporter reporter) {
 
+    int itemIndex = key.get();
     OpenLongFloatHashMap savedValues = new OpenLongFloatHashMap();
-    final int itemIndex = key.get();
-    final VLongWritable userIDWritable = new VLongWritable();
-    final VectorWritable vectorWritable = new VectorWritable();
-    vectorWritable.setWritesLaxPrecision(true);
 
     Vector cooccurrenceColumn = null;
     while (values.hasNext()) {
-
       VectorOrPrefWritable value = values.next();
       if (value.getVector() == null) {
-
         // Then this is a user-pref value
-        long userID = value.getUserID();
-        float preferenceValue = value.getValue();
-        
-        if (cooccurrenceColumn == null) {
-          // Haven't seen the co-occurrencce column yet; save it
-          savedValues.put(userID, preferenceValue);
-        } else {
-          // Have seen it
-          Vector partialProduct = preferenceValue == 1.0f ?
-              cooccurrenceColumn : cooccurrenceColumn.times(preferenceValue);
-          // This makes sure this item isn't recommended for this user:
-          partialProduct.set(itemIndex, Double.NEGATIVE_INFINITY);
-          userIDWritable.set(userID);
-          vectorWritable.set(partialProduct);
-          output.collect(userIDWritable, vectorWritable);
-        }
-
+        savedValues.put(value.getUserID(), value.getValue());
       } else {
-
         // Then this is the column vector
+        if (cooccurrenceColumn != null) {
+          throw new IllegalStateException("Found two co-occurrence columns for 
item index " + itemIndex);
+        }
         cooccurrenceColumn = value.getVector();
+      }
+    }
+
+    if (cooccurrenceColumn == null) {
+      log.info("Column vector missing for {}; continuing", itemIndex);
+      return;
+    }
 
-        final Vector theColumn = cooccurrenceColumn;
-        savedValues.forEachPair(new LongFloatProcedure() {
-          @Override
-          public boolean apply(long userID, float value) {
-            Vector partialProduct = theColumn.times(value);
-            // This makes sure this item isn't recommended for this user:
-            partialProduct.set(itemIndex, Double.NEGATIVE_INFINITY);
+    final VLongWritable userIDWritable = new VLongWritable();
+
+    // These single-element vectors ensure that each user will not be 
recommended
+    // this item
+    Vector excludeVector = new RandomAccessSparseVector(Integer.MAX_VALUE, 1);
+    excludeVector.set(itemIndex, Double.NaN);
+    final VectorWritable excludeWritable = new VectorWritable(excludeVector);
+    excludeWritable.setWritesLaxPrecision(true);
+    savedValues.forEachKey(new LongProcedure() {
+      @Override
+      public boolean apply(long userID) {
+        userIDWritable.set(userID);
+        try {
+          output.collect(userIDWritable, excludeWritable);
+        } catch (IOException ioe) {
+          throw new IllegalStateException(ioe);
+        }
+        return true;
+      }
+    });
+
+    final float smallestLargeValue = findSmallestLargeValue(savedValues);
+
+    final VectorWritable vectorWritable = new VectorWritable();
+    vectorWritable.setWritesLaxPrecision(true);
+
+    final Vector theColumn = cooccurrenceColumn;
+    savedValues.forEachPair(new LongFloatProcedure() {
+      @Override
+      public boolean apply(long userID, float value) {
+        if (Math.abs(value) < smallestLargeValue) {
+          reporter.incrCounter(Counters.PRODUCTS_SKIPPED, 1L);
+        } else {
+          try {
+            Vector partialProduct = value == 1.0f ? theColumn : 
theColumn.times(value);
             userIDWritable.set(userID);
             vectorWritable.set(partialProduct);
-            try {
-              output.collect(userIDWritable, vectorWritable);
-            } catch (IOException ioe) {
-              throw new IllegalStateException(ioe);
-            }
-            return true;
+            output.collect(userIDWritable, vectorWritable);
+          } catch (IOException ioe) {
+            throw new IllegalStateException(ioe);
           }
-        });
-        savedValues.clear();
+          reporter.incrCounter(Counters.PRODUCTS_OUTPUT, 1L);
+        }
+        return true;
       }
-    }
+    });
+
+  }
 
+  private static float findSmallestLargeValue(OpenLongFloatHashMap 
savedValues) {
+    final PriorityQueue<Float> topPrefValues = new 
PriorityQueue<Float>(MAX_PRODUCTS_PER_ITEM + 1);
+    savedValues.forEachPair(new LongFloatProcedure() {
+      @Override
+      public boolean apply(long userID, float value) {
+        if (topPrefValues.size() < MAX_PRODUCTS_PER_ITEM) {
+          topPrefValues.add(value);
+        } else {
+          float absValue = Math.abs(value);
+          if (absValue > topPrefValues.peek()) {
+            topPrefValues.add(absValue);
+            topPrefValues.poll();
+          }
+        }
+        return true;
+      }
+    });
+    return topPrefValues.peek();
   }
 
 }
\ No newline at end of file

Modified: 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/UserVectorToCooccurrenceMapper.java
URL: 
http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/UserVectorToCooccurrenceMapper.java?rev=943028&r1=943027&r2=943028&view=diff
==============================================================================
--- 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/UserVectorToCooccurrenceMapper.java
 (original)
+++ 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/UserVectorToCooccurrenceMapper.java
 Tue May 11 07:50:13 2010
@@ -18,8 +18,9 @@
 package org.apache.mahout.cf.taste.hadoop.item;
 
 import java.io.IOException;
-import java.util.Arrays;
+import java.util.Collections;
 import java.util.Iterator;
+import java.util.PriorityQueue;
 
 import org.apache.hadoop.io.IntWritable;
 import org.apache.hadoop.io.VLongWritable;
@@ -29,7 +30,6 @@ import org.apache.hadoop.mapred.OutputCo
 import org.apache.hadoop.mapred.Reporter;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
-import org.apache.mahout.math.list.IntArrayList;
 import org.apache.mahout.math.map.OpenIntIntHashMap;
 
 public final class UserVectorToCooccurrenceMapper extends MapReduceBase 
implements
@@ -73,37 +73,28 @@ public final class UserVectorToCooccurre
       return userVector;
     }
 
-    OpenIntIntHashMap countCounts = new OpenIntIntHashMap();
+    PriorityQueue<Integer> smallCounts =
+        new PriorityQueue<Integer>(MAX_PREFS_CONSIDERED + 1, 
Collections.reverseOrder());
+
     Iterator<Vector.Element> it = userVector.iterateNonZero();
     while (it.hasNext()) {
-      int index = it.next().index();
-      int count = indexCounts.get(index);
-      countCounts.adjustOrPutValue(count, 1, 1);
-    }
-
-    IntArrayList countsList = new IntArrayList(countCounts.size());
-    countCounts.keys(countsList);
-    int[] counts = countsList.elements();
-    Arrays.sort(counts);
-
-    int resultingSizeAtCutoff = 0;
-    int cutoffIndex = 0;
-    while (cutoffIndex < counts.length && resultingSizeAtCutoff <= 
MAX_PREFS_CONSIDERED) {
-      int cutoff = counts[cutoffIndex];
-      cutoffIndex++;
-      int count = countCounts.get(cutoff);
-      resultingSizeAtCutoff += count;
+      int count = indexCounts.get(it.next().index());
+      if (count > 0) {
+        if (smallCounts.size() < MAX_PREFS_CONSIDERED) {
+          smallCounts.add(count);
+        } else if (count < smallCounts.peek()) {
+          smallCounts.add(count);
+          smallCounts.poll();
+        }
+      }
     }
-    cutoffIndex--;    
+    int greatestSmallCount = smallCounts.peek();
 
-    if (resultingSizeAtCutoff > MAX_PREFS_CONSIDERED) {
-      int cutoff = counts[cutoffIndex];
+    if (greatestSmallCount > 0) {
       Iterator<Vector.Element> it2 = userVector.iterateNonZero();
       while (it2.hasNext()) {
         Vector.Element e = it2.next();
-        int index = e.index();
-        int count = indexCounts.get(index);
-        if (count >= cutoff) {
+        if (indexCounts.get(e.index()) > greatestSmallCount) {
           e.set(0.0);
         }
       }


Reply via email to