http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/TopItems.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/TopItems.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/TopItems.java new file mode 100644 index 0000000..f7b4385 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/TopItems.java @@ -0,0 +1,211 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.recommender; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.PriorityQueue; +import java.util.Queue; + +import com.google.common.base.Preconditions; +import org.apache.mahout.cf.taste.common.NoSuchItemException; +import org.apache.mahout.cf.taste.common.NoSuchUserException; +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; +import org.apache.mahout.cf.taste.impl.similarity.GenericItemSimilarity; +import org.apache.mahout.cf.taste.impl.similarity.GenericUserSimilarity; +import org.apache.mahout.cf.taste.recommender.IDRescorer; +import org.apache.mahout.cf.taste.recommender.RecommendedItem; + +/** + * <p> + * A simple class that refactors the "find top N things" logic that is used in several places. + * </p> + */ +public final class TopItems { + + private static final long[] NO_IDS = new long[0]; + + private TopItems() { } + + public static List<RecommendedItem> getTopItems(int howMany, + LongPrimitiveIterator possibleItemIDs, + IDRescorer rescorer, + Estimator<Long> estimator) throws TasteException { + Preconditions.checkArgument(possibleItemIDs != null, "possibleItemIDs is null"); + Preconditions.checkArgument(estimator != null, "estimator is null"); + + Queue<RecommendedItem> topItems = new PriorityQueue<>(howMany + 1, + Collections.reverseOrder(ByValueRecommendedItemComparator.getInstance())); + boolean full = false; + double lowestTopValue = Double.NEGATIVE_INFINITY; + while (possibleItemIDs.hasNext()) { + long itemID = possibleItemIDs.next(); + if (rescorer == null || !rescorer.isFiltered(itemID)) { + double preference; + try { + preference = estimator.estimate(itemID); + } catch (NoSuchItemException nsie) { + continue; + } + double rescoredPref = rescorer == null ? preference : rescorer.rescore(itemID, preference); + if (!Double.isNaN(rescoredPref) && (!full || rescoredPref > lowestTopValue)) { + topItems.add(new GenericRecommendedItem(itemID, (float) rescoredPref)); + if (full) { + topItems.poll(); + } else if (topItems.size() > howMany) { + full = true; + topItems.poll(); + } + lowestTopValue = topItems.peek().getValue(); + } + } + } + int size = topItems.size(); + if (size == 0) { + return Collections.emptyList(); + } + List<RecommendedItem> result = new ArrayList<>(size); + result.addAll(topItems); + Collections.sort(result, ByValueRecommendedItemComparator.getInstance()); + return result; + } + + public static long[] getTopUsers(int howMany, + LongPrimitiveIterator allUserIDs, + IDRescorer rescorer, + Estimator<Long> estimator) throws TasteException { + Queue<SimilarUser> topUsers = new PriorityQueue<>(howMany + 1, Collections.reverseOrder()); + boolean full = false; + double lowestTopValue = Double.NEGATIVE_INFINITY; + while (allUserIDs.hasNext()) { + long userID = allUserIDs.next(); + if (rescorer != null && rescorer.isFiltered(userID)) { + continue; + } + double similarity; + try { + similarity = estimator.estimate(userID); + } catch (NoSuchUserException nsue) { + continue; + } + double rescoredSimilarity = rescorer == null ? similarity : rescorer.rescore(userID, similarity); + if (!Double.isNaN(rescoredSimilarity) && (!full || rescoredSimilarity > lowestTopValue)) { + topUsers.add(new SimilarUser(userID, rescoredSimilarity)); + if (full) { + topUsers.poll(); + } else if (topUsers.size() > howMany) { + full = true; + topUsers.poll(); + } + lowestTopValue = topUsers.peek().getSimilarity(); + } + } + int size = topUsers.size(); + if (size == 0) { + return NO_IDS; + } + List<SimilarUser> sorted = new ArrayList<>(size); + sorted.addAll(topUsers); + Collections.sort(sorted); + long[] result = new long[size]; + int i = 0; + for (SimilarUser similarUser : sorted) { + result[i++] = similarUser.getUserID(); + } + return result; + } + + /** + * <p> + * Thanks to tsmorton for suggesting this functionality and writing part of the code. + * </p> + * + * @see GenericItemSimilarity#GenericItemSimilarity(Iterable, int) + * @see GenericItemSimilarity#GenericItemSimilarity(org.apache.mahout.cf.taste.similarity.ItemSimilarity, + * org.apache.mahout.cf.taste.model.DataModel, int) + */ + public static List<GenericItemSimilarity.ItemItemSimilarity> getTopItemItemSimilarities( + int howMany, Iterator<GenericItemSimilarity.ItemItemSimilarity> allSimilarities) { + + Queue<GenericItemSimilarity.ItemItemSimilarity> topSimilarities + = new PriorityQueue<>(howMany + 1, Collections.reverseOrder()); + boolean full = false; + double lowestTopValue = Double.NEGATIVE_INFINITY; + while (allSimilarities.hasNext()) { + GenericItemSimilarity.ItemItemSimilarity similarity = allSimilarities.next(); + double value = similarity.getValue(); + if (!Double.isNaN(value) && (!full || value > lowestTopValue)) { + topSimilarities.add(similarity); + if (full) { + topSimilarities.poll(); + } else if (topSimilarities.size() > howMany) { + full = true; + topSimilarities.poll(); + } + lowestTopValue = topSimilarities.peek().getValue(); + } + } + int size = topSimilarities.size(); + if (size == 0) { + return Collections.emptyList(); + } + List<GenericItemSimilarity.ItemItemSimilarity> result = new ArrayList<>(size); + result.addAll(topSimilarities); + Collections.sort(result); + return result; + } + + public static List<GenericUserSimilarity.UserUserSimilarity> getTopUserUserSimilarities( + int howMany, Iterator<GenericUserSimilarity.UserUserSimilarity> allSimilarities) { + + Queue<GenericUserSimilarity.UserUserSimilarity> topSimilarities + = new PriorityQueue<>(howMany + 1, Collections.reverseOrder()); + boolean full = false; + double lowestTopValue = Double.NEGATIVE_INFINITY; + while (allSimilarities.hasNext()) { + GenericUserSimilarity.UserUserSimilarity similarity = allSimilarities.next(); + double value = similarity.getValue(); + if (!Double.isNaN(value) && (!full || value > lowestTopValue)) { + topSimilarities.add(similarity); + if (full) { + topSimilarities.poll(); + } else if (topSimilarities.size() > howMany) { + full = true; + topSimilarities.poll(); + } + lowestTopValue = topSimilarities.peek().getValue(); + } + } + int size = topSimilarities.size(); + if (size == 0) { + return Collections.emptyList(); + } + List<GenericUserSimilarity.UserUserSimilarity> result = new ArrayList<>(size); + result.addAll(topSimilarities); + Collections.sort(result); + return result; + } + + public interface Estimator<T> { + double estimate(T thing) throws TasteException; + } + +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java new file mode 100644 index 0000000..0ba5139 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java @@ -0,0 +1,312 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.recommender.svd; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.FullRunningAverage; +import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; +import org.apache.mahout.cf.taste.impl.common.RunningAverage; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.model.Preference; +import org.apache.mahout.cf.taste.model.PreferenceArray; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.SequentialAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.als.AlternatingLeastSquaresSolver; +import org.apache.mahout.math.als.ImplicitFeedbackAlternatingLeastSquaresSolver; +import org.apache.mahout.math.map.OpenIntObjectHashMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * factorizes the rating matrix using "Alternating-Least-Squares with Weighted-λ-Regularization" as described in + * <a href="http://www.hpl.hp.com/personal/Robert_Schreiber/papers/2008%20AAIM%20Netflix/netflix_aaim08(submitted).pdf"> + * "Large-scale Collaborative Filtering for the Netflix Prize"</a> + * + * also supports the implicit feedback variant of this approach as described in "Collaborative Filtering for Implicit + * Feedback Datasets" available at http://research.yahoo.com/pub/2433 + */ +public class ALSWRFactorizer extends AbstractFactorizer { + + private final DataModel dataModel; + + /** number of features used to compute this factorization */ + private final int numFeatures; + /** parameter to control the regularization */ + private final double lambda; + /** number of iterations */ + private final int numIterations; + + private final boolean usesImplicitFeedback; + /** confidence weighting parameter, only necessary when working with implicit feedback */ + private final double alpha; + + private final int numTrainingThreads; + + private static final double DEFAULT_ALPHA = 40; + + private static final Logger log = LoggerFactory.getLogger(ALSWRFactorizer.class); + + public ALSWRFactorizer(DataModel dataModel, int numFeatures, double lambda, int numIterations, + boolean usesImplicitFeedback, double alpha, int numTrainingThreads) throws TasteException { + super(dataModel); + this.dataModel = dataModel; + this.numFeatures = numFeatures; + this.lambda = lambda; + this.numIterations = numIterations; + this.usesImplicitFeedback = usesImplicitFeedback; + this.alpha = alpha; + this.numTrainingThreads = numTrainingThreads; + } + + public ALSWRFactorizer(DataModel dataModel, int numFeatures, double lambda, int numIterations, + boolean usesImplicitFeedback, double alpha) throws TasteException { + this(dataModel, numFeatures, lambda, numIterations, usesImplicitFeedback, alpha, + Runtime.getRuntime().availableProcessors()); + } + + public ALSWRFactorizer(DataModel dataModel, int numFeatures, double lambda, int numIterations) throws TasteException { + this(dataModel, numFeatures, lambda, numIterations, false, DEFAULT_ALPHA); + } + + static class Features { + + private final DataModel dataModel; + private final int numFeatures; + + private final double[][] M; + private final double[][] U; + + Features(ALSWRFactorizer factorizer) throws TasteException { + dataModel = factorizer.dataModel; + numFeatures = factorizer.numFeatures; + Random random = RandomUtils.getRandom(); + M = new double[dataModel.getNumItems()][numFeatures]; + LongPrimitiveIterator itemIDsIterator = dataModel.getItemIDs(); + while (itemIDsIterator.hasNext()) { + long itemID = itemIDsIterator.nextLong(); + int itemIDIndex = factorizer.itemIndex(itemID); + M[itemIDIndex][0] = averateRating(itemID); + for (int feature = 1; feature < numFeatures; feature++) { + M[itemIDIndex][feature] = random.nextDouble() * 0.1; + } + } + U = new double[dataModel.getNumUsers()][numFeatures]; + } + + double[][] getM() { + return M; + } + + double[][] getU() { + return U; + } + + Vector getUserFeatureColumn(int index) { + return new DenseVector(U[index]); + } + + Vector getItemFeatureColumn(int index) { + return new DenseVector(M[index]); + } + + void setFeatureColumnInU(int idIndex, Vector vector) { + setFeatureColumn(U, idIndex, vector); + } + + void setFeatureColumnInM(int idIndex, Vector vector) { + setFeatureColumn(M, idIndex, vector); + } + + protected void setFeatureColumn(double[][] matrix, int idIndex, Vector vector) { + for (int feature = 0; feature < numFeatures; feature++) { + matrix[idIndex][feature] = vector.get(feature); + } + } + + protected double averateRating(long itemID) throws TasteException { + PreferenceArray prefs = dataModel.getPreferencesForItem(itemID); + RunningAverage avg = new FullRunningAverage(); + for (Preference pref : prefs) { + avg.addDatum(pref.getValue()); + } + return avg.getAverage(); + } + } + + @Override + public Factorization factorize() throws TasteException { + log.info("starting to compute the factorization..."); + final Features features = new Features(this); + + /* feature maps necessary for solving for implicit feedback */ + OpenIntObjectHashMap<Vector> userY = null; + OpenIntObjectHashMap<Vector> itemY = null; + + if (usesImplicitFeedback) { + userY = userFeaturesMapping(dataModel.getUserIDs(), dataModel.getNumUsers(), features.getU()); + itemY = itemFeaturesMapping(dataModel.getItemIDs(), dataModel.getNumItems(), features.getM()); + } + + for (int iteration = 0; iteration < numIterations; iteration++) { + log.info("iteration {}", iteration); + + /* fix M - compute U */ + ExecutorService queue = createQueue(); + LongPrimitiveIterator userIDsIterator = dataModel.getUserIDs(); + try { + + final ImplicitFeedbackAlternatingLeastSquaresSolver implicitFeedbackSolver = usesImplicitFeedback + ? new ImplicitFeedbackAlternatingLeastSquaresSolver(numFeatures, lambda, alpha, itemY, numTrainingThreads) + : null; + + while (userIDsIterator.hasNext()) { + final long userID = userIDsIterator.nextLong(); + final LongPrimitiveIterator itemIDsFromUser = dataModel.getItemIDsFromUser(userID).iterator(); + final PreferenceArray userPrefs = dataModel.getPreferencesFromUser(userID); + queue.execute(new Runnable() { + @Override + public void run() { + List<Vector> featureVectors = new ArrayList<>(); + while (itemIDsFromUser.hasNext()) { + long itemID = itemIDsFromUser.nextLong(); + featureVectors.add(features.getItemFeatureColumn(itemIndex(itemID))); + } + + Vector userFeatures = usesImplicitFeedback + ? implicitFeedbackSolver.solve(sparseUserRatingVector(userPrefs)) + : AlternatingLeastSquaresSolver.solve(featureVectors, ratingVector(userPrefs), lambda, numFeatures); + + features.setFeatureColumnInU(userIndex(userID), userFeatures); + } + }); + } + } finally { + queue.shutdown(); + try { + queue.awaitTermination(dataModel.getNumUsers(), TimeUnit.SECONDS); + } catch (InterruptedException e) { + log.warn("Error when computing user features", e); + } + } + + /* fix U - compute M */ + queue = createQueue(); + LongPrimitiveIterator itemIDsIterator = dataModel.getItemIDs(); + try { + + final ImplicitFeedbackAlternatingLeastSquaresSolver implicitFeedbackSolver = usesImplicitFeedback + ? new ImplicitFeedbackAlternatingLeastSquaresSolver(numFeatures, lambda, alpha, userY, numTrainingThreads) + : null; + + while (itemIDsIterator.hasNext()) { + final long itemID = itemIDsIterator.nextLong(); + final PreferenceArray itemPrefs = dataModel.getPreferencesForItem(itemID); + queue.execute(new Runnable() { + @Override + public void run() { + List<Vector> featureVectors = new ArrayList<>(); + for (Preference pref : itemPrefs) { + long userID = pref.getUserID(); + featureVectors.add(features.getUserFeatureColumn(userIndex(userID))); + } + + Vector itemFeatures = usesImplicitFeedback + ? implicitFeedbackSolver.solve(sparseItemRatingVector(itemPrefs)) + : AlternatingLeastSquaresSolver.solve(featureVectors, ratingVector(itemPrefs), lambda, numFeatures); + + features.setFeatureColumnInM(itemIndex(itemID), itemFeatures); + } + }); + } + } finally { + queue.shutdown(); + try { + queue.awaitTermination(dataModel.getNumItems(), TimeUnit.SECONDS); + } catch (InterruptedException e) { + log.warn("Error when computing item features", e); + } + } + } + + log.info("finished computation of the factorization..."); + return createFactorization(features.getU(), features.getM()); + } + + protected ExecutorService createQueue() { + return Executors.newFixedThreadPool(numTrainingThreads); + } + + protected static Vector ratingVector(PreferenceArray prefs) { + double[] ratings = new double[prefs.length()]; + for (int n = 0; n < prefs.length(); n++) { + ratings[n] = prefs.get(n).getValue(); + } + return new DenseVector(ratings, true); + } + + //TODO find a way to get rid of the object overhead here + protected OpenIntObjectHashMap<Vector> itemFeaturesMapping(LongPrimitiveIterator itemIDs, int numItems, + double[][] featureMatrix) { + OpenIntObjectHashMap<Vector> mapping = new OpenIntObjectHashMap<>(numItems); + while (itemIDs.hasNext()) { + long itemID = itemIDs.next(); + int itemIndex = itemIndex(itemID); + mapping.put(itemIndex, new DenseVector(featureMatrix[itemIndex(itemID)], true)); + } + + return mapping; + } + + protected OpenIntObjectHashMap<Vector> userFeaturesMapping(LongPrimitiveIterator userIDs, int numUsers, + double[][] featureMatrix) { + OpenIntObjectHashMap<Vector> mapping = new OpenIntObjectHashMap<>(numUsers); + + while (userIDs.hasNext()) { + long userID = userIDs.next(); + int userIndex = userIndex(userID); + mapping.put(userIndex, new DenseVector(featureMatrix[userIndex(userID)], true)); + } + + return mapping; + } + + protected Vector sparseItemRatingVector(PreferenceArray prefs) { + SequentialAccessSparseVector ratings = new SequentialAccessSparseVector(Integer.MAX_VALUE, prefs.length()); + for (Preference preference : prefs) { + ratings.set(userIndex(preference.getUserID()), preference.getValue()); + } + return ratings; + } + + protected Vector sparseUserRatingVector(PreferenceArray prefs) { + SequentialAccessSparseVector ratings = new SequentialAccessSparseVector(Integer.MAX_VALUE, prefs.length()); + for (Preference preference : prefs) { + ratings.set(itemIndex(preference.getItemID()), preference.getValue()); + } + return ratings; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/AbstractFactorizer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/AbstractFactorizer.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/AbstractFactorizer.java new file mode 100644 index 0000000..0a39a1d --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/AbstractFactorizer.java @@ -0,0 +1,94 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.recommender.svd; + +import java.util.Collection; +import java.util.concurrent.Callable; + +import org.apache.mahout.cf.taste.common.Refreshable; +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.FastByIDMap; +import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; +import org.apache.mahout.cf.taste.impl.common.RefreshHelper; +import org.apache.mahout.cf.taste.model.DataModel; + +/** + * base class for {@link Factorizer}s, provides ID to index mapping + */ +public abstract class AbstractFactorizer implements Factorizer { + + private final DataModel dataModel; + private FastByIDMap<Integer> userIDMapping; + private FastByIDMap<Integer> itemIDMapping; + private final RefreshHelper refreshHelper; + + protected AbstractFactorizer(DataModel dataModel) throws TasteException { + this.dataModel = dataModel; + buildMappings(); + refreshHelper = new RefreshHelper(new Callable<Object>() { + @Override + public Object call() throws TasteException { + buildMappings(); + return null; + } + }); + refreshHelper.addDependency(dataModel); + } + + private void buildMappings() throws TasteException { + userIDMapping = createIDMapping(dataModel.getNumUsers(), dataModel.getUserIDs()); + itemIDMapping = createIDMapping(dataModel.getNumItems(), dataModel.getItemIDs()); + } + + protected Factorization createFactorization(double[][] userFeatures, double[][] itemFeatures) { + return new Factorization(userIDMapping, itemIDMapping, userFeatures, itemFeatures); + } + + protected Integer userIndex(long userID) { + Integer userIndex = userIDMapping.get(userID); + if (userIndex == null) { + userIndex = userIDMapping.size(); + userIDMapping.put(userID, userIndex); + } + return userIndex; + } + + protected Integer itemIndex(long itemID) { + Integer itemIndex = itemIDMapping.get(itemID); + if (itemIndex == null) { + itemIndex = itemIDMapping.size(); + itemIDMapping.put(itemID, itemIndex); + } + return itemIndex; + } + + private static FastByIDMap<Integer> createIDMapping(int size, LongPrimitiveIterator idIterator) { + FastByIDMap<Integer> mapping = new FastByIDMap<>(size); + int index = 0; + while (idIterator.hasNext()) { + mapping.put(idIterator.nextLong(), index++); + } + return mapping; + } + + @Override + public void refresh(Collection<Refreshable> alreadyRefreshed) { + refreshHelper.refresh(alreadyRefreshed); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorization.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorization.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorization.java new file mode 100644 index 0000000..f169a60 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorization.java @@ -0,0 +1,137 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.recommender.svd; + +import java.util.Arrays; +import java.util.Map; + +import com.google.common.base.Preconditions; +import org.apache.mahout.cf.taste.common.NoSuchItemException; +import org.apache.mahout.cf.taste.common.NoSuchUserException; +import org.apache.mahout.cf.taste.impl.common.FastByIDMap; +import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; + +/** + * a factorization of the rating matrix + */ +public class Factorization { + + /** used to find the rows in the user features matrix by userID */ + private final FastByIDMap<Integer> userIDMapping; + /** used to find the rows in the item features matrix by itemID */ + private final FastByIDMap<Integer> itemIDMapping; + + /** user features matrix */ + private final double[][] userFeatures; + /** item features matrix */ + private final double[][] itemFeatures; + + public Factorization(FastByIDMap<Integer> userIDMapping, FastByIDMap<Integer> itemIDMapping, double[][] userFeatures, + double[][] itemFeatures) { + this.userIDMapping = Preconditions.checkNotNull(userIDMapping); + this.itemIDMapping = Preconditions.checkNotNull(itemIDMapping); + this.userFeatures = userFeatures; + this.itemFeatures = itemFeatures; + } + + public double[][] allUserFeatures() { + return userFeatures; + } + + public double[] getUserFeatures(long userID) throws NoSuchUserException { + Integer index = userIDMapping.get(userID); + if (index == null) { + throw new NoSuchUserException(userID); + } + return userFeatures[index]; + } + + public double[][] allItemFeatures() { + return itemFeatures; + } + + public double[] getItemFeatures(long itemID) throws NoSuchItemException { + Integer index = itemIDMapping.get(itemID); + if (index == null) { + throw new NoSuchItemException(itemID); + } + return itemFeatures[index]; + } + + public int userIndex(long userID) throws NoSuchUserException { + Integer index = userIDMapping.get(userID); + if (index == null) { + throw new NoSuchUserException(userID); + } + return index; + } + + public Iterable<Map.Entry<Long,Integer>> getUserIDMappings() { + return userIDMapping.entrySet(); + } + + public LongPrimitiveIterator getUserIDMappingKeys() { + return userIDMapping.keySetIterator(); + } + + public int itemIndex(long itemID) throws NoSuchItemException { + Integer index = itemIDMapping.get(itemID); + if (index == null) { + throw new NoSuchItemException(itemID); + } + return index; + } + + public Iterable<Map.Entry<Long,Integer>> getItemIDMappings() { + return itemIDMapping.entrySet(); + } + + public LongPrimitiveIterator getItemIDMappingKeys() { + return itemIDMapping.keySetIterator(); + } + + public int numFeatures() { + return userFeatures.length > 0 ? userFeatures[0].length : 0; + } + + public int numUsers() { + return userIDMapping.size(); + } + + public int numItems() { + return itemIDMapping.size(); + } + + @Override + public boolean equals(Object o) { + if (o instanceof Factorization) { + Factorization other = (Factorization) o; + return userIDMapping.equals(other.userIDMapping) && itemIDMapping.equals(other.itemIDMapping) + && Arrays.deepEquals(userFeatures, other.userFeatures) && Arrays.deepEquals(itemFeatures, other.itemFeatures); + } + return false; + } + + @Override + public int hashCode() { + int hashCode = 31 * userIDMapping.hashCode() + itemIDMapping.hashCode(); + hashCode = 31 * hashCode + Arrays.deepHashCode(userFeatures); + hashCode = 31 * hashCode + Arrays.deepHashCode(itemFeatures); + return hashCode; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorizer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorizer.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorizer.java new file mode 100644 index 0000000..2cabe73 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorizer.java @@ -0,0 +1,30 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.recommender.svd; + +import org.apache.mahout.cf.taste.common.Refreshable; +import org.apache.mahout.cf.taste.common.TasteException; + +/** + * Implementation must be able to create a factorization of a rating matrix + */ +public interface Factorizer extends Refreshable { + + Factorization factorize() throws TasteException; + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/FilePersistenceStrategy.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/FilePersistenceStrategy.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/FilePersistenceStrategy.java new file mode 100644 index 0000000..08c038a --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/FilePersistenceStrategy.java @@ -0,0 +1,139 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.recommender.svd; + +import java.io.BufferedInputStream; +import java.io.BufferedOutputStream; +import java.io.DataInput; +import java.io.DataInputStream; +import java.io.DataOutput; +import java.io.DataOutputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.util.Map; + +import com.google.common.base.Preconditions; +import org.apache.mahout.cf.taste.common.NoSuchItemException; +import org.apache.mahout.cf.taste.common.NoSuchUserException; +import org.apache.mahout.cf.taste.impl.common.FastByIDMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Provides a file-based persistent store. */ +public class FilePersistenceStrategy implements PersistenceStrategy { + + private final File file; + + private static final Logger log = LoggerFactory.getLogger(FilePersistenceStrategy.class); + + /** + * @param file the file to use for storage. If the file does not exist it will be created when required. + */ + public FilePersistenceStrategy(File file) { + this.file = Preconditions.checkNotNull(file); + } + + @Override + public Factorization load() throws IOException { + if (!file.exists()) { + log.info("{} does not yet exist, no factorization found", file.getAbsolutePath()); + return null; + } + try (DataInputStream in = new DataInputStream(new BufferedInputStream(new FileInputStream(file)))){ + log.info("Reading factorization from {}...", file.getAbsolutePath()); + return readBinary(in); + } + } + + @Override + public void maybePersist(Factorization factorization) throws IOException { + try (DataOutputStream out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(file)))){ + log.info("Writing factorization to {}...", file.getAbsolutePath()); + writeBinary(factorization, out); + } + } + + protected static void writeBinary(Factorization factorization, DataOutput out) throws IOException { + out.writeInt(factorization.numFeatures()); + out.writeInt(factorization.numUsers()); + out.writeInt(factorization.numItems()); + + for (Map.Entry<Long,Integer> mappingEntry : factorization.getUserIDMappings()) { + long userID = mappingEntry.getKey(); + out.writeInt(mappingEntry.getValue()); + out.writeLong(userID); + try { + double[] userFeatures = factorization.getUserFeatures(userID); + for (int feature = 0; feature < factorization.numFeatures(); feature++) { + out.writeDouble(userFeatures[feature]); + } + } catch (NoSuchUserException e) { + throw new IOException("Unable to persist factorization", e); + } + } + + for (Map.Entry<Long,Integer> entry : factorization.getItemIDMappings()) { + long itemID = entry.getKey(); + out.writeInt(entry.getValue()); + out.writeLong(itemID); + try { + double[] itemFeatures = factorization.getItemFeatures(itemID); + for (int feature = 0; feature < factorization.numFeatures(); feature++) { + out.writeDouble(itemFeatures[feature]); + } + } catch (NoSuchItemException e) { + throw new IOException("Unable to persist factorization", e); + } + } + } + + public static Factorization readBinary(DataInput in) throws IOException { + int numFeatures = in.readInt(); + int numUsers = in.readInt(); + int numItems = in.readInt(); + + FastByIDMap<Integer> userIDMapping = new FastByIDMap<>(numUsers); + double[][] userFeatures = new double[numUsers][numFeatures]; + + for (int n = 0; n < numUsers; n++) { + int userIndex = in.readInt(); + long userID = in.readLong(); + userIDMapping.put(userID, userIndex); + for (int feature = 0; feature < numFeatures; feature++) { + userFeatures[userIndex][feature] = in.readDouble(); + } + } + + FastByIDMap<Integer> itemIDMapping = new FastByIDMap<>(numItems); + double[][] itemFeatures = new double[numItems][numFeatures]; + + for (int n = 0; n < numItems; n++) { + int itemIndex = in.readInt(); + long itemID = in.readLong(); + itemIDMapping.put(itemID, itemIndex); + for (int feature = 0; feature < numFeatures; feature++) { + itemFeatures[itemIndex][feature] = in.readDouble(); + } + } + + return new Factorization(userIDMapping, itemIDMapping, userFeatures, itemFeatures); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/NoPersistenceStrategy.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/NoPersistenceStrategy.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/NoPersistenceStrategy.java new file mode 100644 index 0000000..0d1aab0 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/NoPersistenceStrategy.java @@ -0,0 +1,37 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.recommender.svd; + +import java.io.IOException; + +/** + * A {@link PersistenceStrategy} which does nothing. + */ +public class NoPersistenceStrategy implements PersistenceStrategy { + + @Override + public Factorization load() throws IOException { + return null; + } + + @Override + public void maybePersist(Factorization factorization) throws IOException { + // do nothing. + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ParallelSGDFactorizer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ParallelSGDFactorizer.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ParallelSGDFactorizer.java new file mode 100644 index 0000000..8a6a702 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ParallelSGDFactorizer.java @@ -0,0 +1,340 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.recommender.svd; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.FullRunningAverage; +import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; +import org.apache.mahout.cf.taste.impl.common.RunningAverage; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.model.Preference; +import org.apache.mahout.cf.taste.model.PreferenceArray; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.common.RandomWrapper; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Minimalistic implementation of Parallel SGD factorizer based on + * <a href="http://www.sze.hu/~gtakacs/download/jmlr_2009.pdf"> + * "Scalable Collaborative Filtering Approaches for Large Recommender Systems"</a> + * and + * <a href="hwww.cs.wisc.edu/~brecht/papers/hogwildTR.pdf"> + * "Hogwild!: A Lock-Free Approach to Parallelizing Stochastic Gradient Descent"</a> */ +public class ParallelSGDFactorizer extends AbstractFactorizer { + + private final DataModel dataModel; + /** Parameter used to prevent overfitting. */ + private final double lambda; + /** Number of features used to compute this factorization */ + private final int rank; + /** Number of iterations */ + private final int numEpochs; + + private int numThreads; + + // these next two control decayFactor^steps exponential type of annealing learning rate and decay factor + private double mu0 = 0.01; + private double decayFactor = 1; + // these next two control 1/steps^forget type annealing + private int stepOffset = 0; + // -1 equals even weighting of all examples, 0 means only use exponential annealing + private double forgettingExponent = 0; + + // The following two should be inversely proportional :) + private double biasMuRatio = 0.5; + private double biasLambdaRatio = 0.1; + + /** TODO: this is not safe as += is not atomic on many processors, can be replaced with AtomicDoubleArray + * but it works just fine right now */ + /** user features */ + protected volatile double[][] userVectors; + /** item features */ + protected volatile double[][] itemVectors; + + private final PreferenceShuffler shuffler; + + private int epoch = 1; + /** place in user vector where the bias is stored */ + private static final int USER_BIAS_INDEX = 1; + /** place in item vector where the bias is stored */ + private static final int ITEM_BIAS_INDEX = 2; + private static final int FEATURE_OFFSET = 3; + /** Standard deviation for random initialization of features */ + private static final double NOISE = 0.02; + + private static final Logger logger = LoggerFactory.getLogger(ParallelSGDFactorizer.class); + + protected static class PreferenceShuffler { + + private Preference[] preferences; + private Preference[] unstagedPreferences; + + protected final RandomWrapper random = RandomUtils.getRandom(); + + public PreferenceShuffler(DataModel dataModel) throws TasteException { + cachePreferences(dataModel); + shuffle(); + stage(); + } + + private int countPreferences(DataModel dataModel) throws TasteException { + int numPreferences = 0; + LongPrimitiveIterator userIDs = dataModel.getUserIDs(); + while (userIDs.hasNext()) { + PreferenceArray preferencesFromUser = dataModel.getPreferencesFromUser(userIDs.nextLong()); + numPreferences += preferencesFromUser.length(); + } + return numPreferences; + } + + private void cachePreferences(DataModel dataModel) throws TasteException { + int numPreferences = countPreferences(dataModel); + preferences = new Preference[numPreferences]; + + LongPrimitiveIterator userIDs = dataModel.getUserIDs(); + int index = 0; + while (userIDs.hasNext()) { + long userID = userIDs.nextLong(); + PreferenceArray preferencesFromUser = dataModel.getPreferencesFromUser(userID); + for (Preference preference : preferencesFromUser) { + preferences[index++] = preference; + } + } + } + + public final void shuffle() { + unstagedPreferences = preferences.clone(); + /* Durstenfeld shuffle */ + for (int i = unstagedPreferences.length - 1; i > 0; i--) { + int rand = random.nextInt(i + 1); + swapCachedPreferences(i, rand); + } + } + + //merge this part into shuffle() will make compiler-optimizer do some real absurd stuff, test on OpenJDK7 + private void swapCachedPreferences(int x, int y) { + Preference p = unstagedPreferences[x]; + + unstagedPreferences[x] = unstagedPreferences[y]; + unstagedPreferences[y] = p; + } + + public final void stage() { + preferences = unstagedPreferences; + } + + public Preference get(int i) { + return preferences[i]; + } + + public int size() { + return preferences.length; + } + + } + + public ParallelSGDFactorizer(DataModel dataModel, int numFeatures, double lambda, int numEpochs) + throws TasteException { + super(dataModel); + this.dataModel = dataModel; + this.rank = numFeatures + FEATURE_OFFSET; + this.lambda = lambda; + this.numEpochs = numEpochs; + + shuffler = new PreferenceShuffler(dataModel); + + //max thread num set to n^0.25 as suggested by hogwild! paper + numThreads = Math.min(Runtime.getRuntime().availableProcessors(), (int) Math.pow((double) shuffler.size(), 0.25)); + } + + public ParallelSGDFactorizer(DataModel dataModel, int numFeatures, double lambda, int numIterations, + double mu0, double decayFactor, int stepOffset, double forgettingExponent) throws TasteException { + this(dataModel, numFeatures, lambda, numIterations); + + this.mu0 = mu0; + this.decayFactor = decayFactor; + this.stepOffset = stepOffset; + this.forgettingExponent = forgettingExponent; + } + + public ParallelSGDFactorizer(DataModel dataModel, int numFeatures, double lambda, int numIterations, + double mu0, double decayFactor, int stepOffset, double forgettingExponent, int numThreads) throws TasteException { + this(dataModel, numFeatures, lambda, numIterations, mu0, decayFactor, stepOffset, forgettingExponent); + + this.numThreads = numThreads; + } + + public ParallelSGDFactorizer(DataModel dataModel, int numFeatures, double lambda, int numIterations, + double mu0, double decayFactor, int stepOffset, double forgettingExponent, + double biasMuRatio, double biasLambdaRatio) throws TasteException { + this(dataModel, numFeatures, lambda, numIterations, mu0, decayFactor, stepOffset, forgettingExponent); + + this.biasMuRatio = biasMuRatio; + this.biasLambdaRatio = biasLambdaRatio; + } + + public ParallelSGDFactorizer(DataModel dataModel, int numFeatures, double lambda, int numIterations, + double mu0, double decayFactor, int stepOffset, double forgettingExponent, + double biasMuRatio, double biasLambdaRatio, int numThreads) throws TasteException { + this(dataModel, numFeatures, lambda, numIterations, mu0, decayFactor, stepOffset, forgettingExponent, biasMuRatio, + biasLambdaRatio); + + this.numThreads = numThreads; + } + + protected void initialize() throws TasteException { + RandomWrapper random = RandomUtils.getRandom(); + userVectors = new double[dataModel.getNumUsers()][rank]; + itemVectors = new double[dataModel.getNumItems()][rank]; + + double globalAverage = getAveragePreference(); + for (int userIndex = 0; userIndex < userVectors.length; userIndex++) { + userVectors[userIndex][0] = globalAverage; + userVectors[userIndex][USER_BIAS_INDEX] = 0; // will store user bias + userVectors[userIndex][ITEM_BIAS_INDEX] = 1; // corresponding item feature contains item bias + for (int feature = FEATURE_OFFSET; feature < rank; feature++) { + userVectors[userIndex][feature] = random.nextGaussian() * NOISE; + } + } + for (int itemIndex = 0; itemIndex < itemVectors.length; itemIndex++) { + itemVectors[itemIndex][0] = 1; // corresponding user feature contains global average + itemVectors[itemIndex][USER_BIAS_INDEX] = 1; // corresponding user feature contains user bias + itemVectors[itemIndex][ITEM_BIAS_INDEX] = 0; // will store item bias + for (int feature = FEATURE_OFFSET; feature < rank; feature++) { + itemVectors[itemIndex][feature] = random.nextGaussian() * NOISE; + } + } + } + + //TODO: needs optimization + private double getMu(int i) { + return mu0 * Math.pow(decayFactor, i - 1) * Math.pow(i + stepOffset, forgettingExponent); + } + + @Override + public Factorization factorize() throws TasteException { + initialize(); + + if (logger.isInfoEnabled()) { + logger.info("starting to compute the factorization..."); + } + + for (epoch = 1; epoch <= numEpochs; epoch++) { + shuffler.stage(); + + final double mu = getMu(epoch); + int subSize = shuffler.size() / numThreads + 1; + + ExecutorService executor=Executors.newFixedThreadPool(numThreads); + + try { + for (int t = 0; t < numThreads; t++) { + final int iStart = t * subSize; + final int iEnd = Math.min((t + 1) * subSize, shuffler.size()); + + executor.execute(new Runnable() { + @Override + public void run() { + for (int i = iStart; i < iEnd; i++) { + update(shuffler.get(i), mu); + } + } + }); + } + } finally { + executor.shutdown(); + shuffler.shuffle(); + + try { + boolean terminated = executor.awaitTermination(numEpochs * shuffler.size(), TimeUnit.MICROSECONDS); + if (!terminated) { + logger.error("subtasks takes forever, return anyway"); + } + } catch (InterruptedException e) { + throw new TasteException("waiting fof termination interrupted", e); + } + } + + } + + return createFactorization(userVectors, itemVectors); + } + + double getAveragePreference() throws TasteException { + RunningAverage average = new FullRunningAverage(); + LongPrimitiveIterator it = dataModel.getUserIDs(); + while (it.hasNext()) { + for (Preference pref : dataModel.getPreferencesFromUser(it.nextLong())) { + average.addDatum(pref.getValue()); + } + } + return average.getAverage(); + } + + /** TODO: this is the vanilla sgd by Tacaks 2009, I speculate that using scaling technique proposed in: + * Towards Optimal One Pass Large Scale Learning with Averaged Stochastic Gradient Descent section 5, page 6 + * can be beneficial in term s of both speed and accuracy. + * + * Tacaks' method doesn't calculate gradient of regularization correctly, which has non-zero elements everywhere of + * the matrix. While Tacaks' method can only updates a single row/column, if one user has a lot of recommendation, + * her vector will be more affected by regularization using an isolated scaling factor for both user vectors and + * item vectors can remove this issue without inducing more update cost it even reduces it a bit by only performing + * one addition and one multiplication. + * + * BAD SIDE1: the scaling factor decreases fast, it has to be scaled up from time to time before dropped to zero or + * caused roundoff error + * BAD SIDE2: no body experiment on it before, and people generally use very small lambda + * so it's impact on accuracy may still be unknown. + * BAD SIDE3: don't know how to make it work for L1-regularization or + * "pseudorank?" (sum of singular values)-regularization */ + protected void update(Preference preference, double mu) { + int userIndex = userIndex(preference.getUserID()); + int itemIndex = itemIndex(preference.getItemID()); + + double[] userVector = userVectors[userIndex]; + double[] itemVector = itemVectors[itemIndex]; + + double prediction = dot(userVector, itemVector); + double err = preference.getValue() - prediction; + + // adjust features + for (int k = FEATURE_OFFSET; k < rank; k++) { + double userFeature = userVector[k]; + double itemFeature = itemVector[k]; + + userVector[k] += mu * (err * itemFeature - lambda * userFeature); + itemVector[k] += mu * (err * userFeature - lambda * itemFeature); + } + + // adjust user and item bias + userVector[USER_BIAS_INDEX] += biasMuRatio * mu * (err - biasLambdaRatio * lambda * userVector[USER_BIAS_INDEX]); + itemVector[ITEM_BIAS_INDEX] += biasMuRatio * mu * (err - biasLambdaRatio * lambda * itemVector[ITEM_BIAS_INDEX]); + } + + private double dot(double[] userVector, double[] itemVector) { + double sum = 0; + for (int k = 0; k < rank; k++) { + sum += userVector[k] * itemVector[k]; + } + return sum; + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/PersistenceStrategy.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/PersistenceStrategy.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/PersistenceStrategy.java new file mode 100644 index 0000000..abf3eca --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/PersistenceStrategy.java @@ -0,0 +1,46 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.recommender.svd; + +import java.io.IOException; + +/** + * Provides storage for {@link Factorization}s + */ +public interface PersistenceStrategy { + + /** + * Load a factorization from a persistent store. + * + * @return a Factorization or null if the persistent store is empty. + * + * @throws IOException + */ + Factorization load() throws IOException; + + /** + * Write a factorization to a persistent store unless it already + * contains an identical factorization. + * + * @param factorization + * + * @throws IOException + */ + void maybePersist(Factorization factorization) throws IOException; + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/RatingSGDFactorizer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/RatingSGDFactorizer.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/RatingSGDFactorizer.java new file mode 100644 index 0000000..2c9f0ae --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/RatingSGDFactorizer.java @@ -0,0 +1,221 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.recommender.svd; + +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.FullRunningAverage; +import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; +import org.apache.mahout.cf.taste.impl.common.RunningAverage; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.model.Preference; +import org.apache.mahout.cf.taste.model.PreferenceArray; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.common.RandomWrapper; + +/** Matrix factorization with user and item biases for rating prediction, trained with plain vanilla SGD */ +public class RatingSGDFactorizer extends AbstractFactorizer { + + protected static final int FEATURE_OFFSET = 3; + + /** Multiplicative decay factor for learning_rate */ + protected final double learningRateDecay; + /** Learning rate (step size) */ + protected final double learningRate; + /** Parameter used to prevent overfitting. */ + protected final double preventOverfitting; + /** Number of features used to compute this factorization */ + protected final int numFeatures; + /** Number of iterations */ + private final int numIterations; + /** Standard deviation for random initialization of features */ + protected final double randomNoise; + /** User features */ + protected double[][] userVectors; + /** Item features */ + protected double[][] itemVectors; + protected final DataModel dataModel; + private long[] cachedUserIDs; + private long[] cachedItemIDs; + + protected double biasLearningRate = 0.5; + protected double biasReg = 0.1; + + /** place in user vector where the bias is stored */ + protected static final int USER_BIAS_INDEX = 1; + /** place in item vector where the bias is stored */ + protected static final int ITEM_BIAS_INDEX = 2; + + public RatingSGDFactorizer(DataModel dataModel, int numFeatures, int numIterations) throws TasteException { + this(dataModel, numFeatures, 0.01, 0.1, 0.01, numIterations, 1.0); + } + + public RatingSGDFactorizer(DataModel dataModel, int numFeatures, double learningRate, double preventOverfitting, + double randomNoise, int numIterations, double learningRateDecay) throws TasteException { + super(dataModel); + this.dataModel = dataModel; + this.numFeatures = numFeatures + FEATURE_OFFSET; + this.numIterations = numIterations; + + this.learningRate = learningRate; + this.learningRateDecay = learningRateDecay; + this.preventOverfitting = preventOverfitting; + this.randomNoise = randomNoise; + } + + protected void prepareTraining() throws TasteException { + RandomWrapper random = RandomUtils.getRandom(); + userVectors = new double[dataModel.getNumUsers()][numFeatures]; + itemVectors = new double[dataModel.getNumItems()][numFeatures]; + + double globalAverage = getAveragePreference(); + for (int userIndex = 0; userIndex < userVectors.length; userIndex++) { + userVectors[userIndex][0] = globalAverage; + userVectors[userIndex][USER_BIAS_INDEX] = 0; // will store user bias + userVectors[userIndex][ITEM_BIAS_INDEX] = 1; // corresponding item feature contains item bias + for (int feature = FEATURE_OFFSET; feature < numFeatures; feature++) { + userVectors[userIndex][feature] = random.nextGaussian() * randomNoise; + } + } + for (int itemIndex = 0; itemIndex < itemVectors.length; itemIndex++) { + itemVectors[itemIndex][0] = 1; // corresponding user feature contains global average + itemVectors[itemIndex][USER_BIAS_INDEX] = 1; // corresponding user feature contains user bias + itemVectors[itemIndex][ITEM_BIAS_INDEX] = 0; // will store item bias + for (int feature = FEATURE_OFFSET; feature < numFeatures; feature++) { + itemVectors[itemIndex][feature] = random.nextGaussian() * randomNoise; + } + } + + cachePreferences(); + shufflePreferences(); + } + + private int countPreferences() throws TasteException { + int numPreferences = 0; + LongPrimitiveIterator userIDs = dataModel.getUserIDs(); + while (userIDs.hasNext()) { + PreferenceArray preferencesFromUser = dataModel.getPreferencesFromUser(userIDs.nextLong()); + numPreferences += preferencesFromUser.length(); + } + return numPreferences; + } + + private void cachePreferences() throws TasteException { + int numPreferences = countPreferences(); + cachedUserIDs = new long[numPreferences]; + cachedItemIDs = new long[numPreferences]; + + LongPrimitiveIterator userIDs = dataModel.getUserIDs(); + int index = 0; + while (userIDs.hasNext()) { + long userID = userIDs.nextLong(); + PreferenceArray preferencesFromUser = dataModel.getPreferencesFromUser(userID); + for (Preference preference : preferencesFromUser) { + cachedUserIDs[index] = userID; + cachedItemIDs[index] = preference.getItemID(); + index++; + } + } + } + + protected void shufflePreferences() { + RandomWrapper random = RandomUtils.getRandom(); + /* Durstenfeld shuffle */ + for (int currentPos = cachedUserIDs.length - 1; currentPos > 0; currentPos--) { + int swapPos = random.nextInt(currentPos + 1); + swapCachedPreferences(currentPos, swapPos); + } + } + + private void swapCachedPreferences(int posA, int posB) { + long tmpUserIndex = cachedUserIDs[posA]; + long tmpItemIndex = cachedItemIDs[posA]; + + cachedUserIDs[posA] = cachedUserIDs[posB]; + cachedItemIDs[posA] = cachedItemIDs[posB]; + + cachedUserIDs[posB] = tmpUserIndex; + cachedItemIDs[posB] = tmpItemIndex; + } + + @Override + public Factorization factorize() throws TasteException { + prepareTraining(); + double currentLearningRate = learningRate; + + + for (int it = 0; it < numIterations; it++) { + for (int index = 0; index < cachedUserIDs.length; index++) { + long userId = cachedUserIDs[index]; + long itemId = cachedItemIDs[index]; + float rating = dataModel.getPreferenceValue(userId, itemId); + updateParameters(userId, itemId, rating, currentLearningRate); + } + currentLearningRate *= learningRateDecay; + } + return createFactorization(userVectors, itemVectors); + } + + double getAveragePreference() throws TasteException { + RunningAverage average = new FullRunningAverage(); + LongPrimitiveIterator it = dataModel.getUserIDs(); + while (it.hasNext()) { + for (Preference pref : dataModel.getPreferencesFromUser(it.nextLong())) { + average.addDatum(pref.getValue()); + } + } + return average.getAverage(); + } + + protected void updateParameters(long userID, long itemID, float rating, double currentLearningRate) { + int userIndex = userIndex(userID); + int itemIndex = itemIndex(itemID); + + double[] userVector = userVectors[userIndex]; + double[] itemVector = itemVectors[itemIndex]; + double prediction = predictRating(userIndex, itemIndex); + double err = rating - prediction; + + // adjust user bias + userVector[USER_BIAS_INDEX] += + biasLearningRate * currentLearningRate * (err - biasReg * preventOverfitting * userVector[USER_BIAS_INDEX]); + + // adjust item bias + itemVector[ITEM_BIAS_INDEX] += + biasLearningRate * currentLearningRate * (err - biasReg * preventOverfitting * itemVector[ITEM_BIAS_INDEX]); + + // adjust features + for (int feature = FEATURE_OFFSET; feature < numFeatures; feature++) { + double userFeature = userVector[feature]; + double itemFeature = itemVector[feature]; + + double deltaUserFeature = err * itemFeature - preventOverfitting * userFeature; + userVector[feature] += currentLearningRate * deltaUserFeature; + + double deltaItemFeature = err * userFeature - preventOverfitting * itemFeature; + itemVector[feature] += currentLearningRate * deltaItemFeature; + } + } + + private double predictRating(int userID, int itemID) { + double sum = 0; + for (int feature = 0; feature < numFeatures; feature++) { + sum += userVectors[userID][feature] * itemVectors[itemID][feature]; + } + return sum; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDPlusPlusFactorizer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDPlusPlusFactorizer.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDPlusPlusFactorizer.java new file mode 100644 index 0000000..20446f8 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDPlusPlusFactorizer.java @@ -0,0 +1,178 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.recommender.svd; + +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.FastIDSet; +import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.common.RandomUtils; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; + +/** + * SVD++, an enhancement of classical matrix factorization for rating prediction. + * Additionally to using ratings (how did people rate?) for learning, this model also takes into account + * who rated what. + * + * Yehuda Koren: Factorization Meets the Neighborhood: a Multifaceted Collaborative Filtering Model, KDD 2008. + * http://research.yahoo.com/files/kdd08koren.pdf + */ +public final class SVDPlusPlusFactorizer extends RatingSGDFactorizer { + + private double[][] p; + private double[][] y; + private Map<Integer, List<Integer>> itemsByUser; + + public SVDPlusPlusFactorizer(DataModel dataModel, int numFeatures, int numIterations) throws TasteException { + this(dataModel, numFeatures, 0.01, 0.1, 0.01, numIterations, 1.0); + biasLearningRate = 0.7; + biasReg = 0.33; + } + + public SVDPlusPlusFactorizer(DataModel dataModel, int numFeatures, double learningRate, double preventOverfitting, + double randomNoise, int numIterations, double learningRateDecay) throws TasteException { + super(dataModel, numFeatures, learningRate, preventOverfitting, randomNoise, numIterations, learningRateDecay); + } + + @Override + protected void prepareTraining() throws TasteException { + super.prepareTraining(); + Random random = RandomUtils.getRandom(); + + p = new double[dataModel.getNumUsers()][numFeatures]; + for (int i = 0; i < p.length; i++) { + for (int feature = 0; feature < FEATURE_OFFSET; feature++) { + p[i][feature] = 0; + } + for (int feature = FEATURE_OFFSET; feature < numFeatures; feature++) { + p[i][feature] = random.nextGaussian() * randomNoise; + } + } + + y = new double[dataModel.getNumItems()][numFeatures]; + for (int i = 0; i < y.length; i++) { + for (int feature = 0; feature < FEATURE_OFFSET; feature++) { + y[i][feature] = 0; + } + for (int feature = FEATURE_OFFSET; feature < numFeatures; feature++) { + y[i][feature] = random.nextGaussian() * randomNoise; + } + } + + /* get internal item IDs which we will need several times */ + itemsByUser = new HashMap<>(); + LongPrimitiveIterator userIDs = dataModel.getUserIDs(); + while (userIDs.hasNext()) { + long userId = userIDs.nextLong(); + int userIndex = userIndex(userId); + FastIDSet itemIDsFromUser = dataModel.getItemIDsFromUser(userId); + List<Integer> itemIndexes = new ArrayList<>(itemIDsFromUser.size()); + itemsByUser.put(userIndex, itemIndexes); + for (long itemID2 : itemIDsFromUser) { + int i2 = itemIndex(itemID2); + itemIndexes.add(i2); + } + } + } + + @Override + public Factorization factorize() throws TasteException { + prepareTraining(); + + super.factorize(); + + for (int userIndex = 0; userIndex < userVectors.length; userIndex++) { + for (int itemIndex : itemsByUser.get(userIndex)) { + for (int feature = FEATURE_OFFSET; feature < numFeatures; feature++) { + userVectors[userIndex][feature] += y[itemIndex][feature]; + } + } + double denominator = Math.sqrt(itemsByUser.get(userIndex).size()); + for (int feature = 0; feature < userVectors[userIndex].length; feature++) { + userVectors[userIndex][feature] = + (float) (userVectors[userIndex][feature] / denominator + p[userIndex][feature]); + } + } + + return createFactorization(userVectors, itemVectors); + } + + + @Override + protected void updateParameters(long userID, long itemID, float rating, double currentLearningRate) { + int userIndex = userIndex(userID); + int itemIndex = itemIndex(itemID); + + double[] userVector = p[userIndex]; + double[] itemVector = itemVectors[itemIndex]; + + double[] pPlusY = new double[numFeatures]; + for (int i2 : itemsByUser.get(userIndex)) { + for (int f = FEATURE_OFFSET; f < numFeatures; f++) { + pPlusY[f] += y[i2][f]; + } + } + double denominator = Math.sqrt(itemsByUser.get(userIndex).size()); + for (int feature = 0; feature < pPlusY.length; feature++) { + pPlusY[feature] = (float) (pPlusY[feature] / denominator + p[userIndex][feature]); + } + + double prediction = predictRating(pPlusY, itemIndex); + double err = rating - prediction; + double normalized_error = err / denominator; + + // adjust user bias + userVector[USER_BIAS_INDEX] += + biasLearningRate * currentLearningRate * (err - biasReg * preventOverfitting * userVector[USER_BIAS_INDEX]); + + // adjust item bias + itemVector[ITEM_BIAS_INDEX] += + biasLearningRate * currentLearningRate * (err - biasReg * preventOverfitting * itemVector[ITEM_BIAS_INDEX]); + + // adjust features + for (int feature = FEATURE_OFFSET; feature < numFeatures; feature++) { + double pF = userVector[feature]; + double iF = itemVector[feature]; + + double deltaU = err * iF - preventOverfitting * pF; + userVector[feature] += currentLearningRate * deltaU; + + double deltaI = err * pPlusY[feature] - preventOverfitting * iF; + itemVector[feature] += currentLearningRate * deltaI; + + double commonUpdate = normalized_error * iF; + for (int itemIndex2 : itemsByUser.get(userIndex)) { + double deltaI2 = commonUpdate - preventOverfitting * y[itemIndex2][feature]; + y[itemIndex2][feature] += learningRate * deltaI2; + } + } + } + + private double predictRating(double[] userVector, int itemID) { + double sum = 0; + for (int feature = 0; feature < numFeatures; feature++) { + sum += userVector[feature] * itemVectors[itemID][feature]; + } + return sum; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDPreference.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDPreference.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDPreference.java new file mode 100644 index 0000000..45c54da --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDPreference.java @@ -0,0 +1,41 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.recommender.svd; + +import com.google.common.base.Preconditions; +import org.apache.mahout.cf.taste.impl.model.GenericPreference; + +final class SVDPreference extends GenericPreference { + + private double cache; + + SVDPreference(long userID, long itemID, float value, double cache) { + super(userID, itemID, value); + setCache(cache); + } + + public double getCache() { + return cache; + } + + public void setCache(double value) { + Preconditions.checkArgument(!Double.isNaN(value), "NaN cache value"); + this.cache = value; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommender.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommender.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommender.java new file mode 100644 index 0000000..45d4af7 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommender.java @@ -0,0 +1,185 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.recommender.svd; + +import java.io.IOException; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.Callable; + +import com.google.common.base.Preconditions; +import org.apache.mahout.cf.taste.common.Refreshable; +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.FastIDSet; +import org.apache.mahout.cf.taste.impl.common.RefreshHelper; +import org.apache.mahout.cf.taste.impl.recommender.AbstractRecommender; +import org.apache.mahout.cf.taste.impl.recommender.AllUnknownItemsCandidateItemsStrategy; +import org.apache.mahout.cf.taste.impl.recommender.TopItems; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.model.PreferenceArray; +import org.apache.mahout.cf.taste.recommender.CandidateItemsStrategy; +import org.apache.mahout.cf.taste.recommender.IDRescorer; +import org.apache.mahout.cf.taste.recommender.RecommendedItem; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A {@link org.apache.mahout.cf.taste.recommender.Recommender} that uses matrix factorization (a projection of users + * and items onto a feature space) + */ +public final class SVDRecommender extends AbstractRecommender { + + private Factorization factorization; + private final Factorizer factorizer; + private final PersistenceStrategy persistenceStrategy; + private final RefreshHelper refreshHelper; + + private static final Logger log = LoggerFactory.getLogger(SVDRecommender.class); + + public SVDRecommender(DataModel dataModel, Factorizer factorizer) throws TasteException { + this(dataModel, factorizer, new AllUnknownItemsCandidateItemsStrategy(), getDefaultPersistenceStrategy()); + } + + public SVDRecommender(DataModel dataModel, Factorizer factorizer, CandidateItemsStrategy candidateItemsStrategy) + throws TasteException { + this(dataModel, factorizer, candidateItemsStrategy, getDefaultPersistenceStrategy()); + } + + /** + * Create an SVDRecommender using a persistent store to cache factorizations. A factorization is loaded from the + * store if present, otherwise a new factorization is computed and saved in the store. + * + * The {@link #refresh(java.util.Collection) refresh} method recomputes the factorization and overwrites the store. + * + * @param dataModel + * @param factorizer + * @param persistenceStrategy + * @throws TasteException + * @throws IOException + */ + public SVDRecommender(DataModel dataModel, Factorizer factorizer, PersistenceStrategy persistenceStrategy) + throws TasteException { + this(dataModel, factorizer, getDefaultCandidateItemsStrategy(), persistenceStrategy); + } + + /** + * Create an SVDRecommender using a persistent store to cache factorizations. A factorization is loaded from the + * store if present, otherwise a new factorization is computed and saved in the store. + * + * The {@link #refresh(java.util.Collection) refresh} method recomputes the factorization and overwrites the store. + * + * @param dataModel + * @param factorizer + * @param candidateItemsStrategy + * @param persistenceStrategy + * + * @throws TasteException + */ + public SVDRecommender(DataModel dataModel, Factorizer factorizer, CandidateItemsStrategy candidateItemsStrategy, + PersistenceStrategy persistenceStrategy) throws TasteException { + super(dataModel, candidateItemsStrategy); + this.factorizer = Preconditions.checkNotNull(factorizer); + this.persistenceStrategy = Preconditions.checkNotNull(persistenceStrategy); + try { + factorization = persistenceStrategy.load(); + } catch (IOException e) { + throw new TasteException("Error loading factorization", e); + } + + if (factorization == null) { + train(); + } + + refreshHelper = new RefreshHelper(new Callable<Object>() { + @Override + public Object call() throws TasteException { + train(); + return null; + } + }); + refreshHelper.addDependency(getDataModel()); + refreshHelper.addDependency(factorizer); + refreshHelper.addDependency(candidateItemsStrategy); + } + + static PersistenceStrategy getDefaultPersistenceStrategy() { + return new NoPersistenceStrategy(); + } + + private void train() throws TasteException { + factorization = factorizer.factorize(); + try { + persistenceStrategy.maybePersist(factorization); + } catch (IOException e) { + throw new TasteException("Error persisting factorization", e); + } + } + + @Override + public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer, boolean includeKnownItems) + throws TasteException { + Preconditions.checkArgument(howMany >= 1, "howMany must be at least 1"); + log.debug("Recommending items for user ID '{}'", userID); + + PreferenceArray preferencesFromUser = getDataModel().getPreferencesFromUser(userID); + FastIDSet possibleItemIDs = getAllOtherItems(userID, preferencesFromUser, includeKnownItems); + + List<RecommendedItem> topItems = TopItems.getTopItems(howMany, possibleItemIDs.iterator(), rescorer, + new Estimator(userID)); + log.debug("Recommendations are: {}", topItems); + + return topItems; + } + + /** + * a preference is estimated by computing the dot-product of the user and item feature vectors + */ + @Override + public float estimatePreference(long userID, long itemID) throws TasteException { + double[] userFeatures = factorization.getUserFeatures(userID); + double[] itemFeatures = factorization.getItemFeatures(itemID); + double estimate = 0; + for (int feature = 0; feature < userFeatures.length; feature++) { + estimate += userFeatures[feature] * itemFeatures[feature]; + } + return (float) estimate; + } + + private final class Estimator implements TopItems.Estimator<Long> { + + private final long theUserID; + + private Estimator(long theUserID) { + this.theUserID = theUserID; + } + + @Override + public double estimate(Long itemID) throws TasteException { + return estimatePreference(theUserID, itemID); + } + } + + /** + * Refresh the data model and factorization. + */ + @Override + public void refresh(Collection<Refreshable> alreadyRefreshed) { + refreshHelper.refresh(alreadyRefreshed); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/AbstractItemSimilarity.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/AbstractItemSimilarity.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/AbstractItemSimilarity.java new file mode 100644 index 0000000..e0d6f59 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/AbstractItemSimilarity.java @@ -0,0 +1,64 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.similarity; + +import com.google.common.base.Preconditions; +import org.apache.mahout.cf.taste.common.Refreshable; +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.FastIDSet; +import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; +import org.apache.mahout.cf.taste.impl.common.RefreshHelper; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.similarity.ItemSimilarity; + +import java.util.Collection; + +public abstract class AbstractItemSimilarity implements ItemSimilarity { + + private final DataModel dataModel; + private final RefreshHelper refreshHelper; + + protected AbstractItemSimilarity(DataModel dataModel) { + Preconditions.checkArgument(dataModel != null, "dataModel is null"); + this.dataModel = dataModel; + this.refreshHelper = new RefreshHelper(null); + refreshHelper.addDependency(this.dataModel); + } + + protected DataModel getDataModel() { + return dataModel; + } + + @Override + public long[] allSimilarItemIDs(long itemID) throws TasteException { + FastIDSet allSimilarItemIDs = new FastIDSet(); + LongPrimitiveIterator allItemIDs = dataModel.getItemIDs(); + while (allItemIDs.hasNext()) { + long possiblySimilarItemID = allItemIDs.nextLong(); + if (!Double.isNaN(itemSimilarity(itemID, possiblySimilarItemID))) { + allSimilarItemIDs.add(possiblySimilarItemID); + } + } + return allSimilarItemIDs.toArray(); + } + + @Override + public void refresh(Collection<Refreshable> alreadyRefreshed) { + refreshHelper.refresh(alreadyRefreshed); + } +}
