http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/preparation/ToItemVectorsReducer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/preparation/ToItemVectorsReducer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/preparation/ToItemVectorsReducer.java new file mode 100644 index 0000000..f74511b --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/preparation/ToItemVectorsReducer.java @@ -0,0 +1,38 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.hadoop.preparation; + +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.mahout.math.VectorWritable; + +import java.io.IOException; + +public class ToItemVectorsReducer extends Reducer<IntWritable,VectorWritable,IntWritable,VectorWritable> { + + private final VectorWritable merged = new VectorWritable(); + + @Override + protected void reduce(IntWritable row, Iterable<VectorWritable> vectors, Context ctx) + throws IOException, InterruptedException { + + merged.setWritesLaxPrecision(true); + merged.set(VectorWritable.mergeToVector(vectors.iterator())); + ctx.write(row, merged); + } +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/similarity/item/ItemSimilarityJob.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/similarity/item/ItemSimilarityJob.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/similarity/item/ItemSimilarityJob.java new file mode 100644 index 0000000..c50fa20 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/similarity/item/ItemSimilarityJob.java @@ -0,0 +1,233 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.hadoop.similarity.item; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; + +import com.google.common.base.Preconditions; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; + +import org.apache.hadoop.io.DoubleWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.cf.taste.hadoop.EntityEntityWritable; +import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils; +import org.apache.mahout.cf.taste.hadoop.preparation.PreparePreferenceMatrixJob; +import org.apache.mahout.cf.taste.similarity.precompute.SimilarItem; +import org.apache.mahout.common.AbstractJob; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.hadoop.similarity.cooccurrence.RowSimilarityJob; +import org.apache.mahout.math.hadoop.similarity.cooccurrence.measures.VectorSimilarityMeasures; +import org.apache.mahout.math.map.OpenIntLongHashMap; + +/** + * <p>Distributed precomputation of the item-item-similarities for Itembased Collaborative Filtering</p> + * + * <p>Preferences in the input file should look like {@code userID,itemID[,preferencevalue]}</p> + * + * <p> + * Preference value is optional to accommodate applications that have no notion of a preference value (that is, the user + * simply expresses a preference for an item, but no degree of preference). + * </p> + * + * <p> + * The preference value is assumed to be parseable as a {@code double}. The user IDs and item IDs are + * parsed as {@code long}s. + * </p> + * + * <p>Command line arguments specific to this class are:</p> + * + * <ol> + * <li>--input (path): Directory containing one or more text files with the preference data</li> + * <li>--output (path): output path where similarity data should be written</li> + * <li>--similarityClassname (classname): Name of distributed similarity measure class to instantiate or a predefined + * similarity from {@link org.apache.mahout.math.hadoop.similarity.cooccurrence.measures.VectorSimilarityMeasure}</li> + * <li>--maxSimilaritiesPerItem (integer): Maximum number of similarities considered per item (100)</li> + * <li>--maxPrefsPerUser (integer): max number of preferences to consider per user, users with more preferences will + * be sampled down (1000)</li> + * <li>--minPrefsPerUser (integer): ignore users with less preferences than this (1)</li> + * <li>--booleanData (boolean): Treat input data as having no pref values (false)</li> + * <li>--threshold (double): discard item pairs with a similarity value below this</li> + * </ol> + * + * <p>General command line options are documented in {@link AbstractJob}.</p> + * + * <p>Note that because of how Hadoop parses arguments, all "-D" arguments must appear before all other arguments.</p> + */ +public final class ItemSimilarityJob extends AbstractJob { + + public static final String ITEM_ID_INDEX_PATH_STR = ItemSimilarityJob.class.getName() + ".itemIDIndexPathStr"; + public static final String MAX_SIMILARITIES_PER_ITEM = ItemSimilarityJob.class.getName() + ".maxSimilarItemsPerItem"; + + private static final int DEFAULT_MAX_SIMILAR_ITEMS_PER_ITEM = 100; + private static final int DEFAULT_MAX_PREFS = 500; + private static final int DEFAULT_MIN_PREFS_PER_USER = 1; + + public static void main(String[] args) throws Exception { + ToolRunner.run(new ItemSimilarityJob(), args); + } + + @Override + public int run(String[] args) throws Exception { + + addInputOption(); + addOutputOption(); + addOption("similarityClassname", "s", "Name of distributed similarity measures class to instantiate, " + + "alternatively use one of the predefined similarities (" + VectorSimilarityMeasures.list() + ')'); + addOption("maxSimilaritiesPerItem", "m", "try to cap the number of similar items per item to this number " + + "(default: " + DEFAULT_MAX_SIMILAR_ITEMS_PER_ITEM + ')', + String.valueOf(DEFAULT_MAX_SIMILAR_ITEMS_PER_ITEM)); + addOption("maxPrefs", "mppu", "max number of preferences to consider per user or item, " + + "users or items with more preferences will be sampled down (default: " + DEFAULT_MAX_PREFS + ')', + String.valueOf(DEFAULT_MAX_PREFS)); + addOption("minPrefsPerUser", "mp", "ignore users with less preferences than this " + + "(default: " + DEFAULT_MIN_PREFS_PER_USER + ')', String.valueOf(DEFAULT_MIN_PREFS_PER_USER)); + addOption("booleanData", "b", "Treat input as without pref values", String.valueOf(Boolean.FALSE)); + addOption("threshold", "tr", "discard item pairs with a similarity value below this", false); + addOption("randomSeed", null, "use this seed for sampling", false); + + Map<String,List<String>> parsedArgs = parseArguments(args); + if (parsedArgs == null) { + return -1; + } + + String similarityClassName = getOption("similarityClassname"); + int maxSimilarItemsPerItem = Integer.parseInt(getOption("maxSimilaritiesPerItem")); + int maxPrefs = Integer.parseInt(getOption("maxPrefs")); + int minPrefsPerUser = Integer.parseInt(getOption("minPrefsPerUser")); + boolean booleanData = Boolean.valueOf(getOption("booleanData")); + + double threshold = hasOption("threshold") + ? Double.parseDouble(getOption("threshold")) : RowSimilarityJob.NO_THRESHOLD; + long randomSeed = hasOption("randomSeed") + ? Long.parseLong(getOption("randomSeed")) : RowSimilarityJob.NO_FIXED_RANDOM_SEED; + + Path similarityMatrixPath = getTempPath("similarityMatrix"); + Path prepPath = getTempPath("prepareRatingMatrix"); + + AtomicInteger currentPhase = new AtomicInteger(); + + if (shouldRunNextPhase(parsedArgs, currentPhase)) { + ToolRunner.run(getConf(), new PreparePreferenceMatrixJob(), new String[] { + "--input", getInputPath().toString(), + "--output", prepPath.toString(), + "--minPrefsPerUser", String.valueOf(minPrefsPerUser), + "--booleanData", String.valueOf(booleanData), + "--tempDir", getTempPath().toString(), + }); + } + + if (shouldRunNextPhase(parsedArgs, currentPhase)) { + int numberOfUsers = HadoopUtil.readInt(new Path(prepPath, PreparePreferenceMatrixJob.NUM_USERS), getConf()); + + ToolRunner.run(getConf(), new RowSimilarityJob(), new String[] { + "--input", new Path(prepPath, PreparePreferenceMatrixJob.RATING_MATRIX).toString(), + "--output", similarityMatrixPath.toString(), + "--numberOfColumns", String.valueOf(numberOfUsers), + "--similarityClassname", similarityClassName, + "--maxObservationsPerRow", String.valueOf(maxPrefs), + "--maxObservationsPerColumn", String.valueOf(maxPrefs), + "--maxSimilaritiesPerRow", String.valueOf(maxSimilarItemsPerItem), + "--excludeSelfSimilarity", String.valueOf(Boolean.TRUE), + "--threshold", String.valueOf(threshold), + "--randomSeed", String.valueOf(randomSeed), + "--tempDir", getTempPath().toString(), + }); + } + + if (shouldRunNextPhase(parsedArgs, currentPhase)) { + Job mostSimilarItems = prepareJob(similarityMatrixPath, getOutputPath(), SequenceFileInputFormat.class, + MostSimilarItemPairsMapper.class, EntityEntityWritable.class, DoubleWritable.class, + MostSimilarItemPairsReducer.class, EntityEntityWritable.class, DoubleWritable.class, TextOutputFormat.class); + Configuration mostSimilarItemsConf = mostSimilarItems.getConfiguration(); + mostSimilarItemsConf.set(ITEM_ID_INDEX_PATH_STR, + new Path(prepPath, PreparePreferenceMatrixJob.ITEMID_INDEX).toString()); + mostSimilarItemsConf.setInt(MAX_SIMILARITIES_PER_ITEM, maxSimilarItemsPerItem); + boolean succeeded = mostSimilarItems.waitForCompletion(true); + if (!succeeded) { + return -1; + } + } + + return 0; + } + + public static class MostSimilarItemPairsMapper + extends Mapper<IntWritable,VectorWritable,EntityEntityWritable,DoubleWritable> { + + private OpenIntLongHashMap indexItemIDMap; + private int maxSimilarItemsPerItem; + + @Override + protected void setup(Context ctx) { + Configuration conf = ctx.getConfiguration(); + maxSimilarItemsPerItem = conf.getInt(MAX_SIMILARITIES_PER_ITEM, -1); + indexItemIDMap = TasteHadoopUtils.readIDIndexMap(conf.get(ITEM_ID_INDEX_PATH_STR), conf); + + Preconditions.checkArgument(maxSimilarItemsPerItem > 0, "maxSimilarItemsPerItem must be greater then 0!"); + } + + @Override + protected void map(IntWritable itemIDIndexWritable, VectorWritable similarityVector, Context ctx) + throws IOException, InterruptedException { + + int itemIDIndex = itemIDIndexWritable.get(); + + TopSimilarItemsQueue topKMostSimilarItems = new TopSimilarItemsQueue(maxSimilarItemsPerItem); + + for (Vector.Element element : similarityVector.get().nonZeroes()) { + SimilarItem top = topKMostSimilarItems.top(); + double candidateSimilarity = element.get(); + if (candidateSimilarity > top.getSimilarity()) { + top.set(indexItemIDMap.get(element.index()), candidateSimilarity); + topKMostSimilarItems.updateTop(); + } + } + + long itemID = indexItemIDMap.get(itemIDIndex); + for (SimilarItem similarItem : topKMostSimilarItems.getTopItems()) { + long otherItemID = similarItem.getItemID(); + if (itemID < otherItemID) { + ctx.write(new EntityEntityWritable(itemID, otherItemID), new DoubleWritable(similarItem.getSimilarity())); + } else { + ctx.write(new EntityEntityWritable(otherItemID, itemID), new DoubleWritable(similarItem.getSimilarity())); + } + } + } + } + + public static class MostSimilarItemPairsReducer + extends Reducer<EntityEntityWritable,DoubleWritable,EntityEntityWritable,DoubleWritable> { + @Override + protected void reduce(EntityEntityWritable pair, Iterable<DoubleWritable> values, Context ctx) + throws IOException, InterruptedException { + ctx.write(pair, values.iterator().next()); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/similarity/item/TopSimilarItemsQueue.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/similarity/item/TopSimilarItemsQueue.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/similarity/item/TopSimilarItemsQueue.java new file mode 100644 index 0000000..acb6392 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/similarity/item/TopSimilarItemsQueue.java @@ -0,0 +1,60 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.hadoop.similarity.item; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import org.apache.lucene.util.PriorityQueue; +import org.apache.mahout.cf.taste.similarity.precompute.SimilarItem; + +public class TopSimilarItemsQueue extends PriorityQueue<SimilarItem> { + + private static final long SENTINEL_ID = Long.MIN_VALUE; + + private final int maxSize; + + public TopSimilarItemsQueue(int maxSize) { + super(maxSize); + this.maxSize = maxSize; + } + + public List<SimilarItem> getTopItems() { + List<SimilarItem> items = new ArrayList<>(maxSize); + while (size() > 0) { + SimilarItem topItem = pop(); + // filter out "sentinel" objects necessary for maintaining an efficient priority queue + if (topItem.getItemID() != SENTINEL_ID) { + items.add(topItem); + } + } + Collections.reverse(items); + return items; + } + + @Override + protected boolean lessThan(SimilarItem one, SimilarItem two) { + return one.getSimilarity() < two.getSimilarity(); + } + + @Override + protected SimilarItem getSentinelObject() { + return new SimilarItem(SENTINEL_ID, Double.MIN_VALUE); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/AbstractLongPrimitiveIterator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/AbstractLongPrimitiveIterator.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/AbstractLongPrimitiveIterator.java new file mode 100644 index 0000000..f46785c --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/AbstractLongPrimitiveIterator.java @@ -0,0 +1,27 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.common; + +public abstract class AbstractLongPrimitiveIterator implements LongPrimitiveIterator { + + @Override + public Long next() { + return nextLong(); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/BitSet.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/BitSet.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/BitSet.java new file mode 100644 index 0000000..c46b4b6 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/BitSet.java @@ -0,0 +1,93 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.common; + +import java.io.Serializable; +import java.util.Arrays; + +/** A simplified and streamlined version of {@link java.util.BitSet}. */ +final class BitSet implements Serializable, Cloneable { + + private final long[] bits; + + BitSet(int numBits) { + int numLongs = numBits >>> 6; + if ((numBits & 0x3F) != 0) { + numLongs++; + } + bits = new long[numLongs]; + } + + private BitSet(long[] bits) { + this.bits = bits; + } + + boolean get(int index) { + // skipping range check for speed + return (bits[index >>> 6] & 1L << (index & 0x3F)) != 0L; + } + + void set(int index) { + // skipping range check for speed + bits[index >>> 6] |= 1L << (index & 0x3F); + } + + void clear(int index) { + // skipping range check for speed + bits[index >>> 6] &= ~(1L << (index & 0x3F)); + } + + void clear() { + int length = bits.length; + for (int i = 0; i < length; i++) { + bits[i] = 0L; + } + } + + @Override + public BitSet clone() { + return new BitSet(bits.clone()); + } + + @Override + public int hashCode() { + return Arrays.hashCode(bits); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof BitSet)) { + return false; + } + BitSet other = (BitSet) o; + return Arrays.equals(bits, other.bits); + } + + @Override + public String toString() { + StringBuilder result = new StringBuilder(64 * bits.length); + for (long l : bits) { + for (int j = 0; j < 64; j++) { + result.append((l & 1L << j) == 0 ? '0' : '1'); + } + result.append(' '); + } + return result.toString(); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/Cache.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/Cache.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/Cache.java new file mode 100755 index 0000000..b2d9b36 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/Cache.java @@ -0,0 +1,178 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.common; + +import com.google.common.base.Preconditions; +import org.apache.mahout.cf.taste.common.TasteException; + +import java.util.Iterator; + +/** + * <p> + * An efficient Map-like class which caches values for keys. Values are not "put" into a {@link Cache}; + * instead the caller supplies the instance with an implementation of {@link Retriever} which can load the + * value for a given key. + * </p> + * + * <p> + * The cache does not support {@code null} keys. + * </p> + * + * <p> + * Thanks to Amila Jayasooriya for helping evaluate performance of the rewrite of this class, as part of a + * Google Summer of Code 2007 project. + * </p> + */ +public final class Cache<K,V> implements Retriever<K,V> { + + private static final Object NULL = new Object(); + + private final FastMap<K,V> cache; + private final Retriever<? super K,? extends V> retriever; + + /** + * <p> + * Creates a new cache based on the given {@link Retriever}. + * </p> + * + * @param retriever + * object which can retrieve values for keys + */ + public Cache(Retriever<? super K,? extends V> retriever) { + this(retriever, FastMap.NO_MAX_SIZE); + } + + /** + * <p> + * Creates a new cache based on the given {@link Retriever} and with given maximum size. + * </p> + * + * @param retriever + * object which can retrieve values for keys + * @param maxEntries + * maximum number of entries the cache will store before evicting some + */ + public Cache(Retriever<? super K,? extends V> retriever, int maxEntries) { + Preconditions.checkArgument(retriever != null, "retriever is null"); + Preconditions.checkArgument(maxEntries >= 1, "maxEntries must be at least 1"); + cache = new FastMap<>(11, maxEntries); + this.retriever = retriever; + } + + /** + * <p> + * Returns cached value for a key. If it does not exist, it is loaded using a {@link Retriever}. + * </p> + * + * @param key + * cache key + * @return value for that key + * @throws TasteException + * if an exception occurs while retrieving a new cached value + */ + @Override + public V get(K key) throws TasteException { + V value; + synchronized (cache) { + value = cache.get(key); + } + if (value == null) { + return getAndCacheValue(key); + } + return value == NULL ? null : value; + } + + /** + * <p> + * Uncaches any existing value for a given key. + * </p> + * + * @param key + * cache key + */ + public void remove(K key) { + synchronized (cache) { + cache.remove(key); + } + } + + /** + * Clears all cache entries whose key matches the given predicate. + */ + public void removeKeysMatching(MatchPredicate<K> predicate) { + synchronized (cache) { + Iterator<K> it = cache.keySet().iterator(); + while (it.hasNext()) { + K key = it.next(); + if (predicate.matches(key)) { + it.remove(); + } + } + } + } + + /** + * Clears all cache entries whose value matches the given predicate. + */ + public void removeValueMatching(MatchPredicate<V> predicate) { + synchronized (cache) { + Iterator<V> it = cache.values().iterator(); + while (it.hasNext()) { + V value = it.next(); + if (predicate.matches(value)) { + it.remove(); + } + } + } + } + + /** + * <p> + * Clears the cache. + * </p> + */ + public void clear() { + synchronized (cache) { + cache.clear(); + } + } + + private V getAndCacheValue(K key) throws TasteException { + V value = retriever.get(key); + if (value == null) { + value = (V) NULL; + } + synchronized (cache) { + cache.put(key, value); + } + return value; + } + + @Override + public String toString() { + return "Cache[retriever:" + retriever + ']'; + } + + /** + * Used by {#link #removeKeysMatching(Object)} to decide things that are matching. + */ + public interface MatchPredicate<T> { + boolean matches(T thing); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FastByIDMap.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FastByIDMap.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FastByIDMap.java new file mode 100644 index 0000000..fde8958 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FastByIDMap.java @@ -0,0 +1,661 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.common; + +import java.io.Serializable; +import java.util.AbstractCollection; +import java.util.AbstractSet; +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Set; + +import org.apache.mahout.common.RandomUtils; + +import com.google.common.base.Preconditions; + +/** + * @see FastMap + * @see FastIDSet + */ +public final class FastByIDMap<V> implements Serializable, Cloneable { + + public static final int NO_MAX_SIZE = Integer.MAX_VALUE; + private static final float DEFAULT_LOAD_FACTOR = 1.5f; + + /** Dummy object used to represent a key that has been removed. */ + private static final long REMOVED = Long.MAX_VALUE; + private static final long NULL = Long.MIN_VALUE; + + private long[] keys; + private V[] values; + private float loadFactor; + private int numEntries; + private int numSlotsUsed; + private final int maxSize; + private BitSet recentlyAccessed; + private final boolean countingAccesses; + + /** Creates a new {@link FastByIDMap} with default capacity. */ + public FastByIDMap() { + this(2, NO_MAX_SIZE); + } + + public FastByIDMap(int size) { + this(size, NO_MAX_SIZE); + } + + public FastByIDMap(int size, float loadFactor) { + this(size, NO_MAX_SIZE, loadFactor); + } + + public FastByIDMap(int size, int maxSize) { + this(size, maxSize, DEFAULT_LOAD_FACTOR); + } + + /** + * Creates a new {@link FastByIDMap} whose capacity can accommodate the given number of entries without rehash. + * + * @param size desired capacity + * @param maxSize max capacity + * @param loadFactor ratio of internal hash table size to current size + * @throws IllegalArgumentException if size is less than 0, maxSize is less than 1 + * or at least half of {@link RandomUtils#MAX_INT_SMALLER_TWIN_PRIME}, or + * loadFactor is less than 1 + */ + public FastByIDMap(int size, int maxSize, float loadFactor) { + Preconditions.checkArgument(size >= 0, "size must be at least 0"); + Preconditions.checkArgument(loadFactor >= 1.0f, "loadFactor must be at least 1.0"); + this.loadFactor = loadFactor; + int max = (int) (RandomUtils.MAX_INT_SMALLER_TWIN_PRIME / loadFactor); + Preconditions.checkArgument(size < max, "size must be less than " + max); + Preconditions.checkArgument(maxSize >= 1, "maxSize must be at least 1"); + int hashSize = RandomUtils.nextTwinPrime((int) (loadFactor * size)); + keys = new long[hashSize]; + Arrays.fill(keys, NULL); + values = (V[]) new Object[hashSize]; + this.maxSize = maxSize; + this.countingAccesses = maxSize != Integer.MAX_VALUE; + this.recentlyAccessed = countingAccesses ? new BitSet(hashSize) : null; + } + + /** + * @see #findForAdd(long) + */ + private int find(long key) { + int theHashCode = (int) key & 0x7FFFFFFF; // make sure it's positive + long[] keys = this.keys; + int hashSize = keys.length; + int jump = 1 + theHashCode % (hashSize - 2); + int index = theHashCode % hashSize; + long currentKey = keys[index]; + while (currentKey != NULL && key != currentKey) { + index -= index < jump ? jump - hashSize : jump; + currentKey = keys[index]; + } + return index; + } + + /** + * @see #find(long) + */ + private int findForAdd(long key) { + int theHashCode = (int) key & 0x7FFFFFFF; // make sure it's positive + long[] keys = this.keys; + int hashSize = keys.length; + int jump = 1 + theHashCode % (hashSize - 2); + int index = theHashCode % hashSize; + long currentKey = keys[index]; + while (currentKey != NULL && currentKey != REMOVED && key != currentKey) { + index -= index < jump ? jump - hashSize : jump; + currentKey = keys[index]; + } + if (currentKey != REMOVED) { + return index; + } + // If we're adding, it's here, but, the key might have a value already later + int addIndex = index; + while (currentKey != NULL && key != currentKey) { + index -= index < jump ? jump - hashSize : jump; + currentKey = keys[index]; + } + return key == currentKey ? index : addIndex; + } + + public V get(long key) { + if (key == NULL) { + return null; + } + int index = find(key); + if (countingAccesses) { + recentlyAccessed.set(index); + } + return values[index]; + } + + public int size() { + return numEntries; + } + + public boolean isEmpty() { + return numEntries == 0; + } + + public boolean containsKey(long key) { + return key != NULL && key != REMOVED && keys[find(key)] != NULL; + } + + public boolean containsValue(Object value) { + if (value == null) { + return false; + } + for (V theValue : values) { + if (theValue != null && value.equals(theValue)) { + return true; + } + } + return false; + } + + public V put(long key, V value) { + Preconditions.checkArgument(key != NULL && key != REMOVED); + Preconditions.checkNotNull(value); + // If less than half the slots are open, let's clear it up + if (numSlotsUsed * loadFactor >= keys.length) { + // If over half the slots used are actual entries, let's grow + if (numEntries * loadFactor >= numSlotsUsed) { + growAndRehash(); + } else { + // Otherwise just rehash to clear REMOVED entries and don't grow + rehash(); + } + } + // Here we may later consider implementing Brent's variation described on page 532 + int index = findForAdd(key); + long keyIndex = keys[index]; + if (keyIndex == key) { + V oldValue = values[index]; + values[index] = value; + return oldValue; + } + // If size is limited, + if (countingAccesses && numEntries >= maxSize) { + // and we're too large, clear some old-ish entry + clearStaleEntry(index); + } + keys[index] = key; + values[index] = value; + numEntries++; + if (keyIndex == NULL) { + numSlotsUsed++; + } + return null; + } + + private void clearStaleEntry(int index) { + while (true) { + long currentKey; + do { + if (index == 0) { + index = keys.length - 1; + } else { + index--; + } + currentKey = keys[index]; + } while (currentKey == NULL || currentKey == REMOVED); + if (recentlyAccessed.get(index)) { + recentlyAccessed.clear(index); + } else { + break; + } + } + // Delete the entry + keys[index] = REMOVED; + numEntries--; + values[index] = null; + } + + public V remove(long key) { + if (key == NULL || key == REMOVED) { + return null; + } + int index = find(key); + if (keys[index] == NULL) { + return null; + } else { + keys[index] = REMOVED; + numEntries--; + V oldValue = values[index]; + values[index] = null; + // don't decrement numSlotsUsed + return oldValue; + } + // Could un-set recentlyAccessed's bit but doesn't matter + } + + public void clear() { + numEntries = 0; + numSlotsUsed = 0; + Arrays.fill(keys, NULL); + Arrays.fill(values, null); + if (countingAccesses) { + recentlyAccessed.clear(); + } + } + + public LongPrimitiveIterator keySetIterator() { + return new KeyIterator(); + } + + public Set<Map.Entry<Long,V>> entrySet() { + return new EntrySet(); + } + + public Collection<V> values() { + return new ValueCollection(); + } + + public void rehash() { + rehash(RandomUtils.nextTwinPrime((int) (loadFactor * numEntries))); + } + + private void growAndRehash() { + if (keys.length * loadFactor >= RandomUtils.MAX_INT_SMALLER_TWIN_PRIME) { + throw new IllegalStateException("Can't grow any more"); + } + rehash(RandomUtils.nextTwinPrime((int) (loadFactor * keys.length))); + } + + private void rehash(int newHashSize) { + long[] oldKeys = keys; + V[] oldValues = values; + numEntries = 0; + numSlotsUsed = 0; + if (countingAccesses) { + recentlyAccessed = new BitSet(newHashSize); + } + keys = new long[newHashSize]; + Arrays.fill(keys, NULL); + values = (V[]) new Object[newHashSize]; + int length = oldKeys.length; + for (int i = 0; i < length; i++) { + long key = oldKeys[i]; + if (key != NULL && key != REMOVED) { + put(key, oldValues[i]); + } + } + } + + void iteratorRemove(int lastNext) { + if (lastNext >= values.length) { + throw new NoSuchElementException(); + } + if (lastNext < 0) { + throw new IllegalStateException(); + } + values[lastNext] = null; + keys[lastNext] = REMOVED; + numEntries--; + } + + @Override + public FastByIDMap<V> clone() { + FastByIDMap<V> clone; + try { + clone = (FastByIDMap<V>) super.clone(); + } catch (CloneNotSupportedException cnse) { + throw new AssertionError(); + } + clone.keys = keys.clone(); + clone.values = values.clone(); + clone.recentlyAccessed = countingAccesses ? new BitSet(keys.length) : null; + return clone; + } + + @Override + public String toString() { + if (isEmpty()) { + return "{}"; + } + StringBuilder result = new StringBuilder(); + result.append('{'); + for (int i = 0; i < keys.length; i++) { + long key = keys[i]; + if (key != NULL && key != REMOVED) { + result.append(key).append('=').append(values[i]).append(','); + } + } + result.setCharAt(result.length() - 1, '}'); + return result.toString(); + } + + @Override + public int hashCode() { + int hash = 0; + long[] keys = this.keys; + int max = keys.length; + for (int i = 0; i < max; i++) { + long key = keys[i]; + if (key != NULL && key != REMOVED) { + hash = 31 * hash + ((int) (key >> 32) ^ (int) key); + hash = 31 * hash + values[i].hashCode(); + } + } + return hash; + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof FastByIDMap)) { + return false; + } + FastByIDMap<V> otherMap = (FastByIDMap<V>) other; + long[] otherKeys = otherMap.keys; + V[] otherValues = otherMap.values; + int length = keys.length; + int otherLength = otherKeys.length; + int max = Math.min(length, otherLength); + + int i = 0; + while (i < max) { + long key = keys[i]; + long otherKey = otherKeys[i]; + if (key == NULL || key == REMOVED) { + if (otherKey != NULL && otherKey != REMOVED) { + return false; + } + } else { + if (key != otherKey || !values[i].equals(otherValues[i])) { + return false; + } + } + i++; + } + while (i < length) { + long key = keys[i]; + if (key != NULL && key != REMOVED) { + return false; + } + i++; + } + while (i < otherLength) { + long key = otherKeys[i]; + if (key != NULL && key != REMOVED) { + return false; + } + i++; + } + return true; + } + + private final class KeyIterator extends AbstractLongPrimitiveIterator { + + private int position; + private int lastNext = -1; + + @Override + public boolean hasNext() { + goToNext(); + return position < keys.length; + } + + @Override + public long nextLong() { + goToNext(); + lastNext = position; + if (position >= keys.length) { + throw new NoSuchElementException(); + } + return keys[position++]; + } + + @Override + public long peek() { + goToNext(); + if (position >= keys.length) { + throw new NoSuchElementException(); + } + return keys[position]; + } + + private void goToNext() { + int length = values.length; + while (position < length && values[position] == null) { + position++; + } + } + + @Override + public void remove() { + iteratorRemove(lastNext); + } + + @Override + public void skip(int n) { + position += n; + } + + } + + private final class EntrySet extends AbstractSet<Map.Entry<Long,V>> { + + @Override + public int size() { + return FastByIDMap.this.size(); + } + + @Override + public boolean isEmpty() { + return FastByIDMap.this.isEmpty(); + } + + @Override + public boolean contains(Object o) { + return containsKey((Long) o); + } + + @Override + public Iterator<Map.Entry<Long,V>> iterator() { + return new EntryIterator(); + } + + @Override + public boolean add(Map.Entry<Long,V> t) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean remove(Object o) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean addAll(Collection<? extends Map.Entry<Long,V>> ts) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean retainAll(Collection<?> objects) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean removeAll(Collection<?> objects) { + throw new UnsupportedOperationException(); + } + + @Override + public void clear() { + FastByIDMap.this.clear(); + } + + private final class MapEntry implements Map.Entry<Long,V> { + + private final int index; + + private MapEntry(int index) { + this.index = index; + } + + @Override + public Long getKey() { + return keys[index]; + } + + @Override + public V getValue() { + return values[index]; + } + + @Override + public V setValue(V value) { + Preconditions.checkArgument(value != null); + + V oldValue = values[index]; + values[index] = value; + return oldValue; + } + } + + private final class EntryIterator implements Iterator<Map.Entry<Long,V>> { + + private int position; + private int lastNext = -1; + + @Override + public boolean hasNext() { + goToNext(); + return position < keys.length; + } + + @Override + public Map.Entry<Long,V> next() { + goToNext(); + lastNext = position; + if (position >= keys.length) { + throw new NoSuchElementException(); + } + return new MapEntry(position++); + } + + private void goToNext() { + int length = values.length; + while (position < length && values[position] == null) { + position++; + } + } + + @Override + public void remove() { + iteratorRemove(lastNext); + } + } + + } + + private final class ValueCollection extends AbstractCollection<V> { + + @Override + public int size() { + return FastByIDMap.this.size(); + } + + @Override + public boolean isEmpty() { + return FastByIDMap.this.isEmpty(); + } + + @Override + public boolean contains(Object o) { + return containsValue(o); + } + + @Override + public Iterator<V> iterator() { + return new ValueIterator(); + } + + @Override + public boolean add(V v) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean remove(Object o) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean addAll(Collection<? extends V> vs) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean removeAll(Collection<?> objects) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean retainAll(Collection<?> objects) { + throw new UnsupportedOperationException(); + } + + @Override + public void clear() { + FastByIDMap.this.clear(); + } + + private final class ValueIterator implements Iterator<V> { + + private int position; + private int lastNext = -1; + + @Override + public boolean hasNext() { + goToNext(); + return position < values.length; + } + + @Override + public V next() { + goToNext(); + lastNext = position; + if (position >= values.length) { + throw new NoSuchElementException(); + } + return values[position++]; + } + + private void goToNext() { + int length = values.length; + while (position < length && values[position] == null) { + position++; + } + } + + @Override + public void remove() { + iteratorRemove(lastNext); + } + + } + + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FastIDSet.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FastIDSet.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FastIDSet.java new file mode 100644 index 0000000..5908270 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FastIDSet.java @@ -0,0 +1,426 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.common; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.Iterator; +import java.util.NoSuchElementException; + +import org.apache.mahout.common.RandomUtils; + +import com.google.common.base.Preconditions; + +/** + * @see FastByIDMap + */ +public final class FastIDSet implements Serializable, Cloneable, Iterable<Long> { + + private static final float DEFAULT_LOAD_FACTOR = 1.5f; + + /** Dummy object used to represent a key that has been removed. */ + private static final long REMOVED = Long.MAX_VALUE; + private static final long NULL = Long.MIN_VALUE; + + private long[] keys; + private float loadFactor; + private int numEntries; + private int numSlotsUsed; + + /** Creates a new {@link FastIDSet} with default capacity. */ + public FastIDSet() { + this(2); + } + + public FastIDSet(long[] initialKeys) { + this(initialKeys.length); + addAll(initialKeys); + } + + public FastIDSet(int size) { + this(size, DEFAULT_LOAD_FACTOR); + } + + public FastIDSet(int size, float loadFactor) { + Preconditions.checkArgument(size >= 0, "size must be at least 0"); + Preconditions.checkArgument(loadFactor >= 1.0f, "loadFactor must be at least 1.0"); + this.loadFactor = loadFactor; + int max = (int) (RandomUtils.MAX_INT_SMALLER_TWIN_PRIME / loadFactor); + Preconditions.checkArgument(size < max, "size must be less than %d", max); + int hashSize = RandomUtils.nextTwinPrime((int) (loadFactor * size)); + keys = new long[hashSize]; + Arrays.fill(keys, NULL); + } + + /** + * @see #findForAdd(long) + */ + private int find(long key) { + int theHashCode = (int) key & 0x7FFFFFFF; // make sure it's positive + long[] keys = this.keys; + int hashSize = keys.length; + int jump = 1 + theHashCode % (hashSize - 2); + int index = theHashCode % hashSize; + long currentKey = keys[index]; + while (currentKey != NULL && key != currentKey) { // note: true when currentKey == REMOVED + index -= index < jump ? jump - hashSize : jump; + currentKey = keys[index]; + } + return index; + } + + /** + * @see #find(long) + */ + private int findForAdd(long key) { + int theHashCode = (int) key & 0x7FFFFFFF; // make sure it's positive + long[] keys = this.keys; + int hashSize = keys.length; + int jump = 1 + theHashCode % (hashSize - 2); + int index = theHashCode % hashSize; + long currentKey = keys[index]; + while (currentKey != NULL && currentKey != REMOVED && key != currentKey) { + index -= index < jump ? jump - hashSize : jump; + currentKey = keys[index]; + } + if (currentKey != REMOVED) { + return index; + } + // If we're adding, it's here, but, the key might have a value already later + int addIndex = index; + while (currentKey != NULL && key != currentKey) { + index -= index < jump ? jump - hashSize : jump; + currentKey = keys[index]; + } + return key == currentKey ? index : addIndex; + } + + public int size() { + return numEntries; + } + + public boolean isEmpty() { + return numEntries == 0; + } + + public boolean contains(long key) { + return key != NULL && key != REMOVED && keys[find(key)] != NULL; + } + + public boolean add(long key) { + Preconditions.checkArgument(key != NULL && key != REMOVED); + + // If less than half the slots are open, let's clear it up + if (numSlotsUsed * loadFactor >= keys.length) { + // If over half the slots used are actual entries, let's grow + if (numEntries * loadFactor >= numSlotsUsed) { + growAndRehash(); + } else { + // Otherwise just rehash to clear REMOVED entries and don't grow + rehash(); + } + } + // Here we may later consider implementing Brent's variation described on page 532 + int index = findForAdd(key); + long keyIndex = keys[index]; + if (keyIndex != key) { + keys[index] = key; + numEntries++; + if (keyIndex == NULL) { + numSlotsUsed++; + } + return true; + } + return false; + } + + @Override + public LongPrimitiveIterator iterator() { + return new KeyIterator(); + } + + public long[] toArray() { + long[] result = new long[numEntries]; + for (int i = 0, position = 0; i < result.length; i++) { + while (keys[position] == NULL || keys[position] == REMOVED) { + position++; + } + result[i] = keys[position++]; + } + return result; + } + + public boolean remove(long key) { + if (key == NULL || key == REMOVED) { + return false; + } + int index = find(key); + if (keys[index] == NULL) { + return false; + } else { + keys[index] = REMOVED; + numEntries--; + return true; + } + } + + public boolean addAll(long[] c) { + boolean changed = false; + for (long k : c) { + if (add(k)) { + changed = true; + } + } + return changed; + } + + public boolean addAll(FastIDSet c) { + boolean changed = false; + for (long k : c.keys) { + if (k != NULL && k != REMOVED && add(k)) { + changed = true; + } + } + return changed; + } + + public boolean removeAll(long[] c) { + boolean changed = false; + for (long o : c) { + if (remove(o)) { + changed = true; + } + } + return changed; + } + + public boolean removeAll(FastIDSet c) { + boolean changed = false; + for (long k : c.keys) { + if (k != NULL && k != REMOVED && remove(k)) { + changed = true; + } + } + return changed; + } + + public boolean retainAll(FastIDSet c) { + boolean changed = false; + for (int i = 0; i < keys.length; i++) { + long k = keys[i]; + if (k != NULL && k != REMOVED && !c.contains(k)) { + keys[i] = REMOVED; + numEntries--; + changed = true; + } + } + return changed; + } + + public void clear() { + numEntries = 0; + numSlotsUsed = 0; + Arrays.fill(keys, NULL); + } + + private void growAndRehash() { + if (keys.length * loadFactor >= RandomUtils.MAX_INT_SMALLER_TWIN_PRIME) { + throw new IllegalStateException("Can't grow any more"); + } + rehash(RandomUtils.nextTwinPrime((int) (loadFactor * keys.length))); + } + + public void rehash() { + rehash(RandomUtils.nextTwinPrime((int) (loadFactor * numEntries))); + } + + private void rehash(int newHashSize) { + long[] oldKeys = keys; + numEntries = 0; + numSlotsUsed = 0; + keys = new long[newHashSize]; + Arrays.fill(keys, NULL); + for (long key : oldKeys) { + if (key != NULL && key != REMOVED) { + add(key); + } + } + } + + /** + * Convenience method to quickly compute just the size of the intersection with another {@link FastIDSet}. + * + * @param other + * {@link FastIDSet} to intersect with + * @return number of elements in intersection + */ + public int intersectionSize(FastIDSet other) { + int count = 0; + for (long key : other.keys) { + if (key != NULL && key != REMOVED && keys[find(key)] != NULL) { + count++; + } + } + return count; + } + + @Override + public FastIDSet clone() { + FastIDSet clone; + try { + clone = (FastIDSet) super.clone(); + } catch (CloneNotSupportedException cnse) { + throw new AssertionError(); + } + clone.keys = keys.clone(); + return clone; + } + + @Override + public int hashCode() { + int hash = 0; + long[] keys = this.keys; + for (long key : keys) { + if (key != NULL && key != REMOVED) { + hash = 31 * hash + ((int) (key >> 32) ^ (int) key); + } + } + return hash; + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof FastIDSet)) { + return false; + } + FastIDSet otherMap = (FastIDSet) other; + long[] otherKeys = otherMap.keys; + int length = keys.length; + int otherLength = otherKeys.length; + int max = Math.min(length, otherLength); + + int i = 0; + while (i < max) { + long key = keys[i]; + long otherKey = otherKeys[i]; + if (key == NULL || key == REMOVED) { + if (otherKey != NULL && otherKey != REMOVED) { + return false; + } + } else { + if (key != otherKey) { + return false; + } + } + i++; + } + while (i < length) { + long key = keys[i]; + if (key != NULL && key != REMOVED) { + return false; + } + i++; + } + while (i < otherLength) { + long key = otherKeys[i]; + if (key != NULL && key != REMOVED) { + return false; + } + i++; + } + return true; + } + + @Override + public String toString() { + if (isEmpty()) { + return "[]"; + } + StringBuilder result = new StringBuilder(); + result.append('['); + for (long key : keys) { + if (key != NULL && key != REMOVED) { + result.append(key).append(','); + } + } + result.setCharAt(result.length() - 1, ']'); + return result.toString(); + } + + private final class KeyIterator extends AbstractLongPrimitiveIterator { + + private int position; + private int lastNext = -1; + + @Override + public boolean hasNext() { + goToNext(); + return position < keys.length; + } + + @Override + public long nextLong() { + goToNext(); + lastNext = position; + if (position >= keys.length) { + throw new NoSuchElementException(); + } + return keys[position++]; + } + + @Override + public long peek() { + goToNext(); + if (position >= keys.length) { + throw new NoSuchElementException(); + } + return keys[position]; + } + + private void goToNext() { + int length = keys.length; + while (position < length + && (keys[position] == NULL || keys[position] == REMOVED)) { + position++; + } + } + + @Override + public void remove() { + if (lastNext >= keys.length) { + throw new NoSuchElementException(); + } + if (lastNext < 0) { + throw new IllegalStateException(); + } + keys[lastNext] = REMOVED; + numEntries--; + } + + public Iterator<Long> iterator() { + return new KeyIterator(); + } + + @Override + public void skip(int n) { + position += n; + } + + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FastMap.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FastMap.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FastMap.java new file mode 100644 index 0000000..7c64b44 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FastMap.java @@ -0,0 +1,729 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.common; + +import java.io.Serializable; +import java.util.AbstractCollection; +import java.util.AbstractSet; +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Set; + +import org.apache.mahout.common.RandomUtils; + +import com.google.common.base.Preconditions; + +/** + * <p> + * This is an optimized {@link Map} implementation, based on algorithms described in Knuth's "Art of Computer + * Programming", Vol. 3, p. 529. + * </p> + * + * <p> + * It should be faster than {@link java.util.HashMap} in some cases, but not all. Its main feature is a + * "max size" and the ability to transparently, efficiently and semi-intelligently evict old entries when max + * size is exceeded. + * </p> + * + * <p> + * This class is not a bit thread-safe. + * </p> + * + * <p> + * This implementation does not allow {@code null} as a key or value. + * </p> + */ +public final class FastMap<K,V> implements Map<K,V>, Serializable, Cloneable { + + public static final int NO_MAX_SIZE = Integer.MAX_VALUE; + private static final float DEFAULT_LOAD_FACTOR = 1.5f; + + /** Dummy object used to represent a key that has been removed. */ + private static final Object REMOVED = new Object(); + + private K[] keys; + private V[] values; + private float loadFactor; + private int numEntries; + private int numSlotsUsed; + private final int maxSize; + private BitSet recentlyAccessed; + private final boolean countingAccesses; + + /** Creates a new {@link FastMap} with default capacity. */ + public FastMap() { + this(2, NO_MAX_SIZE); + } + + public FastMap(int size) { + this(size, NO_MAX_SIZE); + } + + public FastMap(Map<K,V> other) { + this(other.size()); + putAll(other); + } + + public FastMap(int size, float loadFactor) { + this(size, NO_MAX_SIZE, loadFactor); + } + + public FastMap(int size, int maxSize) { + this(size, maxSize, DEFAULT_LOAD_FACTOR); + } + + /** + * Creates a new whose capacity can accommodate the given number of entries without rehash. + * + * @param size desired capacity + * @param maxSize max capacity + * @throws IllegalArgumentException if size is less than 0, maxSize is less than 1 + * or at least half of {@link RandomUtils#MAX_INT_SMALLER_TWIN_PRIME}, or + * loadFactor is less than 1 + */ + public FastMap(int size, int maxSize, float loadFactor) { + Preconditions.checkArgument(size >= 0, "size must be at least 0"); + Preconditions.checkArgument(loadFactor >= 1.0f, "loadFactor must be at least 1.0"); + this.loadFactor = loadFactor; + int max = (int) (RandomUtils.MAX_INT_SMALLER_TWIN_PRIME / loadFactor); + Preconditions.checkArgument(size < max, "size must be less than " + max); + Preconditions.checkArgument(maxSize >= 1, "maxSize must be at least 1"); + int hashSize = RandomUtils.nextTwinPrime((int) (loadFactor * size)); + keys = (K[]) new Object[hashSize]; + values = (V[]) new Object[hashSize]; + this.maxSize = maxSize; + this.countingAccesses = maxSize != Integer.MAX_VALUE; + this.recentlyAccessed = countingAccesses ? new BitSet(hashSize) : null; + } + + private int find(Object key) { + int theHashCode = key.hashCode() & 0x7FFFFFFF; // make sure it's positive + K[] keys = this.keys; + int hashSize = keys.length; + int jump = 1 + theHashCode % (hashSize - 2); + int index = theHashCode % hashSize; + K currentKey = keys[index]; + while (currentKey != null && !key.equals(currentKey)) { + index -= index < jump ? jump - hashSize : jump; + currentKey = keys[index]; + } + return index; + } + + private int findForAdd(Object key) { + int theHashCode = key.hashCode() & 0x7FFFFFFF; // make sure it's positive + K[] keys = this.keys; + int hashSize = keys.length; + int jump = 1 + theHashCode % (hashSize - 2); + int index = theHashCode % hashSize; + K currentKey = keys[index]; + while (currentKey != null && currentKey != REMOVED && key != currentKey) { + index -= index < jump ? jump - hashSize : jump; + currentKey = keys[index]; + } + if (currentKey != REMOVED) { + return index; + } + // If we're adding, it's here, but, the key might have a value already later + int addIndex = index; + while (currentKey != null && key != currentKey) { + index -= index < jump ? jump - hashSize : jump; + currentKey = keys[index]; + } + return key == currentKey ? index : addIndex; + } + + @Override + public V get(Object key) { + if (key == null) { + return null; + } + int index = find(key); + if (countingAccesses) { + recentlyAccessed.set(index); + } + return values[index]; + } + + @Override + public int size() { + return numEntries; + } + + @Override + public boolean isEmpty() { + return numEntries == 0; + } + + @Override + public boolean containsKey(Object key) { + return key != null && keys[find(key)] != null; + } + + @Override + public boolean containsValue(Object value) { + if (value == null) { + return false; + } + for (V theValue : values) { + if (theValue != null && value.equals(theValue)) { + return true; + } + } + return false; + } + + /** + * @throws NullPointerException + * if key or value is null + */ + @Override + public V put(K key, V value) { + Preconditions.checkNotNull(key); + Preconditions.checkNotNull(value); + // If less than half the slots are open, let's clear it up + if (numSlotsUsed * loadFactor >= keys.length) { + // If over half the slots used are actual entries, let's grow + if (numEntries * loadFactor >= numSlotsUsed) { + growAndRehash(); + } else { + // Otherwise just rehash to clear REMOVED entries and don't grow + rehash(); + } + } + // Here we may later consider implementing Brent's variation described on page 532 + int index = findForAdd(key); + if (keys[index] == key) { + V oldValue = values[index]; + values[index] = value; + return oldValue; + } + // If size is limited, + if (countingAccesses && numEntries >= maxSize) { + // and we're too large, clear some old-ish entry + clearStaleEntry(index); + } + keys[index] = key; + values[index] = value; + numEntries++; + numSlotsUsed++; + return null; + } + + private void clearStaleEntry(int index) { + while (true) { + K currentKey; + do { + if (index == 0) { + index = keys.length - 1; + } else { + index--; + } + currentKey = keys[index]; + } while (currentKey == null || currentKey == REMOVED); + if (recentlyAccessed.get(index)) { + recentlyAccessed.clear(index); + } else { + break; + } + } + // Delete the entry + ((Object[])keys)[index] = REMOVED; + numEntries--; + values[index] = null; + } + + @Override + public void putAll(Map<? extends K,? extends V> map) { + for (Entry<? extends K,? extends V> entry : map.entrySet()) { + put(entry.getKey(), entry.getValue()); + } + } + + @Override + public V remove(Object key) { + if (key == null) { + return null; + } + int index = find(key); + if (keys[index] == null) { + return null; + } else { + ((Object[])keys)[index] = REMOVED; + numEntries--; + V oldValue = values[index]; + values[index] = null; + // don't decrement numSlotsUsed + return oldValue; + } + // Could un-set recentlyAccessed's bit but doesn't matter + } + + @Override + public void clear() { + numEntries = 0; + numSlotsUsed = 0; + Arrays.fill(keys, null); + Arrays.fill(values, null); + if (countingAccesses) { + recentlyAccessed.clear(); + } + } + + @Override + public Set<K> keySet() { + return new KeySet(); + } + + @Override + public Collection<V> values() { + return new ValueCollection(); + } + + @Override + public Set<Entry<K,V>> entrySet() { + return new EntrySet(); + } + + public void rehash() { + rehash(RandomUtils.nextTwinPrime((int) (loadFactor * numEntries))); + } + + private void growAndRehash() { + if (keys.length * loadFactor >= RandomUtils.MAX_INT_SMALLER_TWIN_PRIME) { + throw new IllegalStateException("Can't grow any more"); + } + rehash(RandomUtils.nextTwinPrime((int) (loadFactor * keys.length))); + } + + private void rehash(int newHashSize) { + K[] oldKeys = keys; + V[] oldValues = values; + numEntries = 0; + numSlotsUsed = 0; + if (countingAccesses) { + recentlyAccessed = new BitSet(newHashSize); + } + keys = (K[]) new Object[newHashSize]; + values = (V[]) new Object[newHashSize]; + int length = oldKeys.length; + for (int i = 0; i < length; i++) { + K key = oldKeys[i]; + if (key != null && key != REMOVED) { + put(key, oldValues[i]); + } + } + } + + void iteratorRemove(int lastNext) { + if (lastNext >= values.length) { + throw new NoSuchElementException(); + } + if (lastNext < 0) { + throw new IllegalStateException(); + } + values[lastNext] = null; + ((Object[])keys)[lastNext] = REMOVED; + numEntries--; + } + + @Override + public FastMap<K,V> clone() { + FastMap<K,V> clone; + try { + clone = (FastMap<K,V>) super.clone(); + } catch (CloneNotSupportedException cnse) { + throw new AssertionError(); + } + clone.keys = keys.clone(); + clone.values = values.clone(); + clone.recentlyAccessed = countingAccesses ? new BitSet(keys.length) : null; + return clone; + } + + @Override + public int hashCode() { + int hash = 0; + K[] keys = this.keys; + int max = keys.length; + for (int i = 0; i < max; i++) { + K key = keys[i]; + if (key != null && key != REMOVED) { + hash = 31 * hash + key.hashCode(); + hash = 31 * hash + values[i].hashCode(); + } + } + return hash; + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof FastMap)) { + return false; + } + FastMap<K,V> otherMap = (FastMap<K,V>) other; + K[] otherKeys = otherMap.keys; + V[] otherValues = otherMap.values; + int length = keys.length; + int otherLength = otherKeys.length; + int max = Math.min(length, otherLength); + + int i = 0; + while (i < max) { + K key = keys[i]; + K otherKey = otherKeys[i]; + if (key == null || key == REMOVED) { + if (otherKey != null && otherKey != REMOVED) { + return false; + } + } else { + if (key != otherKey || !values[i].equals(otherValues[i])) { + return false; + } + } + i++; + } + while (i < length) { + K key = keys[i]; + if (key != null && key != REMOVED) { + return false; + } + i++; + } + while (i < otherLength) { + K key = otherKeys[i]; + if (key != null && key != REMOVED) { + return false; + } + i++; + } + return true; + } + + @Override + public String toString() { + if (isEmpty()) { + return "{}"; + } + StringBuilder result = new StringBuilder(); + result.append('{'); + for (int i = 0; i < keys.length; i++) { + K key = keys[i]; + if (key != null && key != REMOVED) { + result.append(key).append('=').append(values[i]).append(','); + } + } + result.setCharAt(result.length() - 1, '}'); + return result.toString(); + } + + private final class EntrySet extends AbstractSet<Entry<K,V>> { + + @Override + public int size() { + return FastMap.this.size(); + } + + @Override + public boolean isEmpty() { + return FastMap.this.isEmpty(); + } + + @Override + public boolean contains(Object o) { + return containsKey(o); + } + + @Override + public Iterator<Entry<K,V>> iterator() { + return new EntryIterator(); + } + + @Override + public boolean add(Entry<K,V> t) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean remove(Object o) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean addAll(Collection<? extends Entry<K,V>> ts) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean retainAll(Collection<?> objects) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean removeAll(Collection<?> objects) { + throw new UnsupportedOperationException(); + } + + @Override + public void clear() { + FastMap.this.clear(); + } + + private final class MapEntry implements Entry<K,V> { + + private final int index; + + private MapEntry(int index) { + this.index = index; + } + + @Override + public K getKey() { + return keys[index]; + } + + @Override + public V getValue() { + return values[index]; + } + + @Override + public V setValue(V value) { + Preconditions.checkArgument(value != null); + V oldValue = values[index]; + values[index] = value; + return oldValue; + } + } + + private final class EntryIterator implements Iterator<Entry<K,V>> { + + private int position; + private int lastNext = -1; + + @Override + public boolean hasNext() { + goToNext(); + return position < keys.length; + } + + @Override + public Entry<K,V> next() { + goToNext(); + lastNext = position; + if (position >= keys.length) { + throw new NoSuchElementException(); + } + return new MapEntry(position++); + } + + private void goToNext() { + int length = values.length; + while (position < length && values[position] == null) { + position++; + } + } + + @Override + public void remove() { + iteratorRemove(lastNext); + } + } + + } + + private final class KeySet extends AbstractSet<K> { + + @Override + public int size() { + return FastMap.this.size(); + } + + @Override + public boolean isEmpty() { + return FastMap.this.isEmpty(); + } + + @Override + public boolean contains(Object o) { + return containsKey(o); + } + + @Override + public Iterator<K> iterator() { + return new KeyIterator(); + } + + @Override + public boolean add(K t) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean remove(Object o) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean addAll(Collection<? extends K> ts) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean retainAll(Collection<?> objects) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean removeAll(Collection<?> objects) { + throw new UnsupportedOperationException(); + } + + @Override + public void clear() { + FastMap.this.clear(); + } + + private final class KeyIterator implements Iterator<K> { + + private int position; + private int lastNext = -1; + + @Override + public boolean hasNext() { + goToNext(); + return position < keys.length; + } + + @Override + public K next() { + goToNext(); + lastNext = position; + if (position >= keys.length) { + throw new NoSuchElementException(); + } + return keys[position++]; + } + + private void goToNext() { + int length = values.length; + while (position < length && values[position] == null) { + position++; + } + } + + @Override + public void remove() { + iteratorRemove(lastNext); + } + } + + } + + private final class ValueCollection extends AbstractCollection<V> { + + @Override + public int size() { + return FastMap.this.size(); + } + + @Override + public boolean isEmpty() { + return FastMap.this.isEmpty(); + } + + @Override + public boolean contains(Object o) { + return containsValue(o); + } + + @Override + public Iterator<V> iterator() { + return new ValueIterator(); + } + + @Override + public boolean add(V v) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean remove(Object o) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean addAll(Collection<? extends V> vs) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean removeAll(Collection<?> objects) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean retainAll(Collection<?> objects) { + throw new UnsupportedOperationException(); + } + + @Override + public void clear() { + FastMap.this.clear(); + } + + private final class ValueIterator implements Iterator<V> { + + private int position; + private int lastNext = -1; + + @Override + public boolean hasNext() { + goToNext(); + return position < values.length; + } + + @Override + public V next() { + goToNext(); + lastNext = position; + if (position >= values.length) { + throw new NoSuchElementException(); + } + return values[position++]; + } + + private void goToNext() { + int length = values.length; + while (position < length && values[position] == null) { + position++; + } + } + + @Override + public void remove() { + iteratorRemove(lastNext); + } + + } + + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverage.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverage.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverage.java new file mode 100644 index 0000000..1863d2b --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverage.java @@ -0,0 +1,83 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.common; + +import java.io.Serializable; + +/** + * <p> + * A simple class that represents a fixed value of an average and count. This is useful + * when an API needs to return {@link RunningAverage} but is not in a position to accept + * updates to it. + * </p> + */ +public class FixedRunningAverage implements RunningAverage, Serializable { + + private final double average; + private final int count; + + public FixedRunningAverage(double average, int count) { + this.average = average; + this.count = count; + } + + /** + * @throws UnsupportedOperationException + */ + @Override + public synchronized void addDatum(double datum) { + throw new UnsupportedOperationException(); + } + + /** + * @throws UnsupportedOperationException + */ + @Override + public synchronized void removeDatum(double datum) { + throw new UnsupportedOperationException(); + } + + /** + * @throws UnsupportedOperationException + */ + @Override + public synchronized void changeDatum(double delta) { + throw new UnsupportedOperationException(); + } + + @Override + public synchronized int getCount() { + return count; + } + + @Override + public synchronized double getAverage() { + return average; + } + + @Override + public RunningAverage inverse() { + return new InvertedRunningAverage(this); + } + + @Override + public synchronized String toString() { + return String.valueOf(average); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverageAndStdDev.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverageAndStdDev.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverageAndStdDev.java new file mode 100644 index 0000000..619b6b7 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverageAndStdDev.java @@ -0,0 +1,51 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.common; + +/** + * <p> + * A simple class that represents a fixed value of an average, count and standard deviation. This is useful + * when an API needs to return {@link RunningAverageAndStdDev} but is not in a position to accept + * updates to it. + * </p> + */ +public final class FixedRunningAverageAndStdDev extends FixedRunningAverage implements RunningAverageAndStdDev { + + private final double stdDev; + + public FixedRunningAverageAndStdDev(double average, double stdDev, int count) { + super(average, count); + this.stdDev = stdDev; + } + + @Override + public RunningAverageAndStdDev inverse() { + return new InvertedRunningAverageAndStdDev(this); + } + + @Override + public synchronized String toString() { + return super.toString() + ',' + stdDev; + } + + @Override + public double getStandardDeviation() { + return stdDev; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverage.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverage.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverage.java new file mode 100644 index 0000000..00d828f --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverage.java @@ -0,0 +1,109 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.common; + +import java.io.Serializable; + +/** + * <p> + * A simple class that can keep track of a running average of a series of numbers. One can add to or remove + * from the series, as well as update a datum in the series. The class does not actually keep track of the + * series of values, just its running average, so it doesn't even matter if you remove/change a value that + * wasn't added. + * </p> + */ +public class FullRunningAverage implements RunningAverage, Serializable { + + private int count; + private double average; + + public FullRunningAverage() { + this(0, Double.NaN); + } + + public FullRunningAverage(int count, double average) { + this.count = count; + this.average = average; + } + + /** + * @param datum + * new item to add to the running average + */ + @Override + public synchronized void addDatum(double datum) { + if (++count == 1) { + average = datum; + } else { + average = average * (count - 1) / count + datum / count; + } + } + + /** + * @param datum + * item to remove to the running average + * @throws IllegalStateException + * if count is 0 + */ + @Override + public synchronized void removeDatum(double datum) { + if (count == 0) { + throw new IllegalStateException(); + } + if (--count == 0) { + average = Double.NaN; + } else { + average = average * (count + 1) / count - datum / count; + } + } + + /** + * @param delta + * amount by which to change a datum in the running average + * @throws IllegalStateException + * if count is 0 + */ + @Override + public synchronized void changeDatum(double delta) { + if (count == 0) { + throw new IllegalStateException(); + } + average += delta / count; + } + + @Override + public synchronized int getCount() { + return count; + } + + @Override + public synchronized double getAverage() { + return average; + } + + @Override + public RunningAverage inverse() { + return new InvertedRunningAverage(this); + } + + @Override + public synchronized String toString() { + return String.valueOf(average); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverageAndStdDev.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverageAndStdDev.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverageAndStdDev.java new file mode 100644 index 0000000..6212e66 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverageAndStdDev.java @@ -0,0 +1,107 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.common; + +/** + * <p> + * Extends {@link FullRunningAverage} to add a running standard deviation computation. + * Uses Welford's method, as described at http://www.johndcook.com/standard_deviation.html + * </p> + */ +public final class FullRunningAverageAndStdDev extends FullRunningAverage implements RunningAverageAndStdDev { + + private double stdDev; + private double mk; + private double sk; + + public FullRunningAverageAndStdDev() { + mk = 0.0; + sk = 0.0; + recomputeStdDev(); + } + + public FullRunningAverageAndStdDev(int count, double average, double mk, double sk) { + super(count, average); + this.mk = mk; + this.sk = sk; + recomputeStdDev(); + } + + public double getMk() { + return mk; + } + + public double getSk() { + return sk; + } + + @Override + public synchronized double getStandardDeviation() { + return stdDev; + } + + @Override + public synchronized void addDatum(double datum) { + super.addDatum(datum); + int count = getCount(); + if (count == 1) { + mk = datum; + sk = 0.0; + } else { + double oldmk = mk; + double diff = datum - oldmk; + mk += diff / count; + sk += diff * (datum - mk); + } + recomputeStdDev(); + } + + @Override + public synchronized void removeDatum(double datum) { + int oldCount = getCount(); + super.removeDatum(datum); + double oldmk = mk; + mk = (oldCount * oldmk - datum) / (oldCount - 1); + sk -= (datum - mk) * (datum - oldmk); + recomputeStdDev(); + } + + /** + * @throws UnsupportedOperationException + */ + @Override + public void changeDatum(double delta) { + throw new UnsupportedOperationException(); + } + + private synchronized void recomputeStdDev() { + int count = getCount(); + stdDev = count > 1 ? Math.sqrt(sk / (count - 1)) : Double.NaN; + } + + @Override + public RunningAverageAndStdDev inverse() { + return new InvertedRunningAverageAndStdDev(this); + } + + @Override + public synchronized String toString() { + return String.valueOf(String.valueOf(getAverage()) + ',' + stdDev); + } + +}
