http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/neighborhood/FastProjectionSearch.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/neighborhood/FastProjectionSearch.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/neighborhood/FastProjectionSearch.java new file mode 100644 index 0000000..006f4b6 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/neighborhood/FastProjectionSearch.java @@ -0,0 +1,326 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.neighborhood; + +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Set; + +import com.google.common.base.Preconditions; +import com.google.common.collect.AbstractIterator; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.random.RandomProjector; +import org.apache.mahout.math.random.WeightedThing; + +/** + * Does approximate nearest neighbor search by projecting the vectors similar to ProjectionSearch. + * The main difference between this class and the ProjectionSearch is the use of sorted arrays + * instead of binary search trees to implement the sets of scalar projections. + * + * Instead of taking log n time to add a vector to each of the vectors, * the pending additions are + * kept separate and are searched using a brute search. When there are "enough" pending additions, + * they're committed into the main pool of vectors. + */ +public class FastProjectionSearch extends UpdatableSearcher { + // The list of vectors that have not yet been projected (that are pending). + private final List<Vector> pendingAdditions = Lists.newArrayList(); + + // The list of basis vectors. Populated when the first vector's dimension is know by calling + // initialize once. + private Matrix basisMatrix = null; + + // The list of sorted lists of scalar projections. The outer list has one entry for each basis + // vector that all the other vectors will be projected on. + // For each basis vector, the inner list has an entry for each vector that has been projected. + // These entries are WeightedThing<Vector> where the weight is the value of the scalar + // projection and the value is the vector begin referred to. + private List<List<WeightedThing<Vector>>> scalarProjections; + + // The number of projection used for approximating the distance. + private final int numProjections; + + // The number of elements to keep on both sides of the closest estimated distance as possible + // candidates for the best actual distance. + private final int searchSize; + + // Initially, the dimension of the vectors searched by this searcher is unknown. After adding + // the first vector, the basis will be initialized. This marks whether initialization has + // happened or not so we only do it once. + private boolean initialized = false; + + // Removing vectors from the searcher is done lazily to avoid the linear time cost of removing + // elements from an array. This member keeps track of the number of removed vectors (marked as + // "impossible" values in the array) so they can be removed when updating the structure. + private int numPendingRemovals = 0; + + private static final double ADDITION_THRESHOLD = 0.05; + private static final double REMOVAL_THRESHOLD = 0.02; + + public FastProjectionSearch(DistanceMeasure distanceMeasure, int numProjections, int searchSize) { + super(distanceMeasure); + Preconditions.checkArgument(numProjections > 0 && numProjections < 100, + "Unreasonable value for number of projections. Must be: 0 < numProjections < 100"); + this.numProjections = numProjections; + this.searchSize = searchSize; + scalarProjections = Lists.newArrayListWithCapacity(numProjections); + for (int i = 0; i < numProjections; ++i) { + scalarProjections.add(Lists.<WeightedThing<Vector>>newArrayList()); + } + } + + private void initialize(int numDimensions) { + if (initialized) { + return; + } + basisMatrix = RandomProjector.generateBasisNormal(numProjections, numDimensions); + initialized = true; + } + + /** + * Add a new Vector to the Searcher that will be checked when getting + * the nearest neighbors. + * <p/> + * The vector IS NOT CLONED. Do not modify the vector externally otherwise the internal + * Searcher data structures could be invalidated. + */ + @Override + public void add(Vector vector) { + initialize(vector.size()); + pendingAdditions.add(vector); + } + + /** + * Returns the number of WeightedVectors being searched for nearest neighbors. + */ + @Override + public int size() { + return pendingAdditions.size() + scalarProjections.get(0).size() - numPendingRemovals; + } + + /** + * When querying the Searcher for the closest vectors, a list of WeightedThing<Vector>s is + * returned. The value of the WeightedThing is the neighbor and the weight is the + * the distance (calculated by some metric - see a concrete implementation) between the query + * and neighbor. + * The actual type of vector in the pair is the same as the vector added to the Searcher. + */ + @Override + public List<WeightedThing<Vector>> search(Vector query, int limit) { + reindex(false); + + Set<Vector> candidates = Sets.newHashSet(); + Vector projection = basisMatrix.times(query); + for (int i = 0; i < basisMatrix.numRows(); ++i) { + List<WeightedThing<Vector>> currProjections = scalarProjections.get(i); + int middle = Collections.binarySearch(currProjections, + new WeightedThing<Vector>(projection.get(i))); + if (middle < 0) { + middle = -(middle + 1); + } + for (int j = Math.max(0, middle - searchSize); + j < Math.min(currProjections.size(), middle + searchSize + 1); ++j) { + if (currProjections.get(j).getValue() == null) { + continue; + } + candidates.add(currProjections.get(j).getValue()); + } + } + + List<WeightedThing<Vector>> top = + Lists.newArrayListWithCapacity(candidates.size() + pendingAdditions.size()); + for (Vector candidate : Iterables.concat(candidates, pendingAdditions)) { + top.add(new WeightedThing<>(candidate, distanceMeasure.distance(candidate, query))); + } + Collections.sort(top); + + return top.subList(0, Math.min(top.size(), limit)); + } + + /** + * Returns the closest vector to the query. + * When only one the nearest vector is needed, use this method, NOT search(query, limit) because + * it's faster (less overhead). + * + * @param query the vector to search for + * @param differentThanQuery if true, returns the closest vector different than the query (this + * only matters if the query is among the searched vectors), otherwise, + * returns the closest vector to the query (even the same vector). + * @return the weighted vector closest to the query + */ + @Override + public WeightedThing<Vector> searchFirst(Vector query, boolean differentThanQuery) { + reindex(false); + + double bestDistance = Double.POSITIVE_INFINITY; + Vector bestVector = null; + + Vector projection = basisMatrix.times(query); + for (int i = 0; i < basisMatrix.numRows(); ++i) { + List<WeightedThing<Vector>> currProjections = scalarProjections.get(i); + int middle = Collections.binarySearch(currProjections, + new WeightedThing<Vector>(projection.get(i))); + if (middle < 0) { + middle = -(middle + 1); + } + for (int j = Math.max(0, middle - searchSize); + j < Math.min(currProjections.size(), middle + searchSize + 1); ++j) { + if (currProjections.get(j).getValue() == null) { + continue; + } + Vector vector = currProjections.get(j).getValue(); + double distance = distanceMeasure.distance(vector, query); + if (distance < bestDistance && (!differentThanQuery || !vector.equals(query))) { + bestDistance = distance; + bestVector = vector; + } + } + } + + for (Vector vector : pendingAdditions) { + double distance = distanceMeasure.distance(vector, query); + if (distance < bestDistance && (!differentThanQuery || !vector.equals(query))) { + bestDistance = distance; + bestVector = vector; + } + } + + return new WeightedThing<>(bestVector, bestDistance); + } + + @Override + public boolean remove(Vector vector, double epsilon) { + WeightedThing<Vector> closestPair = searchFirst(vector, false); + if (distanceMeasure.distance(closestPair.getValue(), vector) > epsilon) { + return false; + } + + boolean isProjected = true; + Vector projection = basisMatrix.times(vector); + for (int i = 0; i < basisMatrix.numRows(); ++i) { + List<WeightedThing<Vector>> currProjections = scalarProjections.get(i); + WeightedThing<Vector> searchedThing = new WeightedThing<>(projection.get(i)); + int middle = Collections.binarySearch(currProjections, searchedThing); + if (middle < 0) { + isProjected = false; + break; + } + // Elements to be removed are kept in the sorted array until the next reindex, but their inner vector + // is set to null. + scalarProjections.get(i).set(middle, searchedThing); + } + if (isProjected) { + ++numPendingRemovals; + return true; + } + + for (int i = 0; i < pendingAdditions.size(); ++i) { + if (pendingAdditions.get(i).equals(vector)) { + pendingAdditions.remove(i); + break; + } + } + return true; + } + + private void reindex(boolean force) { + int numProjected = scalarProjections.get(0).size(); + if (force || pendingAdditions.size() > ADDITION_THRESHOLD * numProjected + || numPendingRemovals > REMOVAL_THRESHOLD * numProjected) { + + // We only need to copy the first list because when iterating we use only that list for the Vector + // references. + // see public Iterator<Vector> iterator() + List<List<WeightedThing<Vector>>> scalarProjections = Lists.newArrayListWithCapacity(numProjections); + for (int i = 0; i < numProjections; ++i) { + if (i == 0) { + scalarProjections.add(Lists.newArrayList(this.scalarProjections.get(i))); + } else { + scalarProjections.add(this.scalarProjections.get(i)); + } + } + + // Project every pending vector onto every basis vector. + for (Vector pending : pendingAdditions) { + Vector projection = basisMatrix.times(pending); + for (int i = 0; i < numProjections; ++i) { + scalarProjections.get(i).add(new WeightedThing<>(pending, projection.get(i))); + } + } + pendingAdditions.clear(); + // For each basis vector, sort the resulting list (for binary search) and remove the number + // of pending removals (it's the same for every basis vector) at the end (the weights are + // set to Double.POSITIVE_INFINITY when removing). + for (int i = 0; i < numProjections; ++i) { + List<WeightedThing<Vector>> currProjections = scalarProjections.get(i); + for (WeightedThing<Vector> v : currProjections) { + if (v.getValue() == null) { + v.setWeight(Double.POSITIVE_INFINITY); + } + } + Collections.sort(currProjections); + for (int j = 0; j < numPendingRemovals; ++j) { + currProjections.remove(currProjections.size() - 1); + } + } + numPendingRemovals = 0; + + this.scalarProjections = scalarProjections; + } + } + + @Override + public void clear() { + pendingAdditions.clear(); + for (int i = 0; i < numProjections; ++i) { + scalarProjections.get(i).clear(); + } + numPendingRemovals = 0; + } + + /** + * This iterates on the snapshot of the contents first instantiated regardless of any future modifications. + * Changes done after the iterator is created will not be visible to the iterator but will be visible + * when searching. + * @return iterator through the vectors in this searcher. + */ + @Override + public Iterator<Vector> iterator() { + reindex(true); + return new AbstractIterator<Vector>() { + private final Iterator<WeightedThing<Vector>> data = scalarProjections.get(0).iterator(); + @Override + protected Vector computeNext() { + do { + if (!data.hasNext()) { + return endOfData(); + } + WeightedThing<Vector> next = data.next(); + if (next.getValue() != null) { + return next.getValue(); + } + } while (true); + } + }; + } +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/neighborhood/HashedVector.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/neighborhood/HashedVector.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/neighborhood/HashedVector.java new file mode 100644 index 0000000..eb91813 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/neighborhood/HashedVector.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.neighborhood; + +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.WeightedVector; + +/** + * Decorates a weighted vector with a locality sensitive hash. + * + * The LSH function implemented is the random hyperplane based hash function. + * See "Similarity Estimation Techniques from Rounding Algorithms" by Moses S. Charikar, section 3. + * http://www.cs.princeton.edu/courses/archive/spring04/cos598B/bib/CharikarEstim.pdf + */ +public class HashedVector extends WeightedVector { + protected static final int INVALID_INDEX = -1; + + /** + * Value of the locality sensitive hash. It is 64 bit. + */ + private final long hash; + + public HashedVector(Vector vector, long hash, int index) { + super(vector, 1, index); + this.hash = hash; + } + + public HashedVector(Vector vector, Matrix projection, int index, long mask) { + super(vector, 1, index); + this.hash = mask & computeHash64(vector, projection); + } + + public HashedVector(WeightedVector weightedVector, Matrix projection, long mask) { + super(weightedVector.getVector(), weightedVector.getWeight(), weightedVector.getIndex()); + this.hash = mask & computeHash64(weightedVector, projection); + } + + public static long computeHash64(Vector vector, Matrix projection) { + long hash = 0; + for (Element element : projection.times(vector).nonZeroes()) { + if (element.get() > 0) { + hash += 1L << element.index(); + } + } + return hash; + } + + public static HashedVector hash(WeightedVector v, Matrix projection) { + return hash(v, projection, 0); + } + + public static HashedVector hash(WeightedVector v, Matrix projection, long mask) { + return new HashedVector(v, projection, mask); + } + + public int hammingDistance(long otherHash) { + return Long.bitCount(hash ^ otherHash); + } + + public long getHash() { + return hash; + } + + @Override + public String toString() { + return String.format("index=%d, hash=%08x, v=%s", getIndex(), hash, getVector()); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof HashedVector)) { + return o instanceof Vector && this.minus((Vector) o).norm(1) == 0; + } + HashedVector v = (HashedVector) o; + return v.hash == this.hash && this.minus(v).norm(1) == 0; + } + + @Override + public int hashCode() { + int result = super.hashCode(); + result = 31 * result + (int) (hash ^ (hash >>> 32)); + return result; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearch.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearch.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearch.java new file mode 100644 index 0000000..aa1f103 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearch.java @@ -0,0 +1,295 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.neighborhood; + +import java.util.Collections; +import java.util.Iterator; +import java.util.List; + +import com.google.common.base.Function; +import com.google.common.base.Preconditions; +import com.google.common.collect.HashMultiset; +import com.google.common.collect.Iterators; +import com.google.common.collect.Lists; +import com.google.common.collect.Multiset; +import org.apache.lucene.util.PriorityQueue; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.random.RandomProjector; +import org.apache.mahout.math.random.WeightedThing; +import org.apache.mahout.math.stats.OnlineSummarizer; + +/** + * Implements a Searcher that uses locality sensitivity hash as a first pass approximation + * to estimate distance without floating point math. The clever bit about this implementation + * is that it does an adaptive cutoff for the cutoff on the bitwise distance. Making this + * cutoff adaptive means that we only needs to make a single pass through the data. + */ +public class LocalitySensitiveHashSearch extends UpdatableSearcher { + /** + * Number of bits in the locality sensitive hash. 64 bits fix neatly into a long. + */ + private static final int BITS = 64; + + /** + * Bit mask for the computed hash. Currently, it's 0xffffffffffff. + */ + private static final long BIT_MASK = -1L; + + /** + * The maximum Hamming distance between two hashes that the hash limit can grow back to. + * It starts at BITS and decreases as more points than are needed are added to the candidate priority queue. + * But, after the observed distribution of distances becomes too good (we're seeing less than some percentage of the + * total number of points; using the hash strategy somewhere less than 25%) the limit is increased to compute + * more distances. + * This is because + */ + private static final int MAX_HASH_LIMIT = 32; + + /** + * Minimum number of points with a given Hamming from the query that must be observed to consider raising the minimum + * distance for a candidate. + */ + private static final int MIN_DISTRIBUTION_COUNT = 10; + + private final Multiset<HashedVector> trainingVectors = HashMultiset.create(); + + /** + * This matrix of BITS random vectors is used to compute the Locality Sensitive Hash + * we compute the dot product with these vectors using a matrix multiplication and then use just + * sign of each result as one bit in the hash + */ + private Matrix projection; + + /** + * The search size determines how many top results we retain. We do this because the hash distance + * isn't guaranteed to be entirely monotonic with respect to the real distance. To the extent that + * actual distance is well approximated by hash distance, then the searchSize can be decreased to + * roughly the number of results that you want. + */ + private int searchSize; + + /** + * Controls how the hash limit is raised. 0 means use minimum of distribution, 1 means use first quartile. + * Intermediate values indicate an interpolation should be used. Negative values mean to never increase. + */ + private double hashLimitStrategy = 0.9; + + /** + * Number of evaluations of the full distance between two points that was required. + */ + private int distanceEvaluations = 0; + + /** + * Whether the projection matrix was initialized. This has to be deferred until the size of the vectors is known, + * effectively until the first vector is added. + */ + private boolean initialized = false; + + public LocalitySensitiveHashSearch(DistanceMeasure distanceMeasure, int searchSize) { + super(distanceMeasure); + this.searchSize = searchSize; + this.projection = null; + } + + private void initialize(int numDimensions) { + if (initialized) { + return; + } + initialized = true; + projection = RandomProjector.generateBasisNormal(BITS, numDimensions); + } + + private PriorityQueue<WeightedThing<Vector>> searchInternal(Vector query) { + long queryHash = HashedVector.computeHash64(query, projection); + + // We keep an approximation of the closest vectors here. + PriorityQueue<WeightedThing<Vector>> top = Searcher.getCandidateQueue(getSearchSize()); + + // We scan the vectors using bit counts as an approximation of the dot product so we can do as few + // full distance computations as possible. Our goal is to only do full distance computations for + // vectors with hash distance at most as large as the searchSize biggest hash distance seen so far. + + OnlineSummarizer[] distribution = new OnlineSummarizer[BITS + 1]; + for (int i = 0; i < BITS + 1; i++) { + distribution[i] = new OnlineSummarizer(); + } + + distanceEvaluations = 0; + + // We keep the counts of the hash distances here. This lets us accurately + // judge what hash distance cutoff we should use. + int[] hashCounts = new int[BITS + 1]; + + // Maximum number of different bits to still consider a vector a candidate for nearest neighbor. + // Starts at the maximum number of bits, but decreases and can increase. + int hashLimit = BITS; + int limitCount = 0; + double distanceLimit = Double.POSITIVE_INFINITY; + + // In this loop, we have the invariants that: + // + // limitCount = sum_{i<hashLimit} hashCount[i] + // and + // limitCount >= searchSize && limitCount - hashCount[hashLimit-1] < searchSize + for (HashedVector vector : trainingVectors) { + // This computes the Hamming Distance between the vector's hash and the query's hash. + // The result is correlated with the angle between the vectors. + int bitDot = vector.hammingDistance(queryHash); + if (bitDot <= hashLimit) { + distanceEvaluations++; + + double distance = distanceMeasure.distance(query, vector); + distribution[bitDot].add(distance); + + if (distance < distanceLimit) { + top.insertWithOverflow(new WeightedThing<Vector>(vector, distance)); + if (top.size() == searchSize) { + distanceLimit = top.top().getWeight(); + } + + hashCounts[bitDot]++; + limitCount++; + while (hashLimit > 0 && limitCount - hashCounts[hashLimit - 1] > searchSize) { + hashLimit--; + limitCount -= hashCounts[hashLimit]; + } + + if (hashLimitStrategy >= 0) { + while (hashLimit < MAX_HASH_LIMIT && distribution[hashLimit].getCount() > MIN_DISTRIBUTION_COUNT + && ((1 - hashLimitStrategy) * distribution[hashLimit].getQuartile(0) + + hashLimitStrategy * distribution[hashLimit].getQuartile(1)) < distanceLimit) { + limitCount += hashCounts[hashLimit]; + hashLimit++; + } + } + } + } + } + return top; + } + + @Override + public List<WeightedThing<Vector>> search(Vector query, int limit) { + PriorityQueue<WeightedThing<Vector>> top = searchInternal(query); + List<WeightedThing<Vector>> results = Lists.newArrayListWithExpectedSize(top.size()); + while (top.size() != 0) { + WeightedThing<Vector> wv = top.pop(); + results.add(new WeightedThing<>(((HashedVector) wv.getValue()).getVector(), wv.getWeight())); + } + Collections.reverse(results); + if (limit < results.size()) { + results = results.subList(0, limit); + } + return results; + } + + /** + * Returns the closest vector to the query. + * When only one the nearest vector is needed, use this method, NOT search(query, limit) because + * it's faster (less overhead). + * This is nearly the same as search(). + * + * @param query the vector to search for + * @param differentThanQuery if true, returns the closest vector different than the query (this + * only matters if the query is among the searched vectors), otherwise, + * returns the closest vector to the query (even the same vector). + * @return the weighted vector closest to the query + */ + @Override + public WeightedThing<Vector> searchFirst(Vector query, boolean differentThanQuery) { + // We get the top searchSize neighbors. + PriorityQueue<WeightedThing<Vector>> top = searchInternal(query); + // We then cut the number down to just the best 2. + while (top.size() > 2) { + top.pop(); + } + // If there are fewer than 2 results, we just return the one we have. + if (top.size() < 2) { + return removeHash(top.pop()); + } + // There are exactly 2 results. + WeightedThing<Vector> secondBest = top.pop(); + WeightedThing<Vector> best = top.pop(); + // If the best result is the same as the query, but we don't want to return the query. + if (differentThanQuery && best.getValue().equals(query)) { + best = secondBest; + } + return removeHash(best); + } + + protected static WeightedThing<Vector> removeHash(WeightedThing<Vector> input) { + return new WeightedThing<>(((HashedVector) input.getValue()).getVector(), input.getWeight()); + } + + @Override + public void add(Vector vector) { + initialize(vector.size()); + trainingVectors.add(new HashedVector(vector, projection, HashedVector.INVALID_INDEX, BIT_MASK)); + } + + @Override + public int size() { + return trainingVectors.size(); + } + + public int getSearchSize() { + return searchSize; + } + + public void setSearchSize(int size) { + searchSize = size; + } + + public void setRaiseHashLimitStrategy(double strategy) { + hashLimitStrategy = strategy; + } + + /** + * This is only for testing. + * @return the number of times the actual distance between two vectors was computed. + */ + public int resetEvaluationCount() { + int result = distanceEvaluations; + distanceEvaluations = 0; + return result; + } + + @Override + public Iterator<Vector> iterator() { + return Iterators.transform(trainingVectors.iterator(), new Function<HashedVector, Vector>() { + @Override + public Vector apply(org.apache.mahout.math.neighborhood.HashedVector input) { + Preconditions.checkNotNull(input); + //noinspection ConstantConditions + return input.getVector(); + } + }); + } + + @Override + public boolean remove(Vector v, double epsilon) { + return trainingVectors.remove(new HashedVector(v, projection, HashedVector.INVALID_INDEX, BIT_MASK)); + } + + @Override + public void clear() { + trainingVectors.clear(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/neighborhood/ProjectionSearch.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/neighborhood/ProjectionSearch.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/neighborhood/ProjectionSearch.java new file mode 100644 index 0000000..61a9f56 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/neighborhood/ProjectionSearch.java @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.neighborhood; + +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Set; + +import com.google.common.base.Preconditions; +import com.google.common.collect.AbstractIterator; +import com.google.common.collect.BoundType; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import com.google.common.collect.TreeMultiset; +import org.apache.mahout.math.random.RandomProjector; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.random.WeightedThing; + +/** + * Does approximate nearest neighbor dudes search by projecting the data. + */ +public class ProjectionSearch extends UpdatableSearcher { + + /** + * A lists of tree sets containing the scalar projections of each vector. + * The elements in a TreeMultiset are WeightedThing<Integer>, where the weight is the scalar + * projection of the vector at the index pointed to by the Integer from the referenceVectors list + * on the basis vector whose index is the same as the index of the TreeSet in the List. + */ + private List<TreeMultiset<WeightedThing<Vector>>> scalarProjections; + + /** + * The list of random normalized projection vectors forming a basis. + * The TreeSet of scalar projections at index i in scalarProjections corresponds to the vector + * at index i from basisVectors. + */ + private Matrix basisMatrix; + + /** + * The number of elements to consider on both sides in the ball around the vector found by the + * search in a TreeSet from scalarProjections. + */ + private final int searchSize; + + private final int numProjections; + private boolean initialized = false; + + private void initialize(int numDimensions) { + if (initialized) { + return; + } + initialized = true; + basisMatrix = RandomProjector.generateBasisNormal(numProjections, numDimensions); + scalarProjections = Lists.newArrayList(); + for (int i = 0; i < numProjections; ++i) { + scalarProjections.add(TreeMultiset.<WeightedThing<Vector>>create()); + } + } + + public ProjectionSearch(DistanceMeasure distanceMeasure, int numProjections, int searchSize) { + super(distanceMeasure); + Preconditions.checkArgument(numProjections > 0 && numProjections < 100, + "Unreasonable value for number of projections. Must be: 0 < numProjections < 100"); + + this.searchSize = searchSize; + this.numProjections = numProjections; + } + + /** + * Adds a WeightedVector into the set of projections for later searching. + * @param vector The WeightedVector to add. + */ + @Override + public void add(Vector vector) { + initialize(vector.size()); + Vector projection = basisMatrix.times(vector); + // Add the the new vector and the projected distance to each set separately. + int i = 0; + for (TreeMultiset<WeightedThing<Vector>> s : scalarProjections) { + s.add(new WeightedThing<>(vector, projection.get(i++))); + } + int numVectors = scalarProjections.get(0).size(); + for (TreeMultiset<WeightedThing<Vector>> s : scalarProjections) { + Preconditions.checkArgument(s.size() == numVectors, "Number of vectors in projection sets " + + "differ"); + double firstWeight = s.firstEntry().getElement().getWeight(); + for (WeightedThing<Vector> w : s) { + Preconditions.checkArgument(firstWeight <= w.getWeight(), "Weights not in non-decreasing " + + "order"); + firstWeight = w.getWeight(); + } + } + } + + /** + * Returns the number of scalarProjections that we can search + * @return The number of scalarProjections added to the search so far. + */ + @Override + public int size() { + if (scalarProjections == null) { + return 0; + } + return scalarProjections.get(0).size(); + } + + /** + * Searches for the query vector returning the closest limit referenceVectors. + * + * @param query the vector to search for. + * @param limit the number of results to return. + * @return a list of Vectors wrapped in WeightedThings where the "thing"'s weight is the + * distance. + */ + @Override + public List<WeightedThing<Vector>> search(Vector query, int limit) { + Set<Vector> candidates = Sets.newHashSet(); + + Iterator<? extends Vector> projections = basisMatrix.iterator(); + for (TreeMultiset<WeightedThing<Vector>> v : scalarProjections) { + Vector basisVector = projections.next(); + WeightedThing<Vector> projectedQuery = new WeightedThing<>(query, + query.dot(basisVector)); + for (WeightedThing<Vector> candidate : Iterables.concat( + Iterables.limit(v.tailMultiset(projectedQuery, BoundType.CLOSED), searchSize), + Iterables.limit(v.headMultiset(projectedQuery, BoundType.OPEN).descendingMultiset(), searchSize))) { + candidates.add(candidate.getValue()); + } + } + + // If searchSize * scalarProjections.size() is small enough not to cause much memory pressure, + // this is probably just as fast as a priority queue here. + List<WeightedThing<Vector>> top = Lists.newArrayList(); + for (Vector candidate : candidates) { + top.add(new WeightedThing<>(candidate, distanceMeasure.distance(query, candidate))); + } + Collections.sort(top); + return top.subList(0, Math.min(limit, top.size())); + } + + /** + * Returns the closest vector to the query. + * When only one the nearest vector is needed, use this method, NOT search(query, limit) because + * it's faster (less overhead). + * + * @param query the vector to search for + * @param differentThanQuery if true, returns the closest vector different than the query (this + * only matters if the query is among the searched vectors), otherwise, + * returns the closest vector to the query (even the same vector). + * @return the weighted vector closest to the query + */ + @Override + public WeightedThing<Vector> searchFirst(Vector query, boolean differentThanQuery) { + double bestDistance = Double.POSITIVE_INFINITY; + Vector bestVector = null; + + Iterator<? extends Vector> projections = basisMatrix.iterator(); + for (TreeMultiset<WeightedThing<Vector>> v : scalarProjections) { + Vector basisVector = projections.next(); + WeightedThing<Vector> projectedQuery = new WeightedThing<>(query, query.dot(basisVector)); + for (WeightedThing<Vector> candidate : Iterables.concat( + Iterables.limit(v.tailMultiset(projectedQuery, BoundType.CLOSED), searchSize), + Iterables.limit(v.headMultiset(projectedQuery, BoundType.OPEN).descendingMultiset(), searchSize))) { + double distance = distanceMeasure.distance(query, candidate.getValue()); + if (distance < bestDistance && (!differentThanQuery || !candidate.getValue().equals(query))) { + bestDistance = distance; + bestVector = candidate.getValue(); + } + } + } + + return new WeightedThing<>(bestVector, bestDistance); + } + + @Override + public Iterator<Vector> iterator() { + return new AbstractIterator<Vector>() { + private final Iterator<WeightedThing<Vector>> projected = scalarProjections.get(0).iterator(); + @Override + protected Vector computeNext() { + if (!projected.hasNext()) { + return endOfData(); + } + return projected.next().getValue(); + } + }; + } + + @Override + public boolean remove(Vector vector, double epsilon) { + WeightedThing<Vector> toRemove = searchFirst(vector, false); + if (toRemove.getWeight() < epsilon) { + Iterator<? extends Vector> basisVectors = basisMatrix.iterator(); + for (TreeMultiset<WeightedThing<Vector>> projection : scalarProjections) { + if (!projection.remove(new WeightedThing<>(vector, vector.dot(basisVectors.next())))) { + throw new RuntimeException("Internal inconsistency in ProjectionSearch"); + } + } + return true; + } else { + return false; + } + } + + @Override + public void clear() { + if (scalarProjections == null) { + return; + } + for (TreeMultiset<WeightedThing<Vector>> set : scalarProjections) { + set.clear(); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/neighborhood/Searcher.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/neighborhood/Searcher.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/neighborhood/Searcher.java new file mode 100644 index 0000000..dd387b5 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/neighborhood/Searcher.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.neighborhood; + +import java.util.List; + +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; +import org.apache.lucene.util.PriorityQueue; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.math.MatrixSlice; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.WeightedVector; +import org.apache.mahout.math.random.WeightedThing; + +/** + * Describes how to search a bunch of vectors. + * The vectors can be of any type (weighted, sparse, ...) but only the values of the vector matter + * when searching (weights, indices, ...) will not. + * + * When iterating through a Searcher, the Vectors added to it are returned. + */ +public abstract class Searcher implements Iterable<Vector> { + protected DistanceMeasure distanceMeasure; + + protected Searcher(DistanceMeasure distanceMeasure) { + this.distanceMeasure = distanceMeasure; + } + + public DistanceMeasure getDistanceMeasure() { + return distanceMeasure; + } + + /** + * Add a new Vector to the Searcher that will be checked when getting + * the nearest neighbors. + * + * The vector IS NOT CLONED. Do not modify the vector externally otherwise the internal + * Searcher data structures could be invalidated. + */ + public abstract void add(Vector vector); + + /** + * Returns the number of WeightedVectors being searched for nearest neighbors. + */ + public abstract int size(); + + /** + * When querying the Searcher for the closest vectors, a list of WeightedThing<Vector>s is + * returned. The value of the WeightedThing is the neighbor and the weight is the + * the distance (calculated by some metric - see a concrete implementation) between the query + * and neighbor. + * The actual type of vector in the pair is the same as the vector added to the Searcher. + * @param query the vector to search for + * @param limit the number of results to return + * @return the list of weighted vectors closest to the query + */ + public abstract List<WeightedThing<Vector>> search(Vector query, int limit); + + public List<List<WeightedThing<Vector>>> search(Iterable<? extends Vector> queries, int limit) { + List<List<WeightedThing<Vector>>> results = Lists.newArrayListWithExpectedSize(Iterables.size(queries)); + for (Vector query : queries) { + results.add(search(query, limit)); + } + return results; + } + + /** + * Returns the closest vector to the query. + * When only one the nearest vector is needed, use this method, NOT search(query, limit) because + * it's faster (less overhead). + * + * @param query the vector to search for + * @param differentThanQuery if true, returns the closest vector different than the query (this + * only matters if the query is among the searched vectors), otherwise, + * returns the closest vector to the query (even the same vector). + * @return the weighted vector closest to the query + */ + public abstract WeightedThing<Vector> searchFirst(Vector query, boolean differentThanQuery); + + public List<WeightedThing<Vector>> searchFirst(Iterable<? extends Vector> queries, boolean differentThanQuery) { + List<WeightedThing<Vector>> results = Lists.newArrayListWithExpectedSize(Iterables.size(queries)); + for (Vector query : queries) { + results.add(searchFirst(query, differentThanQuery)); + } + return results; + } + + /** + * Adds all the data elements in the Searcher. + * + * @param data an iterable of WeightedVectors to add. + */ + public void addAll(Iterable<? extends Vector> data) { + for (Vector vector : data) { + add(vector); + } + } + + /** + * Adds all the data elements in the Searcher. + * + * @param data an iterable of MatrixSlices to add. + */ + public void addAllMatrixSlices(Iterable<MatrixSlice> data) { + for (MatrixSlice slice : data) { + add(slice.vector()); + } + } + + public void addAllMatrixSlicesAsWeightedVectors(Iterable<MatrixSlice> data) { + for (MatrixSlice slice : data) { + add(new WeightedVector(slice.vector(), 1, slice.index())); + } + } + + public boolean remove(Vector v, double epsilon) { + throw new UnsupportedOperationException("Can't remove a vector from a " + + this.getClass().getName()); + } + + public void clear() { + throw new UnsupportedOperationException("Can't remove vectors from a " + + this.getClass().getName()); + } + + /** + * Returns a bounded size priority queue, in reverse order that keeps track of the best nearest neighbor vectors. + * @param limit maximum size of the heap. + * @return the priority queue. + */ + public static PriorityQueue<WeightedThing<Vector>> getCandidateQueue(int limit) { + return new PriorityQueue<WeightedThing<Vector>>(limit) { + @Override + protected boolean lessThan(WeightedThing<Vector> a, WeightedThing<Vector> b) { + return a.getWeight() > b.getWeight(); + } + }; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/neighborhood/UpdatableSearcher.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/neighborhood/UpdatableSearcher.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/neighborhood/UpdatableSearcher.java new file mode 100644 index 0000000..68365c7 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/neighborhood/UpdatableSearcher.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.neighborhood; + +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.math.Vector; + +/** + * Describes how we search vectors. A class should extend UpdatableSearch only if it can handle a remove function. + */ +public abstract class UpdatableSearcher extends Searcher { + + protected UpdatableSearcher(DistanceMeasure distanceMeasure) { + super(distanceMeasure); + } + + @Override + public abstract boolean remove(Vector v, double epsilon); + + @Override + public abstract void clear(); +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/random/AbstractSamplerFunction.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/random/AbstractSamplerFunction.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/random/AbstractSamplerFunction.java new file mode 100644 index 0000000..d657fd9 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/random/AbstractSamplerFunction.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.random; + +import org.apache.mahout.math.function.DoubleFunction; + +/** + * This shim allows samplers to be used to initialize vectors. + */ +public abstract class AbstractSamplerFunction extends DoubleFunction implements Sampler<Double> { + /** + * Apply the function to the argument and return the result + * + * @param ignored Ignored argument + * @return A sample from this distribution. + */ + @Override + public double apply(double ignored) { + return sample(); + } + + @Override + public abstract Double sample(); +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/random/ChineseRestaurant.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/random/ChineseRestaurant.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/random/ChineseRestaurant.java new file mode 100644 index 0000000..8127b92 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/random/ChineseRestaurant.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.random; + +import com.google.common.base.Preconditions; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.list.DoubleArrayList; + +import java.util.Random; + +/** + * + * Generates samples from a generalized Chinese restaurant process (or Pittman-Yor process). + * + * The number of values drawn exactly once will asymptotically be equal to the discount parameter + * as the total number of draws T increases without bound. The number of unique values sampled will + * increase as O(alpha * log T) if discount = 0 or O(alpha * T^discount) for discount > 0. + */ +public final class ChineseRestaurant implements Sampler<Integer> { + + private final double alpha; + private double weight = 0; + private double discount = 0; + private final DoubleArrayList weights = new DoubleArrayList(); + private final Random rand = RandomUtils.getRandom(); + + /** + * Constructs a Dirichlet process sampler. This is done by setting discount = 0. + * @param alpha The strength parameter for the Dirichlet process. + */ + public ChineseRestaurant(double alpha) { + this(alpha, 0); + } + + /** + * Constructs a Pitman-Yor sampler. + * + * @param alpha The strength parameter that drives the number of unique values as a function of draws. + * @param discount The discount parameter that drives the percentage of values that occur once in a large sample. + */ + public ChineseRestaurant(double alpha, double discount) { + Preconditions.checkArgument(alpha > 0, "Strength Parameter, alpha must be greater then 0!"); + Preconditions.checkArgument(discount >= 0 && discount <= 1, "Must be: 0 <= discount <= 1"); + this.alpha = alpha; + this.discount = discount; + } + + @Override + public Integer sample() { + double u = rand.nextDouble() * (alpha + weight); + for (int j = 0; j < weights.size(); j++) { + // select existing options with probability (w_j - d) / (alpha + w) + if (u < weights.get(j) - discount) { + weights.set(j, weights.get(j) + 1); + weight++; + return j; + } else { + u -= weights.get(j) - discount; + } + } + + // if no existing item selected, pick new item with probability (alpha - d*t) / (alpha + w) + // where t is number of pre-existing cases + weights.add(1); + weight++; + return weights.size() - 1; + } + + /** + * @return the number of unique values that have been returned. + */ + public int size() { + return weights.size(); + } + + /** + * @return the number draws so far. + */ + public int count() { + return (int) weight; + } + + /** + * @param j Which value to test. + * @return The number of times that j has been returned so far. + */ + public int count(int j) { + Preconditions.checkArgument(j >= 0); + + if (j < weights.size()) { + return (int) weights.get(j); + } else { + return 0; + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/random/Empirical.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/random/Empirical.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/random/Empirical.java new file mode 100644 index 0000000..78bfec5 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/random/Empirical.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.random; + +import com.google.common.base.Preconditions; +import org.apache.mahout.common.RandomUtils; + +import java.util.Random; + +/** + * Samples from an empirical cumulative distribution. + */ +public final class Empirical extends AbstractSamplerFunction { + private final Random gen; + private final boolean exceedMinimum; + private final boolean exceedMaximum; + + private final double[] x; + private final double[] y; + private final int n; + + /** + * Sets up a sampler for a specified empirical cumulative distribution function. The distribution + * can have optional exponential tails on either or both ends, but otherwise does a linear + * interpolation between known points. + * + * @param exceedMinimum Should we generate samples less than the smallest quantile (i.e. generate a left tail)? + * @param exceedMaximum Should we generate samples greater than the largest observed quantile (i.e. generate a right + * tail)? + * @param samples The number of samples observed to get the quantiles. + * @param ecdf Alternating values that represent which percentile (in the [0..1] range) + * and values. For instance, if you have the min, median and max of 1, 3, 10 + * you should pass 0.0, 1, 0.5, 3, 1.0, 10. Note that the list must include + * the 0-th (1.0-th) quantile if the left (right) tail is not allowed. + */ + public Empirical(boolean exceedMinimum, boolean exceedMaximum, int samples, double... ecdf) { + Preconditions.checkArgument(ecdf.length % 2 == 0, "ecdf must have an even count of values"); + Preconditions.checkArgument(samples >= 3, "Sample size must be >= 3"); + + // if we can't exceed the observed bounds, then we have to be given the bounds. + Preconditions.checkArgument(exceedMinimum || ecdf[0] == 0); + Preconditions.checkArgument(exceedMaximum || ecdf[ecdf.length - 2] == 1); + + gen = RandomUtils.getRandom(); + + n = ecdf.length / 2; + x = new double[n]; + y = new double[n]; + + double lastX = ecdf[1]; + double lastY = ecdf[0]; + for (int i = 0; i < ecdf.length; i += 2) { + // values have to be monotonic increasing + Preconditions.checkArgument(i == 0 || ecdf[i + 1] > lastY); + y[i / 2] = ecdf[i + 1]; + lastY = y[i / 2]; + + // quantiles have to be in [0,1] and be monotonic increasing + Preconditions.checkArgument(ecdf[i] >= 0 && ecdf[i] <= 1); + Preconditions.checkArgument(i == 0 || ecdf[i] > lastX); + + x[i / 2] = ecdf[i]; + lastX = x[i / 2]; + } + + // squeeze a bit to allow for unobserved tails + double x0 = exceedMinimum ? 0.5 / samples : 0; + double x1 = 1 - (exceedMaximum ? 0.5 / samples : 0); + for (int i = 0; i < n; i++) { + x[i] = x[i] * (x1 - x0) + x0; + } + + this.exceedMinimum = exceedMinimum; + this.exceedMaximum = exceedMaximum; + } + + @Override + public Double sample() { + return sample(gen.nextDouble()); + } + + public double sample(double u) { + if (exceedMinimum && u < x[0]) { + // generate from left tail + if (u == 0) { + u = 1.0e-16; + } + return y[0] + Math.log(u / x[0]) * x[0] * (y[1] - y[0]) / (x[1] - x[0]); + } else if (exceedMaximum && u > x[n - 1]) { + if (u == 1) { + u = 1 - 1.0e-16; + } + // generate from right tail + double dy = y[n - 1] - y[n - 2]; + double dx = x[n - 1] - x[n - 2]; + return y[n - 1] - Math.log((1 - u) / (1 - x[n - 1])) * (1 - x[n - 1]) * dy / dx; + } else { + // linear interpolation + for (int i = 1; i < n; i++) { + if (x[i] > u) { + double dy = y[i] - y[i - 1]; + double dx = x[i] - x[i - 1]; + return y[i - 1] + (u - x[i - 1]) * dy / dx; + } + } + throw new RuntimeException(String.format("Can't happen (%.3f is not in [%.3f,%.3f]", u, x[0], x[n - 1])); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/random/IndianBuffet.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/random/IndianBuffet.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/random/IndianBuffet.java new file mode 100644 index 0000000..27b5d84 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/random/IndianBuffet.java @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.random; + +import com.google.common.base.CharMatcher; +import com.google.common.base.Charsets; +import com.google.common.base.Splitter; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; +import com.google.common.io.LineProcessor; +import com.google.common.io.Resources; +import org.apache.mahout.common.RandomUtils; + +import java.io.IOException; +import java.util.List; +import java.util.Random; + +/** + * Samples a "document" from an IndianBuffet process. + * + * See http://mlg.eng.cam.ac.uk/zoubin/talks/turin09.pdf for details + */ +public final class IndianBuffet<T> implements Sampler<List<T>> { + private final List<Integer> count = Lists.newArrayList(); + private int documents = 0; + private final double alpha; + private WordFunction<T> converter = null; + private final Random gen; + + public IndianBuffet(double alpha, WordFunction<T> converter) { + this.alpha = alpha; + this.converter = converter; + gen = RandomUtils.getRandom(); + } + + public static IndianBuffet<Integer> createIntegerDocumentSampler(double alpha) { + return new IndianBuffet<>(alpha, new IdentityConverter()); + } + + public static IndianBuffet<String> createTextDocumentSampler(double alpha) { + return new IndianBuffet<>(alpha, new WordConverter()); + } + + @Override + public List<T> sample() { + List<T> r = Lists.newArrayList(); + if (documents == 0) { + double n = new PoissonSampler(alpha).sample(); + for (int i = 0; i < n; i++) { + r.add(converter.convert(i)); + count.add(1); + } + documents++; + } else { + documents++; + int i = 0; + for (double cnt : count) { + if (gen.nextDouble() < cnt / documents) { + r.add(converter.convert(i)); + count.set(i, count.get(i) + 1); + } + i++; + } + int newItems = new PoissonSampler(alpha / documents).sample().intValue(); + for (int j = 0; j < newItems; j++) { + r.add(converter.convert(i + j)); + count.add(1); + } + } + return r; + } + + private interface WordFunction<T> { + T convert(int i); + } + + /** + * Just converts to an integer. + */ + public static class IdentityConverter implements WordFunction<Integer> { + @Override + public Integer convert(int i) { + return i; + } + } + + /** + * Converts to a string. + */ + public static class StringConverter implements WordFunction<String> { + @Override + public String convert(int i) { + return String.valueOf(i); + } + } + + /** + * Converts to one of a list of common English words for reasonably small integers and converts + * to a token like w_92463 for big integers. + */ + public static final class WordConverter implements WordFunction<String> { + private final Splitter onSpace = Splitter.on(CharMatcher.WHITESPACE).omitEmptyStrings().trimResults(); + private final List<String> words; + + public WordConverter() { + try { + words = Resources.readLines(Resources.getResource("words.txt"), Charsets.UTF_8, + new LineProcessor<List<String>>() { + private final List<String> theWords = Lists.newArrayList(); + + @Override + public boolean processLine(String line) { + Iterables.addAll(theWords, onSpace.split(line)); + return true; + } + + @Override + public List<String> getResult() { + return theWords; + } + }); + } catch (IOException e) { + throw new ImpossibleException(e); + } + } + + @Override + public String convert(int i) { + if (i < words.size()) { + return words.get(i); + } else { + return "w_" + i; + } + } + } + + public static class ImpossibleException extends RuntimeException { + public ImpossibleException(Throwable e) { + super(e); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/random/Missing.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/random/Missing.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/random/Missing.java new file mode 100644 index 0000000..8141a71 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/random/Missing.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.random; + +import java.util.Random; + +import org.apache.mahout.common.RandomUtils; + +/** + * Models data with missing values. Note that all variables with the same fraction of missing + * values will have the same sequence of missing values. Similarly, if two variables have + * missing probabilities of p1 > p2, then all of the p2 missing values will also be missing for + * p1. + */ +public final class Missing<T> implements Sampler<T> { + private final Random gen; + private final double p; + private final Sampler<T> delegate; + private final T missingMarker; + + public Missing(int seed, double p, Sampler<T> delegate, T missingMarker) { + this.p = p; + this.delegate = delegate; + this.missingMarker = missingMarker; + gen = RandomUtils.getRandom(seed); + } + + public Missing(double p, Sampler<T> delegate, T missingMarker) { + this(1, p, delegate, missingMarker); + } + + public Missing(double p, Sampler<T> delegate) { + this(1, p, delegate, null); + } + + @Override + public T sample() { + if (gen.nextDouble() >= p) { + return delegate.sample(); + } else { + return missingMarker; + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/random/MultiNormal.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/random/MultiNormal.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/random/MultiNormal.java new file mode 100644 index 0000000..748d4e8 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/random/MultiNormal.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.random; + +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.DiagonalMatrix; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.function.DoubleFunction; + +import java.util.Random; + +/** + * Samples from a multi-variate normal distribution. + * <p/> + * This is done by sampling from several independent unit normal distributions to get a vector u. + * The sample value that is returned is then A u + m where A is derived from the covariance matrix + * and m is the mean of the result. + * <p/> + * If \Sigma is the desired covariance matrix, then you can use any value of A such that A' A = + * \Sigma. The Cholesky decomposition can be used to compute A if \Sigma is positive definite. + * Slightly more expensive is to use the SVD U S V' = \Sigma and then set A = U \sqrt{S}. + * + * Useful special cases occur when \Sigma is diagonal so that A = \sqrt(\Sigma) or where \Sigma = r I. + * + * Another special case is where m = 0. + */ +public class MultiNormal implements Sampler<Vector> { + private final Random gen; + private final int dimension; + private final Matrix scale; + private final Vector mean; + + /** + * Constructs a sampler with diagonal scale matrix. + * @param diagonal The diagonal elements of the scale matrix. + */ + public MultiNormal(Vector diagonal) { + this(new DiagonalMatrix(diagonal), null); + } + + /** + * Constructs a sampler with diagonal scale matrix and (potentially) + * non-zero mean. + * @param diagonal The scale matrix's principal diagonal. + * @param mean The desired mean. Set to null if zero mean is desired. + */ + public MultiNormal(Vector diagonal, Vector mean) { + this(new DiagonalMatrix(diagonal), mean); + } + + /** + * Constructs a sampler with non-trivial scale matrix and mean. + */ + public MultiNormal(Matrix a, Vector mean) { + this(a, mean, a.columnSize()); + } + + public MultiNormal(int dimension) { + this(null, null, dimension); + } + + public MultiNormal(double radius, Vector mean) { + this(new DiagonalMatrix(radius, mean.size()), mean); + } + + private MultiNormal(Matrix scale, Vector mean, int dimension) { + gen = RandomUtils.getRandom(); + this.dimension = dimension; + this.scale = scale; + this.mean = mean; + } + + @Override + public Vector sample() { + Vector v = new DenseVector(dimension).assign( + new DoubleFunction() { + @Override + public double apply(double ignored) { + return gen.nextGaussian(); + } + } + ); + if (mean != null) { + if (scale != null) { + return scale.times(v).plus(mean); + } else { + return v.plus(mean); + } + } else { + if (scale != null) { + return scale.times(v); + } else { + return v; + } + } + } + + public Vector getScale() { + return mean; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/random/Multinomial.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/random/Multinomial.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/random/Multinomial.java new file mode 100644 index 0000000..d79c32c --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/random/Multinomial.java @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.random; + +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Random; + +import com.google.common.base.Preconditions; +import com.google.common.collect.AbstractIterator; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Multiset; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.list.DoubleArrayList; + +/** + * Multinomial sampler that allows updates to element probabilities. The basic idea is that sampling is + * done by using a simple balanced tree. Probabilities are kept in the tree so that we can navigate to + * any leaf in log N time. Updates are simple because we can just propagate them upwards. + * <p/> + * In order to facilitate access by value, we maintain an additional map from value to tree node. + */ +public final class Multinomial<T> implements Sampler<T>, Iterable<T> { + // these lists use heap ordering. Thus, the root is at location 1, first level children at 2 and 3, second level + // at 4, 5 and 6, 7. + private final DoubleArrayList weight = new DoubleArrayList(); + private final List<T> values = Lists.newArrayList(); + private final Map<T, Integer> items = Maps.newHashMap(); + private Random rand = RandomUtils.getRandom(); + + public Multinomial() { + weight.add(0); + values.add(null); + } + + public Multinomial(Multiset<T> counts) { + this(); + Preconditions.checkArgument(!counts.isEmpty(), "Need some data to build sampler"); + rand = RandomUtils.getRandom(); + for (T t : counts.elementSet()) { + add(t, counts.count(t)); + } + } + + public Multinomial(Iterable<WeightedThing<T>> things) { + this(); + for (WeightedThing<T> thing : things) { + add(thing.getValue(), thing.getWeight()); + } + } + + public void add(T value, double w) { + Preconditions.checkNotNull(value); + Preconditions.checkArgument(!items.containsKey(value)); + + int n = this.weight.size(); + if (n == 1) { + weight.add(w); + values.add(value); + items.put(value, 1); + } else { + // parent comes down + weight.add(weight.get(n / 2)); + values.add(values.get(n / 2)); + items.put(values.get(n / 2), n); + n++; + + // new item goes in + items.put(value, n); + this.weight.add(w); + values.add(value); + + // parents get incremented all the way to the root + while (n > 1) { + n /= 2; + this.weight.set(n, this.weight.get(n) + w); + } + } + } + + public double getWeight(T value) { + if (items.containsKey(value)) { + return weight.get(items.get(value)); + } else { + return 0; + } + } + + public double getProbability(T value) { + if (items.containsKey(value)) { + return weight.get(items.get(value)) / weight.get(1); + } else { + return 0; + } + } + + public double getWeight() { + if (weight.size() > 1) { + return weight.get(1); + } else { + return 0; + } + } + + public void delete(T value) { + set(value, 0); + } + + public void set(T value, double newP) { + Preconditions.checkArgument(items.containsKey(value)); + int n = items.get(value); + if (newP <= 0) { + // this makes the iterator not see such an element even though we leave a phantom in the tree + // Leaving the phantom behind simplifies tree maintenance and testing, but isn't really necessary. + items.remove(value); + } + double oldP = weight.get(n); + while (n > 0) { + weight.set(n, weight.get(n) - oldP + newP); + n /= 2; + } + } + + @Override + public T sample() { + Preconditions.checkArgument(!weight.isEmpty()); + return sample(rand.nextDouble()); + } + + public T sample(double u) { + u *= weight.get(1); + + int n = 1; + while (2 * n < weight.size()) { + // children are at 2n and 2n+1 + double left = weight.get(2 * n); + if (u <= left) { + n = 2 * n; + } else { + u -= left; + n = 2 * n + 1; + } + } + return values.get(n); + } + + /** + * Exposed for testing only. Returns a list of the leaf weights. These are in an + * order such that probing just before and after the cumulative sum of these weights + * will touch every element of the tree twice and thus will make it possible to test + * every possible left/right decision in navigating the tree. + */ + List<Double> getWeights() { + List<Double> r = Lists.newArrayList(); + int i = Integer.highestOneBit(weight.size()); + while (i < weight.size()) { + r.add(weight.get(i)); + i++; + } + i /= 2; + while (i < Integer.highestOneBit(weight.size())) { + r.add(weight.get(i)); + i++; + } + return r; + } + + @Override + public Iterator<T> iterator() { + return new AbstractIterator<T>() { + Iterator<T> valuesIterator = Iterables.skip(values, 1).iterator(); + @Override + protected T computeNext() { + while (valuesIterator.hasNext()) { + T next = valuesIterator.next(); + if (items.containsKey(next)) { + return next; + } + } + return endOfData(); + } + }; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/random/Normal.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/random/Normal.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/random/Normal.java new file mode 100644 index 0000000..c162f26 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/random/Normal.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.random; + +import org.apache.mahout.common.RandomUtils; + +import java.util.Random; + +public final class Normal extends AbstractSamplerFunction { + private final Random rand = RandomUtils.getRandom(); + private double mean = 0; + private double sd = 1; + + public Normal() {} + + public Normal(double mean, double sd) { + this.mean = mean; + this.sd = sd; + } + + @Override + public Double sample() { + return rand.nextGaussian() * sd + mean; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/random/PoissonSampler.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/random/PoissonSampler.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/random/PoissonSampler.java new file mode 100644 index 0000000..e4e49f8 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/random/PoissonSampler.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.random; + +import com.google.common.collect.Lists; +import org.apache.commons.math3.distribution.PoissonDistribution; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.common.RandomWrapper; + +import java.util.List; + +/** + * Samples from a Poisson distribution. Should probably not be used for lambda > 1000 or so. + */ +public final class PoissonSampler extends AbstractSamplerFunction { + + private double limit; + private Multinomial<Integer> partial; + private final RandomWrapper gen; + private final PoissonDistribution pd; + + public PoissonSampler(double lambda) { + limit = 1; + gen = RandomUtils.getRandom(); + pd = new PoissonDistribution(gen.getRandomGenerator(), + lambda, + PoissonDistribution.DEFAULT_EPSILON, + PoissonDistribution.DEFAULT_MAX_ITERATIONS); + } + + @Override + public Double sample() { + return sample(gen.nextDouble()); + } + + double sample(double u) { + if (u < limit) { + List<WeightedThing<Integer>> steps = Lists.newArrayList(); + limit = 1; + int i = 0; + while (u / 20 < limit) { + double pdf = pd.probability(i); + limit -= pdf; + steps.add(new WeightedThing<>(i, pdf)); + i++; + } + steps.add(new WeightedThing<>(steps.size(), limit)); + partial = new Multinomial<>(steps); + } + return partial.sample(u); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/random/RandomProjector.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/random/RandomProjector.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/random/RandomProjector.java new file mode 100644 index 0000000..79fe4b6 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/random/RandomProjector.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.random; + +import java.util.List; + +import com.google.common.collect.Lists; +import org.apache.commons.lang.math.RandomUtils; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.MatrixSlice; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.function.DoubleFunction; + +public final class RandomProjector { + private RandomProjector() { + } + + /** + * Generates a basis matrix of size projectedVectorSize x vectorSize. Multiplying a a vector by + * this matrix results in the projected vector. + * + * The rows of the matrix are sampled from a multi normal distribution. + * + * @param projectedVectorSize final projected size of a vector (number of projection vectors) + * @param vectorSize initial vector size + * @return a projection matrix + */ + public static Matrix generateBasisNormal(int projectedVectorSize, int vectorSize) { + Matrix basisMatrix = new DenseMatrix(projectedVectorSize, vectorSize); + basisMatrix.assign(new Normal()); + for (MatrixSlice row : basisMatrix) { + row.vector().assign(row.normalize()); + } + return basisMatrix; + } + + /** + * Generates a basis matrix of size projectedVectorSize x vectorSize. Multiplying a a vector by + * this matrix results in the projected vector. + * + * The rows of a matrix are sample from a distribution where: + * - +1 has probability 1/2, + * - -1 has probability 1/2 + * + * See Achlioptas, D. (2003). Database-friendly random projections: Johnson-Lindenstrauss with binary coins. + * Journal of Computer and System Sciences, 66(4), 671â687. doi:10.1016/S0022-0000(03)00025-4 + * + * @param projectedVectorSize final projected size of a vector (number of projection vectors) + * @param vectorSize initial vector size + * @return a projection matrix + */ + public static Matrix generateBasisPlusMinusOne(int projectedVectorSize, int vectorSize) { + Matrix basisMatrix = new DenseMatrix(projectedVectorSize, vectorSize); + for (int i = 0; i < projectedVectorSize; ++i) { + for (int j = 0; j < vectorSize; ++j) { + basisMatrix.set(i, j, RandomUtils.nextInt(2) == 0 ? 1 : -1); + } + } + for (MatrixSlice row : basisMatrix) { + row.vector().assign(row.normalize()); + } + return basisMatrix; + } + + /** + * Generates a basis matrix of size projectedVectorSize x vectorSize. Multiplying a a vector by + * this matrix results in the projected vector. + * + * The rows of a matrix are sample from a distribution where: + * - 0 has probability 2/3, + * - +1 has probability 1/6, + * - -1 has probability 1/6 + * + * See Achlioptas, D. (2003). Database-friendly random projections: Johnson-Lindenstrauss with binary coins. + * Journal of Computer and System Sciences, 66(4), 671â687. doi:10.1016/S0022-0000(03)00025-4 + * + * @param projectedVectorSize final projected size of a vector (number of projection vectors) + * @param vectorSize initial vector size + * @return a projection matrix + */ + public static Matrix generateBasisZeroPlusMinusOne(int projectedVectorSize, int vectorSize) { + Matrix basisMatrix = new DenseMatrix(projectedVectorSize, vectorSize); + Multinomial<Double> choice = new Multinomial<>(); + choice.add(0.0, 2 / 3.0); + choice.add(Math.sqrt(3.0), 1 / 6.0); + choice.add(-Math.sqrt(3.0), 1 / 6.0); + for (int i = 0; i < projectedVectorSize; ++i) { + for (int j = 0; j < vectorSize; ++j) { + basisMatrix.set(i, j, choice.sample()); + } + } + for (MatrixSlice row : basisMatrix) { + row.vector().assign(row.normalize()); + } + return basisMatrix; + } + + /** + * Generates a list of projectedVectorSize vectors, each of size vectorSize. This looks like a + * matrix of size (projectedVectorSize, vectorSize). + * @param projectedVectorSize final projected size of a vector (number of projection vectors) + * @param vectorSize initial vector size + * @return a list of projection vectors + */ + public static List<Vector> generateVectorBasis(int projectedVectorSize, int vectorSize) { + DoubleFunction random = new Normal(); + List<Vector> basisVectors = Lists.newArrayList(); + for (int i = 0; i < projectedVectorSize; ++i) { + Vector basisVector = new DenseVector(vectorSize); + basisVector.assign(random); + basisVector.normalize(); + basisVectors.add(basisVector); + } + return basisVectors; + } +}
