http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/common/SkippingIterator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/common/SkippingIterator.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/common/SkippingIterator.java new file mode 100644 index 0000000..e88f98a --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/common/SkippingIterator.java @@ -0,0 +1,35 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.common; + +import java.util.Iterator; + +/** + * Adds ability to skip ahead in an iterator, perhaps more efficiently than by calling {@link #next()} + * repeatedly. + */ +public interface SkippingIterator<V> extends Iterator<V> { + + /** + * Skip the next n elements supplied by this {@link Iterator}. If there are less than n elements remaining, + * this skips all remaining elements in the {@link Iterator}. This method has the same effect as calling + * {@link #next()} n times, except that it will never throw {@link java.util.NoSuchElementException}. + */ + void skip(int n); + +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverage.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverage.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverage.java new file mode 100644 index 0000000..76e5239 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverage.java @@ -0,0 +1,100 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.common; + +import java.io.Serializable; + +import com.google.common.base.Preconditions; + +public class WeightedRunningAverage implements RunningAverage, Serializable { + + private double totalWeight; + private double average; + + public WeightedRunningAverage() { + totalWeight = 0.0; + average = Double.NaN; + } + + @Override + public synchronized void addDatum(double datum) { + addDatum(datum, 1.0); + } + + public synchronized void addDatum(double datum, double weight) { + double oldTotalWeight = totalWeight; + totalWeight += weight; + if (oldTotalWeight <= 0.0) { + average = datum; + } else { + average = average * oldTotalWeight / totalWeight + datum * weight / totalWeight; + } + } + + @Override + public synchronized void removeDatum(double datum) { + removeDatum(datum, 1.0); + } + + public synchronized void removeDatum(double datum, double weight) { + double oldTotalWeight = totalWeight; + totalWeight -= weight; + if (totalWeight <= 0.0) { + average = Double.NaN; + totalWeight = 0.0; + } else { + average = average * oldTotalWeight / totalWeight - datum * weight / totalWeight; + } + } + + @Override + public synchronized void changeDatum(double delta) { + changeDatum(delta, 1.0); + } + + public synchronized void changeDatum(double delta, double weight) { + Preconditions.checkArgument(weight <= totalWeight, "weight must be <= totalWeight"); + average += delta * weight / totalWeight; + } + + public synchronized double getTotalWeight() { + return totalWeight; + } + + /** @return {@link #getTotalWeight()} */ + @Override + public synchronized int getCount() { + return (int) totalWeight; + } + + @Override + public synchronized double getAverage() { + return average; + } + + @Override + public RunningAverage inverse() { + return new InvertedRunningAverage(this); + } + + @Override + public synchronized String toString() { + return String.valueOf(average); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverageAndStdDev.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverageAndStdDev.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverageAndStdDev.java new file mode 100644 index 0000000..bed5812 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverageAndStdDev.java @@ -0,0 +1,89 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.common; + +/** + * This subclass also provides for a weighted estimate of the sample standard deviation. + * See <a href="http://en.wikipedia.org/wiki/Mean_square_weighted_deviation">estimate formulae here</a>. + */ +public final class WeightedRunningAverageAndStdDev extends WeightedRunningAverage implements RunningAverageAndStdDev { + + private double totalSquaredWeight; + private double totalWeightedData; + private double totalWeightedSquaredData; + + public WeightedRunningAverageAndStdDev() { + totalSquaredWeight = 0.0; + totalWeightedData = 0.0; + totalWeightedSquaredData = 0.0; + } + + @Override + public synchronized void addDatum(double datum, double weight) { + super.addDatum(datum, weight); + totalSquaredWeight += weight * weight; + double weightedData = datum * weight; + totalWeightedData += weightedData; + totalWeightedSquaredData += weightedData * datum; + } + + @Override + public synchronized void removeDatum(double datum, double weight) { + super.removeDatum(datum, weight); + totalSquaredWeight -= weight * weight; + if (totalSquaredWeight <= 0.0) { + totalSquaredWeight = 0.0; + } + double weightedData = datum * weight; + totalWeightedData -= weightedData; + if (totalWeightedData <= 0.0) { + totalWeightedData = 0.0; + } + totalWeightedSquaredData -= weightedData * datum; + if (totalWeightedSquaredData <= 0.0) { + totalWeightedSquaredData = 0.0; + } + } + + /** + * @throws UnsupportedOperationException + */ + @Override + public synchronized void changeDatum(double delta, double weight) { + throw new UnsupportedOperationException(); + } + + + @Override + public synchronized double getStandardDeviation() { + double totalWeight = getTotalWeight(); + return Math.sqrt((totalWeightedSquaredData * totalWeight - totalWeightedData * totalWeightedData) + / (totalWeight * totalWeight - totalSquaredWeight)); + } + + @Override + public RunningAverageAndStdDev inverse() { + return new InvertedRunningAverageAndStdDev(this); + } + + @Override + public synchronized String toString() { + return String.valueOf(String.valueOf(getAverage()) + ',' + getStandardDeviation()); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/common/jdbc/AbstractJDBCComponent.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/common/jdbc/AbstractJDBCComponent.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/common/jdbc/AbstractJDBCComponent.java new file mode 100644 index 0000000..d1e93ab --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/common/jdbc/AbstractJDBCComponent.java @@ -0,0 +1,88 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.common.jdbc; + +import javax.naming.Context; +import javax.naming.InitialContext; +import javax.naming.NamingException; +import javax.sql.DataSource; + +import org.apache.mahout.cf.taste.common.TasteException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.base.Preconditions; + +/** + * A helper class with common elements for several JDBC-related components. + */ +public abstract class AbstractJDBCComponent { + + private static final Logger log = LoggerFactory.getLogger(AbstractJDBCComponent.class); + + private static final int DEFAULT_FETCH_SIZE = 1000; // A max, "big" number of rows to buffer at once + protected static final String DEFAULT_DATASOURCE_NAME = "jdbc/taste"; + + protected static void checkNotNullAndLog(String argName, Object value) { + Preconditions.checkArgument(value != null && !value.toString().isEmpty(), + argName + " is null or empty"); + log.debug("{}: {}", argName, value); + } + + protected static void checkNotNullAndLog(String argName, Object[] values) { + Preconditions.checkArgument(values != null && values.length != 0, argName + " is null or zero-length"); + for (Object value : values) { + checkNotNullAndLog(argName, value); + } + } + + /** + * <p> + * Looks up a {@link DataSource} by name from JNDI. "java:comp/env/" is prepended to the argument before + * looking up the name in JNDI. + * </p> + * + * @param dataSourceName + * JNDI name where a {@link DataSource} is bound (e.g. "jdbc/taste") + * @return {@link DataSource} under that JNDI name + * @throws TasteException + * if a JNDI error occurs + */ + public static DataSource lookupDataSource(String dataSourceName) throws TasteException { + Context context = null; + try { + context = new InitialContext(); + return (DataSource) context.lookup("java:comp/env/" + dataSourceName); + } catch (NamingException ne) { + throw new TasteException(ne); + } finally { + if (context != null) { + try { + context.close(); + } catch (NamingException ne) { + log.warn("Error while closing Context; continuing...", ne); + } + } + } + } + + protected int getFetchSize() { + return DEFAULT_FETCH_SIZE; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/common/jdbc/EachRowIterator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/common/jdbc/EachRowIterator.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/common/jdbc/EachRowIterator.java new file mode 100644 index 0000000..3f024bc --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/common/jdbc/EachRowIterator.java @@ -0,0 +1,92 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.common.jdbc; + +import javax.sql.DataSource; +import java.io.Closeable; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; + +import com.google.common.collect.AbstractIterator; +import org.apache.mahout.common.IOUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Provides an {@link java.util.Iterator} over the result of an SQL query, as an iteration over the {@link ResultSet}. + * While the same object will be returned from the iteration each time, it will be returned once for each row + * of the result. + */ +final class EachRowIterator extends AbstractIterator<ResultSet> implements Closeable { + + private static final Logger log = LoggerFactory.getLogger(EachRowIterator.class); + + private final Connection connection; + private final PreparedStatement statement; + private final ResultSet resultSet; + + EachRowIterator(DataSource dataSource, String sqlQuery) throws SQLException { + try { + connection = dataSource.getConnection(); + statement = connection.prepareStatement(sqlQuery, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); + statement.setFetchDirection(ResultSet.FETCH_FORWARD); + //statement.setFetchSize(getFetchSize()); + log.debug("Executing SQL query: {}", sqlQuery); + resultSet = statement.executeQuery(); + } catch (SQLException sqle) { + close(); + throw sqle; + } + } + + @Override + protected ResultSet computeNext() { + try { + if (resultSet.next()) { + return resultSet; + } else { + close(); + return null; + } + } catch (SQLException sqle) { + close(); + throw new IllegalStateException(sqle); + } + } + + public void skip(int n) throws SQLException { + try { + resultSet.relative(n); + } catch (SQLException sqle) { + // Can't use relative on MySQL Connector/J; try advancing manually + int i = 0; + while (i < n && resultSet.next()) { + i++; + } + } + } + + @Override + public void close() { + IOUtils.quietClose(resultSet, statement, connection); + endOfData(); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/common/jdbc/ResultSetIterator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/common/jdbc/ResultSetIterator.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/common/jdbc/ResultSetIterator.java new file mode 100644 index 0000000..273ebd5 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/common/jdbc/ResultSetIterator.java @@ -0,0 +1,66 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.common.jdbc; + +import javax.sql.DataSource; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.Iterator; + +import com.google.common.base.Function; +import com.google.common.collect.ForwardingIterator; +import com.google.common.collect.Iterators; + +public abstract class ResultSetIterator<T> extends ForwardingIterator<T> { + + private final Iterator<T> delegate; + private final EachRowIterator rowDelegate; + + protected ResultSetIterator(DataSource dataSource, String sqlQuery) throws SQLException { + this.rowDelegate = new EachRowIterator(dataSource, sqlQuery); + delegate = Iterators.transform(rowDelegate, + new Function<ResultSet, T>() { + @Override + public T apply(ResultSet from) { + try { + return parseElement(from); + } catch (SQLException sqle) { + throw new IllegalStateException(sqle); + } + } + }); + } + + @Override + protected Iterator<T> delegate() { + return delegate; + } + + protected abstract T parseElement(ResultSet resultSet) throws SQLException; + + public void skip(int n) { + if (n >= 1) { + try { + rowDelegate.skip(n); + } catch (SQLException sqle) { + throw new IllegalStateException(sqle); + } + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/AbstractDifferenceRecommenderEvaluator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/AbstractDifferenceRecommenderEvaluator.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/AbstractDifferenceRecommenderEvaluator.java new file mode 100644 index 0000000..f926f18 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/AbstractDifferenceRecommenderEvaluator.java @@ -0,0 +1,276 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.eval; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import com.google.common.base.Preconditions; +import org.apache.mahout.cf.taste.common.NoSuchItemException; +import org.apache.mahout.cf.taste.common.NoSuchUserException; +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.eval.DataModelBuilder; +import org.apache.mahout.cf.taste.eval.RecommenderBuilder; +import org.apache.mahout.cf.taste.eval.RecommenderEvaluator; +import org.apache.mahout.cf.taste.impl.common.FastByIDMap; +import org.apache.mahout.cf.taste.impl.common.FullRunningAverageAndStdDev; +import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; +import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev; +import org.apache.mahout.cf.taste.impl.model.GenericDataModel; +import org.apache.mahout.cf.taste.impl.model.GenericPreference; +import org.apache.mahout.cf.taste.impl.model.GenericUserPreferenceArray; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.model.Preference; +import org.apache.mahout.cf.taste.model.PreferenceArray; +import org.apache.mahout.cf.taste.recommender.Recommender; +import org.apache.mahout.common.RandomUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Abstract superclass of a couple implementations, providing shared functionality. + */ +public abstract class AbstractDifferenceRecommenderEvaluator implements RecommenderEvaluator { + + private static final Logger log = LoggerFactory.getLogger(AbstractDifferenceRecommenderEvaluator.class); + + private final Random random; + private float maxPreference; + private float minPreference; + + protected AbstractDifferenceRecommenderEvaluator() { + random = RandomUtils.getRandom(); + maxPreference = Float.NaN; + minPreference = Float.NaN; + } + + @Override + public final float getMaxPreference() { + return maxPreference; + } + + @Override + public final void setMaxPreference(float maxPreference) { + this.maxPreference = maxPreference; + } + + @Override + public final float getMinPreference() { + return minPreference; + } + + @Override + public final void setMinPreference(float minPreference) { + this.minPreference = minPreference; + } + + @Override + public double evaluate(RecommenderBuilder recommenderBuilder, + DataModelBuilder dataModelBuilder, + DataModel dataModel, + double trainingPercentage, + double evaluationPercentage) throws TasteException { + Preconditions.checkNotNull(recommenderBuilder); + Preconditions.checkNotNull(dataModel); + Preconditions.checkArgument(trainingPercentage >= 0.0 && trainingPercentage <= 1.0, + "Invalid trainingPercentage: " + trainingPercentage + ". Must be: 0.0 <= trainingPercentage <= 1.0"); + Preconditions.checkArgument(evaluationPercentage >= 0.0 && evaluationPercentage <= 1.0, + "Invalid evaluationPercentage: " + evaluationPercentage + ". Must be: 0.0 <= evaluationPercentage <= 1.0"); + + log.info("Beginning evaluation using {} of {}", trainingPercentage, dataModel); + + int numUsers = dataModel.getNumUsers(); + FastByIDMap<PreferenceArray> trainingPrefs = new FastByIDMap<>( + 1 + (int) (evaluationPercentage * numUsers)); + FastByIDMap<PreferenceArray> testPrefs = new FastByIDMap<>( + 1 + (int) (evaluationPercentage * numUsers)); + + LongPrimitiveIterator it = dataModel.getUserIDs(); + while (it.hasNext()) { + long userID = it.nextLong(); + if (random.nextDouble() < evaluationPercentage) { + splitOneUsersPrefs(trainingPercentage, trainingPrefs, testPrefs, userID, dataModel); + } + } + + DataModel trainingModel = dataModelBuilder == null ? new GenericDataModel(trainingPrefs) + : dataModelBuilder.buildDataModel(trainingPrefs); + + Recommender recommender = recommenderBuilder.buildRecommender(trainingModel); + + double result = getEvaluation(testPrefs, recommender); + log.info("Evaluation result: {}", result); + return result; + } + + private void splitOneUsersPrefs(double trainingPercentage, + FastByIDMap<PreferenceArray> trainingPrefs, + FastByIDMap<PreferenceArray> testPrefs, + long userID, + DataModel dataModel) throws TasteException { + List<Preference> oneUserTrainingPrefs = null; + List<Preference> oneUserTestPrefs = null; + PreferenceArray prefs = dataModel.getPreferencesFromUser(userID); + int size = prefs.length(); + for (int i = 0; i < size; i++) { + Preference newPref = new GenericPreference(userID, prefs.getItemID(i), prefs.getValue(i)); + if (random.nextDouble() < trainingPercentage) { + if (oneUserTrainingPrefs == null) { + oneUserTrainingPrefs = new ArrayList<>(3); + } + oneUserTrainingPrefs.add(newPref); + } else { + if (oneUserTestPrefs == null) { + oneUserTestPrefs = new ArrayList<>(3); + } + oneUserTestPrefs.add(newPref); + } + } + if (oneUserTrainingPrefs != null) { + trainingPrefs.put(userID, new GenericUserPreferenceArray(oneUserTrainingPrefs)); + if (oneUserTestPrefs != null) { + testPrefs.put(userID, new GenericUserPreferenceArray(oneUserTestPrefs)); + } + } + } + + private float capEstimatedPreference(float estimate) { + if (estimate > maxPreference) { + return maxPreference; + } + if (estimate < minPreference) { + return minPreference; + } + return estimate; + } + + private double getEvaluation(FastByIDMap<PreferenceArray> testPrefs, Recommender recommender) + throws TasteException { + reset(); + Collection<Callable<Void>> estimateCallables = new ArrayList<>(); + AtomicInteger noEstimateCounter = new AtomicInteger(); + for (Map.Entry<Long,PreferenceArray> entry : testPrefs.entrySet()) { + estimateCallables.add( + new PreferenceEstimateCallable(recommender, entry.getKey(), entry.getValue(), noEstimateCounter)); + } + log.info("Beginning evaluation of {} users", estimateCallables.size()); + RunningAverageAndStdDev timing = new FullRunningAverageAndStdDev(); + execute(estimateCallables, noEstimateCounter, timing); + return computeFinalEvaluation(); + } + + protected static void execute(Collection<Callable<Void>> callables, + AtomicInteger noEstimateCounter, + RunningAverageAndStdDev timing) throws TasteException { + + Collection<Callable<Void>> wrappedCallables = wrapWithStatsCallables(callables, noEstimateCounter, timing); + int numProcessors = Runtime.getRuntime().availableProcessors(); + ExecutorService executor = Executors.newFixedThreadPool(numProcessors); + log.info("Starting timing of {} tasks in {} threads", wrappedCallables.size(), numProcessors); + try { + List<Future<Void>> futures = executor.invokeAll(wrappedCallables); + // Go look for exceptions here, really + for (Future<Void> future : futures) { + future.get(); + } + + } catch (InterruptedException ie) { + throw new TasteException(ie); + } catch (ExecutionException ee) { + throw new TasteException(ee.getCause()); + } + + executor.shutdown(); + try { + executor.awaitTermination(10, TimeUnit.SECONDS); + } catch (InterruptedException e) { + throw new TasteException(e.getCause()); + } + } + + private static Collection<Callable<Void>> wrapWithStatsCallables(Iterable<Callable<Void>> callables, + AtomicInteger noEstimateCounter, + RunningAverageAndStdDev timing) { + Collection<Callable<Void>> wrapped = new ArrayList<>(); + int count = 0; + for (Callable<Void> callable : callables) { + boolean logStats = count++ % 1000 == 0; // log every 1000 or so iterations + wrapped.add(new StatsCallable(callable, logStats, timing, noEstimateCounter)); + } + return wrapped; + } + + protected abstract void reset(); + + protected abstract void processOneEstimate(float estimatedPreference, Preference realPref); + + protected abstract double computeFinalEvaluation(); + + public final class PreferenceEstimateCallable implements Callable<Void> { + + private final Recommender recommender; + private final long testUserID; + private final PreferenceArray prefs; + private final AtomicInteger noEstimateCounter; + + public PreferenceEstimateCallable(Recommender recommender, + long testUserID, + PreferenceArray prefs, + AtomicInteger noEstimateCounter) { + this.recommender = recommender; + this.testUserID = testUserID; + this.prefs = prefs; + this.noEstimateCounter = noEstimateCounter; + } + + @Override + public Void call() throws TasteException { + for (Preference realPref : prefs) { + float estimatedPreference = Float.NaN; + try { + estimatedPreference = recommender.estimatePreference(testUserID, realPref.getItemID()); + } catch (NoSuchUserException nsue) { + // It's possible that an item exists in the test data but not training data in which case + // NSEE will be thrown. Just ignore it and move on. + log.info("User exists in test data but not training data: {}", testUserID); + } catch (NoSuchItemException nsie) { + log.info("Item exists in test data but not training data: {}", realPref.getItemID()); + } + if (Float.isNaN(estimatedPreference)) { + noEstimateCounter.incrementAndGet(); + } else { + estimatedPreference = capEstimatedPreference(estimatedPreference); + processOneEstimate(estimatedPreference, realPref); + } + } + return null; + } + + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/AverageAbsoluteDifferenceRecommenderEvaluator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/AverageAbsoluteDifferenceRecommenderEvaluator.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/AverageAbsoluteDifferenceRecommenderEvaluator.java new file mode 100644 index 0000000..4dad040 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/AverageAbsoluteDifferenceRecommenderEvaluator.java @@ -0,0 +1,59 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.eval; + +import org.apache.mahout.cf.taste.impl.common.FullRunningAverage; +import org.apache.mahout.cf.taste.impl.common.RunningAverage; +import org.apache.mahout.cf.taste.model.Preference; + +/** + * <p> + * A {@link org.apache.mahout.cf.taste.eval.RecommenderEvaluator} which computes the average absolute + * difference between predicted and actual ratings for users. + * </p> + * + * <p> + * This algorithm is also called "mean average error". + * </p> + */ +public final class AverageAbsoluteDifferenceRecommenderEvaluator extends + AbstractDifferenceRecommenderEvaluator { + + private RunningAverage average; + + @Override + protected void reset() { + average = new FullRunningAverage(); + } + + @Override + protected void processOneEstimate(float estimatedPreference, Preference realPref) { + average.addDatum(Math.abs(realPref.getValue() - estimatedPreference)); + } + + @Override + protected double computeFinalEvaluation() { + return average.getAverage(); + } + + @Override + public String toString() { + return "AverageAbsoluteDifferenceRecommenderEvaluator"; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/GenericRecommenderIRStatsEvaluator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/GenericRecommenderIRStatsEvaluator.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/GenericRecommenderIRStatsEvaluator.java new file mode 100644 index 0000000..0e121d1 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/GenericRecommenderIRStatsEvaluator.java @@ -0,0 +1,237 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.eval; + +import java.util.List; +import java.util.Random; + +import org.apache.mahout.cf.taste.common.NoSuchUserException; +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.eval.DataModelBuilder; +import org.apache.mahout.cf.taste.eval.IRStatistics; +import org.apache.mahout.cf.taste.eval.RecommenderBuilder; +import org.apache.mahout.cf.taste.eval.RecommenderIRStatsEvaluator; +import org.apache.mahout.cf.taste.eval.RelevantItemsDataSplitter; +import org.apache.mahout.cf.taste.impl.common.FastByIDMap; +import org.apache.mahout.cf.taste.impl.common.FastIDSet; +import org.apache.mahout.cf.taste.impl.common.FullRunningAverage; +import org.apache.mahout.cf.taste.impl.common.FullRunningAverageAndStdDev; +import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; +import org.apache.mahout.cf.taste.impl.common.RunningAverage; +import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev; +import org.apache.mahout.cf.taste.impl.model.GenericDataModel; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.model.PreferenceArray; +import org.apache.mahout.cf.taste.recommender.IDRescorer; +import org.apache.mahout.cf.taste.recommender.RecommendedItem; +import org.apache.mahout.cf.taste.recommender.Recommender; +import org.apache.mahout.common.RandomUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.base.Preconditions; + +/** + * <p> + * For each user, these implementation determine the top {@code n} preferences, then evaluate the IR + * statistics based on a {@link DataModel} that does not have these values. This number {@code n} is the + * "at" value, as in "precision at 5". For example, this would mean precision evaluated by removing the top 5 + * preferences for a user and then finding the percentage of those 5 items included in the top 5 + * recommendations for that user. + * </p> + */ +public final class GenericRecommenderIRStatsEvaluator implements RecommenderIRStatsEvaluator { + + private static final Logger log = LoggerFactory.getLogger(GenericRecommenderIRStatsEvaluator.class); + + private static final double LOG2 = Math.log(2.0); + + /** + * Pass as "relevanceThreshold" argument to + * {@link #evaluate(RecommenderBuilder, DataModelBuilder, DataModel, IDRescorer, int, double, double)} to + * have it attempt to compute a reasonable threshold. Note that this will impact performance. + */ + public static final double CHOOSE_THRESHOLD = Double.NaN; + + private final Random random; + private final RelevantItemsDataSplitter dataSplitter; + + public GenericRecommenderIRStatsEvaluator() { + this(new GenericRelevantItemsDataSplitter()); + } + + public GenericRecommenderIRStatsEvaluator(RelevantItemsDataSplitter dataSplitter) { + Preconditions.checkNotNull(dataSplitter); + random = RandomUtils.getRandom(); + this.dataSplitter = dataSplitter; + } + + @Override + public IRStatistics evaluate(RecommenderBuilder recommenderBuilder, + DataModelBuilder dataModelBuilder, + DataModel dataModel, + IDRescorer rescorer, + int at, + double relevanceThreshold, + double evaluationPercentage) throws TasteException { + + Preconditions.checkArgument(recommenderBuilder != null, "recommenderBuilder is null"); + Preconditions.checkArgument(dataModel != null, "dataModel is null"); + Preconditions.checkArgument(at >= 1, "at must be at least 1"); + Preconditions.checkArgument(evaluationPercentage > 0.0 && evaluationPercentage <= 1.0, + "Invalid evaluationPercentage: " + evaluationPercentage + ". Must be: 0.0 < evaluationPercentage <= 1.0"); + + int numItems = dataModel.getNumItems(); + RunningAverage precision = new FullRunningAverage(); + RunningAverage recall = new FullRunningAverage(); + RunningAverage fallOut = new FullRunningAverage(); + RunningAverage nDCG = new FullRunningAverage(); + int numUsersRecommendedFor = 0; + int numUsersWithRecommendations = 0; + + LongPrimitiveIterator it = dataModel.getUserIDs(); + while (it.hasNext()) { + + long userID = it.nextLong(); + + if (random.nextDouble() >= evaluationPercentage) { + // Skipped + continue; + } + + long start = System.currentTimeMillis(); + + PreferenceArray prefs = dataModel.getPreferencesFromUser(userID); + + // List some most-preferred items that would count as (most) "relevant" results + double theRelevanceThreshold = Double.isNaN(relevanceThreshold) ? computeThreshold(prefs) : relevanceThreshold; + FastIDSet relevantItemIDs = dataSplitter.getRelevantItemsIDs(userID, at, theRelevanceThreshold, dataModel); + + int numRelevantItems = relevantItemIDs.size(); + if (numRelevantItems <= 0) { + continue; + } + + FastByIDMap<PreferenceArray> trainingUsers = new FastByIDMap<>(dataModel.getNumUsers()); + LongPrimitiveIterator it2 = dataModel.getUserIDs(); + while (it2.hasNext()) { + dataSplitter.processOtherUser(userID, relevantItemIDs, trainingUsers, it2.nextLong(), dataModel); + } + + DataModel trainingModel = dataModelBuilder == null ? new GenericDataModel(trainingUsers) + : dataModelBuilder.buildDataModel(trainingUsers); + try { + trainingModel.getPreferencesFromUser(userID); + } catch (NoSuchUserException nsee) { + continue; // Oops we excluded all prefs for the user -- just move on + } + + int size = numRelevantItems + trainingModel.getItemIDsFromUser(userID).size(); + if (size < 2 * at) { + // Really not enough prefs to meaningfully evaluate this user + continue; + } + + Recommender recommender = recommenderBuilder.buildRecommender(trainingModel); + + int intersectionSize = 0; + List<RecommendedItem> recommendedItems = recommender.recommend(userID, at, rescorer); + for (RecommendedItem recommendedItem : recommendedItems) { + if (relevantItemIDs.contains(recommendedItem.getItemID())) { + intersectionSize++; + } + } + + int numRecommendedItems = recommendedItems.size(); + + // Precision + if (numRecommendedItems > 0) { + precision.addDatum((double) intersectionSize / (double) numRecommendedItems); + } + + // Recall + recall.addDatum((double) intersectionSize / (double) numRelevantItems); + + // Fall-out + if (numRelevantItems < size) { + fallOut.addDatum((double) (numRecommendedItems - intersectionSize) + / (double) (numItems - numRelevantItems)); + } + + // nDCG + // In computing, assume relevant IDs have relevance 1 and others 0 + double cumulativeGain = 0.0; + double idealizedGain = 0.0; + for (int i = 0; i < numRecommendedItems; i++) { + RecommendedItem item = recommendedItems.get(i); + double discount = 1.0 / log2(i + 2.0); // Classical formulation says log(i+1), but i is 0-based here + if (relevantItemIDs.contains(item.getItemID())) { + cumulativeGain += discount; + } + // otherwise we're multiplying discount by relevance 0 so it doesn't do anything + + // Ideally results would be ordered with all relevant ones first, so this theoretical + // ideal list starts with number of relevant items equal to the total number of relevant items + if (i < numRelevantItems) { + idealizedGain += discount; + } + } + if (idealizedGain > 0.0) { + nDCG.addDatum(cumulativeGain / idealizedGain); + } + + // Reach + numUsersRecommendedFor++; + if (numRecommendedItems > 0) { + numUsersWithRecommendations++; + } + + long end = System.currentTimeMillis(); + + log.info("Evaluated with user {} in {}ms", userID, end - start); + log.info("Precision/recall/fall-out/nDCG/reach: {} / {} / {} / {} / {}", + precision.getAverage(), recall.getAverage(), fallOut.getAverage(), nDCG.getAverage(), + (double) numUsersWithRecommendations / (double) numUsersRecommendedFor); + } + + return new IRStatisticsImpl( + precision.getAverage(), + recall.getAverage(), + fallOut.getAverage(), + nDCG.getAverage(), + (double) numUsersWithRecommendations / (double) numUsersRecommendedFor); + } + + private static double computeThreshold(PreferenceArray prefs) { + if (prefs.length() < 2) { + // Not enough data points -- return a threshold that allows everything + return Double.NEGATIVE_INFINITY; + } + RunningAverageAndStdDev stdDev = new FullRunningAverageAndStdDev(); + int size = prefs.length(); + for (int i = 0; i < size; i++) { + stdDev.addDatum(prefs.getValue(i)); + } + return stdDev.getAverage() + stdDev.getStandardDeviation(); + } + + private static double log2(double value) { + return Math.log(value) / LOG2; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/GenericRelevantItemsDataSplitter.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/GenericRelevantItemsDataSplitter.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/GenericRelevantItemsDataSplitter.java new file mode 100644 index 0000000..f4e4522 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/GenericRelevantItemsDataSplitter.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.eval; + +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.eval.RelevantItemsDataSplitter; +import org.apache.mahout.cf.taste.impl.common.FastByIDMap; +import org.apache.mahout.cf.taste.impl.common.FastIDSet; +import org.apache.mahout.cf.taste.impl.model.GenericUserPreferenceArray; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.model.Preference; +import org.apache.mahout.cf.taste.model.PreferenceArray; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +/** + * Picks relevant items to be those with the strongest preference, and + * includes the other users' preferences in full. + */ +public final class GenericRelevantItemsDataSplitter implements RelevantItemsDataSplitter { + + @Override + public FastIDSet getRelevantItemsIDs(long userID, + int at, + double relevanceThreshold, + DataModel dataModel) throws TasteException { + PreferenceArray prefs = dataModel.getPreferencesFromUser(userID); + FastIDSet relevantItemIDs = new FastIDSet(at); + prefs.sortByValueReversed(); + for (int i = 0; i < prefs.length() && relevantItemIDs.size() < at; i++) { + if (prefs.getValue(i) >= relevanceThreshold) { + relevantItemIDs.add(prefs.getItemID(i)); + } + } + return relevantItemIDs; + } + + @Override + public void processOtherUser(long userID, + FastIDSet relevantItemIDs, + FastByIDMap<PreferenceArray> trainingUsers, + long otherUserID, + DataModel dataModel) throws TasteException { + PreferenceArray prefs2Array = dataModel.getPreferencesFromUser(otherUserID); + // If we're dealing with the very user that we're evaluating for precision/recall, + if (userID == otherUserID) { + // then must remove all the test IDs, the "relevant" item IDs + List<Preference> prefs2 = new ArrayList<>(prefs2Array.length()); + for (Preference pref : prefs2Array) { + prefs2.add(pref); + } + for (Iterator<Preference> iterator = prefs2.iterator(); iterator.hasNext();) { + Preference pref = iterator.next(); + if (relevantItemIDs.contains(pref.getItemID())) { + iterator.remove(); + } + } + if (!prefs2.isEmpty()) { + trainingUsers.put(otherUserID, new GenericUserPreferenceArray(prefs2)); + } + } else { + // otherwise just add all those other user's prefs + trainingUsers.put(otherUserID, prefs2Array); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/IRStatisticsImpl.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/IRStatisticsImpl.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/IRStatisticsImpl.java new file mode 100644 index 0000000..2838b08 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/IRStatisticsImpl.java @@ -0,0 +1,95 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.eval; + +import java.io.Serializable; + +import org.apache.mahout.cf.taste.eval.IRStatistics; + +import com.google.common.base.Preconditions; + +public final class IRStatisticsImpl implements IRStatistics, Serializable { + + private final double precision; + private final double recall; + private final double fallOut; + private final double ndcg; + private final double reach; + + IRStatisticsImpl(double precision, double recall, double fallOut, double ndcg, double reach) { + Preconditions.checkArgument(Double.isNaN(precision) || (precision >= 0.0 && precision <= 1.0), + "Illegal precision: " + precision + ". Must be: 0.0 <= precision <= 1.0 or NaN"); + Preconditions.checkArgument(Double.isNaN(recall) || (recall >= 0.0 && recall <= 1.0), + "Illegal recall: " + recall + ". Must be: 0.0 <= recall <= 1.0 or NaN"); + Preconditions.checkArgument(Double.isNaN(fallOut) || (fallOut >= 0.0 && fallOut <= 1.0), + "Illegal fallOut: " + fallOut + ". Must be: 0.0 <= fallOut <= 1.0 or NaN"); + Preconditions.checkArgument(Double.isNaN(ndcg) || (ndcg >= 0.0 && ndcg <= 1.0), + "Illegal nDCG: " + ndcg + ". Must be: 0.0 <= nDCG <= 1.0 or NaN"); + Preconditions.checkArgument(Double.isNaN(reach) || (reach >= 0.0 && reach <= 1.0), + "Illegal reach: " + reach + ". Must be: 0.0 <= reach <= 1.0 or NaN"); + this.precision = precision; + this.recall = recall; + this.fallOut = fallOut; + this.ndcg = ndcg; + this.reach = reach; + } + + @Override + public double getPrecision() { + return precision; + } + + @Override + public double getRecall() { + return recall; + } + + @Override + public double getFallOut() { + return fallOut; + } + + @Override + public double getF1Measure() { + return getFNMeasure(1.0); + } + + @Override + public double getFNMeasure(double b) { + double b2 = b * b; + double sum = b2 * precision + recall; + return sum == 0.0 ? Double.NaN : (1.0 + b2) * precision * recall / sum; + } + + @Override + public double getNormalizedDiscountedCumulativeGain() { + return ndcg; + } + + @Override + public double getReach() { + return reach; + } + + @Override + public String toString() { + return "IRStatisticsImpl[precision:" + precision + ",recall:" + recall + ",fallOut:" + + fallOut + ",nDCG:" + ndcg + ",reach:" + reach + ']'; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/LoadCallable.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/LoadCallable.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/LoadCallable.java new file mode 100644 index 0000000..213f7f9 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/LoadCallable.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.eval; + +import org.apache.mahout.cf.taste.recommender.Recommender; + +import java.util.concurrent.Callable; + +final class LoadCallable implements Callable<Void> { + + private final Recommender recommender; + private final long userID; + + LoadCallable(Recommender recommender, long userID) { + this.recommender = recommender; + this.userID = userID; + } + + @Override + public Void call() throws Exception { + recommender.recommend(userID, 10); + return null; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/LoadEvaluator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/LoadEvaluator.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/LoadEvaluator.java new file mode 100644 index 0000000..2d27a37 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/LoadEvaluator.java @@ -0,0 +1,61 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.eval; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.concurrent.Callable; +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.FullRunningAverageAndStdDev; +import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; +import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev; +import org.apache.mahout.cf.taste.impl.common.SamplingLongPrimitiveIterator; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.recommender.Recommender; + +/** + * Simple helper class for running load on a Recommender. + */ +public final class LoadEvaluator { + + private LoadEvaluator() { } + + public static LoadStatistics runLoad(Recommender recommender) throws TasteException { + return runLoad(recommender, 10); + } + + public static LoadStatistics runLoad(Recommender recommender, int howMany) throws TasteException { + DataModel dataModel = recommender.getDataModel(); + int numUsers = dataModel.getNumUsers(); + double sampleRate = 1000.0 / numUsers; + LongPrimitiveIterator userSampler = + SamplingLongPrimitiveIterator.maybeWrapIterator(dataModel.getUserIDs(), sampleRate); + recommender.recommend(userSampler.next(), howMany); // Warm up + Collection<Callable<Void>> callables = new ArrayList<>(); + while (userSampler.hasNext()) { + callables.add(new LoadCallable(recommender, userSampler.next())); + } + AtomicInteger noEstimateCounter = new AtomicInteger(); + RunningAverageAndStdDev timing = new FullRunningAverageAndStdDev(); + AbstractDifferenceRecommenderEvaluator.execute(callables, noEstimateCounter, timing); + return new LoadStatistics(timing); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/LoadStatistics.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/LoadStatistics.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/LoadStatistics.java new file mode 100644 index 0000000..f89160c --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/LoadStatistics.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.eval; + +import org.apache.mahout.cf.taste.impl.common.RunningAverage; + +public final class LoadStatistics { + + private final RunningAverage timing; + + LoadStatistics(RunningAverage timing) { + this.timing = timing; + } + + public RunningAverage getTiming() { + return timing; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/OrderBasedRecommenderEvaluator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/OrderBasedRecommenderEvaluator.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/OrderBasedRecommenderEvaluator.java new file mode 100644 index 0000000..e267a39 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/OrderBasedRecommenderEvaluator.java @@ -0,0 +1,431 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.eval; + +import java.util.Arrays; +import java.util.List; + +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.FastIDSet; +import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; +import org.apache.mahout.cf.taste.impl.common.RunningAverage; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.model.PreferenceArray; +import org.apache.mahout.cf.taste.recommender.RecommendedItem; +import org.apache.mahout.cf.taste.recommender.Recommender; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Evaluate recommender by comparing order of all raw prefs with order in + * recommender's output for that user. Can also compare data models. + */ +public final class OrderBasedRecommenderEvaluator { + + private static final Logger log = LoggerFactory.getLogger(OrderBasedRecommenderEvaluator.class); + + private OrderBasedRecommenderEvaluator() { + } + + public static void evaluate(Recommender recommender1, + Recommender recommender2, + int samples, + RunningAverage tracker, + String tag) throws TasteException { + printHeader(); + LongPrimitiveIterator users = recommender1.getDataModel().getUserIDs(); + + while (users.hasNext()) { + long userID = users.nextLong(); + List<RecommendedItem> recs1 = recommender1.recommend(userID, samples); + List<RecommendedItem> recs2 = recommender2.recommend(userID, samples); + FastIDSet commonSet = new FastIDSet(); + long maxItemID = setBits(commonSet, recs1, samples); + FastIDSet otherSet = new FastIDSet(); + maxItemID = Math.max(maxItemID, setBits(otherSet, recs2, samples)); + int max = mask(commonSet, otherSet, maxItemID); + max = Math.min(max, samples); + if (max < 2) { + continue; + } + Long[] items1 = getCommonItems(commonSet, recs1, max); + Long[] items2 = getCommonItems(commonSet, recs2, max); + double variance = scoreCommonSubset(tag, userID, samples, max, items1, items2); + tracker.addDatum(variance); + } + } + + public static void evaluate(Recommender recommender, + DataModel model, + int samples, + RunningAverage tracker, + String tag) throws TasteException { + printHeader(); + LongPrimitiveIterator users = recommender.getDataModel().getUserIDs(); + while (users.hasNext()) { + long userID = users.nextLong(); + List<RecommendedItem> recs1 = recommender.recommend(userID, model.getNumItems()); + PreferenceArray prefs2 = model.getPreferencesFromUser(userID); + prefs2.sortByValueReversed(); + FastIDSet commonSet = new FastIDSet(); + long maxItemID = setBits(commonSet, recs1, samples); + FastIDSet otherSet = new FastIDSet(); + maxItemID = Math.max(maxItemID, setBits(otherSet, prefs2, samples)); + int max = mask(commonSet, otherSet, maxItemID); + max = Math.min(max, samples); + if (max < 2) { + continue; + } + Long[] items1 = getCommonItems(commonSet, recs1, max); + Long[] items2 = getCommonItems(commonSet, prefs2, max); + double variance = scoreCommonSubset(tag, userID, samples, max, items1, items2); + tracker.addDatum(variance); + } + } + + public static void evaluate(DataModel model1, + DataModel model2, + int samples, + RunningAverage tracker, + String tag) throws TasteException { + printHeader(); + LongPrimitiveIterator users = model1.getUserIDs(); + while (users.hasNext()) { + long userID = users.nextLong(); + PreferenceArray prefs1 = model1.getPreferencesFromUser(userID); + PreferenceArray prefs2 = model2.getPreferencesFromUser(userID); + prefs1.sortByValueReversed(); + prefs2.sortByValueReversed(); + FastIDSet commonSet = new FastIDSet(); + long maxItemID = setBits(commonSet, prefs1, samples); + FastIDSet otherSet = new FastIDSet(); + maxItemID = Math.max(maxItemID, setBits(otherSet, prefs2, samples)); + int max = mask(commonSet, otherSet, maxItemID); + max = Math.min(max, samples); + if (max < 2) { + continue; + } + Long[] items1 = getCommonItems(commonSet, prefs1, max); + Long[] items2 = getCommonItems(commonSet, prefs2, max); + double variance = scoreCommonSubset(tag, userID, samples, max, items1, items2); + tracker.addDatum(variance); + } + } + + /** + * This exists because FastIDSet has 'retainAll' as MASK, but there is + * no count of the number of items in the set. size() is supposed to do + * this but does not work. + */ + private static int mask(FastIDSet commonSet, FastIDSet otherSet, long maxItemID) { + int count = 0; + for (int i = 0; i <= maxItemID; i++) { + if (commonSet.contains(i)) { + if (otherSet.contains(i)) { + count++; + } else { + commonSet.remove(i); + } + } + } + return count; + } + + private static Long[] getCommonItems(FastIDSet commonSet, Iterable<RecommendedItem> recs, int max) { + Long[] commonItems = new Long[max]; + int index = 0; + for (RecommendedItem rec : recs) { + Long item = rec.getItemID(); + if (commonSet.contains(item)) { + commonItems[index++] = item; + } + if (index == max) { + break; + } + } + return commonItems; + } + + private static Long[] getCommonItems(FastIDSet commonSet, PreferenceArray prefs1, int max) { + Long[] commonItems = new Long[max]; + int index = 0; + for (int i = 0; i < prefs1.length(); i++) { + Long item = prefs1.getItemID(i); + if (commonSet.contains(item)) { + commonItems[index++] = item; + } + if (index == max) { + break; + } + } + return commonItems; + } + + private static long setBits(FastIDSet modelSet, List<RecommendedItem> items, int max) { + long maxItem = -1; + for (int i = 0; i < items.size() && i < max; i++) { + long itemID = items.get(i).getItemID(); + modelSet.add(itemID); + if (itemID > maxItem) { + maxItem = itemID; + } + } + return maxItem; + } + + private static long setBits(FastIDSet modelSet, PreferenceArray prefs, int max) { + long maxItem = -1; + for (int i = 0; i < prefs.length() && i < max; i++) { + long itemID = prefs.getItemID(i); + modelSet.add(itemID); + if (itemID > maxItem) { + maxItem = itemID; + } + } + return maxItem; + } + + private static void printHeader() { + log.info("tag,user,samples,common,hamming,bubble,rank,normal,score"); + } + + /** + * Common Subset Scoring + * + * These measurements are given the set of results that are common to both + * recommendation lists. They only get ordered lists. + * + * These measures all return raw numbers do not correlate among the tests. + * The numbers are not corrected against the total number of samples or the + * number of common items. + * The one contract is that all measures are 0 for an exact match and an + * increasing positive number as differences increase. + */ + private static double scoreCommonSubset(String tag, + long userID, + int samples, + int subset, + Long[] itemsL, + Long[] itemsR) { + int[] vectorZ = new int[subset]; + int[] vectorZabs = new int[subset]; + + long bubble = sort(itemsL, itemsR); + int hamming = slidingWindowHamming(itemsR, itemsL); + if (hamming > samples) { + throw new IllegalStateException(); + } + getVectorZ(itemsR, itemsL, vectorZ, vectorZabs); + double normalW = normalWilcoxon(vectorZ, vectorZabs); + double meanRank = getMeanRank(vectorZabs); + // case statement for requested value + double variance = Math.sqrt(meanRank); + log.info("{},{},{},{},{},{},{},{},{}", + tag, userID, samples, subset, hamming, bubble, meanRank, normalW, variance); + return variance; + } + + // simple sliding-window hamming distance: a[i or plus/minus 1] == b[i] + private static int slidingWindowHamming(Long[] itemsR, Long[] itemsL) { + int count = 0; + int samples = itemsR.length; + + if (itemsR[0].equals(itemsL[0]) || itemsR[0].equals(itemsL[1])) { + count++; + } + for (int i = 1; i < samples - 1; i++) { + long itemID = itemsL[i]; + if (itemsR[i] == itemID || itemsR[i - 1] == itemID || itemsR[i + 1] == itemID) { + count++; + } + } + if (itemsR[samples - 1].equals(itemsL[samples - 1]) || itemsR[samples - 1].equals(itemsL[samples - 2])) { + count++; + } + return count; + } + + /** + * Normal-distribution probability value for matched sets of values. + * Based upon: + * http://comp9.psych.cornell.edu/Darlington/normscor.htm + * + * The Standard Wilcoxon is not used because it requires a lookup table. + */ + static double normalWilcoxon(int[] vectorZ, int[] vectorZabs) { + int nitems = vectorZ.length; + + double[] ranks = new double[nitems]; + double[] ranksAbs = new double[nitems]; + wilcoxonRanks(vectorZ, vectorZabs, ranks, ranksAbs); + return Math.min(getMeanWplus(ranks), getMeanWminus(ranks)); + } + + /** + * vector Z is a list of distances between the correct value and the recommended value + * Z[i] = position i of correct itemID - position of correct itemID in recommendation list + * can be positive or negative + * the smaller the better - means recommendations are closer + * both are the same length, and both sample from the same set + * + * destructive to items arrays - allows N log N instead of N^2 order + */ + private static void getVectorZ(Long[] itemsR, Long[] itemsL, int[] vectorZ, int[] vectorZabs) { + int nitems = itemsR.length; + int bottom = 0; + int top = nitems - 1; + for (int i = 0; i < nitems; i++) { + long itemID = itemsR[i]; + for (int j = bottom; j <= top; j++) { + if (itemsL[j] == null) { + continue; + } + long test = itemsL[j]; + if (itemID == test) { + vectorZ[i] = i - j; + vectorZabs[i] = Math.abs(i - j); + if (j == bottom) { + bottom++; + } else if (j == top) { + top--; + } else { + itemsL[j] = null; + } + break; + } + } + } + } + + /** + * Ranks are the position of the value from low to high, divided by the # of values. + * I had to walk through it a few times. + */ + private static void wilcoxonRanks(int[] vectorZ, int[] vectorZabs, double[] ranks, double[] ranksAbs) { + int nitems = vectorZ.length; + int[] sorted = vectorZabs.clone(); + Arrays.sort(sorted); + int zeros = 0; + for (; zeros < nitems; zeros++) { + if (sorted[zeros] > 0) { + break; + } + } + for (int i = 0; i < nitems; i++) { + double rank = 0.0; + int count = 0; + int score = vectorZabs[i]; + for (int j = 0; j < nitems; j++) { + if (score == sorted[j]) { + rank += j + 1 - zeros; + count++; + } else if (score < sorted[j]) { + break; + } + } + if (vectorZ[i] != 0) { + ranks[i] = (rank / count) * (vectorZ[i] < 0 ? -1 : 1); // better be at least 1 + ranksAbs[i] = Math.abs(ranks[i]); + } + } + } + + private static double getMeanRank(int[] ranks) { + int nitems = ranks.length; + double sum = 0.0; + for (int rank : ranks) { + sum += rank; + } + return sum / nitems; + } + + private static double getMeanWplus(double[] ranks) { + int nitems = ranks.length; + double sum = 0.0; + for (double rank : ranks) { + if (rank > 0) { + sum += rank; + } + } + return sum / nitems; + } + + private static double getMeanWminus(double[] ranks) { + int nitems = ranks.length; + double sum = 0.0; + for (double rank : ranks) { + if (rank < 0) { + sum -= rank; + } + } + return sum / nitems; + } + + /** + * Do bubble sort and return number of swaps needed to match preference lists. + * Sort itemsR using itemsL as the reference order. + */ + static long sort(Long[] itemsL, Long[] itemsR) { + int length = itemsL.length; + if (length < 2) { + return 0; + } + if (length == 2) { + return itemsL[0].longValue() == itemsR[0].longValue() ? 0 : 1; + } + // 1) avoid changing originals; 2) primitive type is more efficient + long[] reference = new long[length]; + long[] sortable = new long[length]; + for (int i = 0; i < length; i++) { + reference[i] = itemsL[i]; + sortable[i] = itemsR[i]; + } + int sorted = 0; + long swaps = 0; + while (sorted < length - 1) { + // opportunistically trim back the top + while (length > 0 && reference[length - 1] == sortable[length - 1]) { + length--; + } + if (length == 0) { + break; + } + if (reference[sorted] == sortable[sorted]) { + sorted++; + } else { + for (int j = sorted; j < length - 1; j++) { + // do not swap anything already in place + int jump = 1; + if (reference[j] == sortable[j]) { + while (j + jump < length && reference[j + jump] == sortable[j + jump]) { + jump++; + } + } + if (j + jump < length && !(reference[j] == sortable[j] && reference[j + jump] == sortable[j + jump])) { + long tmp = sortable[j]; + sortable[j] = sortable[j + 1]; + sortable[j + 1] = tmp; + swaps++; + } + } + } + } + return swaps; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/RMSRecommenderEvaluator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/RMSRecommenderEvaluator.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/RMSRecommenderEvaluator.java new file mode 100644 index 0000000..97eda10 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/RMSRecommenderEvaluator.java @@ -0,0 +1,56 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.eval; + +import org.apache.mahout.cf.taste.impl.common.FullRunningAverage; +import org.apache.mahout.cf.taste.impl.common.RunningAverage; +import org.apache.mahout.cf.taste.model.Preference; + +/** + * <p> + * A {@link org.apache.mahout.cf.taste.eval.RecommenderEvaluator} which computes the "root mean squared" + * difference between predicted and actual ratings for users. This is the square root of the average of this + * difference, squared. + * </p> + */ +public final class RMSRecommenderEvaluator extends AbstractDifferenceRecommenderEvaluator { + + private RunningAverage average; + + @Override + protected void reset() { + average = new FullRunningAverage(); + } + + @Override + protected void processOneEstimate(float estimatedPreference, Preference realPref) { + double diff = realPref.getValue() - estimatedPreference; + average.addDatum(diff * diff); + } + + @Override + protected double computeFinalEvaluation() { + return Math.sqrt(average.getAverage()); + } + + @Override + public String toString() { + return "RMSRecommenderEvaluator"; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/StatsCallable.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/StatsCallable.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/StatsCallable.java new file mode 100644 index 0000000..036d0b4 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/StatsCallable.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.eval; + +import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.concurrent.Callable; +import java.util.concurrent.atomic.AtomicInteger; + +final class StatsCallable implements Callable<Void> { + + private static final Logger log = LoggerFactory.getLogger(StatsCallable.class); + + private final Callable<Void> delegate; + private final boolean logStats; + private final RunningAverageAndStdDev timing; + private final AtomicInteger noEstimateCounter; + + StatsCallable(Callable<Void> delegate, + boolean logStats, + RunningAverageAndStdDev timing, + AtomicInteger noEstimateCounter) { + this.delegate = delegate; + this.logStats = logStats; + this.timing = timing; + this.noEstimateCounter = noEstimateCounter; + } + + @Override + public Void call() throws Exception { + long start = System.currentTimeMillis(); + delegate.call(); + long end = System.currentTimeMillis(); + timing.addDatum(end - start); + if (logStats) { + Runtime runtime = Runtime.getRuntime(); + int average = (int) timing.getAverage(); + log.info("Average time per recommendation: {}ms", average); + long totalMemory = runtime.totalMemory(); + long memory = totalMemory - runtime.freeMemory(); + log.info("Approximate memory used: {}MB / {}MB", memory / 1000000L, totalMemory / 1000000L); + log.info("Unable to recommend in {} cases", noEstimateCounter.get()); + } + return null; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/model/AbstractDataModel.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/model/AbstractDataModel.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/model/AbstractDataModel.java new file mode 100644 index 0000000..a1a2a1f --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/model/AbstractDataModel.java @@ -0,0 +1,53 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.model; + +import org.apache.mahout.cf.taste.model.DataModel; + +/** + * Contains some features common to all implementations. + */ +public abstract class AbstractDataModel implements DataModel { + + private float maxPreference; + private float minPreference; + + protected AbstractDataModel() { + maxPreference = Float.NaN; + minPreference = Float.NaN; + } + + @Override + public float getMaxPreference() { + return maxPreference; + } + + protected void setMaxPreference(float maxPreference) { + this.maxPreference = maxPreference; + } + + @Override + public float getMinPreference() { + return minPreference; + } + + protected void setMinPreference(float minPreference) { + this.minPreference = minPreference; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/model/AbstractIDMigrator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/model/AbstractIDMigrator.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/model/AbstractIDMigrator.java new file mode 100644 index 0000000..6efa6fa --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/impl/model/AbstractIDMigrator.java @@ -0,0 +1,66 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.model; + +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.Collection; + +import org.apache.commons.io.Charsets; +import org.apache.mahout.cf.taste.common.Refreshable; +import org.apache.mahout.cf.taste.model.IDMigrator; + +public abstract class AbstractIDMigrator implements IDMigrator { + + private final MessageDigest md5Digest; + + protected AbstractIDMigrator() { + try { + md5Digest = MessageDigest.getInstance("MD5"); + } catch (NoSuchAlgorithmException nsae) { + // Can't happen + throw new IllegalStateException(nsae); + } + } + + /** + * @return most significant 8 bytes of the MD5 hash of the string, as a long + */ + protected final long hash(String value) { + byte[] md5hash; + synchronized (md5Digest) { + md5hash = md5Digest.digest(value.getBytes(Charsets.UTF_8)); + md5Digest.reset(); + } + long hash = 0L; + for (int i = 0; i < 8; i++) { + hash = hash << 8 | md5hash[i] & 0x00000000000000FFL; + } + return hash; + } + + @Override + public long toLongID(String stringID) { + return hash(stringID); + } + + @Override + public void refresh(Collection<Refreshable> alreadyRefreshed) { + } + +}
