Author: dfilimon
Date: Fri May 17 09:04:12 2013
New Revision: 1483702
URL: http://svn.apache.org/r1483702
Log:
MAHOUT-1216: Add locality sensitive hashing and a LocalitySensitiveHash searcher
This issue tackles the LocalitySensitiveHashSearch, that was initially supposed
to be part of MAHOUT-1156.
It adds HashedVector, the class that adds the LSH to vectors, a new searcher
(although a better implementation is possible) and adds support in the existing
tests and new StreamingKMeans infrastructure.
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/math/neighborhood/HashedVector.java
mahout/trunk/core/src/main/java/org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearch.java
mahout/trunk/core/src/test/java/org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearchTest.java
Modified:
mahout/trunk/CHANGELOG
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansUtilsMR.java
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeansTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansTestMR.java
mahout/trunk/core/src/test/java/org/apache/mahout/math/neighborhood/SearchQualityTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/math/neighborhood/SearchSanityTest.java
Modified: mahout/trunk/CHANGELOG
URL:
http://svn.apache.org/viewvc/mahout/trunk/CHANGELOG?rev=1483702&r1=1483701&r2=1483702&view=diff
==============================================================================
--- mahout/trunk/CHANGELOG (original)
+++ mahout/trunk/CHANGELOG Fri May 17 09:04:12 2013
@@ -2,6 +2,8 @@ Mahout Change Log
Release 0.8 - unreleased
+__MAHOUT-1216: Add locality sensitive hashing and a LocalitySensitiveHash
searcher (dfilimon)
+
__MAHOUT-1181: Adding StreamingKMeans MapReduce classes (dfilimon)
MAHOUT-1212: Incorrect classify-20newsgroups.sh file description (Julian
Ortega via smarthi)
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansUtilsMR.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansUtilsMR.java?rev=1483702&r1=1483701&r2=1483702&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansUtilsMR.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansUtilsMR.java
Fri May 17 09:04:12 2013
@@ -20,6 +20,7 @@ import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.neighborhood.BruteSearch;
import org.apache.mahout.math.neighborhood.FastProjectionSearch;
+import org.apache.mahout.math.neighborhood.LocalitySensitiveHashSearch;
import org.apache.mahout.math.neighborhood.ProjectionSearch;
import org.apache.mahout.math.neighborhood.UpdatableSearcher;
@@ -54,6 +55,10 @@ public class StreamingKMeansUtilsMR {
return ClassUtils.instantiateAs(searcherClass, UpdatableSearcher.class,
new Class[]{DistanceMeasure.class, int.class, int.class},
new Object[]{distanceMeasure, numProjections, searchSize});
+ } else if
(searcherClass.equals(LocalitySensitiveHashSearch.class.getName())) {
+ return ClassUtils.instantiateAs(searcherClass,
LocalitySensitiveHashSearch.class,
+ new Class[]{DistanceMeasure.class, int.class},
+ new Object[]{distanceMeasure, searchSize});
} else {
throw new IllegalStateException("Unknown class instantiation requested");
}
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/math/neighborhood/HashedVector.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/neighborhood/HashedVector.java?rev=1483702&view=auto
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/math/neighborhood/HashedVector.java
(added)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/math/neighborhood/HashedVector.java
Fri May 17 09:04:12 2013
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.neighborhood;
+
+import java.util.Iterator;
+
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.WeightedVector;
+
+/**
+ * Decorates a weighted vector with a locality sensitive hash.
+ *
+ * The LSH function implemented is the random hyperplane based hash function.
+ * See "Similarity Estimation Techniques from Rounding Algorithms" by Moses S.
Charikar, section 3.
+ *
http://www.cs.princeton.edu/courses/archive/spring04/cos598B/bib/CharikarEstim.pdf
+ */
+public class HashedVector extends WeightedVector {
+ protected static int INVALID_INDEX = -1;
+
+ /**
+ * Value of the locality sensitive hash. It is 64 bit.
+ */
+ private long hash;
+
+ public HashedVector(Vector vector, long hash, int index) {
+ super(vector, 1, index);
+ this.hash = hash;
+ }
+
+ public HashedVector(Vector vector, Matrix projection, int index, long mask) {
+ super(vector, 1, index);
+ this.hash = mask & computeHash64(vector, projection);
+ }
+
+ public HashedVector(WeightedVector weightedVector, Matrix projection, long
mask) {
+ super(weightedVector.getVector(), weightedVector.getWeight(),
weightedVector.getIndex());
+ this.hash = mask & computeHash64(weightedVector, projection);
+ }
+
+ public static long computeHash64(Vector vector, Matrix projection) {
+ long hash = 0;
+ Iterator<Element> iterator = projection.times(vector).iterateNonZero();
+ Element element;
+ while (iterator.hasNext()) {
+ element = iterator.next();
+ if (element.get() > 0) {
+ hash += 1 << element.index();
+ }
+ }
+ return hash;
+ }
+
+ public static HashedVector hash(WeightedVector v, Matrix projection) {
+ return hash(v, projection, 0);
+ }
+
+ public static HashedVector hash(WeightedVector v, Matrix projection, long
mask) {
+ return new HashedVector(v, projection, mask);
+ }
+
+ public int hammingDistance(long otherHash) {
+ return Long.bitCount(hash ^ otherHash);
+ }
+
+ public long getHash() {
+ return hash;
+ }
+
+ @Override
+ public String toString() {
+ return String.format("index=%d, hash=%08x, v=%s", getIndex(), hash,
getVector());
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (!(o instanceof HashedVector)) {
+ return o instanceof Vector && this.minus((Vector) o).norm(1) == 0;
+ } else {
+ HashedVector v = (HashedVector) o;
+ return v.hash == this.hash && this.minus(v).norm(1) == 0;
+ }
+ }
+
+ @Override
+ public int hashCode() {
+ int result = super.hashCode();
+ result = 31 * result + (int) (hash ^ (hash >>> 32));
+ return result;
+ }
+}
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearch.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearch.java?rev=1483702&view=auto
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearch.java
(added)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearch.java
Fri May 17 09:04:12 2013
@@ -0,0 +1,273 @@
+package org.apache.mahout.math.neighborhood;
+
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+
+import com.google.common.base.Function;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.HashMultiset;
+import com.google.common.collect.Iterators;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Multiset;
+import org.apache.lucene.util.PriorityQueue;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.random.RandomProjector;
+import org.apache.mahout.math.random.WeightedThing;
+import org.apache.mahout.math.stats.OnlineSummarizer;
+
+/**
+ * Implements a Searcher that uses locality sensitivity hash as a first pass
approximation
+ * to estimate distance without floating point math. The clever bit about
this implementation
+ * is that it does an adaptive cutoff for the cutoff on the bitwise distance.
Making this
+ * cutoff adaptive means that we only needs to make a single pass through the
data.
+ */
+public class LocalitySensitiveHashSearch extends UpdatableSearcher implements
Iterable<Vector> {
+ /**
+ * Number of bits in the locality sensitive hash. 64 bits fix neatly into a
long.
+ */
+ private static final int BITS = 64;
+
+ /**
+ * Bit mask for the computed hash. Currently, it's 0xffffffffffff.
+ */
+ private static final long BIT_MASK = -1L;
+
+ /**
+ * The maximum Hamming distance between two hashes that the hash limit can
grow back to.
+ * It starts at BITS and decreases as more points than are needed are added
to the candidate priority queue.
+ * But, after the observed distribution of distances becomes too good (we're
seeing less than some percentage of the
+ * total number of points; using the hash strategy somewhere less than 25%)
the limit is increased to compute
+ * more distances.
+ * This is because
+ */
+ private static final int MAX_HASH_LIMIT = 32;
+
+ /**
+ * Minimum number of points with a given Hamming from the query that must be
observed to consider raising the minimum
+ * distance for a candidate.
+ */
+ private static final int MIN_DISTRIBUTION_COUNT = 10;
+
+ private Multiset<HashedVector> trainingVectors = HashMultiset.create();
+
+ /**
+ * This matrix of BITS random vectors is used to compute the Locality
Sensitive Hash
+ * we compute the dot product with these vectors using a matrix
multiplication and then use just
+ * sign of each result as one bit in the hash
+ */
+ private Matrix projection;
+
+ /**
+ * The search size determines how many top results we retain. We do this
because the hash distance
+ * isn't guaranteed to be entirely monotonic with respect to the real
distance. To the extent that
+ * actual distance is well approximated by hash distance, then the
searchSize can be decreased to
+ * roughly the number of results that you want.
+ */
+ private int searchSize;
+
+ /**
+ * Controls how the hash limit is raised. 0 means use minimum of
distribution, 1 means use first quartile.
+ * Intermediate values indicate an interpolation should be used. Negative
values mean to never increase.
+ */
+ private double hashLimitStrategy = 0.9;
+
+ /**
+ * Number of evaluations of the full distance between two points that was
required.
+ */
+ private int distanceEvaluations = 0;
+
+ /**
+ * Whether the projection matrix was initialized. This has to be deferred
until the size of the vectors is known,
+ * effectively until the first vector is added.
+ */
+ private boolean initialized = false;
+
+ public LocalitySensitiveHashSearch(DistanceMeasure distanceMeasure, int
searchSize) {
+ super(distanceMeasure);
+ this.searchSize = searchSize;
+ this.projection = null;
+ }
+
+ private void initialize(int numDimensions) {
+ if (initialized) {
+ return;
+ }
+ initialized = true;
+ projection = RandomProjector.generateBasisNormal(BITS, numDimensions);
+ }
+
+ private PriorityQueue<WeightedThing<Vector>> searchInternal(Vector query) {
+ long queryHash = HashedVector.computeHash64(query, projection);
+
+ // We keep an approximation of the closest vectors here.
+ PriorityQueue<WeightedThing<Vector>> top =
Searcher.getCandidateQueue(getSearchSize());
+
+ // We keep the counts of the hash distances here. This lets us accurately
+ // judge what hash distance cutoff we should use.
+ int[] hashCounts = new int[BITS + 1];
+
+ // We scan the vectors using bit counts as an approximation of the dot
product so we can do as few
+ // full distance computations as possible. Our goal is to only do full
distance computations for
+ // vectors with hash distance at most as large as the searchSize biggest
hash distance seen so far.
+
+ OnlineSummarizer[] distribution = new OnlineSummarizer[BITS + 1];
+ for (int i = 0; i < BITS + 1; i++) {
+ distribution[i] = new OnlineSummarizer();
+ }
+
+ // Maximum number of different bits to still consider a vector a candidate
for nearest neighbor.
+ // Starts at the maximum number of bits, but decreases and can increase.
+ int hashLimit = BITS;
+ int limitCount = 0;
+ double distanceLimit = Double.POSITIVE_INFINITY;
+ distanceEvaluations = 0;
+
+ // In this loop, we have the invariants that:
+ //
+ // limitCount = sum_{i<hashLimit} hashCount[i]
+ // and
+ // limitCount >= searchSize && limitCount - hashCount[hashLimit-1] <
searchSize
+ for (HashedVector vector : trainingVectors) {
+ // This computes the Hamming Distance between the vector's hash and the
query's hash.
+ // The result is correlated with the angle between the vectors.
+ int bitDot = vector.hammingDistance(queryHash);
+ if (bitDot <= hashLimit) {
+ distanceEvaluations++;
+
+ double distance = distanceMeasure.distance(query, vector);
+ distribution[bitDot].add(distance);
+
+ if (distance < distanceLimit) {
+ top.insertWithOverflow(new WeightedThing<Vector>(vector, distance));
+ if (top.size() == searchSize) {
+ distanceLimit = top.top().getWeight();
+ }
+
+ hashCounts[bitDot]++;
+ limitCount++;
+ while (hashLimit > 0 && limitCount - hashCounts[hashLimit - 1] >
searchSize) {
+ hashLimit--;
+ limitCount -= hashCounts[hashLimit];
+ }
+
+ if (hashLimitStrategy >= 0) {
+ while (hashLimit < MAX_HASH_LIMIT &&
distribution[hashLimit].getCount() > MIN_DISTRIBUTION_COUNT
+ && ((1 - hashLimitStrategy) *
distribution[hashLimit].getQuartile(0)
+ + hashLimitStrategy * distribution[hashLimit].getQuartile(1))
< distanceLimit) {
+ limitCount += hashCounts[hashLimit];
+ hashLimit++;
+ }
+ }
+ }
+ }
+ }
+ return top;
+ }
+
+ @Override
+ public List<WeightedThing<Vector>> search(Vector query, int limit) {
+ PriorityQueue<WeightedThing<Vector>> top = searchInternal(query);
+ List<WeightedThing<Vector>> results =
Lists.newArrayListWithExpectedSize(limit);
+ while (limit > 0 && top.size() != 0) {
+ WeightedThing<Vector> wv = top.pop();
+ results.add(new WeightedThing<Vector>(((HashedVector)
wv.getValue()).getVector(), wv.getWeight()));
+ }
+ Collections.reverse(results);
+ return results;
+ }
+
+ /**
+ * Returns the closest vector to the query.
+ * When only one the nearest vector is needed, use this method, NOT
search(query, limit) because
+ * it's faster (less overhead).
+ * This is nearly the same as search().
+ *
+ * @param query the vector to search for
+ * @param differentThanQuery if true, returns the closest vector different
than the query (this
+ * only matters if the query is among the searched
vectors), otherwise,
+ * returns the closest vector to the query (even
the same vector).
+ * @return the weighted vector closest to the query
+ */
+ @Override
+ public WeightedThing<Vector> searchFirst(Vector query, boolean
differentThanQuery) {
+ // We get the top searchSize neighbors.
+ PriorityQueue<WeightedThing<Vector>> top = searchInternal(query);
+ // We then cut the number down to just the best 2.
+ while (top.size() > 2) {
+ top.pop();
+ }
+ // If there are fewer than 2 results, we just return the one we have.
+ if (top.size() < 2) {
+ return removeHash(top.pop());
+ }
+ // There are exactly 2 results.
+ WeightedThing<Vector> secondBest = top.pop();
+ WeightedThing<Vector> best = top.pop();
+ // If the best result is the same as the query, but we don't want to
return the query.
+ if (differentThanQuery && best.getValue().equals(query)) {
+ best = secondBest;
+ }
+ return removeHash(best);
+ }
+
+ protected WeightedThing<Vector> removeHash(WeightedThing<Vector> input) {
+ return new WeightedThing<Vector>(((HashedVector)
input.getValue()).getVector(), input.getWeight());
+ }
+
+ @Override
+ public void add(Vector vector) {
+ initialize(vector.size());
+ trainingVectors.add(new HashedVector(vector, projection,
HashedVector.INVALID_INDEX, BIT_MASK));
+ }
+
+ public int size() {
+ return trainingVectors.size();
+ }
+
+ public int getSearchSize() {
+ return searchSize;
+ }
+
+ public void setSearchSize(int size) {
+ searchSize = size;
+ }
+
+ public void setRaiseHashLimitStrategy(double strategy) {
+ hashLimitStrategy = strategy;
+ }
+
+ /**
+ * This is only for testing.
+ * @return the number of times the actual distance between two vectors was
computed.
+ */
+ public int resetEvaluationCount() {
+ int result = distanceEvaluations;
+ distanceEvaluations = 0;
+ return result;
+ }
+
+ @Override
+ public Iterator<Vector> iterator() {
+ return Iterators.transform(trainingVectors.iterator(), new
Function<HashedVector, Vector>() {
+ @Override
+ public Vector apply(org.apache.mahout.math.neighborhood.HashedVector
input) {
+ Preconditions.checkNotNull(input);
+ //noinspection ConstantConditions
+ return input.getVector();
+ }
+ });
+ }
+
+ @Override
+ public boolean remove(Vector v, double epsilon) {
+ return trainingVectors.remove(new HashedVector(v, projection,
HashedVector.INVALID_INDEX, BIT_MASK));
+ }
+
+ @Override
+ public void clear() {
+ trainingVectors.clear();
+ }
+}
Modified:
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeansTest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeansTest.java?rev=1483702&r1=1483701&r2=1483702&view=diff
==============================================================================
---
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeansTest.java
(original)
+++
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeansTest.java
Fri May 17 09:04:12 2013
@@ -29,6 +29,7 @@ import org.apache.mahout.math.Centroid;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.neighborhood.BruteSearch;
import org.apache.mahout.math.neighborhood.FastProjectionSearch;
+import org.apache.mahout.math.neighborhood.LocalitySensitiveHashSearch;
import org.apache.mahout.math.neighborhood.ProjectionSearch;
import org.apache.mahout.math.neighborhood.Searcher;
import org.apache.mahout.math.neighborhood.UpdatableSearcher;
@@ -66,9 +67,11 @@ public class StreamingKMeansTest {
{new ProjectionSearch(new SquaredEuclideanDistanceMeasure(),
NUM_PROJECTIONS, SEARCH_SIZE), true},
{new FastProjectionSearch(new SquaredEuclideanDistanceMeasure(),
NUM_PROJECTIONS, SEARCH_SIZE),
true},
+ {new LocalitySensitiveHashSearch(new
SquaredEuclideanDistanceMeasure(), SEARCH_SIZE), true},
{new ProjectionSearch(new SquaredEuclideanDistanceMeasure(),
NUM_PROJECTIONS, SEARCH_SIZE), false},
{new FastProjectionSearch(new SquaredEuclideanDistanceMeasure(),
NUM_PROJECTIONS, SEARCH_SIZE),
false},
+ {new LocalitySensitiveHashSearch(new
SquaredEuclideanDistanceMeasure(), SEARCH_SIZE), false},
});
}
Modified:
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansTestMR.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansTestMR.java?rev=1483702&r1=1483701&r2=1483702&view=diff
==============================================================================
---
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansTestMR.java
(original)
+++
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansTestMR.java
Fri May 17 09:04:12 2013
@@ -44,6 +44,7 @@ import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.neighborhood.BruteSearch;
import org.apache.mahout.math.neighborhood.FastProjectionSearch;
+import org.apache.mahout.math.neighborhood.LocalitySensitiveHashSearch;
import org.apache.mahout.math.neighborhood.ProjectionSearch;
import org.apache.mahout.math.random.WeightedThing;
import org.junit.Test;
@@ -90,6 +91,7 @@ public class StreamingKMeansTestMR {
return Arrays.asList(new Object[][]{
{ProjectionSearch.class.getName(),
SquaredEuclideanDistanceMeasure.class.getName()},
{FastProjectionSearch.class.getName(),
SquaredEuclideanDistanceMeasure.class.getName()},
+ {LocalitySensitiveHashSearch.class.getName(),
SquaredEuclideanDistanceMeasure.class.getName()},
});
}
Added:
mahout/trunk/core/src/test/java/org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearchTest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearchTest.java?rev=1483702&view=auto
==============================================================================
---
mahout/trunk/core/src/test/java/org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearchTest.java
(added)
+++
mahout/trunk/core/src/test/java/org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearchTest.java
Fri May 17 09:04:12 2013
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.math.neighborhood;
+
+import java.util.BitSet;
+import java.util.List;
+
+import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.WeightedVector;
+import org.apache.mahout.math.random.Normal;
+import org.apache.mahout.math.random.WeightedThing;
+import org.apache.mahout.math.stats.OnlineSummarizer;
+import org.junit.Test;
+
+import static org.junit.Assert.assertTrue;
+
+public class LocalitySensitiveHashSearchTest {
+
+ @Test
+ public void testNormal() {
+ Matrix testData = new DenseMatrix(100000, 10);
+ final Normal gen = new Normal();
+ testData.assign(gen);
+
+ final EuclideanDistanceMeasure distance = new EuclideanDistanceMeasure();
+ BruteSearch ref = new BruteSearch(distance);
+ ref.addAllMatrixSlicesAsWeightedVectors(testData);
+
+ LocalitySensitiveHashSearch cut = new
LocalitySensitiveHashSearch(distance, 10);
+ cut.addAllMatrixSlicesAsWeightedVectors(testData);
+
+ cut.setSearchSize(200);
+ cut.resetEvaluationCount();
+
+ System.out.printf("speedup,q1,q2,q3\n");
+
+ for (int i = 0; i < 12; i++) {
+ double strategy = (i - 1.0) / 10.0;
+ cut.setRaiseHashLimitStrategy(strategy);
+ OnlineSummarizer t1 = evaluateStrategy(testData, ref, cut);
+ int evals = cut.resetEvaluationCount();
+ final double speedup = 10e6 / evals;
+ System.out.printf("%.1f,%.2f,%.2f,%.2f\n", speedup, t1.getQuartile(1),
+ t1.getQuartile(2), t1.getQuartile(3));
+ assertTrue(t1.getQuartile(2) > 0.45);
+ assertTrue(speedup > 4 || t1.getQuartile(2) > 0.9);
+ assertTrue(speedup > 15 || t1.getQuartile(2) > 0.8);
+ }
+ }
+
+ private OnlineSummarizer evaluateStrategy(Matrix testData, BruteSearch ref,
+ LocalitySensitiveHashSearch cut) {
+ OnlineSummarizer t1 = new OnlineSummarizer();
+
+ for (int i = 0; i < 100; i++) {
+ final Vector q = testData.viewRow(i);
+ List<WeightedThing<Vector>> v1 = cut.search(q, 150);
+ BitSet b1 = new BitSet();
+ for (WeightedThing<Vector> v : v1) {
+ b1.set(((WeightedVector)v.getValue()).getIndex());
+ }
+
+ List<WeightedThing<Vector>> v2 = ref.search(q, 100);
+ BitSet b2 = new BitSet();
+ for (WeightedThing<Vector> v : v2) {
+ b2.set(((WeightedVector)v.getValue()).getIndex());
+ }
+
+ b1.and(b2);
+ t1.add(b1.cardinality());
+ }
+ return t1;
+ }
+
+ @Test
+ public void testDotCorrelation() {
+ final Normal gen = new Normal();
+
+ Matrix projection = new DenseMatrix(64, 10);
+ projection.assign(gen);
+
+ Vector query = new DenseVector(10);
+ query.assign(gen);
+ long qhash = HashedVector.computeHash64(query, projection);
+
+ int count[] = new int[65];
+ Vector v = new DenseVector(10);
+ for (int i = 0; i <500000; i++) {
+ v.assign(gen);
+ long hash = HashedVector.computeHash64(v, projection);
+ final int bitDot = Long.bitCount(qhash ^ hash);
+ count[bitDot]++;
+ if (count[bitDot] < 200) {
+ System.out.printf("%d, %.3f\n", bitDot, v.dot(query) /
Math.sqrt(v.getLengthSquared() * query.getLengthSquared()));
+ }
+ }
+ for (int i = 0; i < 65; ++i) {
+ System.out.printf("%d, %d\n", i, count[i]);
+ }
+ }
+}
Modified:
mahout/trunk/core/src/test/java/org/apache/mahout/math/neighborhood/SearchQualityTest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/neighborhood/SearchQualityTest.java?rev=1483702&r1=1483701&r2=1483702&view=diff
==============================================================================
---
mahout/trunk/core/src/test/java/org/apache/mahout/math/neighborhood/SearchQualityTest.java
(original)
+++
mahout/trunk/core/src/test/java/org/apache/mahout/math/neighborhood/SearchQualityTest.java
Fri May 17 09:04:12 2013
@@ -59,6 +59,10 @@ public class SearchQualityTest {
// NUM_PROJECTIONS = 5
// SEARCH_SIZE = 5
{new ProjectionSearch(distanceMeasure, 5, 5), dataPoints, queries,
reference, referenceSearchFirst},
+ {new FastProjectionSearch(distanceMeasure, 5, 5), dataPoints, queries,
reference, referenceSearchFirst},
+ {new LocalitySensitiveHashSearch(distanceMeasure, 5), dataPoints,
queries, reference, referenceSearchFirst},
+ // SEARCH_SIZE = 2
+ {new LocalitySensitiveHashSearch(distanceMeasure, 2), dataPoints,
queries, reference, referenceSearchFirst},
}
);
}
Modified:
mahout/trunk/core/src/test/java/org/apache/mahout/math/neighborhood/SearchSanityTest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/neighborhood/SearchSanityTest.java?rev=1483702&r1=1483701&r2=1483702&view=diff
==============================================================================
---
mahout/trunk/core/src/test/java/org/apache/mahout/math/neighborhood/SearchSanityTest.java
(original)
+++
mahout/trunk/core/src/test/java/org/apache/mahout/math/neighborhood/SearchSanityTest.java
Fri May 17 09:04:12 2013
@@ -61,6 +61,7 @@ public class SearchSanityTest extends Ma
{new ProjectionSearch(new EuclideanDistanceMeasure(), NUM_PROJECTIONS,
SEARCH_SIZE), dataPoints},
{new FastProjectionSearch(new EuclideanDistanceMeasure(),
NUM_PROJECTIONS, SEARCH_SIZE),
dataPoints},
+ {new LocalitySensitiveHashSearch(new EuclideanDistanceMeasure(),
SEARCH_SIZE), dataPoints},
});
}
@@ -185,4 +186,21 @@ public class SearchSanityTest extends Ma
}
}
}
+
+ @Test
+ public void testSearchFirst() {
+ searcher.clear();
+ searcher.addAll(dataPoints);
+ for (Vector datapoint : dataPoints) {
+ WeightedThing<Vector> first = searcher.searchFirst(datapoint, false);
+ WeightedThing<Vector> second = searcher.searchFirst(datapoint, true);
+ List<WeightedThing<Vector>> firstTwo = searcher.search(datapoint, 2);
+
+ assertEquals("First isn't self", 0, first.getWeight(), 0);
+ assertEquals("First isn't self", datapoint, first.getValue());
+ assertEquals("First doesn't match", first, firstTwo.get(0));
+ assertEquals(String.format("Second doesn't match got %f expected %f",
second.getWeight(), firstTwo.get(1).getWeight()),
+ second, firstTwo.get(1));
+ }
+ }
}