Author: srowen
Date: Tue May 11 07:50:13 2010
New Revision: 943028
URL: http://svn.apache.org/viewvc?rev=943028&view=rev
Log:
Interim commit of more fixes and improvements
Modified:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/AggregateAndRecommendReducer.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/PartialMultiplyReducer.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/UserVectorToCooccurrenceMapper.java
Modified:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/AggregateAndRecommendReducer.java
URL:
http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/AggregateAndRecommendReducer.java?rev=943028&r1=943027&r2=943028&view=diff
==============================================================================
---
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/AggregateAndRecommendReducer.java
(original)
+++
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/AggregateAndRecommendReducer.java
Tue May 11 07:50:13 2010
@@ -105,12 +105,11 @@ public final class AggregateAndRecommend
while (recommendationVectorIterator.hasNext()) {
Vector.Element element = recommendationVectorIterator.next();
int index = element.index();
- if (topItems.size() < recommendationsPerUser) {
- long theItemID = indexItemIDMap.get(index);
- topItems.add(new GenericRecommendedItem(theItemID, (float)
element.get()));
- } else if (element.get() > topItems.peek().getValue()) {
- long theItemID = indexItemIDMap.get(index);
- topItems.add(new GenericRecommendedItem(theItemID, (float)
element.get()));
+ float value = (float) element.get();
+ if (topItems.size() < recommendationsPerUser && !Float.isNaN(value)) {
+ topItems.add(new GenericRecommendedItem(indexItemIDMap.get(index),
value));
+ } else if (value > topItems.peek().getValue()) {
+ topItems.add(new GenericRecommendedItem(indexItemIDMap.get(index),
value));
topItems.poll();
}
}
Modified:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/PartialMultiplyReducer.java
URL:
http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/PartialMultiplyReducer.java?rev=943028&r1=943027&r2=943028&view=diff
==============================================================================
---
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/PartialMultiplyReducer.java
(original)
+++
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/PartialMultiplyReducer.java
Tue May 11 07:50:13 2010
@@ -19,6 +19,7 @@ package org.apache.mahout.cf.taste.hadoo
import java.io.IOException;
import java.util.Iterator;
+import java.util.PriorityQueue;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.VLongWritable;
@@ -26,76 +27,123 @@ import org.apache.hadoop.mapred.MapReduc
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.hadoop.mapred.Reducer;
import org.apache.hadoop.mapred.Reporter;
+import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.LongFloatProcedure;
+import org.apache.mahout.math.function.LongProcedure;
import org.apache.mahout.math.map.OpenLongFloatHashMap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
public final class PartialMultiplyReducer extends MapReduceBase implements
Reducer<IntWritable,VectorOrPrefWritable,VLongWritable,VectorWritable> {
+ private static final Logger log =
LoggerFactory.getLogger(PartialMultiplyReducer.class);
+
+ private static final int MAX_PRODUCTS_PER_ITEM = 1000;
+
+ private enum Counters {
+ PRODUCTS_OUTPUT,
+ PRODUCTS_SKIPPED,
+ }
+
@Override
public void reduce(IntWritable key,
Iterator<VectorOrPrefWritable> values,
final OutputCollector<VLongWritable,VectorWritable>
output,
- Reporter reporter) throws IOException {
+ final Reporter reporter) {
+ int itemIndex = key.get();
OpenLongFloatHashMap savedValues = new OpenLongFloatHashMap();
- final int itemIndex = key.get();
- final VLongWritable userIDWritable = new VLongWritable();
- final VectorWritable vectorWritable = new VectorWritable();
- vectorWritable.setWritesLaxPrecision(true);
Vector cooccurrenceColumn = null;
while (values.hasNext()) {
-
VectorOrPrefWritable value = values.next();
if (value.getVector() == null) {
-
// Then this is a user-pref value
- long userID = value.getUserID();
- float preferenceValue = value.getValue();
-
- if (cooccurrenceColumn == null) {
- // Haven't seen the co-occurrencce column yet; save it
- savedValues.put(userID, preferenceValue);
- } else {
- // Have seen it
- Vector partialProduct = preferenceValue == 1.0f ?
- cooccurrenceColumn : cooccurrenceColumn.times(preferenceValue);
- // This makes sure this item isn't recommended for this user:
- partialProduct.set(itemIndex, Double.NEGATIVE_INFINITY);
- userIDWritable.set(userID);
- vectorWritable.set(partialProduct);
- output.collect(userIDWritable, vectorWritable);
- }
-
+ savedValues.put(value.getUserID(), value.getValue());
} else {
-
// Then this is the column vector
+ if (cooccurrenceColumn != null) {
+ throw new IllegalStateException("Found two co-occurrence columns for
item index " + itemIndex);
+ }
cooccurrenceColumn = value.getVector();
+ }
+ }
+
+ if (cooccurrenceColumn == null) {
+ log.info("Column vector missing for {}; continuing", itemIndex);
+ return;
+ }
- final Vector theColumn = cooccurrenceColumn;
- savedValues.forEachPair(new LongFloatProcedure() {
- @Override
- public boolean apply(long userID, float value) {
- Vector partialProduct = theColumn.times(value);
- // This makes sure this item isn't recommended for this user:
- partialProduct.set(itemIndex, Double.NEGATIVE_INFINITY);
+ final VLongWritable userIDWritable = new VLongWritable();
+
+ // These single-element vectors ensure that each user will not be
recommended
+ // this item
+ Vector excludeVector = new RandomAccessSparseVector(Integer.MAX_VALUE, 1);
+ excludeVector.set(itemIndex, Double.NaN);
+ final VectorWritable excludeWritable = new VectorWritable(excludeVector);
+ excludeWritable.setWritesLaxPrecision(true);
+ savedValues.forEachKey(new LongProcedure() {
+ @Override
+ public boolean apply(long userID) {
+ userIDWritable.set(userID);
+ try {
+ output.collect(userIDWritable, excludeWritable);
+ } catch (IOException ioe) {
+ throw new IllegalStateException(ioe);
+ }
+ return true;
+ }
+ });
+
+ final float smallestLargeValue = findSmallestLargeValue(savedValues);
+
+ final VectorWritable vectorWritable = new VectorWritable();
+ vectorWritable.setWritesLaxPrecision(true);
+
+ final Vector theColumn = cooccurrenceColumn;
+ savedValues.forEachPair(new LongFloatProcedure() {
+ @Override
+ public boolean apply(long userID, float value) {
+ if (Math.abs(value) < smallestLargeValue) {
+ reporter.incrCounter(Counters.PRODUCTS_SKIPPED, 1L);
+ } else {
+ try {
+ Vector partialProduct = value == 1.0f ? theColumn :
theColumn.times(value);
userIDWritable.set(userID);
vectorWritable.set(partialProduct);
- try {
- output.collect(userIDWritable, vectorWritable);
- } catch (IOException ioe) {
- throw new IllegalStateException(ioe);
- }
- return true;
+ output.collect(userIDWritable, vectorWritable);
+ } catch (IOException ioe) {
+ throw new IllegalStateException(ioe);
}
- });
- savedValues.clear();
+ reporter.incrCounter(Counters.PRODUCTS_OUTPUT, 1L);
+ }
+ return true;
}
- }
+ });
+
+ }
+ private static float findSmallestLargeValue(OpenLongFloatHashMap
savedValues) {
+ final PriorityQueue<Float> topPrefValues = new
PriorityQueue<Float>(MAX_PRODUCTS_PER_ITEM + 1);
+ savedValues.forEachPair(new LongFloatProcedure() {
+ @Override
+ public boolean apply(long userID, float value) {
+ if (topPrefValues.size() < MAX_PRODUCTS_PER_ITEM) {
+ topPrefValues.add(value);
+ } else {
+ float absValue = Math.abs(value);
+ if (absValue > topPrefValues.peek()) {
+ topPrefValues.add(absValue);
+ topPrefValues.poll();
+ }
+ }
+ return true;
+ }
+ });
+ return topPrefValues.peek();
}
}
\ No newline at end of file
Modified:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/UserVectorToCooccurrenceMapper.java
URL:
http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/UserVectorToCooccurrenceMapper.java?rev=943028&r1=943027&r2=943028&view=diff
==============================================================================
---
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/UserVectorToCooccurrenceMapper.java
(original)
+++
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/UserVectorToCooccurrenceMapper.java
Tue May 11 07:50:13 2010
@@ -18,8 +18,9 @@
package org.apache.mahout.cf.taste.hadoop.item;
import java.io.IOException;
-import java.util.Arrays;
+import java.util.Collections;
import java.util.Iterator;
+import java.util.PriorityQueue;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.VLongWritable;
@@ -29,7 +30,6 @@ import org.apache.hadoop.mapred.OutputCo
import org.apache.hadoop.mapred.Reporter;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
-import org.apache.mahout.math.list.IntArrayList;
import org.apache.mahout.math.map.OpenIntIntHashMap;
public final class UserVectorToCooccurrenceMapper extends MapReduceBase
implements
@@ -73,37 +73,28 @@ public final class UserVectorToCooccurre
return userVector;
}
- OpenIntIntHashMap countCounts = new OpenIntIntHashMap();
+ PriorityQueue<Integer> smallCounts =
+ new PriorityQueue<Integer>(MAX_PREFS_CONSIDERED + 1,
Collections.reverseOrder());
+
Iterator<Vector.Element> it = userVector.iterateNonZero();
while (it.hasNext()) {
- int index = it.next().index();
- int count = indexCounts.get(index);
- countCounts.adjustOrPutValue(count, 1, 1);
- }
-
- IntArrayList countsList = new IntArrayList(countCounts.size());
- countCounts.keys(countsList);
- int[] counts = countsList.elements();
- Arrays.sort(counts);
-
- int resultingSizeAtCutoff = 0;
- int cutoffIndex = 0;
- while (cutoffIndex < counts.length && resultingSizeAtCutoff <=
MAX_PREFS_CONSIDERED) {
- int cutoff = counts[cutoffIndex];
- cutoffIndex++;
- int count = countCounts.get(cutoff);
- resultingSizeAtCutoff += count;
+ int count = indexCounts.get(it.next().index());
+ if (count > 0) {
+ if (smallCounts.size() < MAX_PREFS_CONSIDERED) {
+ smallCounts.add(count);
+ } else if (count < smallCounts.peek()) {
+ smallCounts.add(count);
+ smallCounts.poll();
+ }
+ }
}
- cutoffIndex--;
+ int greatestSmallCount = smallCounts.peek();
- if (resultingSizeAtCutoff > MAX_PREFS_CONSIDERED) {
- int cutoff = counts[cutoffIndex];
+ if (greatestSmallCount > 0) {
Iterator<Vector.Element> it2 = userVector.iterateNonZero();
while (it2.hasNext()) {
Vector.Element e = it2.next();
- int index = e.index();
- int count = indexCounts.get(index);
- if (count >= cutoff) {
+ if (indexCounts.get(e.index()) > greatestSmallCount) {
e.set(0.0);
}
}