http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/InvertedRunningAverage.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/InvertedRunningAverage.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/InvertedRunningAverage.java new file mode 100644 index 0000000..0f94c22 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/InvertedRunningAverage.java @@ -0,0 +1,58 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.common; + +public final class InvertedRunningAverage implements RunningAverage { + + private final RunningAverage delegate; + + public InvertedRunningAverage(RunningAverage delegate) { + this.delegate = delegate; + } + + @Override + public void addDatum(double datum) { + throw new UnsupportedOperationException(); + } + + @Override + public void removeDatum(double datum) { + throw new UnsupportedOperationException(); + } + + @Override + public void changeDatum(double delta) { + throw new UnsupportedOperationException(); + } + + @Override + public int getCount() { + return delegate.getCount(); + } + + @Override + public double getAverage() { + return -delegate.getAverage(); + } + + @Override + public RunningAverage inverse() { + return delegate; + } + +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/InvertedRunningAverageAndStdDev.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/InvertedRunningAverageAndStdDev.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/InvertedRunningAverageAndStdDev.java new file mode 100644 index 0000000..147012d --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/InvertedRunningAverageAndStdDev.java @@ -0,0 +1,63 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.common; + +public final class InvertedRunningAverageAndStdDev implements RunningAverageAndStdDev { + + private final RunningAverageAndStdDev delegate; + + public InvertedRunningAverageAndStdDev(RunningAverageAndStdDev delegate) { + this.delegate = delegate; + } + + @Override + public void addDatum(double datum) { + throw new UnsupportedOperationException(); + } + + @Override + public void removeDatum(double datum) { + throw new UnsupportedOperationException(); + } + + @Override + public void changeDatum(double delta) { + throw new UnsupportedOperationException(); + } + + @Override + public int getCount() { + return delegate.getCount(); + } + + @Override + public double getAverage() { + return -delegate.getAverage(); + } + + @Override + public double getStandardDeviation() { + return delegate.getStandardDeviation(); + } + + @Override + public RunningAverageAndStdDev inverse() { + return delegate; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/LongPrimitiveArrayIterator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/LongPrimitiveArrayIterator.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/LongPrimitiveArrayIterator.java new file mode 100644 index 0000000..5127df0 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/LongPrimitiveArrayIterator.java @@ -0,0 +1,93 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.common; + +import java.util.NoSuchElementException; + +import com.google.common.base.Preconditions; + +/** + * While long[] is an Iterable, it is not an Iterable<Long>. This adapter class addresses that. + */ +public final class LongPrimitiveArrayIterator implements LongPrimitiveIterator { + + private final long[] array; + private int position; + private final int max; + + /** + * <p> + * Creates an {@link LongPrimitiveArrayIterator} over an entire array. + * </p> + * + * @param array + * array to iterate over + */ + public LongPrimitiveArrayIterator(long[] array) { + this.array = Preconditions.checkNotNull(array); // yeah, not going to copy the array here, for performance + this.position = 0; + this.max = array.length; + } + + @Override + public boolean hasNext() { + return position < max; + } + + @Override + public Long next() { + return nextLong(); + } + + @Override + public long nextLong() { + if (position >= array.length) { + throw new NoSuchElementException(); + } + return array[position++]; + } + + @Override + public long peek() { + if (position >= array.length) { + throw new NoSuchElementException(); + } + return array[position]; + } + + /** + * @throws UnsupportedOperationException + */ + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + + @Override + public void skip(int n) { + if (n > 0) { + position += n; + } + } + + @Override + public String toString() { + return "LongPrimitiveArrayIterator"; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/LongPrimitiveIterator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/LongPrimitiveIterator.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/LongPrimitiveIterator.java new file mode 100644 index 0000000..0840749 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/LongPrimitiveIterator.java @@ -0,0 +1,39 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.common; + +/** + * Adds notion of iterating over {@code long} primitives in the style of an {@link java.util.Iterator} -- as + * opposed to iterating over {@link Long}. Implementations of this interface however also implement + * {@link java.util.Iterator} and {@link Iterable} over {@link Long} for convenience. + */ +public interface LongPrimitiveIterator extends SkippingIterator<Long> { + + /** + * @return next {@code long} in iteration + * @throws java.util.NoSuchElementException + * if no more elements exist in the iteration + */ + long nextLong(); + + /** + * @return next {@code long} in iteration without advancing iteration + */ + long peek(); + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/RefreshHelper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/RefreshHelper.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/RefreshHelper.java new file mode 100644 index 0000000..3e03108 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/RefreshHelper.java @@ -0,0 +1,122 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.common; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.locks.ReentrantLock; + +import org.apache.mahout.cf.taste.common.Refreshable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A helper class for implementing {@link Refreshable}. This object is typically included in an implementation + * {@link Refreshable} to implement {@link Refreshable#refresh(Collection)}. It execute the class's own + * supplied update logic, after updating all the object's dependencies. This also ensures that dependencies + * are not updated multiple times. + */ +public final class RefreshHelper implements Refreshable { + + private static final Logger log = LoggerFactory.getLogger(RefreshHelper.class); + + private final List<Refreshable> dependencies; + private final ReentrantLock refreshLock; + private final Callable<?> refreshRunnable; + + /** + * @param refreshRunnable + * encapsulates the containing object's own refresh logic + */ + public RefreshHelper(Callable<?> refreshRunnable) { + this.dependencies = new ArrayList<>(3); + this.refreshLock = new ReentrantLock(); + this.refreshRunnable = refreshRunnable; + } + + /** Add a dependency to be refreshed first when the encapsulating object does. */ + public void addDependency(Refreshable refreshable) { + if (refreshable != null) { + dependencies.add(refreshable); + } + } + + public void removeDependency(Refreshable refreshable) { + if (refreshable != null) { + dependencies.remove(refreshable); + } + } + + /** + * Typically this is called in {@link Refreshable#refresh(java.util.Collection)} and is the entire body of + * that method. + */ + @Override + public void refresh(Collection<Refreshable> alreadyRefreshed) { + if (refreshLock.tryLock()) { + try { + alreadyRefreshed = buildRefreshed(alreadyRefreshed); + for (Refreshable dependency : dependencies) { + maybeRefresh(alreadyRefreshed, dependency); + } + if (refreshRunnable != null) { + try { + refreshRunnable.call(); + } catch (Exception e) { + log.warn("Unexpected exception while refreshing", e); + } + } + } finally { + refreshLock.unlock(); + } + } + } + + /** + * Creates a new and empty {@link Collection} if the method parameter is {@code null}. + * + * @param currentAlreadyRefreshed + * {@link Refreshable}s to refresh later on + * @return an empty {@link Collection} if the method param was {@code null} or the unmodified method + * param. + */ + public static Collection<Refreshable> buildRefreshed(Collection<Refreshable> currentAlreadyRefreshed) { + return currentAlreadyRefreshed == null ? new HashSet<Refreshable>(3) : currentAlreadyRefreshed; + } + + /** + * Adds the specified {@link Refreshable} to the given collection of {@link Refreshable}s if it is not + * already there and immediately refreshes it. + * + * @param alreadyRefreshed + * the collection of {@link Refreshable}s + * @param refreshable + * the {@link Refreshable} to potentially add and refresh + */ + public static void maybeRefresh(Collection<Refreshable> alreadyRefreshed, Refreshable refreshable) { + if (!alreadyRefreshed.contains(refreshable)) { + alreadyRefreshed.add(refreshable); + log.info("Added refreshable: {}", refreshable); + refreshable.refresh(alreadyRefreshed); + log.info("Refreshed: {}", alreadyRefreshed); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/Retriever.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/Retriever.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/Retriever.java new file mode 100644 index 0000000..40da9de --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/Retriever.java @@ -0,0 +1,36 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.common; + +import org.apache.mahout.cf.taste.common.TasteException; + +/** + * <p> + * Implementations can retrieve a value for a given key. + * </p> + */ +public interface Retriever<K,V> { + + /** + * @param key key for which a value should be retrieved + * @return value for key + * @throws TasteException if an error occurs while retrieving the value + */ + V get(K key) throws TasteException; + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/RunningAverage.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/RunningAverage.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/RunningAverage.java new file mode 100644 index 0000000..bf8e39c --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/RunningAverage.java @@ -0,0 +1,67 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.common; + +/** + * <p> + * Interface for classes that can keep track of a running average of a series of numbers. One can add to or + * remove from the series, as well as update a datum in the series. The class does not actually keep track of + * the series of values, just its running average, so it doesn't even matter if you remove/change a value that + * wasn't added. + * </p> + */ +public interface RunningAverage { + + /** + * @param datum + * new item to add to the running average + * @throws IllegalArgumentException + * if datum is {@link Double#NaN} + */ + void addDatum(double datum); + + /** + * @param datum + * item to remove to the running average + * @throws IllegalArgumentException + * if datum is {@link Double#NaN} + * @throws IllegalStateException + * if count is 0 + */ + void removeDatum(double datum); + + /** + * @param delta + * amount by which to change a datum in the running average + * @throws IllegalArgumentException + * if delta is {@link Double#NaN} + * @throws IllegalStateException + * if count is 0 + */ + void changeDatum(double delta); + + int getCount(); + + double getAverage(); + + /** + * @return a (possibly immutable) object whose average is the negative of this object's + */ + RunningAverage inverse(); + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/RunningAverageAndStdDev.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/RunningAverageAndStdDev.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/RunningAverageAndStdDev.java new file mode 100644 index 0000000..4ac6108 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/RunningAverageAndStdDev.java @@ -0,0 +1,36 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.common; + +/** + * <p> + * Extends {@link RunningAverage} by adding standard deviation too. + * </p> + */ +public interface RunningAverageAndStdDev extends RunningAverage { + + /** @return standard deviation of data */ + double getStandardDeviation(); + + /** + * @return a (possibly immutable) object whose average is the negative of this object's + */ + @Override + RunningAverageAndStdDev inverse(); + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/SamplingLongPrimitiveIterator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/SamplingLongPrimitiveIterator.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/SamplingLongPrimitiveIterator.java new file mode 100644 index 0000000..6da709d --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/SamplingLongPrimitiveIterator.java @@ -0,0 +1,111 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.common; + +import java.util.NoSuchElementException; + +import com.google.common.base.Preconditions; +import org.apache.commons.math3.distribution.PascalDistribution; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.common.RandomWrapper; + +/** + * Wraps a {@link LongPrimitiveIterator} and returns only some subset of the elements that it would, + * as determined by a sampling rate parameter. + */ +public final class SamplingLongPrimitiveIterator extends AbstractLongPrimitiveIterator { + + private final PascalDistribution geometricDistribution; + private final LongPrimitiveIterator delegate; + private long next; + private boolean hasNext; + + public SamplingLongPrimitiveIterator(LongPrimitiveIterator delegate, double samplingRate) { + this(RandomUtils.getRandom(), delegate, samplingRate); + } + + public SamplingLongPrimitiveIterator(RandomWrapper random, LongPrimitiveIterator delegate, double samplingRate) { + Preconditions.checkNotNull(delegate); + Preconditions.checkArgument(samplingRate > 0.0 && samplingRate <= 1.0, "Must be: 0.0 < samplingRate <= 1.0"); + // Geometric distribution is special case of negative binomial (aka Pascal) with r=1: + geometricDistribution = new PascalDistribution(random.getRandomGenerator(), 1, samplingRate); + this.delegate = delegate; + this.hasNext = true; + doNext(); + } + + @Override + public boolean hasNext() { + return hasNext; + } + + @Override + public long nextLong() { + if (hasNext) { + long result = next; + doNext(); + return result; + } + throw new NoSuchElementException(); + } + + @Override + public long peek() { + if (hasNext) { + return next; + } + throw new NoSuchElementException(); + } + + private void doNext() { + int toSkip = geometricDistribution.sample(); + delegate.skip(toSkip); + if (delegate.hasNext()) { + next = delegate.next(); + } else { + hasNext = false; + } + } + + /** + * @throws UnsupportedOperationException + */ + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + + @Override + public void skip(int n) { + int toSkip = 0; + for (int i = 0; i < n; i++) { + toSkip += geometricDistribution.sample(); + } + delegate.skip(toSkip); + if (delegate.hasNext()) { + next = delegate.next(); + } else { + hasNext = false; + } + } + + public static LongPrimitiveIterator maybeWrapIterator(LongPrimitiveIterator delegate, double samplingRate) { + return samplingRate >= 1.0 ? delegate : new SamplingLongPrimitiveIterator(delegate, samplingRate); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/SkippingIterator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/SkippingIterator.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/SkippingIterator.java new file mode 100644 index 0000000..e88f98a --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/SkippingIterator.java @@ -0,0 +1,35 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.common; + +import java.util.Iterator; + +/** + * Adds ability to skip ahead in an iterator, perhaps more efficiently than by calling {@link #next()} + * repeatedly. + */ +public interface SkippingIterator<V> extends Iterator<V> { + + /** + * Skip the next n elements supplied by this {@link Iterator}. If there are less than n elements remaining, + * this skips all remaining elements in the {@link Iterator}. This method has the same effect as calling + * {@link #next()} n times, except that it will never throw {@link java.util.NoSuchElementException}. + */ + void skip(int n); + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverage.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverage.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverage.java new file mode 100644 index 0000000..76e5239 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverage.java @@ -0,0 +1,100 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.common; + +import java.io.Serializable; + +import com.google.common.base.Preconditions; + +public class WeightedRunningAverage implements RunningAverage, Serializable { + + private double totalWeight; + private double average; + + public WeightedRunningAverage() { + totalWeight = 0.0; + average = Double.NaN; + } + + @Override + public synchronized void addDatum(double datum) { + addDatum(datum, 1.0); + } + + public synchronized void addDatum(double datum, double weight) { + double oldTotalWeight = totalWeight; + totalWeight += weight; + if (oldTotalWeight <= 0.0) { + average = datum; + } else { + average = average * oldTotalWeight / totalWeight + datum * weight / totalWeight; + } + } + + @Override + public synchronized void removeDatum(double datum) { + removeDatum(datum, 1.0); + } + + public synchronized void removeDatum(double datum, double weight) { + double oldTotalWeight = totalWeight; + totalWeight -= weight; + if (totalWeight <= 0.0) { + average = Double.NaN; + totalWeight = 0.0; + } else { + average = average * oldTotalWeight / totalWeight - datum * weight / totalWeight; + } + } + + @Override + public synchronized void changeDatum(double delta) { + changeDatum(delta, 1.0); + } + + public synchronized void changeDatum(double delta, double weight) { + Preconditions.checkArgument(weight <= totalWeight, "weight must be <= totalWeight"); + average += delta * weight / totalWeight; + } + + public synchronized double getTotalWeight() { + return totalWeight; + } + + /** @return {@link #getTotalWeight()} */ + @Override + public synchronized int getCount() { + return (int) totalWeight; + } + + @Override + public synchronized double getAverage() { + return average; + } + + @Override + public RunningAverage inverse() { + return new InvertedRunningAverage(this); + } + + @Override + public synchronized String toString() { + return String.valueOf(average); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverageAndStdDev.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverageAndStdDev.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverageAndStdDev.java new file mode 100644 index 0000000..bed5812 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverageAndStdDev.java @@ -0,0 +1,89 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.common; + +/** + * This subclass also provides for a weighted estimate of the sample standard deviation. + * See <a href="http://en.wikipedia.org/wiki/Mean_square_weighted_deviation">estimate formulae here</a>. + */ +public final class WeightedRunningAverageAndStdDev extends WeightedRunningAverage implements RunningAverageAndStdDev { + + private double totalSquaredWeight; + private double totalWeightedData; + private double totalWeightedSquaredData; + + public WeightedRunningAverageAndStdDev() { + totalSquaredWeight = 0.0; + totalWeightedData = 0.0; + totalWeightedSquaredData = 0.0; + } + + @Override + public synchronized void addDatum(double datum, double weight) { + super.addDatum(datum, weight); + totalSquaredWeight += weight * weight; + double weightedData = datum * weight; + totalWeightedData += weightedData; + totalWeightedSquaredData += weightedData * datum; + } + + @Override + public synchronized void removeDatum(double datum, double weight) { + super.removeDatum(datum, weight); + totalSquaredWeight -= weight * weight; + if (totalSquaredWeight <= 0.0) { + totalSquaredWeight = 0.0; + } + double weightedData = datum * weight; + totalWeightedData -= weightedData; + if (totalWeightedData <= 0.0) { + totalWeightedData = 0.0; + } + totalWeightedSquaredData -= weightedData * datum; + if (totalWeightedSquaredData <= 0.0) { + totalWeightedSquaredData = 0.0; + } + } + + /** + * @throws UnsupportedOperationException + */ + @Override + public synchronized void changeDatum(double delta, double weight) { + throw new UnsupportedOperationException(); + } + + + @Override + public synchronized double getStandardDeviation() { + double totalWeight = getTotalWeight(); + return Math.sqrt((totalWeightedSquaredData * totalWeight - totalWeightedData * totalWeightedData) + / (totalWeight * totalWeight - totalSquaredWeight)); + } + + @Override + public RunningAverageAndStdDev inverse() { + return new InvertedRunningAverageAndStdDev(this); + } + + @Override + public synchronized String toString() { + return String.valueOf(String.valueOf(getAverage()) + ',' + getStandardDeviation()); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/jdbc/AbstractJDBCComponent.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/jdbc/AbstractJDBCComponent.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/jdbc/AbstractJDBCComponent.java new file mode 100644 index 0000000..d1e93ab --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/jdbc/AbstractJDBCComponent.java @@ -0,0 +1,88 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.common.jdbc; + +import javax.naming.Context; +import javax.naming.InitialContext; +import javax.naming.NamingException; +import javax.sql.DataSource; + +import org.apache.mahout.cf.taste.common.TasteException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.base.Preconditions; + +/** + * A helper class with common elements for several JDBC-related components. + */ +public abstract class AbstractJDBCComponent { + + private static final Logger log = LoggerFactory.getLogger(AbstractJDBCComponent.class); + + private static final int DEFAULT_FETCH_SIZE = 1000; // A max, "big" number of rows to buffer at once + protected static final String DEFAULT_DATASOURCE_NAME = "jdbc/taste"; + + protected static void checkNotNullAndLog(String argName, Object value) { + Preconditions.checkArgument(value != null && !value.toString().isEmpty(), + argName + " is null or empty"); + log.debug("{}: {}", argName, value); + } + + protected static void checkNotNullAndLog(String argName, Object[] values) { + Preconditions.checkArgument(values != null && values.length != 0, argName + " is null or zero-length"); + for (Object value : values) { + checkNotNullAndLog(argName, value); + } + } + + /** + * <p> + * Looks up a {@link DataSource} by name from JNDI. "java:comp/env/" is prepended to the argument before + * looking up the name in JNDI. + * </p> + * + * @param dataSourceName + * JNDI name where a {@link DataSource} is bound (e.g. "jdbc/taste") + * @return {@link DataSource} under that JNDI name + * @throws TasteException + * if a JNDI error occurs + */ + public static DataSource lookupDataSource(String dataSourceName) throws TasteException { + Context context = null; + try { + context = new InitialContext(); + return (DataSource) context.lookup("java:comp/env/" + dataSourceName); + } catch (NamingException ne) { + throw new TasteException(ne); + } finally { + if (context != null) { + try { + context.close(); + } catch (NamingException ne) { + log.warn("Error while closing Context; continuing...", ne); + } + } + } + } + + protected int getFetchSize() { + return DEFAULT_FETCH_SIZE; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/jdbc/EachRowIterator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/jdbc/EachRowIterator.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/jdbc/EachRowIterator.java new file mode 100644 index 0000000..3f024bc --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/jdbc/EachRowIterator.java @@ -0,0 +1,92 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.common.jdbc; + +import javax.sql.DataSource; +import java.io.Closeable; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; + +import com.google.common.collect.AbstractIterator; +import org.apache.mahout.common.IOUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Provides an {@link java.util.Iterator} over the result of an SQL query, as an iteration over the {@link ResultSet}. + * While the same object will be returned from the iteration each time, it will be returned once for each row + * of the result. + */ +final class EachRowIterator extends AbstractIterator<ResultSet> implements Closeable { + + private static final Logger log = LoggerFactory.getLogger(EachRowIterator.class); + + private final Connection connection; + private final PreparedStatement statement; + private final ResultSet resultSet; + + EachRowIterator(DataSource dataSource, String sqlQuery) throws SQLException { + try { + connection = dataSource.getConnection(); + statement = connection.prepareStatement(sqlQuery, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); + statement.setFetchDirection(ResultSet.FETCH_FORWARD); + //statement.setFetchSize(getFetchSize()); + log.debug("Executing SQL query: {}", sqlQuery); + resultSet = statement.executeQuery(); + } catch (SQLException sqle) { + close(); + throw sqle; + } + } + + @Override + protected ResultSet computeNext() { + try { + if (resultSet.next()) { + return resultSet; + } else { + close(); + return null; + } + } catch (SQLException sqle) { + close(); + throw new IllegalStateException(sqle); + } + } + + public void skip(int n) throws SQLException { + try { + resultSet.relative(n); + } catch (SQLException sqle) { + // Can't use relative on MySQL Connector/J; try advancing manually + int i = 0; + while (i < n && resultSet.next()) { + i++; + } + } + } + + @Override + public void close() { + IOUtils.quietClose(resultSet, statement, connection); + endOfData(); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/jdbc/ResultSetIterator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/jdbc/ResultSetIterator.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/jdbc/ResultSetIterator.java new file mode 100644 index 0000000..273ebd5 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/jdbc/ResultSetIterator.java @@ -0,0 +1,66 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.common.jdbc; + +import javax.sql.DataSource; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.Iterator; + +import com.google.common.base.Function; +import com.google.common.collect.ForwardingIterator; +import com.google.common.collect.Iterators; + +public abstract class ResultSetIterator<T> extends ForwardingIterator<T> { + + private final Iterator<T> delegate; + private final EachRowIterator rowDelegate; + + protected ResultSetIterator(DataSource dataSource, String sqlQuery) throws SQLException { + this.rowDelegate = new EachRowIterator(dataSource, sqlQuery); + delegate = Iterators.transform(rowDelegate, + new Function<ResultSet, T>() { + @Override + public T apply(ResultSet from) { + try { + return parseElement(from); + } catch (SQLException sqle) { + throw new IllegalStateException(sqle); + } + } + }); + } + + @Override + protected Iterator<T> delegate() { + return delegate; + } + + protected abstract T parseElement(ResultSet resultSet) throws SQLException; + + public void skip(int n) { + if (n >= 1) { + try { + rowDelegate.skip(n); + } catch (SQLException sqle) { + throw new IllegalStateException(sqle); + } + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/AbstractDifferenceRecommenderEvaluator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/AbstractDifferenceRecommenderEvaluator.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/AbstractDifferenceRecommenderEvaluator.java new file mode 100644 index 0000000..f926f18 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/AbstractDifferenceRecommenderEvaluator.java @@ -0,0 +1,276 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.eval; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import com.google.common.base.Preconditions; +import org.apache.mahout.cf.taste.common.NoSuchItemException; +import org.apache.mahout.cf.taste.common.NoSuchUserException; +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.eval.DataModelBuilder; +import org.apache.mahout.cf.taste.eval.RecommenderBuilder; +import org.apache.mahout.cf.taste.eval.RecommenderEvaluator; +import org.apache.mahout.cf.taste.impl.common.FastByIDMap; +import org.apache.mahout.cf.taste.impl.common.FullRunningAverageAndStdDev; +import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; +import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev; +import org.apache.mahout.cf.taste.impl.model.GenericDataModel; +import org.apache.mahout.cf.taste.impl.model.GenericPreference; +import org.apache.mahout.cf.taste.impl.model.GenericUserPreferenceArray; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.model.Preference; +import org.apache.mahout.cf.taste.model.PreferenceArray; +import org.apache.mahout.cf.taste.recommender.Recommender; +import org.apache.mahout.common.RandomUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Abstract superclass of a couple implementations, providing shared functionality. + */ +public abstract class AbstractDifferenceRecommenderEvaluator implements RecommenderEvaluator { + + private static final Logger log = LoggerFactory.getLogger(AbstractDifferenceRecommenderEvaluator.class); + + private final Random random; + private float maxPreference; + private float minPreference; + + protected AbstractDifferenceRecommenderEvaluator() { + random = RandomUtils.getRandom(); + maxPreference = Float.NaN; + minPreference = Float.NaN; + } + + @Override + public final float getMaxPreference() { + return maxPreference; + } + + @Override + public final void setMaxPreference(float maxPreference) { + this.maxPreference = maxPreference; + } + + @Override + public final float getMinPreference() { + return minPreference; + } + + @Override + public final void setMinPreference(float minPreference) { + this.minPreference = minPreference; + } + + @Override + public double evaluate(RecommenderBuilder recommenderBuilder, + DataModelBuilder dataModelBuilder, + DataModel dataModel, + double trainingPercentage, + double evaluationPercentage) throws TasteException { + Preconditions.checkNotNull(recommenderBuilder); + Preconditions.checkNotNull(dataModel); + Preconditions.checkArgument(trainingPercentage >= 0.0 && trainingPercentage <= 1.0, + "Invalid trainingPercentage: " + trainingPercentage + ". Must be: 0.0 <= trainingPercentage <= 1.0"); + Preconditions.checkArgument(evaluationPercentage >= 0.0 && evaluationPercentage <= 1.0, + "Invalid evaluationPercentage: " + evaluationPercentage + ". Must be: 0.0 <= evaluationPercentage <= 1.0"); + + log.info("Beginning evaluation using {} of {}", trainingPercentage, dataModel); + + int numUsers = dataModel.getNumUsers(); + FastByIDMap<PreferenceArray> trainingPrefs = new FastByIDMap<>( + 1 + (int) (evaluationPercentage * numUsers)); + FastByIDMap<PreferenceArray> testPrefs = new FastByIDMap<>( + 1 + (int) (evaluationPercentage * numUsers)); + + LongPrimitiveIterator it = dataModel.getUserIDs(); + while (it.hasNext()) { + long userID = it.nextLong(); + if (random.nextDouble() < evaluationPercentage) { + splitOneUsersPrefs(trainingPercentage, trainingPrefs, testPrefs, userID, dataModel); + } + } + + DataModel trainingModel = dataModelBuilder == null ? new GenericDataModel(trainingPrefs) + : dataModelBuilder.buildDataModel(trainingPrefs); + + Recommender recommender = recommenderBuilder.buildRecommender(trainingModel); + + double result = getEvaluation(testPrefs, recommender); + log.info("Evaluation result: {}", result); + return result; + } + + private void splitOneUsersPrefs(double trainingPercentage, + FastByIDMap<PreferenceArray> trainingPrefs, + FastByIDMap<PreferenceArray> testPrefs, + long userID, + DataModel dataModel) throws TasteException { + List<Preference> oneUserTrainingPrefs = null; + List<Preference> oneUserTestPrefs = null; + PreferenceArray prefs = dataModel.getPreferencesFromUser(userID); + int size = prefs.length(); + for (int i = 0; i < size; i++) { + Preference newPref = new GenericPreference(userID, prefs.getItemID(i), prefs.getValue(i)); + if (random.nextDouble() < trainingPercentage) { + if (oneUserTrainingPrefs == null) { + oneUserTrainingPrefs = new ArrayList<>(3); + } + oneUserTrainingPrefs.add(newPref); + } else { + if (oneUserTestPrefs == null) { + oneUserTestPrefs = new ArrayList<>(3); + } + oneUserTestPrefs.add(newPref); + } + } + if (oneUserTrainingPrefs != null) { + trainingPrefs.put(userID, new GenericUserPreferenceArray(oneUserTrainingPrefs)); + if (oneUserTestPrefs != null) { + testPrefs.put(userID, new GenericUserPreferenceArray(oneUserTestPrefs)); + } + } + } + + private float capEstimatedPreference(float estimate) { + if (estimate > maxPreference) { + return maxPreference; + } + if (estimate < minPreference) { + return minPreference; + } + return estimate; + } + + private double getEvaluation(FastByIDMap<PreferenceArray> testPrefs, Recommender recommender) + throws TasteException { + reset(); + Collection<Callable<Void>> estimateCallables = new ArrayList<>(); + AtomicInteger noEstimateCounter = new AtomicInteger(); + for (Map.Entry<Long,PreferenceArray> entry : testPrefs.entrySet()) { + estimateCallables.add( + new PreferenceEstimateCallable(recommender, entry.getKey(), entry.getValue(), noEstimateCounter)); + } + log.info("Beginning evaluation of {} users", estimateCallables.size()); + RunningAverageAndStdDev timing = new FullRunningAverageAndStdDev(); + execute(estimateCallables, noEstimateCounter, timing); + return computeFinalEvaluation(); + } + + protected static void execute(Collection<Callable<Void>> callables, + AtomicInteger noEstimateCounter, + RunningAverageAndStdDev timing) throws TasteException { + + Collection<Callable<Void>> wrappedCallables = wrapWithStatsCallables(callables, noEstimateCounter, timing); + int numProcessors = Runtime.getRuntime().availableProcessors(); + ExecutorService executor = Executors.newFixedThreadPool(numProcessors); + log.info("Starting timing of {} tasks in {} threads", wrappedCallables.size(), numProcessors); + try { + List<Future<Void>> futures = executor.invokeAll(wrappedCallables); + // Go look for exceptions here, really + for (Future<Void> future : futures) { + future.get(); + } + + } catch (InterruptedException ie) { + throw new TasteException(ie); + } catch (ExecutionException ee) { + throw new TasteException(ee.getCause()); + } + + executor.shutdown(); + try { + executor.awaitTermination(10, TimeUnit.SECONDS); + } catch (InterruptedException e) { + throw new TasteException(e.getCause()); + } + } + + private static Collection<Callable<Void>> wrapWithStatsCallables(Iterable<Callable<Void>> callables, + AtomicInteger noEstimateCounter, + RunningAverageAndStdDev timing) { + Collection<Callable<Void>> wrapped = new ArrayList<>(); + int count = 0; + for (Callable<Void> callable : callables) { + boolean logStats = count++ % 1000 == 0; // log every 1000 or so iterations + wrapped.add(new StatsCallable(callable, logStats, timing, noEstimateCounter)); + } + return wrapped; + } + + protected abstract void reset(); + + protected abstract void processOneEstimate(float estimatedPreference, Preference realPref); + + protected abstract double computeFinalEvaluation(); + + public final class PreferenceEstimateCallable implements Callable<Void> { + + private final Recommender recommender; + private final long testUserID; + private final PreferenceArray prefs; + private final AtomicInteger noEstimateCounter; + + public PreferenceEstimateCallable(Recommender recommender, + long testUserID, + PreferenceArray prefs, + AtomicInteger noEstimateCounter) { + this.recommender = recommender; + this.testUserID = testUserID; + this.prefs = prefs; + this.noEstimateCounter = noEstimateCounter; + } + + @Override + public Void call() throws TasteException { + for (Preference realPref : prefs) { + float estimatedPreference = Float.NaN; + try { + estimatedPreference = recommender.estimatePreference(testUserID, realPref.getItemID()); + } catch (NoSuchUserException nsue) { + // It's possible that an item exists in the test data but not training data in which case + // NSEE will be thrown. Just ignore it and move on. + log.info("User exists in test data but not training data: {}", testUserID); + } catch (NoSuchItemException nsie) { + log.info("Item exists in test data but not training data: {}", realPref.getItemID()); + } + if (Float.isNaN(estimatedPreference)) { + noEstimateCounter.incrementAndGet(); + } else { + estimatedPreference = capEstimatedPreference(estimatedPreference); + processOneEstimate(estimatedPreference, realPref); + } + } + return null; + } + + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/AverageAbsoluteDifferenceRecommenderEvaluator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/AverageAbsoluteDifferenceRecommenderEvaluator.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/AverageAbsoluteDifferenceRecommenderEvaluator.java new file mode 100644 index 0000000..4dad040 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/AverageAbsoluteDifferenceRecommenderEvaluator.java @@ -0,0 +1,59 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.eval; + +import org.apache.mahout.cf.taste.impl.common.FullRunningAverage; +import org.apache.mahout.cf.taste.impl.common.RunningAverage; +import org.apache.mahout.cf.taste.model.Preference; + +/** + * <p> + * A {@link org.apache.mahout.cf.taste.eval.RecommenderEvaluator} which computes the average absolute + * difference between predicted and actual ratings for users. + * </p> + * + * <p> + * This algorithm is also called "mean average error". + * </p> + */ +public final class AverageAbsoluteDifferenceRecommenderEvaluator extends + AbstractDifferenceRecommenderEvaluator { + + private RunningAverage average; + + @Override + protected void reset() { + average = new FullRunningAverage(); + } + + @Override + protected void processOneEstimate(float estimatedPreference, Preference realPref) { + average.addDatum(Math.abs(realPref.getValue() - estimatedPreference)); + } + + @Override + protected double computeFinalEvaluation() { + return average.getAverage(); + } + + @Override + public String toString() { + return "AverageAbsoluteDifferenceRecommenderEvaluator"; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/GenericRecommenderIRStatsEvaluator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/GenericRecommenderIRStatsEvaluator.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/GenericRecommenderIRStatsEvaluator.java new file mode 100644 index 0000000..0e121d1 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/GenericRecommenderIRStatsEvaluator.java @@ -0,0 +1,237 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.eval; + +import java.util.List; +import java.util.Random; + +import org.apache.mahout.cf.taste.common.NoSuchUserException; +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.eval.DataModelBuilder; +import org.apache.mahout.cf.taste.eval.IRStatistics; +import org.apache.mahout.cf.taste.eval.RecommenderBuilder; +import org.apache.mahout.cf.taste.eval.RecommenderIRStatsEvaluator; +import org.apache.mahout.cf.taste.eval.RelevantItemsDataSplitter; +import org.apache.mahout.cf.taste.impl.common.FastByIDMap; +import org.apache.mahout.cf.taste.impl.common.FastIDSet; +import org.apache.mahout.cf.taste.impl.common.FullRunningAverage; +import org.apache.mahout.cf.taste.impl.common.FullRunningAverageAndStdDev; +import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; +import org.apache.mahout.cf.taste.impl.common.RunningAverage; +import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev; +import org.apache.mahout.cf.taste.impl.model.GenericDataModel; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.model.PreferenceArray; +import org.apache.mahout.cf.taste.recommender.IDRescorer; +import org.apache.mahout.cf.taste.recommender.RecommendedItem; +import org.apache.mahout.cf.taste.recommender.Recommender; +import org.apache.mahout.common.RandomUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.base.Preconditions; + +/** + * <p> + * For each user, these implementation determine the top {@code n} preferences, then evaluate the IR + * statistics based on a {@link DataModel} that does not have these values. This number {@code n} is the + * "at" value, as in "precision at 5". For example, this would mean precision evaluated by removing the top 5 + * preferences for a user and then finding the percentage of those 5 items included in the top 5 + * recommendations for that user. + * </p> + */ +public final class GenericRecommenderIRStatsEvaluator implements RecommenderIRStatsEvaluator { + + private static final Logger log = LoggerFactory.getLogger(GenericRecommenderIRStatsEvaluator.class); + + private static final double LOG2 = Math.log(2.0); + + /** + * Pass as "relevanceThreshold" argument to + * {@link #evaluate(RecommenderBuilder, DataModelBuilder, DataModel, IDRescorer, int, double, double)} to + * have it attempt to compute a reasonable threshold. Note that this will impact performance. + */ + public static final double CHOOSE_THRESHOLD = Double.NaN; + + private final Random random; + private final RelevantItemsDataSplitter dataSplitter; + + public GenericRecommenderIRStatsEvaluator() { + this(new GenericRelevantItemsDataSplitter()); + } + + public GenericRecommenderIRStatsEvaluator(RelevantItemsDataSplitter dataSplitter) { + Preconditions.checkNotNull(dataSplitter); + random = RandomUtils.getRandom(); + this.dataSplitter = dataSplitter; + } + + @Override + public IRStatistics evaluate(RecommenderBuilder recommenderBuilder, + DataModelBuilder dataModelBuilder, + DataModel dataModel, + IDRescorer rescorer, + int at, + double relevanceThreshold, + double evaluationPercentage) throws TasteException { + + Preconditions.checkArgument(recommenderBuilder != null, "recommenderBuilder is null"); + Preconditions.checkArgument(dataModel != null, "dataModel is null"); + Preconditions.checkArgument(at >= 1, "at must be at least 1"); + Preconditions.checkArgument(evaluationPercentage > 0.0 && evaluationPercentage <= 1.0, + "Invalid evaluationPercentage: " + evaluationPercentage + ". Must be: 0.0 < evaluationPercentage <= 1.0"); + + int numItems = dataModel.getNumItems(); + RunningAverage precision = new FullRunningAverage(); + RunningAverage recall = new FullRunningAverage(); + RunningAverage fallOut = new FullRunningAverage(); + RunningAverage nDCG = new FullRunningAverage(); + int numUsersRecommendedFor = 0; + int numUsersWithRecommendations = 0; + + LongPrimitiveIterator it = dataModel.getUserIDs(); + while (it.hasNext()) { + + long userID = it.nextLong(); + + if (random.nextDouble() >= evaluationPercentage) { + // Skipped + continue; + } + + long start = System.currentTimeMillis(); + + PreferenceArray prefs = dataModel.getPreferencesFromUser(userID); + + // List some most-preferred items that would count as (most) "relevant" results + double theRelevanceThreshold = Double.isNaN(relevanceThreshold) ? computeThreshold(prefs) : relevanceThreshold; + FastIDSet relevantItemIDs = dataSplitter.getRelevantItemsIDs(userID, at, theRelevanceThreshold, dataModel); + + int numRelevantItems = relevantItemIDs.size(); + if (numRelevantItems <= 0) { + continue; + } + + FastByIDMap<PreferenceArray> trainingUsers = new FastByIDMap<>(dataModel.getNumUsers()); + LongPrimitiveIterator it2 = dataModel.getUserIDs(); + while (it2.hasNext()) { + dataSplitter.processOtherUser(userID, relevantItemIDs, trainingUsers, it2.nextLong(), dataModel); + } + + DataModel trainingModel = dataModelBuilder == null ? new GenericDataModel(trainingUsers) + : dataModelBuilder.buildDataModel(trainingUsers); + try { + trainingModel.getPreferencesFromUser(userID); + } catch (NoSuchUserException nsee) { + continue; // Oops we excluded all prefs for the user -- just move on + } + + int size = numRelevantItems + trainingModel.getItemIDsFromUser(userID).size(); + if (size < 2 * at) { + // Really not enough prefs to meaningfully evaluate this user + continue; + } + + Recommender recommender = recommenderBuilder.buildRecommender(trainingModel); + + int intersectionSize = 0; + List<RecommendedItem> recommendedItems = recommender.recommend(userID, at, rescorer); + for (RecommendedItem recommendedItem : recommendedItems) { + if (relevantItemIDs.contains(recommendedItem.getItemID())) { + intersectionSize++; + } + } + + int numRecommendedItems = recommendedItems.size(); + + // Precision + if (numRecommendedItems > 0) { + precision.addDatum((double) intersectionSize / (double) numRecommendedItems); + } + + // Recall + recall.addDatum((double) intersectionSize / (double) numRelevantItems); + + // Fall-out + if (numRelevantItems < size) { + fallOut.addDatum((double) (numRecommendedItems - intersectionSize) + / (double) (numItems - numRelevantItems)); + } + + // nDCG + // In computing, assume relevant IDs have relevance 1 and others 0 + double cumulativeGain = 0.0; + double idealizedGain = 0.0; + for (int i = 0; i < numRecommendedItems; i++) { + RecommendedItem item = recommendedItems.get(i); + double discount = 1.0 / log2(i + 2.0); // Classical formulation says log(i+1), but i is 0-based here + if (relevantItemIDs.contains(item.getItemID())) { + cumulativeGain += discount; + } + // otherwise we're multiplying discount by relevance 0 so it doesn't do anything + + // Ideally results would be ordered with all relevant ones first, so this theoretical + // ideal list starts with number of relevant items equal to the total number of relevant items + if (i < numRelevantItems) { + idealizedGain += discount; + } + } + if (idealizedGain > 0.0) { + nDCG.addDatum(cumulativeGain / idealizedGain); + } + + // Reach + numUsersRecommendedFor++; + if (numRecommendedItems > 0) { + numUsersWithRecommendations++; + } + + long end = System.currentTimeMillis(); + + log.info("Evaluated with user {} in {}ms", userID, end - start); + log.info("Precision/recall/fall-out/nDCG/reach: {} / {} / {} / {} / {}", + precision.getAverage(), recall.getAverage(), fallOut.getAverage(), nDCG.getAverage(), + (double) numUsersWithRecommendations / (double) numUsersRecommendedFor); + } + + return new IRStatisticsImpl( + precision.getAverage(), + recall.getAverage(), + fallOut.getAverage(), + nDCG.getAverage(), + (double) numUsersWithRecommendations / (double) numUsersRecommendedFor); + } + + private static double computeThreshold(PreferenceArray prefs) { + if (prefs.length() < 2) { + // Not enough data points -- return a threshold that allows everything + return Double.NEGATIVE_INFINITY; + } + RunningAverageAndStdDev stdDev = new FullRunningAverageAndStdDev(); + int size = prefs.length(); + for (int i = 0; i < size; i++) { + stdDev.addDatum(prefs.getValue(i)); + } + return stdDev.getAverage() + stdDev.getStandardDeviation(); + } + + private static double log2(double value) { + return Math.log(value) / LOG2; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/GenericRelevantItemsDataSplitter.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/GenericRelevantItemsDataSplitter.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/GenericRelevantItemsDataSplitter.java new file mode 100644 index 0000000..f4e4522 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/GenericRelevantItemsDataSplitter.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.eval; + +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.eval.RelevantItemsDataSplitter; +import org.apache.mahout.cf.taste.impl.common.FastByIDMap; +import org.apache.mahout.cf.taste.impl.common.FastIDSet; +import org.apache.mahout.cf.taste.impl.model.GenericUserPreferenceArray; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.model.Preference; +import org.apache.mahout.cf.taste.model.PreferenceArray; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +/** + * Picks relevant items to be those with the strongest preference, and + * includes the other users' preferences in full. + */ +public final class GenericRelevantItemsDataSplitter implements RelevantItemsDataSplitter { + + @Override + public FastIDSet getRelevantItemsIDs(long userID, + int at, + double relevanceThreshold, + DataModel dataModel) throws TasteException { + PreferenceArray prefs = dataModel.getPreferencesFromUser(userID); + FastIDSet relevantItemIDs = new FastIDSet(at); + prefs.sortByValueReversed(); + for (int i = 0; i < prefs.length() && relevantItemIDs.size() < at; i++) { + if (prefs.getValue(i) >= relevanceThreshold) { + relevantItemIDs.add(prefs.getItemID(i)); + } + } + return relevantItemIDs; + } + + @Override + public void processOtherUser(long userID, + FastIDSet relevantItemIDs, + FastByIDMap<PreferenceArray> trainingUsers, + long otherUserID, + DataModel dataModel) throws TasteException { + PreferenceArray prefs2Array = dataModel.getPreferencesFromUser(otherUserID); + // If we're dealing with the very user that we're evaluating for precision/recall, + if (userID == otherUserID) { + // then must remove all the test IDs, the "relevant" item IDs + List<Preference> prefs2 = new ArrayList<>(prefs2Array.length()); + for (Preference pref : prefs2Array) { + prefs2.add(pref); + } + for (Iterator<Preference> iterator = prefs2.iterator(); iterator.hasNext();) { + Preference pref = iterator.next(); + if (relevantItemIDs.contains(pref.getItemID())) { + iterator.remove(); + } + } + if (!prefs2.isEmpty()) { + trainingUsers.put(otherUserID, new GenericUserPreferenceArray(prefs2)); + } + } else { + // otherwise just add all those other user's prefs + trainingUsers.put(otherUserID, prefs2Array); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/IRStatisticsImpl.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/IRStatisticsImpl.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/IRStatisticsImpl.java new file mode 100644 index 0000000..2838b08 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/IRStatisticsImpl.java @@ -0,0 +1,95 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.eval; + +import java.io.Serializable; + +import org.apache.mahout.cf.taste.eval.IRStatistics; + +import com.google.common.base.Preconditions; + +public final class IRStatisticsImpl implements IRStatistics, Serializable { + + private final double precision; + private final double recall; + private final double fallOut; + private final double ndcg; + private final double reach; + + IRStatisticsImpl(double precision, double recall, double fallOut, double ndcg, double reach) { + Preconditions.checkArgument(Double.isNaN(precision) || (precision >= 0.0 && precision <= 1.0), + "Illegal precision: " + precision + ". Must be: 0.0 <= precision <= 1.0 or NaN"); + Preconditions.checkArgument(Double.isNaN(recall) || (recall >= 0.0 && recall <= 1.0), + "Illegal recall: " + recall + ". Must be: 0.0 <= recall <= 1.0 or NaN"); + Preconditions.checkArgument(Double.isNaN(fallOut) || (fallOut >= 0.0 && fallOut <= 1.0), + "Illegal fallOut: " + fallOut + ". Must be: 0.0 <= fallOut <= 1.0 or NaN"); + Preconditions.checkArgument(Double.isNaN(ndcg) || (ndcg >= 0.0 && ndcg <= 1.0), + "Illegal nDCG: " + ndcg + ". Must be: 0.0 <= nDCG <= 1.0 or NaN"); + Preconditions.checkArgument(Double.isNaN(reach) || (reach >= 0.0 && reach <= 1.0), + "Illegal reach: " + reach + ". Must be: 0.0 <= reach <= 1.0 or NaN"); + this.precision = precision; + this.recall = recall; + this.fallOut = fallOut; + this.ndcg = ndcg; + this.reach = reach; + } + + @Override + public double getPrecision() { + return precision; + } + + @Override + public double getRecall() { + return recall; + } + + @Override + public double getFallOut() { + return fallOut; + } + + @Override + public double getF1Measure() { + return getFNMeasure(1.0); + } + + @Override + public double getFNMeasure(double b) { + double b2 = b * b; + double sum = b2 * precision + recall; + return sum == 0.0 ? Double.NaN : (1.0 + b2) * precision * recall / sum; + } + + @Override + public double getNormalizedDiscountedCumulativeGain() { + return ndcg; + } + + @Override + public double getReach() { + return reach; + } + + @Override + public String toString() { + return "IRStatisticsImpl[precision:" + precision + ",recall:" + recall + ",fallOut:" + + fallOut + ",nDCG:" + ndcg + ",reach:" + reach + ']'; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/LoadCallable.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/LoadCallable.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/LoadCallable.java new file mode 100644 index 0000000..213f7f9 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/LoadCallable.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.eval; + +import org.apache.mahout.cf.taste.recommender.Recommender; + +import java.util.concurrent.Callable; + +final class LoadCallable implements Callable<Void> { + + private final Recommender recommender; + private final long userID; + + LoadCallable(Recommender recommender, long userID) { + this.recommender = recommender; + this.userID = userID; + } + + @Override + public Void call() throws Exception { + recommender.recommend(userID, 10); + return null; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/LoadEvaluator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/LoadEvaluator.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/LoadEvaluator.java new file mode 100644 index 0000000..2d27a37 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/LoadEvaluator.java @@ -0,0 +1,61 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.eval; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.concurrent.Callable; +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.FullRunningAverageAndStdDev; +import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; +import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev; +import org.apache.mahout.cf.taste.impl.common.SamplingLongPrimitiveIterator; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.recommender.Recommender; + +/** + * Simple helper class for running load on a Recommender. + */ +public final class LoadEvaluator { + + private LoadEvaluator() { } + + public static LoadStatistics runLoad(Recommender recommender) throws TasteException { + return runLoad(recommender, 10); + } + + public static LoadStatistics runLoad(Recommender recommender, int howMany) throws TasteException { + DataModel dataModel = recommender.getDataModel(); + int numUsers = dataModel.getNumUsers(); + double sampleRate = 1000.0 / numUsers; + LongPrimitiveIterator userSampler = + SamplingLongPrimitiveIterator.maybeWrapIterator(dataModel.getUserIDs(), sampleRate); + recommender.recommend(userSampler.next(), howMany); // Warm up + Collection<Callable<Void>> callables = new ArrayList<>(); + while (userSampler.hasNext()) { + callables.add(new LoadCallable(recommender, userSampler.next())); + } + AtomicInteger noEstimateCounter = new AtomicInteger(); + RunningAverageAndStdDev timing = new FullRunningAverageAndStdDev(); + AbstractDifferenceRecommenderEvaluator.execute(callables, noEstimateCounter, timing); + return new LoadStatistics(timing); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/LoadStatistics.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/LoadStatistics.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/LoadStatistics.java new file mode 100644 index 0000000..f89160c --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/LoadStatistics.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.eval; + +import org.apache.mahout.cf.taste.impl.common.RunningAverage; + +public final class LoadStatistics { + + private final RunningAverage timing; + + LoadStatistics(RunningAverage timing) { + this.timing = timing; + } + + public RunningAverage getTiming() { + return timing; + } + +}
