Author: ssc
Date: Mon Aug 13 09:07:29 2012
New Revision: 1372332
URL: http://svn.apache.org/viewvc?rev=1372332&view=rev
Log:
MAHOUT-1056 ALSWRFactorizer should also handle implicit feedback data
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizerTest.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java?rev=1372332&r1=1372331&r2=1372332&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java
Mon Aug 13 09:07:29 2012
@@ -27,11 +27,15 @@ import org.apache.mahout.cf.taste.model.
import org.apache.mahout.cf.taste.model.PreferenceArray;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.als.AlternatingLeastSquaresSolver;
+import
org.apache.mahout.math.als.ImplicitFeedbackAlternatingLeastSquaresSolver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.apache.mahout.math.map.OpenIntObjectHashMap;
+
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutorService;
@@ -39,10 +43,12 @@ import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
/**
- * factorizes the rating matrix using "Alternating-Least-Squares with
Weighted-λ-Regularization" as described in
- * the paper
+ * factorizes the rating matrix using "Alternating-Least-Squares with
Weighted-λ-Regularization" as described in the paper
* <a
href="http://www.hpl.hp.com/personal/Robert_Schreiber/papers/2008%20AAIM%20Netflix/netflix_aaim08(submitted).pdf">
* "Large-scale Collaborative Filtering for the Netflix Prize"</a>
+ *
+ * also supports the implicit feedback variant of this approach as described
in "Collaborative Filtering for Implicit Feedback Datasets"
+ * available at http://research.yahoo.com/pub/2433
*/
public class ALSWRFactorizer extends AbstractFactorizer {
@@ -55,14 +61,27 @@ public class ALSWRFactorizer extends Abs
/** number of iterations */
private final int numIterations;
+ private final boolean usesImplicitFeedback;
+ /** confidence weighting parameter, only necessary when working with
implicit feedback */
+ private final double alpha;
+
+ private static final double DEFAULT_ALPHA = 40;
+
private static final Logger log =
LoggerFactory.getLogger(ALSWRFactorizer.class);
- public ALSWRFactorizer(DataModel dataModel, int numFeatures, double lambda,
int numIterations) throws TasteException {
+ public ALSWRFactorizer(DataModel dataModel, int numFeatures, double lambda,
int numIterations,
+ boolean usesImplicitFeedback, double alpha) throws TasteException {
super(dataModel);
this.dataModel = dataModel;
this.numFeatures = numFeatures;
this.lambda = lambda;
this.numIterations = numIterations;
+ this.usesImplicitFeedback = usesImplicitFeedback;
+ this.alpha = alpha;
+ }
+
+ public ALSWRFactorizer(DataModel dataModel, int numFeatures, double lambda,
int numIterations) throws TasteException {
+ this(dataModel, numFeatures, lambda, numIterations, false, DEFAULT_ALPHA);
}
static class Features {
@@ -135,6 +154,15 @@ public class ALSWRFactorizer extends Abs
log.info("starting to compute the factorization...");
final Features features = new Features(this);
+ /* feature maps necessary for solving for implicit feedback */
+ OpenIntObjectHashMap<Vector> userY = null;
+ OpenIntObjectHashMap<Vector> itemY = null;
+
+ if (usesImplicitFeedback) {
+ userY = userFeaturesMapping(dataModel.getUserIDs(),
dataModel.getNumUsers(), features.getU());
+ itemY = itemFeaturesMapping(dataModel.getItemIDs(),
dataModel.getNumItems(), features.getM());
+ }
+
for (int iteration = 0; iteration < numIterations; iteration++) {
log.info("iteration {}", iteration);
@@ -142,6 +170,10 @@ public class ALSWRFactorizer extends Abs
ExecutorService queue = createQueue();
LongPrimitiveIterator userIDsIterator = dataModel.getUserIDs();
try {
+
+ final ImplicitFeedbackAlternatingLeastSquaresSolver
implicitFeedbackSolver = usesImplicitFeedback ?
+ new ImplicitFeedbackAlternatingLeastSquaresSolver(numFeatures,
lambda, alpha, itemY) : null;
+
while (userIDsIterator.hasNext()) {
final long userID = userIDsIterator.nextLong();
final LongPrimitiveIterator itemIDsFromUser =
dataModel.getItemIDsFromUser(userID).iterator();
@@ -154,8 +186,11 @@ public class ALSWRFactorizer extends Abs
long itemID = itemIDsFromUser.nextLong();
featureVectors.add(features.getItemFeatureColumn(itemIndex(itemID)));
}
- Vector userFeatures =
+
+ Vector userFeatures = usesImplicitFeedback ?
+
implicitFeedbackSolver.solve(sparseUserRatingVector(userPrefs)) :
AlternatingLeastSquaresSolver.solve(featureVectors,
ratingVector(userPrefs), lambda, numFeatures);
+
features.setFeatureColumnInU(userIndex(userID), userFeatures);
}
});
@@ -173,6 +208,10 @@ public class ALSWRFactorizer extends Abs
queue = createQueue();
LongPrimitiveIterator itemIDsIterator = dataModel.getItemIDs();
try {
+
+ final ImplicitFeedbackAlternatingLeastSquaresSolver
implicitFeedbackSolver = usesImplicitFeedback ?
+ new ImplicitFeedbackAlternatingLeastSquaresSolver(numFeatures,
lambda, alpha, userY) : null;
+
while (itemIDsIterator.hasNext()) {
final long itemID = itemIDsIterator.nextLong();
final PreferenceArray itemPrefs =
dataModel.getPreferencesForItem(itemID);
@@ -184,8 +223,11 @@ public class ALSWRFactorizer extends Abs
long userID = pref.getUserID();
featureVectors.add(features.getUserFeatureColumn(userIndex(userID)));
}
- Vector itemFeatures =
+
+ Vector itemFeatures = usesImplicitFeedback ?
+
implicitFeedbackSolver.solve(sparseItemRatingVector(itemPrefs)) :
AlternatingLeastSquaresSolver.solve(featureVectors,
ratingVector(itemPrefs), lambda, numFeatures);
+
features.setFeatureColumnInM(itemIndex(itemID), itemFeatures);
}
});
@@ -204,7 +246,7 @@ public class ALSWRFactorizer extends Abs
return createFactorization(features.getU(), features.getM());
}
- protected static ExecutorService createQueue() {
+ protected ExecutorService createQueue() {
return
Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
}
@@ -213,6 +255,46 @@ public class ALSWRFactorizer extends Abs
for (int n = 0; n < prefs.length(); n++) {
ratings[n] = prefs.get(n).getValue();
}
- return new DenseVector(ratings);
+ return new DenseVector(ratings, true);
+ }
+
+ //TODO find a way to get rid of the object overhead here
+ protected OpenIntObjectHashMap<Vector>
itemFeaturesMapping(LongPrimitiveIterator itemIDs, int numItems,
+ double[][] featureMatrix) {
+ OpenIntObjectHashMap<Vector> mapping = new
OpenIntObjectHashMap<Vector>(numItems);
+ while (itemIDs.hasNext()) {
+ long itemID = itemIDs.next();
+ mapping.put((int) itemID, new
DenseVector(featureMatrix[itemIndex(itemID)], true));
+ }
+
+ return mapping;
+ }
+
+ protected OpenIntObjectHashMap<Vector>
userFeaturesMapping(LongPrimitiveIterator userIDs, int numUsers,
+ double[][] featureMatrix) {
+ OpenIntObjectHashMap<Vector> mapping = new
OpenIntObjectHashMap<Vector>(numUsers);
+
+ while (userIDs.hasNext()) {
+ long userID = userIDs.next();
+ mapping.put((int) userID, new
DenseVector(featureMatrix[userIndex(userID)], true));
+ }
+
+ return mapping;
+ }
+
+ protected Vector sparseItemRatingVector(PreferenceArray prefs) {
+ SequentialAccessSparseVector ratings = new
SequentialAccessSparseVector(Integer.MAX_VALUE, prefs.length());
+ for (Preference preference : prefs) {
+ ratings.set((int) preference.getUserID(), preference.getValue());
+ }
+ return ratings;
+ }
+
+ protected Vector sparseUserRatingVector(PreferenceArray prefs) {
+ SequentialAccessSparseVector ratings = new
SequentialAccessSparseVector(Integer.MAX_VALUE, prefs.length());
+ for (Preference preference : prefs) {
+ ratings.set((int) preference.getItemID(), preference.getValue());
+ }
+ return ratings;
}
}
Modified:
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizerTest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizerTest.java?rev=1372332&r1=1372331&r2=1372332&view=diff
==============================================================================
---
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizerTest.java
(original)
+++
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizerTest.java
Mon Aug 13 09:07:29 2012
@@ -29,26 +29,35 @@ import org.apache.mahout.cf.taste.model.
import org.apache.mahout.cf.taste.model.Preference;
import org.apache.mahout.cf.taste.model.PreferenceArray;
import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixSlice;
+import org.apache.mahout.math.SparseRowMatrix;
import org.apache.mahout.math.Vector;
import org.junit.Before;
import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import java.util.Arrays;
+import java.util.Iterator;
public class ALSWRFactorizerTest extends TasteTestCase {
private ALSWRFactorizer factorizer;
private DataModel dataModel;
- /**
- * rating-matrix
- *
- * burger hotdog berries icecream
- * dog 5 5 2 -
- * rabbit 2 - 3 5
- * cow - 5 - 3
- * donkey 3 - - 5
- */
+ private static final Logger log =
LoggerFactory.getLogger(ALSWRFactorizerTest.class);
+
+ /**
+ * rating-matrix
+ *
+ * burger hotdog berries icecream
+ * dog 5 5 2 -
+ * rabbit 2 - 3 5
+ * cow - 5 - 3
+ * donkey 3 - - 5
+ */
+
@Override
@Before
public void setUp() throws Exception {
@@ -145,4 +154,52 @@ public class ALSWRFactorizerTest extends
double rmse = Math.sqrt(avg.getAverage());
assertTrue(rmse < 0.2);
}
+
+ @Test
+ public void toyExampleImplicit() throws Exception {
+
+ Matrix observations = new SparseRowMatrix(4, 4, new Vector[] {
+ new DenseVector(new double[] { 5.0, 5.0, 2.0, 0 }),
+ new DenseVector(new double[] { 2.0, 0, 3.0, 5.0 }),
+ new DenseVector(new double[] { 0, 5.0, 0, 3.0 }),
+ new DenseVector(new double[] { 3.0, 0, 0, 5.0 }) });
+
+ Matrix preferences = new SparseRowMatrix(4, 4, new Vector[] {
+ new DenseVector(new double[] { 1.0, 1.0, 1.0, 0 }),
+ new DenseVector(new double[] { 1.0, 0, 1.0, 1.0 }),
+ new DenseVector(new double[] { 0, 1.0, 0, 1.0 }),
+ new DenseVector(new double[] { 1.0, 0, 0, 1.0 }) });
+
+ double alpha = 20;
+
+ ALSWRFactorizer factorizer = new ALSWRFactorizer(dataModel, 3, 0.065, 5,
true, alpha);
+
+ SVDRecommender svdRecommender = new SVDRecommender(dataModel, factorizer);
+
+ RunningAverage avg = new FullRunningAverage();
+ Iterator<MatrixSlice> sliceIterator = preferences.iterateAll();
+ while (sliceIterator.hasNext()) {
+ MatrixSlice slice = sliceIterator.next();
+ for (Vector.Element e : slice.vector()) {
+
+ long userID = slice.index() + 1;
+ long itemID = e.index() + 1;
+
+ if (!Double.isNaN(e.get())) {
+ double pref = e.get();
+ double estimate = svdRecommender.estimatePreference(userID, itemID);
+
+ double confidence = 1 + alpha * observations.getQuick(slice.index(),
e.index());
+ double err = confidence * (pref - estimate) * (pref - estimate);
+ avg.addDatum(err);
+ log.info("Comparing preference of user [{}] towards item [{}], was
[{}] with confidence [{}] "
+ + "estimate is [{}]", new Object[] { slice.index(), e.index(),
pref, confidence, estimate });
+ }
+ }
+ }
+ double rmse = Math.sqrt(avg.getAverage());
+ log.info("RMSE: {}", rmse);
+
+ assertTrue(rmse < 0.4);
+ }
}