http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDPlusPlusFactorizer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDPlusPlusFactorizer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDPlusPlusFactorizer.java new file mode 100644 index 0000000..20446f8 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDPlusPlusFactorizer.java @@ -0,0 +1,178 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.recommender.svd; + +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.FastIDSet; +import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.common.RandomUtils; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; + +/** + * SVD++, an enhancement of classical matrix factorization for rating prediction. + * Additionally to using ratings (how did people rate?) for learning, this model also takes into account + * who rated what. + * + * Yehuda Koren: Factorization Meets the Neighborhood: a Multifaceted Collaborative Filtering Model, KDD 2008. + * http://research.yahoo.com/files/kdd08koren.pdf + */ +public final class SVDPlusPlusFactorizer extends RatingSGDFactorizer { + + private double[][] p; + private double[][] y; + private Map<Integer, List<Integer>> itemsByUser; + + public SVDPlusPlusFactorizer(DataModel dataModel, int numFeatures, int numIterations) throws TasteException { + this(dataModel, numFeatures, 0.01, 0.1, 0.01, numIterations, 1.0); + biasLearningRate = 0.7; + biasReg = 0.33; + } + + public SVDPlusPlusFactorizer(DataModel dataModel, int numFeatures, double learningRate, double preventOverfitting, + double randomNoise, int numIterations, double learningRateDecay) throws TasteException { + super(dataModel, numFeatures, learningRate, preventOverfitting, randomNoise, numIterations, learningRateDecay); + } + + @Override + protected void prepareTraining() throws TasteException { + super.prepareTraining(); + Random random = RandomUtils.getRandom(); + + p = new double[dataModel.getNumUsers()][numFeatures]; + for (int i = 0; i < p.length; i++) { + for (int feature = 0; feature < FEATURE_OFFSET; feature++) { + p[i][feature] = 0; + } + for (int feature = FEATURE_OFFSET; feature < numFeatures; feature++) { + p[i][feature] = random.nextGaussian() * randomNoise; + } + } + + y = new double[dataModel.getNumItems()][numFeatures]; + for (int i = 0; i < y.length; i++) { + for (int feature = 0; feature < FEATURE_OFFSET; feature++) { + y[i][feature] = 0; + } + for (int feature = FEATURE_OFFSET; feature < numFeatures; feature++) { + y[i][feature] = random.nextGaussian() * randomNoise; + } + } + + /* get internal item IDs which we will need several times */ + itemsByUser = new HashMap<>(); + LongPrimitiveIterator userIDs = dataModel.getUserIDs(); + while (userIDs.hasNext()) { + long userId = userIDs.nextLong(); + int userIndex = userIndex(userId); + FastIDSet itemIDsFromUser = dataModel.getItemIDsFromUser(userId); + List<Integer> itemIndexes = new ArrayList<>(itemIDsFromUser.size()); + itemsByUser.put(userIndex, itemIndexes); + for (long itemID2 : itemIDsFromUser) { + int i2 = itemIndex(itemID2); + itemIndexes.add(i2); + } + } + } + + @Override + public Factorization factorize() throws TasteException { + prepareTraining(); + + super.factorize(); + + for (int userIndex = 0; userIndex < userVectors.length; userIndex++) { + for (int itemIndex : itemsByUser.get(userIndex)) { + for (int feature = FEATURE_OFFSET; feature < numFeatures; feature++) { + userVectors[userIndex][feature] += y[itemIndex][feature]; + } + } + double denominator = Math.sqrt(itemsByUser.get(userIndex).size()); + for (int feature = 0; feature < userVectors[userIndex].length; feature++) { + userVectors[userIndex][feature] = + (float) (userVectors[userIndex][feature] / denominator + p[userIndex][feature]); + } + } + + return createFactorization(userVectors, itemVectors); + } + + + @Override + protected void updateParameters(long userID, long itemID, float rating, double currentLearningRate) { + int userIndex = userIndex(userID); + int itemIndex = itemIndex(itemID); + + double[] userVector = p[userIndex]; + double[] itemVector = itemVectors[itemIndex]; + + double[] pPlusY = new double[numFeatures]; + for (int i2 : itemsByUser.get(userIndex)) { + for (int f = FEATURE_OFFSET; f < numFeatures; f++) { + pPlusY[f] += y[i2][f]; + } + } + double denominator = Math.sqrt(itemsByUser.get(userIndex).size()); + for (int feature = 0; feature < pPlusY.length; feature++) { + pPlusY[feature] = (float) (pPlusY[feature] / denominator + p[userIndex][feature]); + } + + double prediction = predictRating(pPlusY, itemIndex); + double err = rating - prediction; + double normalized_error = err / denominator; + + // adjust user bias + userVector[USER_BIAS_INDEX] += + biasLearningRate * currentLearningRate * (err - biasReg * preventOverfitting * userVector[USER_BIAS_INDEX]); + + // adjust item bias + itemVector[ITEM_BIAS_INDEX] += + biasLearningRate * currentLearningRate * (err - biasReg * preventOverfitting * itemVector[ITEM_BIAS_INDEX]); + + // adjust features + for (int feature = FEATURE_OFFSET; feature < numFeatures; feature++) { + double pF = userVector[feature]; + double iF = itemVector[feature]; + + double deltaU = err * iF - preventOverfitting * pF; + userVector[feature] += currentLearningRate * deltaU; + + double deltaI = err * pPlusY[feature] - preventOverfitting * iF; + itemVector[feature] += currentLearningRate * deltaI; + + double commonUpdate = normalized_error * iF; + for (int itemIndex2 : itemsByUser.get(userIndex)) { + double deltaI2 = commonUpdate - preventOverfitting * y[itemIndex2][feature]; + y[itemIndex2][feature] += learningRate * deltaI2; + } + } + } + + private double predictRating(double[] userVector, int itemID) { + double sum = 0; + for (int feature = 0; feature < numFeatures; feature++) { + sum += userVector[feature] * itemVectors[itemID][feature]; + } + return sum; + } +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDPreference.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDPreference.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDPreference.java new file mode 100644 index 0000000..45c54da --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDPreference.java @@ -0,0 +1,41 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.recommender.svd; + +import com.google.common.base.Preconditions; +import org.apache.mahout.cf.taste.impl.model.GenericPreference; + +final class SVDPreference extends GenericPreference { + + private double cache; + + SVDPreference(long userID, long itemID, float value, double cache) { + super(userID, itemID, value); + setCache(cache); + } + + public double getCache() { + return cache; + } + + public void setCache(double value) { + Preconditions.checkArgument(!Double.isNaN(value), "NaN cache value"); + this.cache = value; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommender.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommender.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommender.java new file mode 100644 index 0000000..45d4af7 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommender.java @@ -0,0 +1,185 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.recommender.svd; + +import java.io.IOException; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.Callable; + +import com.google.common.base.Preconditions; +import org.apache.mahout.cf.taste.common.Refreshable; +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.FastIDSet; +import org.apache.mahout.cf.taste.impl.common.RefreshHelper; +import org.apache.mahout.cf.taste.impl.recommender.AbstractRecommender; +import org.apache.mahout.cf.taste.impl.recommender.AllUnknownItemsCandidateItemsStrategy; +import org.apache.mahout.cf.taste.impl.recommender.TopItems; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.model.PreferenceArray; +import org.apache.mahout.cf.taste.recommender.CandidateItemsStrategy; +import org.apache.mahout.cf.taste.recommender.IDRescorer; +import org.apache.mahout.cf.taste.recommender.RecommendedItem; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A {@link org.apache.mahout.cf.taste.recommender.Recommender} that uses matrix factorization (a projection of users + * and items onto a feature space) + */ +public final class SVDRecommender extends AbstractRecommender { + + private Factorization factorization; + private final Factorizer factorizer; + private final PersistenceStrategy persistenceStrategy; + private final RefreshHelper refreshHelper; + + private static final Logger log = LoggerFactory.getLogger(SVDRecommender.class); + + public SVDRecommender(DataModel dataModel, Factorizer factorizer) throws TasteException { + this(dataModel, factorizer, new AllUnknownItemsCandidateItemsStrategy(), getDefaultPersistenceStrategy()); + } + + public SVDRecommender(DataModel dataModel, Factorizer factorizer, CandidateItemsStrategy candidateItemsStrategy) + throws TasteException { + this(dataModel, factorizer, candidateItemsStrategy, getDefaultPersistenceStrategy()); + } + + /** + * Create an SVDRecommender using a persistent store to cache factorizations. A factorization is loaded from the + * store if present, otherwise a new factorization is computed and saved in the store. + * + * The {@link #refresh(java.util.Collection) refresh} method recomputes the factorization and overwrites the store. + * + * @param dataModel + * @param factorizer + * @param persistenceStrategy + * @throws TasteException + * @throws IOException + */ + public SVDRecommender(DataModel dataModel, Factorizer factorizer, PersistenceStrategy persistenceStrategy) + throws TasteException { + this(dataModel, factorizer, getDefaultCandidateItemsStrategy(), persistenceStrategy); + } + + /** + * Create an SVDRecommender using a persistent store to cache factorizations. A factorization is loaded from the + * store if present, otherwise a new factorization is computed and saved in the store. + * + * The {@link #refresh(java.util.Collection) refresh} method recomputes the factorization and overwrites the store. + * + * @param dataModel + * @param factorizer + * @param candidateItemsStrategy + * @param persistenceStrategy + * + * @throws TasteException + */ + public SVDRecommender(DataModel dataModel, Factorizer factorizer, CandidateItemsStrategy candidateItemsStrategy, + PersistenceStrategy persistenceStrategy) throws TasteException { + super(dataModel, candidateItemsStrategy); + this.factorizer = Preconditions.checkNotNull(factorizer); + this.persistenceStrategy = Preconditions.checkNotNull(persistenceStrategy); + try { + factorization = persistenceStrategy.load(); + } catch (IOException e) { + throw new TasteException("Error loading factorization", e); + } + + if (factorization == null) { + train(); + } + + refreshHelper = new RefreshHelper(new Callable<Object>() { + @Override + public Object call() throws TasteException { + train(); + return null; + } + }); + refreshHelper.addDependency(getDataModel()); + refreshHelper.addDependency(factorizer); + refreshHelper.addDependency(candidateItemsStrategy); + } + + static PersistenceStrategy getDefaultPersistenceStrategy() { + return new NoPersistenceStrategy(); + } + + private void train() throws TasteException { + factorization = factorizer.factorize(); + try { + persistenceStrategy.maybePersist(factorization); + } catch (IOException e) { + throw new TasteException("Error persisting factorization", e); + } + } + + @Override + public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer, boolean includeKnownItems) + throws TasteException { + Preconditions.checkArgument(howMany >= 1, "howMany must be at least 1"); + log.debug("Recommending items for user ID '{}'", userID); + + PreferenceArray preferencesFromUser = getDataModel().getPreferencesFromUser(userID); + FastIDSet possibleItemIDs = getAllOtherItems(userID, preferencesFromUser, includeKnownItems); + + List<RecommendedItem> topItems = TopItems.getTopItems(howMany, possibleItemIDs.iterator(), rescorer, + new Estimator(userID)); + log.debug("Recommendations are: {}", topItems); + + return topItems; + } + + /** + * a preference is estimated by computing the dot-product of the user and item feature vectors + */ + @Override + public float estimatePreference(long userID, long itemID) throws TasteException { + double[] userFeatures = factorization.getUserFeatures(userID); + double[] itemFeatures = factorization.getItemFeatures(itemID); + double estimate = 0; + for (int feature = 0; feature < userFeatures.length; feature++) { + estimate += userFeatures[feature] * itemFeatures[feature]; + } + return (float) estimate; + } + + private final class Estimator implements TopItems.Estimator<Long> { + + private final long theUserID; + + private Estimator(long theUserID) { + this.theUserID = theUserID; + } + + @Override + public double estimate(Long itemID) throws TasteException { + return estimatePreference(theUserID, itemID); + } + } + + /** + * Refresh the data model and factorization. + */ + @Override + public void refresh(Collection<Refreshable> alreadyRefreshed) { + refreshHelper.refresh(alreadyRefreshed); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/AbstractItemSimilarity.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/AbstractItemSimilarity.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/AbstractItemSimilarity.java new file mode 100644 index 0000000..e0d6f59 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/AbstractItemSimilarity.java @@ -0,0 +1,64 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.similarity; + +import com.google.common.base.Preconditions; +import org.apache.mahout.cf.taste.common.Refreshable; +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.FastIDSet; +import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; +import org.apache.mahout.cf.taste.impl.common.RefreshHelper; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.similarity.ItemSimilarity; + +import java.util.Collection; + +public abstract class AbstractItemSimilarity implements ItemSimilarity { + + private final DataModel dataModel; + private final RefreshHelper refreshHelper; + + protected AbstractItemSimilarity(DataModel dataModel) { + Preconditions.checkArgument(dataModel != null, "dataModel is null"); + this.dataModel = dataModel; + this.refreshHelper = new RefreshHelper(null); + refreshHelper.addDependency(this.dataModel); + } + + protected DataModel getDataModel() { + return dataModel; + } + + @Override + public long[] allSimilarItemIDs(long itemID) throws TasteException { + FastIDSet allSimilarItemIDs = new FastIDSet(); + LongPrimitiveIterator allItemIDs = dataModel.getItemIDs(); + while (allItemIDs.hasNext()) { + long possiblySimilarItemID = allItemIDs.nextLong(); + if (!Double.isNaN(itemSimilarity(itemID, possiblySimilarItemID))) { + allSimilarItemIDs.add(possiblySimilarItemID); + } + } + return allSimilarItemIDs.toArray(); + } + + @Override + public void refresh(Collection<Refreshable> alreadyRefreshed) { + refreshHelper.refresh(alreadyRefreshed); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/AbstractSimilarity.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/AbstractSimilarity.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/AbstractSimilarity.java new file mode 100644 index 0000000..59c30d9 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/AbstractSimilarity.java @@ -0,0 +1,343 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.similarity; + +import java.util.Collection; +import java.util.concurrent.Callable; + +import org.apache.mahout.cf.taste.common.Refreshable; +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.common.Weighting; +import org.apache.mahout.cf.taste.impl.common.RefreshHelper; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.model.PreferenceArray; +import org.apache.mahout.cf.taste.similarity.PreferenceInferrer; +import org.apache.mahout.cf.taste.similarity.UserSimilarity; + +import com.google.common.base.Preconditions; + +/** Abstract superclass encapsulating functionality that is common to most implementations in this package. */ +abstract class AbstractSimilarity extends AbstractItemSimilarity implements UserSimilarity { + + private PreferenceInferrer inferrer; + private final boolean weighted; + private final boolean centerData; + private int cachedNumItems; + private int cachedNumUsers; + private final RefreshHelper refreshHelper; + + /** + * <p> + * Creates a possibly weighted {@link AbstractSimilarity}. + * </p> + */ + AbstractSimilarity(final DataModel dataModel, Weighting weighting, boolean centerData) throws TasteException { + super(dataModel); + this.weighted = weighting == Weighting.WEIGHTED; + this.centerData = centerData; + this.cachedNumItems = dataModel.getNumItems(); + this.cachedNumUsers = dataModel.getNumUsers(); + this.refreshHelper = new RefreshHelper(new Callable<Object>() { + @Override + public Object call() throws TasteException { + cachedNumItems = dataModel.getNumItems(); + cachedNumUsers = dataModel.getNumUsers(); + return null; + } + }); + } + + final PreferenceInferrer getPreferenceInferrer() { + return inferrer; + } + + @Override + public final void setPreferenceInferrer(PreferenceInferrer inferrer) { + Preconditions.checkArgument(inferrer != null, "inferrer is null"); + refreshHelper.addDependency(inferrer); + refreshHelper.removeDependency(this.inferrer); + this.inferrer = inferrer; + } + + final boolean isWeighted() { + return weighted; + } + + /** + * <p> + * Several subclasses in this package implement this method to actually compute the similarity from figures + * computed over users or items. Note that the computations in this class "center" the data, such that X and + * Y's mean are 0. + * </p> + * + * <p> + * Note that the sum of all X and Y values must then be 0. This value isn't passed down into the standard + * similarity computations as a result. + * </p> + * + * @param n + * total number of users or items + * @param sumXY + * sum of product of user/item preference values, over all items/users preferred by both + * users/items + * @param sumX2 + * sum of the square of user/item preference values, over the first item/user + * @param sumY2 + * sum of the square of the user/item preference values, over the second item/user + * @param sumXYdiff2 + * sum of squares of differences in X and Y values + * @return similarity value between -1.0 and 1.0, inclusive, or {@link Double#NaN} if no similarity can be + * computed (e.g. when no items have been rated by both users + */ + abstract double computeResult(int n, double sumXY, double sumX2, double sumY2, double sumXYdiff2); + + @Override + public double userSimilarity(long userID1, long userID2) throws TasteException { + DataModel dataModel = getDataModel(); + PreferenceArray xPrefs = dataModel.getPreferencesFromUser(userID1); + PreferenceArray yPrefs = dataModel.getPreferencesFromUser(userID2); + int xLength = xPrefs.length(); + int yLength = yPrefs.length(); + + if (xLength == 0 || yLength == 0) { + return Double.NaN; + } + + long xIndex = xPrefs.getItemID(0); + long yIndex = yPrefs.getItemID(0); + int xPrefIndex = 0; + int yPrefIndex = 0; + + double sumX = 0.0; + double sumX2 = 0.0; + double sumY = 0.0; + double sumY2 = 0.0; + double sumXY = 0.0; + double sumXYdiff2 = 0.0; + int count = 0; + + boolean hasInferrer = inferrer != null; + + while (true) { + int compare = xIndex < yIndex ? -1 : xIndex > yIndex ? 1 : 0; + if (hasInferrer || compare == 0) { + double x; + double y; + if (xIndex == yIndex) { + // Both users expressed a preference for the item + x = xPrefs.getValue(xPrefIndex); + y = yPrefs.getValue(yPrefIndex); + } else { + // Only one user expressed a preference, but infer the other one's preference and tally + // as if the other user expressed that preference + if (compare < 0) { + // X has a value; infer Y's + x = xPrefs.getValue(xPrefIndex); + y = inferrer.inferPreference(userID2, xIndex); + } else { + // compare > 0 + // Y has a value; infer X's + x = inferrer.inferPreference(userID1, yIndex); + y = yPrefs.getValue(yPrefIndex); + } + } + sumXY += x * y; + sumX += x; + sumX2 += x * x; + sumY += y; + sumY2 += y * y; + double diff = x - y; + sumXYdiff2 += diff * diff; + count++; + } + if (compare <= 0) { + if (++xPrefIndex >= xLength) { + if (hasInferrer) { + // Must count other Ys; pretend next X is far away + if (yIndex == Long.MAX_VALUE) { + // ... but stop if both are done! + break; + } + xIndex = Long.MAX_VALUE; + } else { + break; + } + } else { + xIndex = xPrefs.getItemID(xPrefIndex); + } + } + if (compare >= 0) { + if (++yPrefIndex >= yLength) { + if (hasInferrer) { + // Must count other Xs; pretend next Y is far away + if (xIndex == Long.MAX_VALUE) { + // ... but stop if both are done! + break; + } + yIndex = Long.MAX_VALUE; + } else { + break; + } + } else { + yIndex = yPrefs.getItemID(yPrefIndex); + } + } + } + + // "Center" the data. If my math is correct, this'll do it. + double result; + if (centerData) { + double meanX = sumX / count; + double meanY = sumY / count; + // double centeredSumXY = sumXY - meanY * sumX - meanX * sumY + n * meanX * meanY; + double centeredSumXY = sumXY - meanY * sumX; + // double centeredSumX2 = sumX2 - 2.0 * meanX * sumX + n * meanX * meanX; + double centeredSumX2 = sumX2 - meanX * sumX; + // double centeredSumY2 = sumY2 - 2.0 * meanY * sumY + n * meanY * meanY; + double centeredSumY2 = sumY2 - meanY * sumY; + result = computeResult(count, centeredSumXY, centeredSumX2, centeredSumY2, sumXYdiff2); + } else { + result = computeResult(count, sumXY, sumX2, sumY2, sumXYdiff2); + } + + if (!Double.isNaN(result)) { + result = normalizeWeightResult(result, count, cachedNumItems); + } + return result; + } + + @Override + public final double itemSimilarity(long itemID1, long itemID2) throws TasteException { + DataModel dataModel = getDataModel(); + PreferenceArray xPrefs = dataModel.getPreferencesForItem(itemID1); + PreferenceArray yPrefs = dataModel.getPreferencesForItem(itemID2); + int xLength = xPrefs.length(); + int yLength = yPrefs.length(); + + if (xLength == 0 || yLength == 0) { + return Double.NaN; + } + + long xIndex = xPrefs.getUserID(0); + long yIndex = yPrefs.getUserID(0); + int xPrefIndex = 0; + int yPrefIndex = 0; + + double sumX = 0.0; + double sumX2 = 0.0; + double sumY = 0.0; + double sumY2 = 0.0; + double sumXY = 0.0; + double sumXYdiff2 = 0.0; + int count = 0; + + // No, pref inferrers and transforms don't apply here. I think. + + while (true) { + int compare = xIndex < yIndex ? -1 : xIndex > yIndex ? 1 : 0; + if (compare == 0) { + // Both users expressed a preference for the item + double x = xPrefs.getValue(xPrefIndex); + double y = yPrefs.getValue(yPrefIndex); + sumXY += x * y; + sumX += x; + sumX2 += x * x; + sumY += y; + sumY2 += y * y; + double diff = x - y; + sumXYdiff2 += diff * diff; + count++; + } + if (compare <= 0) { + if (++xPrefIndex == xLength) { + break; + } + xIndex = xPrefs.getUserID(xPrefIndex); + } + if (compare >= 0) { + if (++yPrefIndex == yLength) { + break; + } + yIndex = yPrefs.getUserID(yPrefIndex); + } + } + + double result; + if (centerData) { + // See comments above on these computations + double n = (double) count; + double meanX = sumX / n; + double meanY = sumY / n; + // double centeredSumXY = sumXY - meanY * sumX - meanX * sumY + n * meanX * meanY; + double centeredSumXY = sumXY - meanY * sumX; + // double centeredSumX2 = sumX2 - 2.0 * meanX * sumX + n * meanX * meanX; + double centeredSumX2 = sumX2 - meanX * sumX; + // double centeredSumY2 = sumY2 - 2.0 * meanY * sumY + n * meanY * meanY; + double centeredSumY2 = sumY2 - meanY * sumY; + result = computeResult(count, centeredSumXY, centeredSumX2, centeredSumY2, sumXYdiff2); + } else { + result = computeResult(count, sumXY, sumX2, sumY2, sumXYdiff2); + } + + if (!Double.isNaN(result)) { + result = normalizeWeightResult(result, count, cachedNumUsers); + } + return result; + } + + @Override + public double[] itemSimilarities(long itemID1, long[] itemID2s) throws TasteException { + int length = itemID2s.length; + double[] result = new double[length]; + for (int i = 0; i < length; i++) { + result[i] = itemSimilarity(itemID1, itemID2s[i]); + } + return result; + } + + final double normalizeWeightResult(double result, int count, int num) { + double normalizedResult = result; + if (weighted) { + double scaleFactor = 1.0 - (double) count / (double) (num + 1); + if (normalizedResult < 0.0) { + normalizedResult = -1.0 + scaleFactor * (1.0 + normalizedResult); + } else { + normalizedResult = 1.0 - scaleFactor * (1.0 - normalizedResult); + } + } + // Make sure the result is not accidentally a little outside [-1.0, 1.0] due to rounding: + if (normalizedResult < -1.0) { + normalizedResult = -1.0; + } else if (normalizedResult > 1.0) { + normalizedResult = 1.0; + } + return normalizedResult; + } + + @Override + public final void refresh(Collection<Refreshable> alreadyRefreshed) { + super.refresh(alreadyRefreshed); + refreshHelper.refresh(alreadyRefreshed); + } + + @Override + public final String toString() { + return this.getClass().getSimpleName() + "[dataModel:" + getDataModel() + ",inferrer:" + inferrer + ']'; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/AveragingPreferenceInferrer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/AveragingPreferenceInferrer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/AveragingPreferenceInferrer.java new file mode 100644 index 0000000..7c655fe --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/AveragingPreferenceInferrer.java @@ -0,0 +1,85 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.similarity; + +import java.util.Collection; + +import org.apache.mahout.cf.taste.common.Refreshable; +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.Cache; +import org.apache.mahout.cf.taste.impl.common.FullRunningAverage; +import org.apache.mahout.cf.taste.impl.common.Retriever; +import org.apache.mahout.cf.taste.impl.common.RunningAverage; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.model.PreferenceArray; +import org.apache.mahout.cf.taste.similarity.PreferenceInferrer; + +/** + * <p> + * Implementations of this interface compute an inferred preference for a user and an item that the user has + * not expressed any preference for. This might be an average of other preferences scores from that user, for + * example. This technique is sometimes called "default voting". + * </p> + */ +public final class AveragingPreferenceInferrer implements PreferenceInferrer { + + private static final Float ZERO = 0.0f; + + private final DataModel dataModel; + private final Cache<Long,Float> averagePreferenceValue; + + public AveragingPreferenceInferrer(DataModel dataModel) throws TasteException { + this.dataModel = dataModel; + Retriever<Long,Float> retriever = new PrefRetriever(); + averagePreferenceValue = new Cache<>(retriever, dataModel.getNumUsers()); + refresh(null); + } + + @Override + public float inferPreference(long userID, long itemID) throws TasteException { + return averagePreferenceValue.get(userID); + } + + @Override + public void refresh(Collection<Refreshable> alreadyRefreshed) { + averagePreferenceValue.clear(); + } + + private final class PrefRetriever implements Retriever<Long,Float> { + + @Override + public Float get(Long key) throws TasteException { + PreferenceArray prefs = dataModel.getPreferencesFromUser(key); + int size = prefs.length(); + if (size == 0) { + return ZERO; + } + RunningAverage average = new FullRunningAverage(); + for (int i = 0; i < size; i++) { + average.addDatum(prefs.getValue(i)); + } + return (float) average.getAverage(); + } + } + + @Override + public String toString() { + return "AveragingPreferenceInferrer"; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/CachingItemSimilarity.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/CachingItemSimilarity.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/CachingItemSimilarity.java new file mode 100644 index 0000000..87aeae9 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/CachingItemSimilarity.java @@ -0,0 +1,111 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.similarity; + +import java.util.Collection; +import java.util.concurrent.Callable; + +import org.apache.mahout.cf.taste.common.Refreshable; +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.Cache; +import org.apache.mahout.cf.taste.impl.common.RefreshHelper; +import org.apache.mahout.cf.taste.impl.common.Retriever; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.similarity.ItemSimilarity; +import org.apache.mahout.common.LongPair; +import com.google.common.base.Preconditions; + +/** + * Caches the results from an underlying {@link ItemSimilarity} implementation. + */ +public final class CachingItemSimilarity implements ItemSimilarity { + + private final ItemSimilarity similarity; + private final Cache<LongPair,Double> similarityCache; + private final RefreshHelper refreshHelper; + + /** + * Creates this on top of the given {@link ItemSimilarity}. + * The cache is sized according to properties of the given {@link DataModel}. + */ + public CachingItemSimilarity(ItemSimilarity similarity, DataModel dataModel) throws TasteException { + this(similarity, dataModel.getNumItems()); + } + + /** + * Creates this on top of the given {@link ItemSimilarity}. + * The cache size is capped by the given size. + */ + public CachingItemSimilarity(ItemSimilarity similarity, int maxCacheSize) { + Preconditions.checkArgument(similarity != null, "similarity is null"); + this.similarity = similarity; + this.similarityCache = new Cache<>(new SimilarityRetriever(similarity), maxCacheSize); + this.refreshHelper = new RefreshHelper(new Callable<Void>() { + @Override + public Void call() { + similarityCache.clear(); + return null; + } + }); + refreshHelper.addDependency(similarity); + } + + @Override + public double itemSimilarity(long itemID1, long itemID2) throws TasteException { + LongPair key = itemID1 < itemID2 ? new LongPair(itemID1, itemID2) : new LongPair(itemID2, itemID1); + return similarityCache.get(key); + } + + @Override + public double[] itemSimilarities(long itemID1, long[] itemID2s) throws TasteException { + int length = itemID2s.length; + double[] result = new double[length]; + for (int i = 0; i < length; i++) { + result[i] = itemSimilarity(itemID1, itemID2s[i]); + } + return result; + } + + @Override + public long[] allSimilarItemIDs(long itemID) throws TasteException { + return similarity.allSimilarItemIDs(itemID); + } + + @Override + public void refresh(Collection<Refreshable> alreadyRefreshed) { + refreshHelper.refresh(alreadyRefreshed); + } + + public void clearCacheForItem(long itemID) { + similarityCache.removeKeysMatching(new LongPairMatchPredicate(itemID)); + } + + private static final class SimilarityRetriever implements Retriever<LongPair,Double> { + private final ItemSimilarity similarity; + + private SimilarityRetriever(ItemSimilarity similarity) { + this.similarity = similarity; + } + + @Override + public Double get(LongPair key) throws TasteException { + return similarity.itemSimilarity(key.getFirst(), key.getSecond()); + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/CachingUserSimilarity.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/CachingUserSimilarity.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/CachingUserSimilarity.java new file mode 100644 index 0000000..873568a --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/CachingUserSimilarity.java @@ -0,0 +1,104 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.similarity; + +import java.util.Collection; +import java.util.concurrent.Callable; + +import org.apache.mahout.cf.taste.common.Refreshable; +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.Cache; +import org.apache.mahout.cf.taste.impl.common.RefreshHelper; +import org.apache.mahout.cf.taste.impl.common.Retriever; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.similarity.PreferenceInferrer; +import org.apache.mahout.cf.taste.similarity.UserSimilarity; +import org.apache.mahout.common.LongPair; + +import com.google.common.base.Preconditions; + +/** + * Caches the results from an underlying {@link UserSimilarity} implementation. + */ +public final class CachingUserSimilarity implements UserSimilarity { + + private final UserSimilarity similarity; + private final Cache<LongPair,Double> similarityCache; + private final RefreshHelper refreshHelper; + + /** + * Creates this on top of the given {@link UserSimilarity}. + * The cache is sized according to properties of the given {@link DataModel}. + */ + public CachingUserSimilarity(UserSimilarity similarity, DataModel dataModel) throws TasteException { + this(similarity, dataModel.getNumUsers()); + } + + /** + * Creates this on top of the given {@link UserSimilarity}. + * The cache size is capped by the given size. + */ + public CachingUserSimilarity(UserSimilarity similarity, int maxCacheSize) { + Preconditions.checkArgument(similarity != null, "similarity is null"); + this.similarity = similarity; + this.similarityCache = new Cache<>(new SimilarityRetriever(similarity), maxCacheSize); + this.refreshHelper = new RefreshHelper(new Callable<Void>() { + @Override + public Void call() { + similarityCache.clear(); + return null; + } + }); + refreshHelper.addDependency(similarity); + } + + @Override + public double userSimilarity(long userID1, long userID2) throws TasteException { + LongPair key = userID1 < userID2 ? new LongPair(userID1, userID2) : new LongPair(userID2, userID1); + return similarityCache.get(key); + } + + @Override + public void setPreferenceInferrer(PreferenceInferrer inferrer) { + similarityCache.clear(); + similarity.setPreferenceInferrer(inferrer); + } + + public void clearCacheForUser(long userID) { + similarityCache.removeKeysMatching(new LongPairMatchPredicate(userID)); + } + + @Override + public void refresh(Collection<Refreshable> alreadyRefreshed) { + refreshHelper.refresh(alreadyRefreshed); + } + + private static final class SimilarityRetriever implements Retriever<LongPair,Double> { + private final UserSimilarity similarity; + + private SimilarityRetriever(UserSimilarity similarity) { + this.similarity = similarity; + } + + @Override + public Double get(LongPair key) throws TasteException { + return similarity.userSimilarity(key.getFirst(), key.getSecond()); + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/CityBlockSimilarity.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/CityBlockSimilarity.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/CityBlockSimilarity.java new file mode 100644 index 0000000..88fbe58 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/CityBlockSimilarity.java @@ -0,0 +1,98 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.cf.taste.impl.similarity; + +import java.util.Collection; + +import org.apache.mahout.cf.taste.common.Refreshable; +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.FastIDSet; +import org.apache.mahout.cf.taste.impl.common.RefreshHelper; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.similarity.PreferenceInferrer; +import org.apache.mahout.cf.taste.similarity.UserSimilarity; + +/** + * Implementation of City Block distance (also known as Manhattan distance) - the absolute value of the difference of + * each direction is summed. The resulting unbounded distance is then mapped between 0 and 1. + */ +public final class CityBlockSimilarity extends AbstractItemSimilarity implements UserSimilarity { + + public CityBlockSimilarity(DataModel dataModel) { + super(dataModel); + } + + /** + * @throws UnsupportedOperationException + */ + @Override + public void setPreferenceInferrer(PreferenceInferrer inferrer) { + throw new UnsupportedOperationException(); + } + + @Override + public void refresh(Collection<Refreshable> alreadyRefreshed) { + Collection<Refreshable> refreshed = RefreshHelper.buildRefreshed(alreadyRefreshed); + RefreshHelper.maybeRefresh(refreshed, getDataModel()); + } + + @Override + public double itemSimilarity(long itemID1, long itemID2) throws TasteException { + DataModel dataModel = getDataModel(); + int preferring1 = dataModel.getNumUsersWithPreferenceFor(itemID1); + int preferring2 = dataModel.getNumUsersWithPreferenceFor(itemID2); + int intersection = dataModel.getNumUsersWithPreferenceFor(itemID1, itemID2); + return doSimilarity(preferring1, preferring2, intersection); + } + + @Override + public double[] itemSimilarities(long itemID1, long[] itemID2s) throws TasteException { + DataModel dataModel = getDataModel(); + int preferring1 = dataModel.getNumUsersWithPreferenceFor(itemID1); + double[] distance = new double[itemID2s.length]; + for (int i = 0; i < itemID2s.length; ++i) { + int preferring2 = dataModel.getNumUsersWithPreferenceFor(itemID2s[i]); + int intersection = dataModel.getNumUsersWithPreferenceFor(itemID1, itemID2s[i]); + distance[i] = doSimilarity(preferring1, preferring2, intersection); + } + return distance; + } + + @Override + public double userSimilarity(long userID1, long userID2) throws TasteException { + DataModel dataModel = getDataModel(); + FastIDSet prefs1 = dataModel.getItemIDsFromUser(userID1); + FastIDSet prefs2 = dataModel.getItemIDsFromUser(userID2); + int prefs1Size = prefs1.size(); + int prefs2Size = prefs2.size(); + int intersectionSize = prefs1Size < prefs2Size ? prefs2.intersectionSize(prefs1) : prefs1.intersectionSize(prefs2); + return doSimilarity(prefs1Size, prefs2Size, intersectionSize); + } + + /** + * Calculate City Block Distance from total non-zero values and intersections and map to a similarity value. + * + * @param pref1 number of non-zero values in left vector + * @param pref2 number of non-zero values in right vector + * @param intersection number of overlapping non-zero values + */ + private static double doSimilarity(int pref1, int pref2, int intersection) { + int distance = pref1 + pref2 - 2 * intersection; + return 1.0 / (1.0 + distance); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/EuclideanDistanceSimilarity.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/EuclideanDistanceSimilarity.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/EuclideanDistanceSimilarity.java new file mode 100644 index 0000000..990e9ea --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/EuclideanDistanceSimilarity.java @@ -0,0 +1,67 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.similarity; + +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.common.Weighting; +import org.apache.mahout.cf.taste.model.DataModel; + +import com.google.common.base.Preconditions; + +/** + * <p> + * An implementation of a "similarity" based on the Euclidean "distance" between two users X and Y. Thinking + * of items as dimensions and preferences as points along those dimensions, a distance is computed using all + * items (dimensions) where both users have expressed a preference for that item. This is simply the square + * root of the sum of the squares of differences in position (preference) along each dimension.</p> + * + * <p>The similarity could be computed as 1 / (1 + distance / sqrt(n)), so the resulting values are in the range (0,1]. + * This would weight against pairs that overlap in more dimensions, which should indicate more similarity, + * since more dimensions offer more opportunities to be farther apart. Actually, it is computed as + * sqrt(n) / (1 + distance), where n is the number of dimensions, in order to help correct for this. + * sqrt(n) is chosen since randomly-chosen points have a distance that grows as sqrt(n).</p> + * + * <p>Note that this could cause a similarity to exceed 1; such values are capped at 1.</p> + * + * <p>Note that the distance isn't normalized in any way; it's not valid to compare similarities computed from + * different domains (different rating scales, for example). Within one domain, normalizing doesn't matter much as + * it doesn't change ordering.</p> + */ +public final class EuclideanDistanceSimilarity extends AbstractSimilarity { + + /** + * @throws IllegalArgumentException if {@link DataModel} does not have preference values + */ + public EuclideanDistanceSimilarity(DataModel dataModel) throws TasteException { + this(dataModel, Weighting.UNWEIGHTED); + } + + /** + * @throws IllegalArgumentException if {@link DataModel} does not have preference values + */ + public EuclideanDistanceSimilarity(DataModel dataModel, Weighting weighting) throws TasteException { + super(dataModel, weighting, false); + Preconditions.checkArgument(dataModel.hasPreferenceValues(), "DataModel doesn't have preference values"); + } + + @Override + double computeResult(int n, double sumXY, double sumX2, double sumY2, double sumXYdiff2) { + return 1.0 / (1.0 + Math.sqrt(sumXYdiff2) / Math.sqrt(n)); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/GenericItemSimilarity.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/GenericItemSimilarity.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/GenericItemSimilarity.java new file mode 100644 index 0000000..d0c9b8c --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/GenericItemSimilarity.java @@ -0,0 +1,358 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.similarity; + +import java.util.Collection; +import java.util.Iterator; + +import com.google.common.collect.AbstractIterator; +import org.apache.mahout.cf.taste.common.Refreshable; +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.FastByIDMap; +import org.apache.mahout.cf.taste.impl.common.FastIDSet; +import org.apache.mahout.cf.taste.impl.recommender.TopItems; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.similarity.ItemSimilarity; +import org.apache.mahout.common.RandomUtils; + +import com.google.common.base.Preconditions; + +/** + * <p> + * A "generic" {@link ItemSimilarity} which takes a static list of precomputed item similarities and bases its + * responses on that alone. The values may have been precomputed offline by another process, stored in a file, + * and then read and fed into an instance of this class. + * </p> + * + * <p> + * This is perhaps the best {@link ItemSimilarity} to use with + * {@link org.apache.mahout.cf.taste.impl.recommender.GenericItemBasedRecommender}, for now, since the point + * of item-based recommenders is that they can take advantage of the fact that item similarity is relatively + * static, can be precomputed, and then used in computation to gain a significant performance advantage. + * </p> + */ +public final class GenericItemSimilarity implements ItemSimilarity { + + private static final long[] NO_IDS = new long[0]; + + private final FastByIDMap<FastByIDMap<Double>> similarityMaps = new FastByIDMap<>(); + private final FastByIDMap<FastIDSet> similarItemIDsIndex = new FastByIDMap<>(); + + /** + * <p> + * Creates a {@link GenericItemSimilarity} from a precomputed list of {@link ItemItemSimilarity}s. Each + * represents the similarity between two distinct items. Since similarity is assumed to be symmetric, it is + * not necessary to specify similarity between item1 and item2, and item2 and item1. Both are the same. It + * is also not necessary to specify a similarity between any item and itself; these are assumed to be 1.0. + * </p> + * + * <p> + * Note that specifying a similarity between two items twice is not an error, but, the later value will win. + * </p> + * + * @param similarities + * set of {@link ItemItemSimilarity}s on which to base this instance + */ + public GenericItemSimilarity(Iterable<ItemItemSimilarity> similarities) { + initSimilarityMaps(similarities.iterator()); + } + + /** + * <p> + * Like {@link #GenericItemSimilarity(Iterable)}, but will only keep the specified number of similarities + * from the given {@link Iterable} of similarities. It will keep those with the highest similarity -- those + * that are therefore most important. + * </p> + * + * <p> + * Thanks to tsmorton for suggesting this and providing part of the implementation. + * </p> + * + * @param similarities + * set of {@link ItemItemSimilarity}s on which to base this instance + * @param maxToKeep + * maximum number of similarities to keep + */ + public GenericItemSimilarity(Iterable<ItemItemSimilarity> similarities, int maxToKeep) { + Iterable<ItemItemSimilarity> keptSimilarities = + TopItems.getTopItemItemSimilarities(maxToKeep, similarities.iterator()); + initSimilarityMaps(keptSimilarities.iterator()); + } + + /** + * <p> + * Builds a list of item-item similarities given an {@link ItemSimilarity} implementation and a + * {@link DataModel}, rather than a list of {@link ItemItemSimilarity}s. + * </p> + * + * <p> + * It's valid to build a {@link GenericItemSimilarity} this way, but perhaps missing some of the point of an + * item-based recommender. Item-based recommenders use the assumption that item-item similarities are + * relatively fixed, and might be known already independent of user preferences. Hence it is useful to + * inject that information, using {@link #GenericItemSimilarity(Iterable)}. + * </p> + * + * @param otherSimilarity + * other {@link ItemSimilarity} to get similarities from + * @param dataModel + * data model to get items from + * @throws TasteException + * if an error occurs while accessing the {@link DataModel} items + */ + public GenericItemSimilarity(ItemSimilarity otherSimilarity, DataModel dataModel) throws TasteException { + long[] itemIDs = GenericUserSimilarity.longIteratorToList(dataModel.getItemIDs()); + initSimilarityMaps(new DataModelSimilaritiesIterator(otherSimilarity, itemIDs)); + } + + /** + * <p> + * Like {@link #GenericItemSimilarity(ItemSimilarity, DataModel)} )}, but will only keep the specified + * number of similarities from the given {@link DataModel}. It will keep those with the highest similarity + * -- those that are therefore most important. + * </p> + * + * <p> + * Thanks to tsmorton for suggesting this and providing part of the implementation. + * </p> + * + * @param otherSimilarity + * other {@link ItemSimilarity} to get similarities from + * @param dataModel + * data model to get items from + * @param maxToKeep + * maximum number of similarities to keep + * @throws TasteException + * if an error occurs while accessing the {@link DataModel} items + */ + public GenericItemSimilarity(ItemSimilarity otherSimilarity, + DataModel dataModel, + int maxToKeep) throws TasteException { + long[] itemIDs = GenericUserSimilarity.longIteratorToList(dataModel.getItemIDs()); + Iterator<ItemItemSimilarity> it = new DataModelSimilaritiesIterator(otherSimilarity, itemIDs); + Iterable<ItemItemSimilarity> keptSimilarities = TopItems.getTopItemItemSimilarities(maxToKeep, it); + initSimilarityMaps(keptSimilarities.iterator()); + } + + private void initSimilarityMaps(Iterator<ItemItemSimilarity> similarities) { + while (similarities.hasNext()) { + ItemItemSimilarity iic = similarities.next(); + long similarityItemID1 = iic.getItemID1(); + long similarityItemID2 = iic.getItemID2(); + if (similarityItemID1 != similarityItemID2) { + // Order them -- first key should be the "smaller" one + long itemID1; + long itemID2; + if (similarityItemID1 < similarityItemID2) { + itemID1 = similarityItemID1; + itemID2 = similarityItemID2; + } else { + itemID1 = similarityItemID2; + itemID2 = similarityItemID1; + } + FastByIDMap<Double> map = similarityMaps.get(itemID1); + if (map == null) { + map = new FastByIDMap<>(); + similarityMaps.put(itemID1, map); + } + map.put(itemID2, iic.getValue()); + + doIndex(itemID1, itemID2); + doIndex(itemID2, itemID1); + } + // else similarity between item and itself already assumed to be 1.0 + } + } + + private void doIndex(long fromItemID, long toItemID) { + FastIDSet similarItemIDs = similarItemIDsIndex.get(fromItemID); + if (similarItemIDs == null) { + similarItemIDs = new FastIDSet(); + similarItemIDsIndex.put(fromItemID, similarItemIDs); + } + similarItemIDs.add(toItemID); + } + + /** + * <p> + * Returns the similarity between two items. Note that similarity is assumed to be symmetric, that + * {@code itemSimilarity(item1, item2) == itemSimilarity(item2, item1)}, and that + * {@code itemSimilarity(item1,item1) == 1.0} for all items. + * </p> + * + * @param itemID1 + * first item + * @param itemID2 + * second item + * @return similarity between the two + */ + @Override + public double itemSimilarity(long itemID1, long itemID2) { + if (itemID1 == itemID2) { + return 1.0; + } + long firstID; + long secondID; + if (itemID1 < itemID2) { + firstID = itemID1; + secondID = itemID2; + } else { + firstID = itemID2; + secondID = itemID1; + } + FastByIDMap<Double> nextMap = similarityMaps.get(firstID); + if (nextMap == null) { + return Double.NaN; + } + Double similarity = nextMap.get(secondID); + return similarity == null ? Double.NaN : similarity; + } + + @Override + public double[] itemSimilarities(long itemID1, long[] itemID2s) { + int length = itemID2s.length; + double[] result = new double[length]; + for (int i = 0; i < length; i++) { + result[i] = itemSimilarity(itemID1, itemID2s[i]); + } + return result; + } + + @Override + public long[] allSimilarItemIDs(long itemID) { + FastIDSet similarItemIDs = similarItemIDsIndex.get(itemID); + return similarItemIDs != null ? similarItemIDs.toArray() : NO_IDS; + } + + @Override + public void refresh(Collection<Refreshable> alreadyRefreshed) { + // Do nothing + } + + /** Encapsulates a similarity between two items. Similarity must be in the range [-1.0,1.0]. */ + public static final class ItemItemSimilarity implements Comparable<ItemItemSimilarity> { + + private final long itemID1; + private final long itemID2; + private final double value; + + /** + * @param itemID1 + * first item + * @param itemID2 + * second item + * @param value + * similarity between the two + * @throws IllegalArgumentException + * if value is NaN, less than -1.0 or greater than 1.0 + */ + public ItemItemSimilarity(long itemID1, long itemID2, double value) { + Preconditions.checkArgument(value >= -1.0 && value <= 1.0, "Illegal value: " + value + ". Must be: -1.0 <= value <= 1.0"); + this.itemID1 = itemID1; + this.itemID2 = itemID2; + this.value = value; + } + + public long getItemID1() { + return itemID1; + } + + public long getItemID2() { + return itemID2; + } + + public double getValue() { + return value; + } + + @Override + public String toString() { + return "ItemItemSimilarity[" + itemID1 + ',' + itemID2 + ':' + value + ']'; + } + + /** Defines an ordering from highest similarity to lowest. */ + @Override + public int compareTo(ItemItemSimilarity other) { + double otherValue = other.getValue(); + return value > otherValue ? -1 : value < otherValue ? 1 : 0; + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof ItemItemSimilarity)) { + return false; + } + ItemItemSimilarity otherSimilarity = (ItemItemSimilarity) other; + return otherSimilarity.getItemID1() == itemID1 + && otherSimilarity.getItemID2() == itemID2 + && otherSimilarity.getValue() == value; + } + + @Override + public int hashCode() { + return (int) itemID1 ^ (int) itemID2 ^ RandomUtils.hashDouble(value); + } + + } + + private static final class DataModelSimilaritiesIterator extends AbstractIterator<ItemItemSimilarity> { + + private final ItemSimilarity otherSimilarity; + private final long[] itemIDs; + private int i; + private long itemID1; + private int j; + + private DataModelSimilaritiesIterator(ItemSimilarity otherSimilarity, long[] itemIDs) { + this.otherSimilarity = otherSimilarity; + this.itemIDs = itemIDs; + i = 0; + itemID1 = itemIDs[0]; + j = 1; + } + + @Override + protected ItemItemSimilarity computeNext() { + int size = itemIDs.length; + ItemItemSimilarity result = null; + while (result == null && i < size - 1) { + long itemID2 = itemIDs[j]; + double similarity; + try { + similarity = otherSimilarity.itemSimilarity(itemID1, itemID2); + } catch (TasteException te) { + // ugly: + throw new IllegalStateException(te); + } + if (!Double.isNaN(similarity)) { + result = new ItemItemSimilarity(itemID1, itemID2, similarity); + } + if (++j == size) { + itemID1 = itemIDs[++i]; + j = i + 1; + } + } + if (result == null) { + return endOfData(); + } else { + return result; + } + } + + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/GenericUserSimilarity.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/GenericUserSimilarity.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/GenericUserSimilarity.java new file mode 100644 index 0000000..1c221c2 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/GenericUserSimilarity.java @@ -0,0 +1,238 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.similarity; + +import java.util.Collection; +import java.util.Iterator; + +import com.google.common.collect.AbstractIterator; +import org.apache.mahout.cf.taste.common.Refreshable; +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.FastByIDMap; +import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; +import org.apache.mahout.cf.taste.impl.recommender.TopItems; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.similarity.PreferenceInferrer; +import org.apache.mahout.cf.taste.similarity.UserSimilarity; +import org.apache.mahout.common.RandomUtils; + +import com.google.common.base.Preconditions; + +public final class GenericUserSimilarity implements UserSimilarity { + + private final FastByIDMap<FastByIDMap<Double>> similarityMaps = new FastByIDMap<>(); + + public GenericUserSimilarity(Iterable<UserUserSimilarity> similarities) { + initSimilarityMaps(similarities.iterator()); + } + + public GenericUserSimilarity(Iterable<UserUserSimilarity> similarities, int maxToKeep) { + Iterable<UserUserSimilarity> keptSimilarities = + TopItems.getTopUserUserSimilarities(maxToKeep, similarities.iterator()); + initSimilarityMaps(keptSimilarities.iterator()); + } + + public GenericUserSimilarity(UserSimilarity otherSimilarity, DataModel dataModel) throws TasteException { + long[] userIDs = longIteratorToList(dataModel.getUserIDs()); + initSimilarityMaps(new DataModelSimilaritiesIterator(otherSimilarity, userIDs)); + } + + public GenericUserSimilarity(UserSimilarity otherSimilarity, + DataModel dataModel, + int maxToKeep) throws TasteException { + long[] userIDs = longIteratorToList(dataModel.getUserIDs()); + Iterator<UserUserSimilarity> it = new DataModelSimilaritiesIterator(otherSimilarity, userIDs); + Iterable<UserUserSimilarity> keptSimilarities = TopItems.getTopUserUserSimilarities(maxToKeep, it); + initSimilarityMaps(keptSimilarities.iterator()); + } + + static long[] longIteratorToList(LongPrimitiveIterator iterator) { + long[] result = new long[5]; + int size = 0; + while (iterator.hasNext()) { + if (size == result.length) { + long[] newResult = new long[result.length << 1]; + System.arraycopy(result, 0, newResult, 0, result.length); + result = newResult; + } + result[size++] = iterator.next(); + } + if (size != result.length) { + long[] newResult = new long[size]; + System.arraycopy(result, 0, newResult, 0, size); + result = newResult; + } + return result; + } + + private void initSimilarityMaps(Iterator<UserUserSimilarity> similarities) { + while (similarities.hasNext()) { + UserUserSimilarity uuc = similarities.next(); + long similarityUser1 = uuc.getUserID1(); + long similarityUser2 = uuc.getUserID2(); + if (similarityUser1 != similarityUser2) { + // Order them -- first key should be the "smaller" one + long user1; + long user2; + if (similarityUser1 < similarityUser2) { + user1 = similarityUser1; + user2 = similarityUser2; + } else { + user1 = similarityUser2; + user2 = similarityUser1; + } + FastByIDMap<Double> map = similarityMaps.get(user1); + if (map == null) { + map = new FastByIDMap<>(); + similarityMaps.put(user1, map); + } + map.put(user2, uuc.getValue()); + } + // else similarity between user and itself already assumed to be 1.0 + } + } + + @Override + public double userSimilarity(long userID1, long userID2) { + if (userID1 == userID2) { + return 1.0; + } + long first; + long second; + if (userID1 < userID2) { + first = userID1; + second = userID2; + } else { + first = userID2; + second = userID1; + } + FastByIDMap<Double> nextMap = similarityMaps.get(first); + if (nextMap == null) { + return Double.NaN; + } + Double similarity = nextMap.get(second); + return similarity == null ? Double.NaN : similarity; + } + + @Override + public void setPreferenceInferrer(PreferenceInferrer inferrer) { + throw new UnsupportedOperationException(); + } + + @Override + public void refresh(Collection<Refreshable> alreadyRefreshed) { + // Do nothing + } + + public static final class UserUserSimilarity implements Comparable<UserUserSimilarity> { + + private final long userID1; + private final long userID2; + private final double value; + + public UserUserSimilarity(long userID1, long userID2, double value) { + Preconditions.checkArgument(value >= -1.0 && value <= 1.0, "Illegal value: " + value + ". Must be: -1.0 <= value <= 1.0"); + this.userID1 = userID1; + this.userID2 = userID2; + this.value = value; + } + + public long getUserID1() { + return userID1; + } + + public long getUserID2() { + return userID2; + } + + public double getValue() { + return value; + } + + @Override + public String toString() { + return "UserUserSimilarity[" + userID1 + ',' + userID2 + ':' + value + ']'; + } + + /** Defines an ordering from highest similarity to lowest. */ + @Override + public int compareTo(UserUserSimilarity other) { + double otherValue = other.getValue(); + return value > otherValue ? -1 : value < otherValue ? 1 : 0; + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof UserUserSimilarity)) { + return false; + } + UserUserSimilarity otherSimilarity = (UserUserSimilarity) other; + return otherSimilarity.getUserID1() == userID1 + && otherSimilarity.getUserID2() == userID2 + && otherSimilarity.getValue() == value; + } + + @Override + public int hashCode() { + return (int) userID1 ^ (int) userID2 ^ RandomUtils.hashDouble(value); + } + + } + + private static final class DataModelSimilaritiesIterator extends AbstractIterator<UserUserSimilarity> { + + private final UserSimilarity otherSimilarity; + private final long[] itemIDs; + private int i; + private long itemID1; + private int j; + + private DataModelSimilaritiesIterator(UserSimilarity otherSimilarity, long[] itemIDs) { + this.otherSimilarity = otherSimilarity; + this.itemIDs = itemIDs; + i = 0; + itemID1 = itemIDs[0]; + j = 1; + } + + @Override + protected UserUserSimilarity computeNext() { + int size = itemIDs.length; + while (i < size - 1) { + long itemID2 = itemIDs[j]; + double similarity; + try { + similarity = otherSimilarity.userSimilarity(itemID1, itemID2); + } catch (TasteException te) { + // ugly: + throw new IllegalStateException(te); + } + if (!Double.isNaN(similarity)) { + return new UserUserSimilarity(itemID1, itemID2, similarity); + } + if (++j == size) { + itemID1 = itemIDs[++i]; + j = i + 1; + } + } + return endOfData(); + } + + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarity.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarity.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarity.java new file mode 100644 index 0000000..3084c8f --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarity.java @@ -0,0 +1,121 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.similarity; + +import java.util.Collection; + +import org.apache.mahout.cf.taste.common.Refreshable; +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.FastIDSet; +import org.apache.mahout.cf.taste.impl.common.RefreshHelper; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.similarity.PreferenceInferrer; +import org.apache.mahout.cf.taste.similarity.UserSimilarity; +import org.apache.mahout.math.stats.LogLikelihood; + +/** + * See <a href="http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.14.5962"> + * http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.14.5962</a> and + * <a href="http://tdunning.blogspot.com/2008/03/surprise-and-coincidence.html"> + * http://tdunning.blogspot.com/2008/03/surprise-and-coincidence.html</a>. + */ +public final class LogLikelihoodSimilarity extends AbstractItemSimilarity implements UserSimilarity { + + public LogLikelihoodSimilarity(DataModel dataModel) { + super(dataModel); + } + + /** + * @throws UnsupportedOperationException + */ + @Override + public void setPreferenceInferrer(PreferenceInferrer inferrer) { + throw new UnsupportedOperationException(); + } + + @Override + public double userSimilarity(long userID1, long userID2) throws TasteException { + + DataModel dataModel = getDataModel(); + FastIDSet prefs1 = dataModel.getItemIDsFromUser(userID1); + FastIDSet prefs2 = dataModel.getItemIDsFromUser(userID2); + + long prefs1Size = prefs1.size(); + long prefs2Size = prefs2.size(); + long intersectionSize = + prefs1Size < prefs2Size ? prefs2.intersectionSize(prefs1) : prefs1.intersectionSize(prefs2); + if (intersectionSize == 0) { + return Double.NaN; + } + long numItems = dataModel.getNumItems(); + double logLikelihood = + LogLikelihood.logLikelihoodRatio(intersectionSize, + prefs2Size - intersectionSize, + prefs1Size - intersectionSize, + numItems - prefs1Size - prefs2Size + intersectionSize); + return 1.0 - 1.0 / (1.0 + logLikelihood); + } + + @Override + public double itemSimilarity(long itemID1, long itemID2) throws TasteException { + DataModel dataModel = getDataModel(); + long preferring1 = dataModel.getNumUsersWithPreferenceFor(itemID1); + long numUsers = dataModel.getNumUsers(); + return doItemSimilarity(itemID1, itemID2, preferring1, numUsers); + } + + @Override + public double[] itemSimilarities(long itemID1, long[] itemID2s) throws TasteException { + DataModel dataModel = getDataModel(); + long preferring1 = dataModel.getNumUsersWithPreferenceFor(itemID1); + long numUsers = dataModel.getNumUsers(); + int length = itemID2s.length; + double[] result = new double[length]; + for (int i = 0; i < length; i++) { + result[i] = doItemSimilarity(itemID1, itemID2s[i], preferring1, numUsers); + } + return result; + } + + private double doItemSimilarity(long itemID1, long itemID2, long preferring1, long numUsers) throws TasteException { + DataModel dataModel = getDataModel(); + long preferring1and2 = dataModel.getNumUsersWithPreferenceFor(itemID1, itemID2); + if (preferring1and2 == 0) { + return Double.NaN; + } + long preferring2 = dataModel.getNumUsersWithPreferenceFor(itemID2); + double logLikelihood = + LogLikelihood.logLikelihoodRatio(preferring1and2, + preferring2 - preferring1and2, + preferring1 - preferring1and2, + numUsers - preferring1 - preferring2 + preferring1and2); + return 1.0 - 1.0 / (1.0 + logLikelihood); + } + + @Override + public void refresh(Collection<Refreshable> alreadyRefreshed) { + alreadyRefreshed = RefreshHelper.buildRefreshed(alreadyRefreshed); + RefreshHelper.maybeRefresh(alreadyRefreshed, getDataModel()); + } + + @Override + public String toString() { + return "LogLikelihoodSimilarity[dataModel:" + getDataModel() + ']'; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/LongPairMatchPredicate.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/LongPairMatchPredicate.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/LongPairMatchPredicate.java new file mode 100644 index 0000000..48dc4e0 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/LongPairMatchPredicate.java @@ -0,0 +1,40 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.similarity; + +import org.apache.mahout.cf.taste.impl.common.Cache; +import org.apache.mahout.common.LongPair; + +/** + * A {@link Cache.MatchPredicate} which will match an ID against either element of a + * {@link LongPair}. + */ +final class LongPairMatchPredicate implements Cache.MatchPredicate<LongPair> { + + private final long id; + + LongPairMatchPredicate(long id) { + this.id = id; + } + + @Override + public boolean matches(LongPair pair) { + return pair.getFirst() == id || pair.getSecond() == id; + } + +}
