http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/RandomRecommender.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/RandomRecommender.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/RandomRecommender.java new file mode 100644 index 0000000..08aa5ae --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/RandomRecommender.java @@ -0,0 +1,97 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.recommender; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Random; + +import org.apache.mahout.cf.taste.common.Refreshable; +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.model.PreferenceArray; +import org.apache.mahout.cf.taste.recommender.IDRescorer; +import org.apache.mahout.cf.taste.recommender.RecommendedItem; +import org.apache.mahout.common.RandomUtils; + +/** + * Produces random recommendations and preference estimates. This is likely only useful as a novelty and for + * benchmarking. + */ +public final class RandomRecommender extends AbstractRecommender { + + private final Random random = RandomUtils.getRandom(); + private final float minPref; + private final float maxPref; + + public RandomRecommender(DataModel dataModel) throws TasteException { + super(dataModel); + float maxPref = Float.NEGATIVE_INFINITY; + float minPref = Float.POSITIVE_INFINITY; + LongPrimitiveIterator userIterator = dataModel.getUserIDs(); + while (userIterator.hasNext()) { + long userID = userIterator.next(); + PreferenceArray prefs = dataModel.getPreferencesFromUser(userID); + for (int i = 0; i < prefs.length(); i++) { + float prefValue = prefs.getValue(i); + if (prefValue < minPref) { + minPref = prefValue; + } + if (prefValue > maxPref) { + maxPref = prefValue; + } + } + } + this.minPref = minPref; + this.maxPref = maxPref; + } + + @Override + public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer, boolean includeKnownItems) + throws TasteException { + DataModel dataModel = getDataModel(); + int numItems = dataModel.getNumItems(); + List<RecommendedItem> result = new ArrayList<>(howMany); + while (result.size() < howMany) { + LongPrimitiveIterator it = dataModel.getItemIDs(); + it.skip(random.nextInt(numItems)); + long itemID = it.next(); + if (includeKnownItems || dataModel.getPreferenceValue(userID, itemID) == null) { + result.add(new GenericRecommendedItem(itemID, randomPref())); + } + } + return result; + } + + @Override + public float estimatePreference(long userID, long itemID) { + return randomPref(); + } + + private float randomPref() { + return minPref + random.nextFloat() * (maxPref - minPref); + } + + @Override + public void refresh(Collection<Refreshable> alreadyRefreshed) { + getDataModel().refresh(alreadyRefreshed); + } + +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/SamplingCandidateItemsStrategy.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/SamplingCandidateItemsStrategy.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/SamplingCandidateItemsStrategy.java new file mode 100644 index 0000000..623a60b --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/SamplingCandidateItemsStrategy.java @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.recommender; + +import com.google.common.base.Preconditions; + +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.FastIDSet; +import org.apache.mahout.cf.taste.impl.common.LongPrimitiveArrayIterator; +import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; +import org.apache.mahout.cf.taste.impl.common.SamplingLongPrimitiveIterator; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.model.Preference; +import org.apache.mahout.cf.taste.model.PreferenceArray; +import org.apache.mahout.common.iterator.FixedSizeSamplingIterator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Iterator; + +/** + * <p>Returns all items that have not been rated by the user <em>(3)</em> and that were preferred by another user + * <em>(2)</em> that has preferred at least one item <em>(1)</em> that the current user has preferred too.</p> + * + * <p>This strategy uses sampling to limit the number of items that are considered, by sampling three different + * things, noted above:</p> + * + * <ol> + * <li>The items that the user has preferred</li> + * <li>The users who also prefer each of those items</li> + * <li>The items those users also prefer</li> + * </ol> + * + * <p>There is a maximum associated with each of these three things; if the number of items or users exceeds + * that max, it is sampled so that the expected number of items or users actually used in that part of the + * computation is equal to the max.</p> + * + * <p>Three arguments control these three maxima. Each is a "factor" f, which establishes the max at + * f * log2(n), where n is the number of users or items in the data. For example if factor #2 is 5, + * which controls the number of users sampled per item, then 5 * log2(# users) is the maximum for this + * part of the computation.</p> + * + * <p>Each can be set to not do any limiting with value {@link #NO_LIMIT_FACTOR}.</p> + */ +public class SamplingCandidateItemsStrategy extends AbstractCandidateItemsStrategy { + + private static final Logger log = LoggerFactory.getLogger(SamplingCandidateItemsStrategy.class); + + /** + * Default factor used if not otherwise specified, for all limits. (30). + */ + public static final int DEFAULT_FACTOR = 30; + /** + * Specify this value as a factor to mean no limit. + */ + public static final int NO_LIMIT_FACTOR = Integer.MAX_VALUE; + private static final int MAX_LIMIT = Integer.MAX_VALUE; + private static final double LOG2 = Math.log(2.0); + + private final int maxItems; + private final int maxUsersPerItem; + private final int maxItemsPerUser; + + /** + * Defaults to using no limit ({@link #NO_LIMIT_FACTOR}) for all factors, except + * {@code candidatesPerUserFactor} which defaults to {@link #DEFAULT_FACTOR}. + * + * @see #SamplingCandidateItemsStrategy(int, int, int, int, int) + */ + public SamplingCandidateItemsStrategy(int numUsers, int numItems) { + this(DEFAULT_FACTOR, DEFAULT_FACTOR, DEFAULT_FACTOR, numUsers, numItems); + } + + /** + * @param itemsFactor factor controlling max items considered for a user + * @param usersPerItemFactor factor controlling max users considered for each of those items + * @param candidatesPerUserFactor factor controlling max candidate items considered from each of those users + * @param numUsers number of users currently in the data + * @param numItems number of items in the data + */ + public SamplingCandidateItemsStrategy(int itemsFactor, + int usersPerItemFactor, + int candidatesPerUserFactor, + int numUsers, + int numItems) { + Preconditions.checkArgument(itemsFactor > 0, "itemsFactor must be greater then 0!"); + Preconditions.checkArgument(usersPerItemFactor > 0, "usersPerItemFactor must be greater then 0!"); + Preconditions.checkArgument(candidatesPerUserFactor > 0, "candidatesPerUserFactor must be greater then 0!"); + Preconditions.checkArgument(numUsers > 0, "numUsers must be greater then 0!"); + Preconditions.checkArgument(numItems > 0, "numItems must be greater then 0!"); + maxItems = computeMaxFrom(itemsFactor, numItems); + maxUsersPerItem = computeMaxFrom(usersPerItemFactor, numUsers); + maxItemsPerUser = computeMaxFrom(candidatesPerUserFactor, numItems); + log.debug("maxItems {}, maxUsersPerItem {}, maxItemsPerUser {}", maxItems, maxUsersPerItem, maxItemsPerUser); + } + + private static int computeMaxFrom(int factor, int numThings) { + if (factor == NO_LIMIT_FACTOR) { + return MAX_LIMIT; + } + long max = (long) (factor * (1.0 + Math.log(numThings) / LOG2)); + return max > MAX_LIMIT ? MAX_LIMIT : (int) max; + } + + @Override + protected FastIDSet doGetCandidateItems(long[] preferredItemIDs, DataModel dataModel, boolean includeKnownItems) + throws TasteException { + LongPrimitiveIterator preferredItemIDsIterator = new LongPrimitiveArrayIterator(preferredItemIDs); + if (preferredItemIDs.length > maxItems) { + double samplingRate = (double) maxItems / preferredItemIDs.length; +// log.info("preferredItemIDs.length {}, samplingRate {}", preferredItemIDs.length, samplingRate); + preferredItemIDsIterator = + new SamplingLongPrimitiveIterator(preferredItemIDsIterator, samplingRate); + } + FastIDSet possibleItemsIDs = new FastIDSet(); + while (preferredItemIDsIterator.hasNext()) { + long itemID = preferredItemIDsIterator.nextLong(); + PreferenceArray prefs = dataModel.getPreferencesForItem(itemID); + int prefsLength = prefs.length(); + if (prefsLength > maxUsersPerItem) { + Iterator<Preference> sampledPrefs = + new FixedSizeSamplingIterator<>(maxUsersPerItem, prefs.iterator()); + while (sampledPrefs.hasNext()) { + addSomeOf(possibleItemsIDs, dataModel.getItemIDsFromUser(sampledPrefs.next().getUserID())); + } + } else { + for (int i = 0; i < prefsLength; i++) { + addSomeOf(possibleItemsIDs, dataModel.getItemIDsFromUser(prefs.getUserID(i))); + } + } + } + if (!includeKnownItems) { + possibleItemsIDs.removeAll(preferredItemIDs); + } + return possibleItemsIDs; + } + + private void addSomeOf(FastIDSet possibleItemIDs, FastIDSet itemIDs) { + if (itemIDs.size() > maxItemsPerUser) { + LongPrimitiveIterator it = + new SamplingLongPrimitiveIterator(itemIDs.iterator(), (double) maxItemsPerUser / itemIDs.size()); + while (it.hasNext()) { + possibleItemIDs.add(it.nextLong()); + } + } else { + possibleItemIDs.addAll(itemIDs); + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/SimilarUser.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/SimilarUser.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/SimilarUser.java new file mode 100644 index 0000000..c6d417f --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/SimilarUser.java @@ -0,0 +1,80 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.recommender; + +import org.apache.mahout.common.RandomUtils; + +/** Simply encapsulates a user and a similarity value. */ +public final class SimilarUser implements Comparable<SimilarUser> { + + private final long userID; + private final double similarity; + + public SimilarUser(long userID, double similarity) { + this.userID = userID; + this.similarity = similarity; + } + + long getUserID() { + return userID; + } + + double getSimilarity() { + return similarity; + } + + @Override + public int hashCode() { + return (int) userID ^ RandomUtils.hashDouble(similarity); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof SimilarUser)) { + return false; + } + SimilarUser other = (SimilarUser) o; + return userID == other.getUserID() && similarity == other.getSimilarity(); + } + + @Override + public String toString() { + return "SimilarUser[user:" + userID + ", similarity:" + similarity + ']'; + } + + /** Defines an ordering from most similar to least similar. */ + @Override + public int compareTo(SimilarUser other) { + double otherSimilarity = other.getSimilarity(); + if (similarity > otherSimilarity) { + return -1; + } + if (similarity < otherSimilarity) { + return 1; + } + long otherUserID = other.getUserID(); + if (userID < otherUserID) { + return -1; + } + if (userID > otherUserID) { + return 1; + } + return 0; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/TopItems.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/TopItems.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/TopItems.java new file mode 100644 index 0000000..f7b4385 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/TopItems.java @@ -0,0 +1,211 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.recommender; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.PriorityQueue; +import java.util.Queue; + +import com.google.common.base.Preconditions; +import org.apache.mahout.cf.taste.common.NoSuchItemException; +import org.apache.mahout.cf.taste.common.NoSuchUserException; +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; +import org.apache.mahout.cf.taste.impl.similarity.GenericItemSimilarity; +import org.apache.mahout.cf.taste.impl.similarity.GenericUserSimilarity; +import org.apache.mahout.cf.taste.recommender.IDRescorer; +import org.apache.mahout.cf.taste.recommender.RecommendedItem; + +/** + * <p> + * A simple class that refactors the "find top N things" logic that is used in several places. + * </p> + */ +public final class TopItems { + + private static final long[] NO_IDS = new long[0]; + + private TopItems() { } + + public static List<RecommendedItem> getTopItems(int howMany, + LongPrimitiveIterator possibleItemIDs, + IDRescorer rescorer, + Estimator<Long> estimator) throws TasteException { + Preconditions.checkArgument(possibleItemIDs != null, "possibleItemIDs is null"); + Preconditions.checkArgument(estimator != null, "estimator is null"); + + Queue<RecommendedItem> topItems = new PriorityQueue<>(howMany + 1, + Collections.reverseOrder(ByValueRecommendedItemComparator.getInstance())); + boolean full = false; + double lowestTopValue = Double.NEGATIVE_INFINITY; + while (possibleItemIDs.hasNext()) { + long itemID = possibleItemIDs.next(); + if (rescorer == null || !rescorer.isFiltered(itemID)) { + double preference; + try { + preference = estimator.estimate(itemID); + } catch (NoSuchItemException nsie) { + continue; + } + double rescoredPref = rescorer == null ? preference : rescorer.rescore(itemID, preference); + if (!Double.isNaN(rescoredPref) && (!full || rescoredPref > lowestTopValue)) { + topItems.add(new GenericRecommendedItem(itemID, (float) rescoredPref)); + if (full) { + topItems.poll(); + } else if (topItems.size() > howMany) { + full = true; + topItems.poll(); + } + lowestTopValue = topItems.peek().getValue(); + } + } + } + int size = topItems.size(); + if (size == 0) { + return Collections.emptyList(); + } + List<RecommendedItem> result = new ArrayList<>(size); + result.addAll(topItems); + Collections.sort(result, ByValueRecommendedItemComparator.getInstance()); + return result; + } + + public static long[] getTopUsers(int howMany, + LongPrimitiveIterator allUserIDs, + IDRescorer rescorer, + Estimator<Long> estimator) throws TasteException { + Queue<SimilarUser> topUsers = new PriorityQueue<>(howMany + 1, Collections.reverseOrder()); + boolean full = false; + double lowestTopValue = Double.NEGATIVE_INFINITY; + while (allUserIDs.hasNext()) { + long userID = allUserIDs.next(); + if (rescorer != null && rescorer.isFiltered(userID)) { + continue; + } + double similarity; + try { + similarity = estimator.estimate(userID); + } catch (NoSuchUserException nsue) { + continue; + } + double rescoredSimilarity = rescorer == null ? similarity : rescorer.rescore(userID, similarity); + if (!Double.isNaN(rescoredSimilarity) && (!full || rescoredSimilarity > lowestTopValue)) { + topUsers.add(new SimilarUser(userID, rescoredSimilarity)); + if (full) { + topUsers.poll(); + } else if (topUsers.size() > howMany) { + full = true; + topUsers.poll(); + } + lowestTopValue = topUsers.peek().getSimilarity(); + } + } + int size = topUsers.size(); + if (size == 0) { + return NO_IDS; + } + List<SimilarUser> sorted = new ArrayList<>(size); + sorted.addAll(topUsers); + Collections.sort(sorted); + long[] result = new long[size]; + int i = 0; + for (SimilarUser similarUser : sorted) { + result[i++] = similarUser.getUserID(); + } + return result; + } + + /** + * <p> + * Thanks to tsmorton for suggesting this functionality and writing part of the code. + * </p> + * + * @see GenericItemSimilarity#GenericItemSimilarity(Iterable, int) + * @see GenericItemSimilarity#GenericItemSimilarity(org.apache.mahout.cf.taste.similarity.ItemSimilarity, + * org.apache.mahout.cf.taste.model.DataModel, int) + */ + public static List<GenericItemSimilarity.ItemItemSimilarity> getTopItemItemSimilarities( + int howMany, Iterator<GenericItemSimilarity.ItemItemSimilarity> allSimilarities) { + + Queue<GenericItemSimilarity.ItemItemSimilarity> topSimilarities + = new PriorityQueue<>(howMany + 1, Collections.reverseOrder()); + boolean full = false; + double lowestTopValue = Double.NEGATIVE_INFINITY; + while (allSimilarities.hasNext()) { + GenericItemSimilarity.ItemItemSimilarity similarity = allSimilarities.next(); + double value = similarity.getValue(); + if (!Double.isNaN(value) && (!full || value > lowestTopValue)) { + topSimilarities.add(similarity); + if (full) { + topSimilarities.poll(); + } else if (topSimilarities.size() > howMany) { + full = true; + topSimilarities.poll(); + } + lowestTopValue = topSimilarities.peek().getValue(); + } + } + int size = topSimilarities.size(); + if (size == 0) { + return Collections.emptyList(); + } + List<GenericItemSimilarity.ItemItemSimilarity> result = new ArrayList<>(size); + result.addAll(topSimilarities); + Collections.sort(result); + return result; + } + + public static List<GenericUserSimilarity.UserUserSimilarity> getTopUserUserSimilarities( + int howMany, Iterator<GenericUserSimilarity.UserUserSimilarity> allSimilarities) { + + Queue<GenericUserSimilarity.UserUserSimilarity> topSimilarities + = new PriorityQueue<>(howMany + 1, Collections.reverseOrder()); + boolean full = false; + double lowestTopValue = Double.NEGATIVE_INFINITY; + while (allSimilarities.hasNext()) { + GenericUserSimilarity.UserUserSimilarity similarity = allSimilarities.next(); + double value = similarity.getValue(); + if (!Double.isNaN(value) && (!full || value > lowestTopValue)) { + topSimilarities.add(similarity); + if (full) { + topSimilarities.poll(); + } else if (topSimilarities.size() > howMany) { + full = true; + topSimilarities.poll(); + } + lowestTopValue = topSimilarities.peek().getValue(); + } + } + int size = topSimilarities.size(); + if (size == 0) { + return Collections.emptyList(); + } + List<GenericUserSimilarity.UserUserSimilarity> result = new ArrayList<>(size); + result.addAll(topSimilarities); + Collections.sort(result); + return result; + } + + public interface Estimator<T> { + double estimate(T thing) throws TasteException; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java new file mode 100644 index 0000000..0ba5139 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java @@ -0,0 +1,312 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.recommender.svd; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.FullRunningAverage; +import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; +import org.apache.mahout.cf.taste.impl.common.RunningAverage; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.model.Preference; +import org.apache.mahout.cf.taste.model.PreferenceArray; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.SequentialAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.als.AlternatingLeastSquaresSolver; +import org.apache.mahout.math.als.ImplicitFeedbackAlternatingLeastSquaresSolver; +import org.apache.mahout.math.map.OpenIntObjectHashMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * factorizes the rating matrix using "Alternating-Least-Squares with Weighted-λ-Regularization" as described in + * <a href="http://www.hpl.hp.com/personal/Robert_Schreiber/papers/2008%20AAIM%20Netflix/netflix_aaim08(submitted).pdf"> + * "Large-scale Collaborative Filtering for the Netflix Prize"</a> + * + * also supports the implicit feedback variant of this approach as described in "Collaborative Filtering for Implicit + * Feedback Datasets" available at http://research.yahoo.com/pub/2433 + */ +public class ALSWRFactorizer extends AbstractFactorizer { + + private final DataModel dataModel; + + /** number of features used to compute this factorization */ + private final int numFeatures; + /** parameter to control the regularization */ + private final double lambda; + /** number of iterations */ + private final int numIterations; + + private final boolean usesImplicitFeedback; + /** confidence weighting parameter, only necessary when working with implicit feedback */ + private final double alpha; + + private final int numTrainingThreads; + + private static final double DEFAULT_ALPHA = 40; + + private static final Logger log = LoggerFactory.getLogger(ALSWRFactorizer.class); + + public ALSWRFactorizer(DataModel dataModel, int numFeatures, double lambda, int numIterations, + boolean usesImplicitFeedback, double alpha, int numTrainingThreads) throws TasteException { + super(dataModel); + this.dataModel = dataModel; + this.numFeatures = numFeatures; + this.lambda = lambda; + this.numIterations = numIterations; + this.usesImplicitFeedback = usesImplicitFeedback; + this.alpha = alpha; + this.numTrainingThreads = numTrainingThreads; + } + + public ALSWRFactorizer(DataModel dataModel, int numFeatures, double lambda, int numIterations, + boolean usesImplicitFeedback, double alpha) throws TasteException { + this(dataModel, numFeatures, lambda, numIterations, usesImplicitFeedback, alpha, + Runtime.getRuntime().availableProcessors()); + } + + public ALSWRFactorizer(DataModel dataModel, int numFeatures, double lambda, int numIterations) throws TasteException { + this(dataModel, numFeatures, lambda, numIterations, false, DEFAULT_ALPHA); + } + + static class Features { + + private final DataModel dataModel; + private final int numFeatures; + + private final double[][] M; + private final double[][] U; + + Features(ALSWRFactorizer factorizer) throws TasteException { + dataModel = factorizer.dataModel; + numFeatures = factorizer.numFeatures; + Random random = RandomUtils.getRandom(); + M = new double[dataModel.getNumItems()][numFeatures]; + LongPrimitiveIterator itemIDsIterator = dataModel.getItemIDs(); + while (itemIDsIterator.hasNext()) { + long itemID = itemIDsIterator.nextLong(); + int itemIDIndex = factorizer.itemIndex(itemID); + M[itemIDIndex][0] = averateRating(itemID); + for (int feature = 1; feature < numFeatures; feature++) { + M[itemIDIndex][feature] = random.nextDouble() * 0.1; + } + } + U = new double[dataModel.getNumUsers()][numFeatures]; + } + + double[][] getM() { + return M; + } + + double[][] getU() { + return U; + } + + Vector getUserFeatureColumn(int index) { + return new DenseVector(U[index]); + } + + Vector getItemFeatureColumn(int index) { + return new DenseVector(M[index]); + } + + void setFeatureColumnInU(int idIndex, Vector vector) { + setFeatureColumn(U, idIndex, vector); + } + + void setFeatureColumnInM(int idIndex, Vector vector) { + setFeatureColumn(M, idIndex, vector); + } + + protected void setFeatureColumn(double[][] matrix, int idIndex, Vector vector) { + for (int feature = 0; feature < numFeatures; feature++) { + matrix[idIndex][feature] = vector.get(feature); + } + } + + protected double averateRating(long itemID) throws TasteException { + PreferenceArray prefs = dataModel.getPreferencesForItem(itemID); + RunningAverage avg = new FullRunningAverage(); + for (Preference pref : prefs) { + avg.addDatum(pref.getValue()); + } + return avg.getAverage(); + } + } + + @Override + public Factorization factorize() throws TasteException { + log.info("starting to compute the factorization..."); + final Features features = new Features(this); + + /* feature maps necessary for solving for implicit feedback */ + OpenIntObjectHashMap<Vector> userY = null; + OpenIntObjectHashMap<Vector> itemY = null; + + if (usesImplicitFeedback) { + userY = userFeaturesMapping(dataModel.getUserIDs(), dataModel.getNumUsers(), features.getU()); + itemY = itemFeaturesMapping(dataModel.getItemIDs(), dataModel.getNumItems(), features.getM()); + } + + for (int iteration = 0; iteration < numIterations; iteration++) { + log.info("iteration {}", iteration); + + /* fix M - compute U */ + ExecutorService queue = createQueue(); + LongPrimitiveIterator userIDsIterator = dataModel.getUserIDs(); + try { + + final ImplicitFeedbackAlternatingLeastSquaresSolver implicitFeedbackSolver = usesImplicitFeedback + ? new ImplicitFeedbackAlternatingLeastSquaresSolver(numFeatures, lambda, alpha, itemY, numTrainingThreads) + : null; + + while (userIDsIterator.hasNext()) { + final long userID = userIDsIterator.nextLong(); + final LongPrimitiveIterator itemIDsFromUser = dataModel.getItemIDsFromUser(userID).iterator(); + final PreferenceArray userPrefs = dataModel.getPreferencesFromUser(userID); + queue.execute(new Runnable() { + @Override + public void run() { + List<Vector> featureVectors = new ArrayList<>(); + while (itemIDsFromUser.hasNext()) { + long itemID = itemIDsFromUser.nextLong(); + featureVectors.add(features.getItemFeatureColumn(itemIndex(itemID))); + } + + Vector userFeatures = usesImplicitFeedback + ? implicitFeedbackSolver.solve(sparseUserRatingVector(userPrefs)) + : AlternatingLeastSquaresSolver.solve(featureVectors, ratingVector(userPrefs), lambda, numFeatures); + + features.setFeatureColumnInU(userIndex(userID), userFeatures); + } + }); + } + } finally { + queue.shutdown(); + try { + queue.awaitTermination(dataModel.getNumUsers(), TimeUnit.SECONDS); + } catch (InterruptedException e) { + log.warn("Error when computing user features", e); + } + } + + /* fix U - compute M */ + queue = createQueue(); + LongPrimitiveIterator itemIDsIterator = dataModel.getItemIDs(); + try { + + final ImplicitFeedbackAlternatingLeastSquaresSolver implicitFeedbackSolver = usesImplicitFeedback + ? new ImplicitFeedbackAlternatingLeastSquaresSolver(numFeatures, lambda, alpha, userY, numTrainingThreads) + : null; + + while (itemIDsIterator.hasNext()) { + final long itemID = itemIDsIterator.nextLong(); + final PreferenceArray itemPrefs = dataModel.getPreferencesForItem(itemID); + queue.execute(new Runnable() { + @Override + public void run() { + List<Vector> featureVectors = new ArrayList<>(); + for (Preference pref : itemPrefs) { + long userID = pref.getUserID(); + featureVectors.add(features.getUserFeatureColumn(userIndex(userID))); + } + + Vector itemFeatures = usesImplicitFeedback + ? implicitFeedbackSolver.solve(sparseItemRatingVector(itemPrefs)) + : AlternatingLeastSquaresSolver.solve(featureVectors, ratingVector(itemPrefs), lambda, numFeatures); + + features.setFeatureColumnInM(itemIndex(itemID), itemFeatures); + } + }); + } + } finally { + queue.shutdown(); + try { + queue.awaitTermination(dataModel.getNumItems(), TimeUnit.SECONDS); + } catch (InterruptedException e) { + log.warn("Error when computing item features", e); + } + } + } + + log.info("finished computation of the factorization..."); + return createFactorization(features.getU(), features.getM()); + } + + protected ExecutorService createQueue() { + return Executors.newFixedThreadPool(numTrainingThreads); + } + + protected static Vector ratingVector(PreferenceArray prefs) { + double[] ratings = new double[prefs.length()]; + for (int n = 0; n < prefs.length(); n++) { + ratings[n] = prefs.get(n).getValue(); + } + return new DenseVector(ratings, true); + } + + //TODO find a way to get rid of the object overhead here + protected OpenIntObjectHashMap<Vector> itemFeaturesMapping(LongPrimitiveIterator itemIDs, int numItems, + double[][] featureMatrix) { + OpenIntObjectHashMap<Vector> mapping = new OpenIntObjectHashMap<>(numItems); + while (itemIDs.hasNext()) { + long itemID = itemIDs.next(); + int itemIndex = itemIndex(itemID); + mapping.put(itemIndex, new DenseVector(featureMatrix[itemIndex(itemID)], true)); + } + + return mapping; + } + + protected OpenIntObjectHashMap<Vector> userFeaturesMapping(LongPrimitiveIterator userIDs, int numUsers, + double[][] featureMatrix) { + OpenIntObjectHashMap<Vector> mapping = new OpenIntObjectHashMap<>(numUsers); + + while (userIDs.hasNext()) { + long userID = userIDs.next(); + int userIndex = userIndex(userID); + mapping.put(userIndex, new DenseVector(featureMatrix[userIndex(userID)], true)); + } + + return mapping; + } + + protected Vector sparseItemRatingVector(PreferenceArray prefs) { + SequentialAccessSparseVector ratings = new SequentialAccessSparseVector(Integer.MAX_VALUE, prefs.length()); + for (Preference preference : prefs) { + ratings.set(userIndex(preference.getUserID()), preference.getValue()); + } + return ratings; + } + + protected Vector sparseUserRatingVector(PreferenceArray prefs) { + SequentialAccessSparseVector ratings = new SequentialAccessSparseVector(Integer.MAX_VALUE, prefs.length()); + for (Preference preference : prefs) { + ratings.set(itemIndex(preference.getItemID()), preference.getValue()); + } + return ratings; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/AbstractFactorizer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/AbstractFactorizer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/AbstractFactorizer.java new file mode 100644 index 0000000..0a39a1d --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/AbstractFactorizer.java @@ -0,0 +1,94 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.recommender.svd; + +import java.util.Collection; +import java.util.concurrent.Callable; + +import org.apache.mahout.cf.taste.common.Refreshable; +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.FastByIDMap; +import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; +import org.apache.mahout.cf.taste.impl.common.RefreshHelper; +import org.apache.mahout.cf.taste.model.DataModel; + +/** + * base class for {@link Factorizer}s, provides ID to index mapping + */ +public abstract class AbstractFactorizer implements Factorizer { + + private final DataModel dataModel; + private FastByIDMap<Integer> userIDMapping; + private FastByIDMap<Integer> itemIDMapping; + private final RefreshHelper refreshHelper; + + protected AbstractFactorizer(DataModel dataModel) throws TasteException { + this.dataModel = dataModel; + buildMappings(); + refreshHelper = new RefreshHelper(new Callable<Object>() { + @Override + public Object call() throws TasteException { + buildMappings(); + return null; + } + }); + refreshHelper.addDependency(dataModel); + } + + private void buildMappings() throws TasteException { + userIDMapping = createIDMapping(dataModel.getNumUsers(), dataModel.getUserIDs()); + itemIDMapping = createIDMapping(dataModel.getNumItems(), dataModel.getItemIDs()); + } + + protected Factorization createFactorization(double[][] userFeatures, double[][] itemFeatures) { + return new Factorization(userIDMapping, itemIDMapping, userFeatures, itemFeatures); + } + + protected Integer userIndex(long userID) { + Integer userIndex = userIDMapping.get(userID); + if (userIndex == null) { + userIndex = userIDMapping.size(); + userIDMapping.put(userID, userIndex); + } + return userIndex; + } + + protected Integer itemIndex(long itemID) { + Integer itemIndex = itemIDMapping.get(itemID); + if (itemIndex == null) { + itemIndex = itemIDMapping.size(); + itemIDMapping.put(itemID, itemIndex); + } + return itemIndex; + } + + private static FastByIDMap<Integer> createIDMapping(int size, LongPrimitiveIterator idIterator) { + FastByIDMap<Integer> mapping = new FastByIDMap<>(size); + int index = 0; + while (idIterator.hasNext()) { + mapping.put(idIterator.nextLong(), index++); + } + return mapping; + } + + @Override + public void refresh(Collection<Refreshable> alreadyRefreshed) { + refreshHelper.refresh(alreadyRefreshed); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorization.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorization.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorization.java new file mode 100644 index 0000000..f169a60 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorization.java @@ -0,0 +1,137 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.recommender.svd; + +import java.util.Arrays; +import java.util.Map; + +import com.google.common.base.Preconditions; +import org.apache.mahout.cf.taste.common.NoSuchItemException; +import org.apache.mahout.cf.taste.common.NoSuchUserException; +import org.apache.mahout.cf.taste.impl.common.FastByIDMap; +import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; + +/** + * a factorization of the rating matrix + */ +public class Factorization { + + /** used to find the rows in the user features matrix by userID */ + private final FastByIDMap<Integer> userIDMapping; + /** used to find the rows in the item features matrix by itemID */ + private final FastByIDMap<Integer> itemIDMapping; + + /** user features matrix */ + private final double[][] userFeatures; + /** item features matrix */ + private final double[][] itemFeatures; + + public Factorization(FastByIDMap<Integer> userIDMapping, FastByIDMap<Integer> itemIDMapping, double[][] userFeatures, + double[][] itemFeatures) { + this.userIDMapping = Preconditions.checkNotNull(userIDMapping); + this.itemIDMapping = Preconditions.checkNotNull(itemIDMapping); + this.userFeatures = userFeatures; + this.itemFeatures = itemFeatures; + } + + public double[][] allUserFeatures() { + return userFeatures; + } + + public double[] getUserFeatures(long userID) throws NoSuchUserException { + Integer index = userIDMapping.get(userID); + if (index == null) { + throw new NoSuchUserException(userID); + } + return userFeatures[index]; + } + + public double[][] allItemFeatures() { + return itemFeatures; + } + + public double[] getItemFeatures(long itemID) throws NoSuchItemException { + Integer index = itemIDMapping.get(itemID); + if (index == null) { + throw new NoSuchItemException(itemID); + } + return itemFeatures[index]; + } + + public int userIndex(long userID) throws NoSuchUserException { + Integer index = userIDMapping.get(userID); + if (index == null) { + throw new NoSuchUserException(userID); + } + return index; + } + + public Iterable<Map.Entry<Long,Integer>> getUserIDMappings() { + return userIDMapping.entrySet(); + } + + public LongPrimitiveIterator getUserIDMappingKeys() { + return userIDMapping.keySetIterator(); + } + + public int itemIndex(long itemID) throws NoSuchItemException { + Integer index = itemIDMapping.get(itemID); + if (index == null) { + throw new NoSuchItemException(itemID); + } + return index; + } + + public Iterable<Map.Entry<Long,Integer>> getItemIDMappings() { + return itemIDMapping.entrySet(); + } + + public LongPrimitiveIterator getItemIDMappingKeys() { + return itemIDMapping.keySetIterator(); + } + + public int numFeatures() { + return userFeatures.length > 0 ? userFeatures[0].length : 0; + } + + public int numUsers() { + return userIDMapping.size(); + } + + public int numItems() { + return itemIDMapping.size(); + } + + @Override + public boolean equals(Object o) { + if (o instanceof Factorization) { + Factorization other = (Factorization) o; + return userIDMapping.equals(other.userIDMapping) && itemIDMapping.equals(other.itemIDMapping) + && Arrays.deepEquals(userFeatures, other.userFeatures) && Arrays.deepEquals(itemFeatures, other.itemFeatures); + } + return false; + } + + @Override + public int hashCode() { + int hashCode = 31 * userIDMapping.hashCode() + itemIDMapping.hashCode(); + hashCode = 31 * hashCode + Arrays.deepHashCode(userFeatures); + hashCode = 31 * hashCode + Arrays.deepHashCode(itemFeatures); + return hashCode; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorizer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorizer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorizer.java new file mode 100644 index 0000000..2cabe73 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorizer.java @@ -0,0 +1,30 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.recommender.svd; + +import org.apache.mahout.cf.taste.common.Refreshable; +import org.apache.mahout.cf.taste.common.TasteException; + +/** + * Implementation must be able to create a factorization of a rating matrix + */ +public interface Factorizer extends Refreshable { + + Factorization factorize() throws TasteException; + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/FilePersistenceStrategy.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/FilePersistenceStrategy.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/FilePersistenceStrategy.java new file mode 100644 index 0000000..08c038a --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/FilePersistenceStrategy.java @@ -0,0 +1,139 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.recommender.svd; + +import java.io.BufferedInputStream; +import java.io.BufferedOutputStream; +import java.io.DataInput; +import java.io.DataInputStream; +import java.io.DataOutput; +import java.io.DataOutputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.util.Map; + +import com.google.common.base.Preconditions; +import org.apache.mahout.cf.taste.common.NoSuchItemException; +import org.apache.mahout.cf.taste.common.NoSuchUserException; +import org.apache.mahout.cf.taste.impl.common.FastByIDMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Provides a file-based persistent store. */ +public class FilePersistenceStrategy implements PersistenceStrategy { + + private final File file; + + private static final Logger log = LoggerFactory.getLogger(FilePersistenceStrategy.class); + + /** + * @param file the file to use for storage. If the file does not exist it will be created when required. + */ + public FilePersistenceStrategy(File file) { + this.file = Preconditions.checkNotNull(file); + } + + @Override + public Factorization load() throws IOException { + if (!file.exists()) { + log.info("{} does not yet exist, no factorization found", file.getAbsolutePath()); + return null; + } + try (DataInputStream in = new DataInputStream(new BufferedInputStream(new FileInputStream(file)))){ + log.info("Reading factorization from {}...", file.getAbsolutePath()); + return readBinary(in); + } + } + + @Override + public void maybePersist(Factorization factorization) throws IOException { + try (DataOutputStream out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(file)))){ + log.info("Writing factorization to {}...", file.getAbsolutePath()); + writeBinary(factorization, out); + } + } + + protected static void writeBinary(Factorization factorization, DataOutput out) throws IOException { + out.writeInt(factorization.numFeatures()); + out.writeInt(factorization.numUsers()); + out.writeInt(factorization.numItems()); + + for (Map.Entry<Long,Integer> mappingEntry : factorization.getUserIDMappings()) { + long userID = mappingEntry.getKey(); + out.writeInt(mappingEntry.getValue()); + out.writeLong(userID); + try { + double[] userFeatures = factorization.getUserFeatures(userID); + for (int feature = 0; feature < factorization.numFeatures(); feature++) { + out.writeDouble(userFeatures[feature]); + } + } catch (NoSuchUserException e) { + throw new IOException("Unable to persist factorization", e); + } + } + + for (Map.Entry<Long,Integer> entry : factorization.getItemIDMappings()) { + long itemID = entry.getKey(); + out.writeInt(entry.getValue()); + out.writeLong(itemID); + try { + double[] itemFeatures = factorization.getItemFeatures(itemID); + for (int feature = 0; feature < factorization.numFeatures(); feature++) { + out.writeDouble(itemFeatures[feature]); + } + } catch (NoSuchItemException e) { + throw new IOException("Unable to persist factorization", e); + } + } + } + + public static Factorization readBinary(DataInput in) throws IOException { + int numFeatures = in.readInt(); + int numUsers = in.readInt(); + int numItems = in.readInt(); + + FastByIDMap<Integer> userIDMapping = new FastByIDMap<>(numUsers); + double[][] userFeatures = new double[numUsers][numFeatures]; + + for (int n = 0; n < numUsers; n++) { + int userIndex = in.readInt(); + long userID = in.readLong(); + userIDMapping.put(userID, userIndex); + for (int feature = 0; feature < numFeatures; feature++) { + userFeatures[userIndex][feature] = in.readDouble(); + } + } + + FastByIDMap<Integer> itemIDMapping = new FastByIDMap<>(numItems); + double[][] itemFeatures = new double[numItems][numFeatures]; + + for (int n = 0; n < numItems; n++) { + int itemIndex = in.readInt(); + long itemID = in.readLong(); + itemIDMapping.put(itemID, itemIndex); + for (int feature = 0; feature < numFeatures; feature++) { + itemFeatures[itemIndex][feature] = in.readDouble(); + } + } + + return new Factorization(userIDMapping, itemIDMapping, userFeatures, itemFeatures); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/NoPersistenceStrategy.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/NoPersistenceStrategy.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/NoPersistenceStrategy.java new file mode 100644 index 0000000..0d1aab0 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/NoPersistenceStrategy.java @@ -0,0 +1,37 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.recommender.svd; + +import java.io.IOException; + +/** + * A {@link PersistenceStrategy} which does nothing. + */ +public class NoPersistenceStrategy implements PersistenceStrategy { + + @Override + public Factorization load() throws IOException { + return null; + } + + @Override + public void maybePersist(Factorization factorization) throws IOException { + // do nothing. + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ParallelSGDFactorizer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ParallelSGDFactorizer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ParallelSGDFactorizer.java new file mode 100644 index 0000000..8a6a702 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ParallelSGDFactorizer.java @@ -0,0 +1,340 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.recommender.svd; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.FullRunningAverage; +import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; +import org.apache.mahout.cf.taste.impl.common.RunningAverage; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.model.Preference; +import org.apache.mahout.cf.taste.model.PreferenceArray; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.common.RandomWrapper; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Minimalistic implementation of Parallel SGD factorizer based on + * <a href="http://www.sze.hu/~gtakacs/download/jmlr_2009.pdf"> + * "Scalable Collaborative Filtering Approaches for Large Recommender Systems"</a> + * and + * <a href="hwww.cs.wisc.edu/~brecht/papers/hogwildTR.pdf"> + * "Hogwild!: A Lock-Free Approach to Parallelizing Stochastic Gradient Descent"</a> */ +public class ParallelSGDFactorizer extends AbstractFactorizer { + + private final DataModel dataModel; + /** Parameter used to prevent overfitting. */ + private final double lambda; + /** Number of features used to compute this factorization */ + private final int rank; + /** Number of iterations */ + private final int numEpochs; + + private int numThreads; + + // these next two control decayFactor^steps exponential type of annealing learning rate and decay factor + private double mu0 = 0.01; + private double decayFactor = 1; + // these next two control 1/steps^forget type annealing + private int stepOffset = 0; + // -1 equals even weighting of all examples, 0 means only use exponential annealing + private double forgettingExponent = 0; + + // The following two should be inversely proportional :) + private double biasMuRatio = 0.5; + private double biasLambdaRatio = 0.1; + + /** TODO: this is not safe as += is not atomic on many processors, can be replaced with AtomicDoubleArray + * but it works just fine right now */ + /** user features */ + protected volatile double[][] userVectors; + /** item features */ + protected volatile double[][] itemVectors; + + private final PreferenceShuffler shuffler; + + private int epoch = 1; + /** place in user vector where the bias is stored */ + private static final int USER_BIAS_INDEX = 1; + /** place in item vector where the bias is stored */ + private static final int ITEM_BIAS_INDEX = 2; + private static final int FEATURE_OFFSET = 3; + /** Standard deviation for random initialization of features */ + private static final double NOISE = 0.02; + + private static final Logger logger = LoggerFactory.getLogger(ParallelSGDFactorizer.class); + + protected static class PreferenceShuffler { + + private Preference[] preferences; + private Preference[] unstagedPreferences; + + protected final RandomWrapper random = RandomUtils.getRandom(); + + public PreferenceShuffler(DataModel dataModel) throws TasteException { + cachePreferences(dataModel); + shuffle(); + stage(); + } + + private int countPreferences(DataModel dataModel) throws TasteException { + int numPreferences = 0; + LongPrimitiveIterator userIDs = dataModel.getUserIDs(); + while (userIDs.hasNext()) { + PreferenceArray preferencesFromUser = dataModel.getPreferencesFromUser(userIDs.nextLong()); + numPreferences += preferencesFromUser.length(); + } + return numPreferences; + } + + private void cachePreferences(DataModel dataModel) throws TasteException { + int numPreferences = countPreferences(dataModel); + preferences = new Preference[numPreferences]; + + LongPrimitiveIterator userIDs = dataModel.getUserIDs(); + int index = 0; + while (userIDs.hasNext()) { + long userID = userIDs.nextLong(); + PreferenceArray preferencesFromUser = dataModel.getPreferencesFromUser(userID); + for (Preference preference : preferencesFromUser) { + preferences[index++] = preference; + } + } + } + + public final void shuffle() { + unstagedPreferences = preferences.clone(); + /* Durstenfeld shuffle */ + for (int i = unstagedPreferences.length - 1; i > 0; i--) { + int rand = random.nextInt(i + 1); + swapCachedPreferences(i, rand); + } + } + + //merge this part into shuffle() will make compiler-optimizer do some real absurd stuff, test on OpenJDK7 + private void swapCachedPreferences(int x, int y) { + Preference p = unstagedPreferences[x]; + + unstagedPreferences[x] = unstagedPreferences[y]; + unstagedPreferences[y] = p; + } + + public final void stage() { + preferences = unstagedPreferences; + } + + public Preference get(int i) { + return preferences[i]; + } + + public int size() { + return preferences.length; + } + + } + + public ParallelSGDFactorizer(DataModel dataModel, int numFeatures, double lambda, int numEpochs) + throws TasteException { + super(dataModel); + this.dataModel = dataModel; + this.rank = numFeatures + FEATURE_OFFSET; + this.lambda = lambda; + this.numEpochs = numEpochs; + + shuffler = new PreferenceShuffler(dataModel); + + //max thread num set to n^0.25 as suggested by hogwild! paper + numThreads = Math.min(Runtime.getRuntime().availableProcessors(), (int) Math.pow((double) shuffler.size(), 0.25)); + } + + public ParallelSGDFactorizer(DataModel dataModel, int numFeatures, double lambda, int numIterations, + double mu0, double decayFactor, int stepOffset, double forgettingExponent) throws TasteException { + this(dataModel, numFeatures, lambda, numIterations); + + this.mu0 = mu0; + this.decayFactor = decayFactor; + this.stepOffset = stepOffset; + this.forgettingExponent = forgettingExponent; + } + + public ParallelSGDFactorizer(DataModel dataModel, int numFeatures, double lambda, int numIterations, + double mu0, double decayFactor, int stepOffset, double forgettingExponent, int numThreads) throws TasteException { + this(dataModel, numFeatures, lambda, numIterations, mu0, decayFactor, stepOffset, forgettingExponent); + + this.numThreads = numThreads; + } + + public ParallelSGDFactorizer(DataModel dataModel, int numFeatures, double lambda, int numIterations, + double mu0, double decayFactor, int stepOffset, double forgettingExponent, + double biasMuRatio, double biasLambdaRatio) throws TasteException { + this(dataModel, numFeatures, lambda, numIterations, mu0, decayFactor, stepOffset, forgettingExponent); + + this.biasMuRatio = biasMuRatio; + this.biasLambdaRatio = biasLambdaRatio; + } + + public ParallelSGDFactorizer(DataModel dataModel, int numFeatures, double lambda, int numIterations, + double mu0, double decayFactor, int stepOffset, double forgettingExponent, + double biasMuRatio, double biasLambdaRatio, int numThreads) throws TasteException { + this(dataModel, numFeatures, lambda, numIterations, mu0, decayFactor, stepOffset, forgettingExponent, biasMuRatio, + biasLambdaRatio); + + this.numThreads = numThreads; + } + + protected void initialize() throws TasteException { + RandomWrapper random = RandomUtils.getRandom(); + userVectors = new double[dataModel.getNumUsers()][rank]; + itemVectors = new double[dataModel.getNumItems()][rank]; + + double globalAverage = getAveragePreference(); + for (int userIndex = 0; userIndex < userVectors.length; userIndex++) { + userVectors[userIndex][0] = globalAverage; + userVectors[userIndex][USER_BIAS_INDEX] = 0; // will store user bias + userVectors[userIndex][ITEM_BIAS_INDEX] = 1; // corresponding item feature contains item bias + for (int feature = FEATURE_OFFSET; feature < rank; feature++) { + userVectors[userIndex][feature] = random.nextGaussian() * NOISE; + } + } + for (int itemIndex = 0; itemIndex < itemVectors.length; itemIndex++) { + itemVectors[itemIndex][0] = 1; // corresponding user feature contains global average + itemVectors[itemIndex][USER_BIAS_INDEX] = 1; // corresponding user feature contains user bias + itemVectors[itemIndex][ITEM_BIAS_INDEX] = 0; // will store item bias + for (int feature = FEATURE_OFFSET; feature < rank; feature++) { + itemVectors[itemIndex][feature] = random.nextGaussian() * NOISE; + } + } + } + + //TODO: needs optimization + private double getMu(int i) { + return mu0 * Math.pow(decayFactor, i - 1) * Math.pow(i + stepOffset, forgettingExponent); + } + + @Override + public Factorization factorize() throws TasteException { + initialize(); + + if (logger.isInfoEnabled()) { + logger.info("starting to compute the factorization..."); + } + + for (epoch = 1; epoch <= numEpochs; epoch++) { + shuffler.stage(); + + final double mu = getMu(epoch); + int subSize = shuffler.size() / numThreads + 1; + + ExecutorService executor=Executors.newFixedThreadPool(numThreads); + + try { + for (int t = 0; t < numThreads; t++) { + final int iStart = t * subSize; + final int iEnd = Math.min((t + 1) * subSize, shuffler.size()); + + executor.execute(new Runnable() { + @Override + public void run() { + for (int i = iStart; i < iEnd; i++) { + update(shuffler.get(i), mu); + } + } + }); + } + } finally { + executor.shutdown(); + shuffler.shuffle(); + + try { + boolean terminated = executor.awaitTermination(numEpochs * shuffler.size(), TimeUnit.MICROSECONDS); + if (!terminated) { + logger.error("subtasks takes forever, return anyway"); + } + } catch (InterruptedException e) { + throw new TasteException("waiting fof termination interrupted", e); + } + } + + } + + return createFactorization(userVectors, itemVectors); + } + + double getAveragePreference() throws TasteException { + RunningAverage average = new FullRunningAverage(); + LongPrimitiveIterator it = dataModel.getUserIDs(); + while (it.hasNext()) { + for (Preference pref : dataModel.getPreferencesFromUser(it.nextLong())) { + average.addDatum(pref.getValue()); + } + } + return average.getAverage(); + } + + /** TODO: this is the vanilla sgd by Tacaks 2009, I speculate that using scaling technique proposed in: + * Towards Optimal One Pass Large Scale Learning with Averaged Stochastic Gradient Descent section 5, page 6 + * can be beneficial in term s of both speed and accuracy. + * + * Tacaks' method doesn't calculate gradient of regularization correctly, which has non-zero elements everywhere of + * the matrix. While Tacaks' method can only updates a single row/column, if one user has a lot of recommendation, + * her vector will be more affected by regularization using an isolated scaling factor for both user vectors and + * item vectors can remove this issue without inducing more update cost it even reduces it a bit by only performing + * one addition and one multiplication. + * + * BAD SIDE1: the scaling factor decreases fast, it has to be scaled up from time to time before dropped to zero or + * caused roundoff error + * BAD SIDE2: no body experiment on it before, and people generally use very small lambda + * so it's impact on accuracy may still be unknown. + * BAD SIDE3: don't know how to make it work for L1-regularization or + * "pseudorank?" (sum of singular values)-regularization */ + protected void update(Preference preference, double mu) { + int userIndex = userIndex(preference.getUserID()); + int itemIndex = itemIndex(preference.getItemID()); + + double[] userVector = userVectors[userIndex]; + double[] itemVector = itemVectors[itemIndex]; + + double prediction = dot(userVector, itemVector); + double err = preference.getValue() - prediction; + + // adjust features + for (int k = FEATURE_OFFSET; k < rank; k++) { + double userFeature = userVector[k]; + double itemFeature = itemVector[k]; + + userVector[k] += mu * (err * itemFeature - lambda * userFeature); + itemVector[k] += mu * (err * userFeature - lambda * itemFeature); + } + + // adjust user and item bias + userVector[USER_BIAS_INDEX] += biasMuRatio * mu * (err - biasLambdaRatio * lambda * userVector[USER_BIAS_INDEX]); + itemVector[ITEM_BIAS_INDEX] += biasMuRatio * mu * (err - biasLambdaRatio * lambda * itemVector[ITEM_BIAS_INDEX]); + } + + private double dot(double[] userVector, double[] itemVector) { + double sum = 0; + for (int k = 0; k < rank; k++) { + sum += userVector[k] * itemVector[k]; + } + return sum; + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/PersistenceStrategy.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/PersistenceStrategy.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/PersistenceStrategy.java new file mode 100644 index 0000000..abf3eca --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/PersistenceStrategy.java @@ -0,0 +1,46 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.recommender.svd; + +import java.io.IOException; + +/** + * Provides storage for {@link Factorization}s + */ +public interface PersistenceStrategy { + + /** + * Load a factorization from a persistent store. + * + * @return a Factorization or null if the persistent store is empty. + * + * @throws IOException + */ + Factorization load() throws IOException; + + /** + * Write a factorization to a persistent store unless it already + * contains an identical factorization. + * + * @param factorization + * + * @throws IOException + */ + void maybePersist(Factorization factorization) throws IOException; + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/RatingSGDFactorizer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/RatingSGDFactorizer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/RatingSGDFactorizer.java new file mode 100644 index 0000000..2c9f0ae --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/RatingSGDFactorizer.java @@ -0,0 +1,221 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.recommender.svd; + +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.FullRunningAverage; +import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; +import org.apache.mahout.cf.taste.impl.common.RunningAverage; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.model.Preference; +import org.apache.mahout.cf.taste.model.PreferenceArray; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.common.RandomWrapper; + +/** Matrix factorization with user and item biases for rating prediction, trained with plain vanilla SGD */ +public class RatingSGDFactorizer extends AbstractFactorizer { + + protected static final int FEATURE_OFFSET = 3; + + /** Multiplicative decay factor for learning_rate */ + protected final double learningRateDecay; + /** Learning rate (step size) */ + protected final double learningRate; + /** Parameter used to prevent overfitting. */ + protected final double preventOverfitting; + /** Number of features used to compute this factorization */ + protected final int numFeatures; + /** Number of iterations */ + private final int numIterations; + /** Standard deviation for random initialization of features */ + protected final double randomNoise; + /** User features */ + protected double[][] userVectors; + /** Item features */ + protected double[][] itemVectors; + protected final DataModel dataModel; + private long[] cachedUserIDs; + private long[] cachedItemIDs; + + protected double biasLearningRate = 0.5; + protected double biasReg = 0.1; + + /** place in user vector where the bias is stored */ + protected static final int USER_BIAS_INDEX = 1; + /** place in item vector where the bias is stored */ + protected static final int ITEM_BIAS_INDEX = 2; + + public RatingSGDFactorizer(DataModel dataModel, int numFeatures, int numIterations) throws TasteException { + this(dataModel, numFeatures, 0.01, 0.1, 0.01, numIterations, 1.0); + } + + public RatingSGDFactorizer(DataModel dataModel, int numFeatures, double learningRate, double preventOverfitting, + double randomNoise, int numIterations, double learningRateDecay) throws TasteException { + super(dataModel); + this.dataModel = dataModel; + this.numFeatures = numFeatures + FEATURE_OFFSET; + this.numIterations = numIterations; + + this.learningRate = learningRate; + this.learningRateDecay = learningRateDecay; + this.preventOverfitting = preventOverfitting; + this.randomNoise = randomNoise; + } + + protected void prepareTraining() throws TasteException { + RandomWrapper random = RandomUtils.getRandom(); + userVectors = new double[dataModel.getNumUsers()][numFeatures]; + itemVectors = new double[dataModel.getNumItems()][numFeatures]; + + double globalAverage = getAveragePreference(); + for (int userIndex = 0; userIndex < userVectors.length; userIndex++) { + userVectors[userIndex][0] = globalAverage; + userVectors[userIndex][USER_BIAS_INDEX] = 0; // will store user bias + userVectors[userIndex][ITEM_BIAS_INDEX] = 1; // corresponding item feature contains item bias + for (int feature = FEATURE_OFFSET; feature < numFeatures; feature++) { + userVectors[userIndex][feature] = random.nextGaussian() * randomNoise; + } + } + for (int itemIndex = 0; itemIndex < itemVectors.length; itemIndex++) { + itemVectors[itemIndex][0] = 1; // corresponding user feature contains global average + itemVectors[itemIndex][USER_BIAS_INDEX] = 1; // corresponding user feature contains user bias + itemVectors[itemIndex][ITEM_BIAS_INDEX] = 0; // will store item bias + for (int feature = FEATURE_OFFSET; feature < numFeatures; feature++) { + itemVectors[itemIndex][feature] = random.nextGaussian() * randomNoise; + } + } + + cachePreferences(); + shufflePreferences(); + } + + private int countPreferences() throws TasteException { + int numPreferences = 0; + LongPrimitiveIterator userIDs = dataModel.getUserIDs(); + while (userIDs.hasNext()) { + PreferenceArray preferencesFromUser = dataModel.getPreferencesFromUser(userIDs.nextLong()); + numPreferences += preferencesFromUser.length(); + } + return numPreferences; + } + + private void cachePreferences() throws TasteException { + int numPreferences = countPreferences(); + cachedUserIDs = new long[numPreferences]; + cachedItemIDs = new long[numPreferences]; + + LongPrimitiveIterator userIDs = dataModel.getUserIDs(); + int index = 0; + while (userIDs.hasNext()) { + long userID = userIDs.nextLong(); + PreferenceArray preferencesFromUser = dataModel.getPreferencesFromUser(userID); + for (Preference preference : preferencesFromUser) { + cachedUserIDs[index] = userID; + cachedItemIDs[index] = preference.getItemID(); + index++; + } + } + } + + protected void shufflePreferences() { + RandomWrapper random = RandomUtils.getRandom(); + /* Durstenfeld shuffle */ + for (int currentPos = cachedUserIDs.length - 1; currentPos > 0; currentPos--) { + int swapPos = random.nextInt(currentPos + 1); + swapCachedPreferences(currentPos, swapPos); + } + } + + private void swapCachedPreferences(int posA, int posB) { + long tmpUserIndex = cachedUserIDs[posA]; + long tmpItemIndex = cachedItemIDs[posA]; + + cachedUserIDs[posA] = cachedUserIDs[posB]; + cachedItemIDs[posA] = cachedItemIDs[posB]; + + cachedUserIDs[posB] = tmpUserIndex; + cachedItemIDs[posB] = tmpItemIndex; + } + + @Override + public Factorization factorize() throws TasteException { + prepareTraining(); + double currentLearningRate = learningRate; + + + for (int it = 0; it < numIterations; it++) { + for (int index = 0; index < cachedUserIDs.length; index++) { + long userId = cachedUserIDs[index]; + long itemId = cachedItemIDs[index]; + float rating = dataModel.getPreferenceValue(userId, itemId); + updateParameters(userId, itemId, rating, currentLearningRate); + } + currentLearningRate *= learningRateDecay; + } + return createFactorization(userVectors, itemVectors); + } + + double getAveragePreference() throws TasteException { + RunningAverage average = new FullRunningAverage(); + LongPrimitiveIterator it = dataModel.getUserIDs(); + while (it.hasNext()) { + for (Preference pref : dataModel.getPreferencesFromUser(it.nextLong())) { + average.addDatum(pref.getValue()); + } + } + return average.getAverage(); + } + + protected void updateParameters(long userID, long itemID, float rating, double currentLearningRate) { + int userIndex = userIndex(userID); + int itemIndex = itemIndex(itemID); + + double[] userVector = userVectors[userIndex]; + double[] itemVector = itemVectors[itemIndex]; + double prediction = predictRating(userIndex, itemIndex); + double err = rating - prediction; + + // adjust user bias + userVector[USER_BIAS_INDEX] += + biasLearningRate * currentLearningRate * (err - biasReg * preventOverfitting * userVector[USER_BIAS_INDEX]); + + // adjust item bias + itemVector[ITEM_BIAS_INDEX] += + biasLearningRate * currentLearningRate * (err - biasReg * preventOverfitting * itemVector[ITEM_BIAS_INDEX]); + + // adjust features + for (int feature = FEATURE_OFFSET; feature < numFeatures; feature++) { + double userFeature = userVector[feature]; + double itemFeature = itemVector[feature]; + + double deltaUserFeature = err * itemFeature - preventOverfitting * userFeature; + userVector[feature] += currentLearningRate * deltaUserFeature; + + double deltaItemFeature = err * userFeature - preventOverfitting * itemFeature; + itemVector[feature] += currentLearningRate * deltaItemFeature; + } + } + + private double predictRating(int userID, int itemID) { + double sum = 0; + for (int feature = 0; feature < numFeatures; feature++) { + sum += userVectors[userID][feature] * itemVectors[itemID][feature]; + } + return sum; + } +}
