Author: dfilimon
Date: Fri May 17 09:04:12 2013
New Revision: 1483702

URL: http://svn.apache.org/r1483702
Log:
MAHOUT-1216: Add locality sensitive hashing and a LocalitySensitiveHash searcher

This issue tackles the LocalitySensitiveHashSearch, that was initially supposed
to be part of MAHOUT-1156.
It adds HashedVector, the class that adds the LSH to vectors, a new searcher
(although a better implementation is possible) and adds support in the existing
tests and new StreamingKMeans infrastructure.


Added:
    
mahout/trunk/core/src/main/java/org/apache/mahout/math/neighborhood/HashedVector.java
    
mahout/trunk/core/src/main/java/org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearch.java
    
mahout/trunk/core/src/test/java/org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearchTest.java
Modified:
    mahout/trunk/CHANGELOG
    
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansUtilsMR.java
    
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeansTest.java
    
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansTestMR.java
    
mahout/trunk/core/src/test/java/org/apache/mahout/math/neighborhood/SearchQualityTest.java
    
mahout/trunk/core/src/test/java/org/apache/mahout/math/neighborhood/SearchSanityTest.java

Modified: mahout/trunk/CHANGELOG
URL: 
http://svn.apache.org/viewvc/mahout/trunk/CHANGELOG?rev=1483702&r1=1483701&r2=1483702&view=diff
==============================================================================
--- mahout/trunk/CHANGELOG (original)
+++ mahout/trunk/CHANGELOG Fri May 17 09:04:12 2013
@@ -2,6 +2,8 @@ Mahout Change Log
 
 Release 0.8 - unreleased
 
+__MAHOUT-1216: Add locality sensitive hashing and a LocalitySensitiveHash 
searcher (dfilimon)
+
 __MAHOUT-1181: Adding StreamingKMeans MapReduce classes (dfilimon)
 
   MAHOUT-1212: Incorrect classify-20newsgroups.sh file description (Julian 
Ortega via smarthi)

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansUtilsMR.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansUtilsMR.java?rev=1483702&r1=1483701&r2=1483702&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansUtilsMR.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansUtilsMR.java
 Fri May 17 09:04:12 2013
@@ -20,6 +20,7 @@ import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
 import org.apache.mahout.math.neighborhood.BruteSearch;
 import org.apache.mahout.math.neighborhood.FastProjectionSearch;
+import org.apache.mahout.math.neighborhood.LocalitySensitiveHashSearch;
 import org.apache.mahout.math.neighborhood.ProjectionSearch;
 import org.apache.mahout.math.neighborhood.UpdatableSearcher;
 
@@ -54,6 +55,10 @@ public class StreamingKMeansUtilsMR {
       return ClassUtils.instantiateAs(searcherClass, UpdatableSearcher.class,
           new Class[]{DistanceMeasure.class, int.class, int.class},
           new Object[]{distanceMeasure, numProjections, searchSize});
+    } else if 
(searcherClass.equals(LocalitySensitiveHashSearch.class.getName())) {
+      return ClassUtils.instantiateAs(searcherClass, 
LocalitySensitiveHashSearch.class,
+          new Class[]{DistanceMeasure.class, int.class},
+          new Object[]{distanceMeasure, searchSize});
     } else {
       throw new IllegalStateException("Unknown class instantiation requested");
     }

Added: 
mahout/trunk/core/src/main/java/org/apache/mahout/math/neighborhood/HashedVector.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/neighborhood/HashedVector.java?rev=1483702&view=auto
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/math/neighborhood/HashedVector.java
 (added)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/math/neighborhood/HashedVector.java
 Fri May 17 09:04:12 2013
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.neighborhood;
+
+import java.util.Iterator;
+
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.WeightedVector;
+
+/**
+ * Decorates a weighted vector with a locality sensitive hash.
+ *
+ * The LSH function implemented is the random hyperplane based hash function.
+ * See "Similarity Estimation Techniques from Rounding Algorithms" by Moses S. 
Charikar, section 3.
+ * 
http://www.cs.princeton.edu/courses/archive/spring04/cos598B/bib/CharikarEstim.pdf
+ */
+public class HashedVector extends WeightedVector {
+  protected static int INVALID_INDEX = -1;
+
+  /**
+   * Value of the locality sensitive hash. It is 64 bit.
+   */
+  private long hash;
+
+  public HashedVector(Vector vector, long hash, int index) {
+    super(vector, 1, index);
+    this.hash = hash;
+  }
+
+  public HashedVector(Vector vector, Matrix projection, int index, long mask) {
+    super(vector, 1, index);
+    this.hash = mask & computeHash64(vector, projection);
+  }
+
+  public HashedVector(WeightedVector weightedVector, Matrix projection, long 
mask) {
+    super(weightedVector.getVector(), weightedVector.getWeight(), 
weightedVector.getIndex());
+    this.hash = mask & computeHash64(weightedVector, projection);
+  }
+
+  public static long computeHash64(Vector vector, Matrix projection) {
+    long hash = 0;
+    Iterator<Element> iterator = projection.times(vector).iterateNonZero();
+    Element element;
+    while (iterator.hasNext()) {
+      element = iterator.next();
+      if (element.get() > 0) {
+        hash += 1 << element.index();
+      }
+    }
+    return hash;
+  }
+
+  public static HashedVector hash(WeightedVector v, Matrix projection) {
+    return hash(v, projection, 0);
+  }
+
+  public static HashedVector hash(WeightedVector v, Matrix projection, long 
mask) {
+    return new HashedVector(v, projection, mask);
+  }
+
+  public int hammingDistance(long otherHash) {
+    return Long.bitCount(hash ^ otherHash);
+  }
+
+  public long getHash() {
+    return hash;
+  }
+
+  @Override
+  public String toString() {
+    return String.format("index=%d, hash=%08x, v=%s", getIndex(), hash, 
getVector());
+  }
+
+  @Override
+  public boolean equals(Object o) {
+    if (this == o) {
+      return true;
+    }
+    if (!(o instanceof HashedVector)) {
+      return o instanceof Vector && this.minus((Vector) o).norm(1) == 0;
+    } else {
+      HashedVector v = (HashedVector) o;
+      return v.hash == this.hash && this.minus(v).norm(1) == 0;
+    }
+  }
+
+  @Override
+  public int hashCode() {
+    int result = super.hashCode();
+    result = 31 * result + (int) (hash ^ (hash >>> 32));
+    return result;
+  }
+}

Added: 
mahout/trunk/core/src/main/java/org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearch.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearch.java?rev=1483702&view=auto
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearch.java
 (added)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearch.java
 Fri May 17 09:04:12 2013
@@ -0,0 +1,273 @@
+package org.apache.mahout.math.neighborhood;
+
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+
+import com.google.common.base.Function;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.HashMultiset;
+import com.google.common.collect.Iterators;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Multiset;
+import org.apache.lucene.util.PriorityQueue;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.random.RandomProjector;
+import org.apache.mahout.math.random.WeightedThing;
+import org.apache.mahout.math.stats.OnlineSummarizer;
+
+/**
+ * Implements a Searcher that uses locality sensitivity hash as a first pass 
approximation
+ * to estimate distance without floating point math.  The clever bit about 
this implementation
+ * is that it does an adaptive cutoff for the cutoff on the bitwise distance.  
Making this
+ * cutoff adaptive means that we only needs to make a single pass through the 
data.
+ */
+public class LocalitySensitiveHashSearch extends UpdatableSearcher implements 
Iterable<Vector> {
+  /**
+   * Number of bits in the locality sensitive hash. 64 bits fix neatly into a 
long.
+   */
+  private static final int BITS = 64;
+
+  /**
+   * Bit mask for the computed hash. Currently, it's 0xffffffffffff.
+   */
+  private static final long BIT_MASK = -1L;
+
+  /**
+   * The maximum Hamming distance between two hashes that the hash limit can 
grow back to.
+   * It starts at BITS and decreases as more points than are needed are added 
to the candidate priority queue.
+   * But, after the observed distribution of distances becomes too good (we're 
seeing less than some percentage of the
+   * total number of points; using the hash strategy somewhere less than 25%) 
the limit is increased to compute
+   * more distances.
+   * This is because
+   */
+  private static final int MAX_HASH_LIMIT = 32;
+
+  /**
+   * Minimum number of points with a given Hamming from the query that must be 
observed to consider raising the minimum
+   * distance for a candidate.
+   */
+  private static final int MIN_DISTRIBUTION_COUNT = 10;
+
+  private Multiset<HashedVector> trainingVectors = HashMultiset.create();
+
+  /**
+   * This matrix of BITS random vectors is used to compute the Locality 
Sensitive Hash
+   * we compute the dot product with these vectors using a matrix 
multiplication and then use just
+   * sign of each result as one bit in the hash
+   */
+  private Matrix projection;
+
+  /**
+   * The search size determines how many top results we retain.  We do this 
because the hash distance
+   * isn't guaranteed to be entirely monotonic with respect to the real 
distance.  To the extent that
+   * actual distance is well approximated by hash distance, then the 
searchSize can be decreased to
+   * roughly the number of results that you want.
+   */
+  private int searchSize;
+
+  /**
+   * Controls how the hash limit is raised. 0 means use minimum of 
distribution, 1 means use first quartile.
+   * Intermediate values indicate an interpolation should be used. Negative 
values mean to never increase.
+   */
+  private double hashLimitStrategy = 0.9;
+
+  /**
+   * Number of evaluations of the full distance between two points that was 
required.
+   */
+  private int distanceEvaluations = 0;
+
+  /**
+   * Whether the projection matrix was initialized. This has to be deferred 
until the size of the vectors is known,
+   * effectively until the first vector is added.
+   */
+  private boolean initialized = false;
+
+  public LocalitySensitiveHashSearch(DistanceMeasure distanceMeasure, int 
searchSize) {
+    super(distanceMeasure);
+    this.searchSize = searchSize;
+    this.projection = null;
+  }
+
+  private void initialize(int numDimensions) {
+    if (initialized) {
+      return;
+    }
+    initialized = true;
+    projection = RandomProjector.generateBasisNormal(BITS, numDimensions);
+  }
+
+  private PriorityQueue<WeightedThing<Vector>> searchInternal(Vector query) {
+    long queryHash = HashedVector.computeHash64(query, projection);
+
+    // We keep an approximation of the closest vectors here.
+    PriorityQueue<WeightedThing<Vector>> top = 
Searcher.getCandidateQueue(getSearchSize());
+
+    // We keep the counts of the hash distances here.  This lets us accurately
+    // judge what hash distance cutoff we should use.
+    int[] hashCounts = new int[BITS + 1];
+
+    // We scan the vectors using bit counts as an approximation of the dot 
product so we can do as few
+    // full distance computations as possible.  Our goal is to only do full 
distance computations for
+    // vectors with hash distance at most as large as the searchSize biggest 
hash distance seen so far.
+
+    OnlineSummarizer[] distribution = new OnlineSummarizer[BITS + 1];
+    for (int i = 0; i < BITS + 1; i++) {
+      distribution[i] = new OnlineSummarizer();
+    }
+
+    // Maximum number of different bits to still consider a vector a candidate 
for nearest neighbor.
+    // Starts at the maximum number of bits, but decreases and can increase.
+    int hashLimit = BITS;
+    int limitCount = 0;
+    double distanceLimit = Double.POSITIVE_INFINITY;
+    distanceEvaluations = 0;
+
+    // In this loop, we have the invariants that:
+    //
+    // limitCount = sum_{i<hashLimit} hashCount[i]
+    // and
+    // limitCount >= searchSize && limitCount - hashCount[hashLimit-1] < 
searchSize
+    for (HashedVector vector : trainingVectors) {
+      // This computes the Hamming Distance between the vector's hash and the 
query's hash.
+      // The result is correlated with the angle between the vectors.
+      int bitDot = vector.hammingDistance(queryHash);
+      if (bitDot <= hashLimit) {
+        distanceEvaluations++;
+
+        double distance = distanceMeasure.distance(query, vector);
+        distribution[bitDot].add(distance);
+
+        if (distance < distanceLimit) {
+          top.insertWithOverflow(new WeightedThing<Vector>(vector, distance));
+          if (top.size() == searchSize) {
+            distanceLimit = top.top().getWeight();
+          }
+
+          hashCounts[bitDot]++;
+          limitCount++;
+          while (hashLimit > 0 && limitCount - hashCounts[hashLimit - 1] > 
searchSize) {
+            hashLimit--;
+            limitCount -= hashCounts[hashLimit];
+          }
+
+          if (hashLimitStrategy >= 0) {
+            while (hashLimit < MAX_HASH_LIMIT && 
distribution[hashLimit].getCount() > MIN_DISTRIBUTION_COUNT
+                && ((1 - hashLimitStrategy) * 
distribution[hashLimit].getQuartile(0)
+                + hashLimitStrategy * distribution[hashLimit].getQuartile(1)) 
< distanceLimit) {
+              limitCount += hashCounts[hashLimit];
+              hashLimit++;
+            }
+          }
+        }
+      }
+    }
+    return top;
+  }
+
+  @Override
+  public List<WeightedThing<Vector>> search(Vector query, int limit) {
+    PriorityQueue<WeightedThing<Vector>> top = searchInternal(query);
+    List<WeightedThing<Vector>> results = 
Lists.newArrayListWithExpectedSize(limit);
+    while (limit > 0 && top.size() != 0) {
+      WeightedThing<Vector> wv = top.pop();
+      results.add(new WeightedThing<Vector>(((HashedVector) 
wv.getValue()).getVector(), wv.getWeight()));
+    }
+    Collections.reverse(results);
+    return results;
+  }
+
+  /**
+   * Returns the closest vector to the query.
+   * When only one the nearest vector is needed, use this method, NOT 
search(query, limit) because
+   * it's faster (less overhead).
+   * This is nearly the same as search().
+   *
+   * @param query the vector to search for
+   * @param differentThanQuery if true, returns the closest vector different 
than the query (this
+   *                           only matters if the query is among the searched 
vectors), otherwise,
+   *                           returns the closest vector to the query (even 
the same vector).
+   * @return the weighted vector closest to the query
+   */
+  @Override
+  public WeightedThing<Vector> searchFirst(Vector query, boolean 
differentThanQuery) {
+    // We get the top searchSize neighbors.
+    PriorityQueue<WeightedThing<Vector>> top = searchInternal(query);
+    // We then cut the number down to just the best 2.
+    while (top.size() > 2) {
+      top.pop();
+    }
+    // If there are fewer than 2 results, we just return the one we have.
+    if (top.size() < 2) {
+      return removeHash(top.pop());
+    }
+    // There are exactly 2 results.
+    WeightedThing<Vector> secondBest = top.pop();
+    WeightedThing<Vector> best = top.pop();
+    // If the best result is the same as the query, but we don't want to 
return the query.
+    if (differentThanQuery && best.getValue().equals(query)) {
+      best = secondBest;
+    }
+    return removeHash(best);
+  }
+
+  protected WeightedThing<Vector> removeHash(WeightedThing<Vector> input) {
+    return new WeightedThing<Vector>(((HashedVector) 
input.getValue()).getVector(), input.getWeight());
+  }
+
+  @Override
+  public void add(Vector vector) {
+    initialize(vector.size());
+    trainingVectors.add(new HashedVector(vector, projection, 
HashedVector.INVALID_INDEX, BIT_MASK));
+  }
+
+  public int size() {
+    return trainingVectors.size();
+  }
+
+  public int getSearchSize() {
+    return searchSize;
+  }
+
+  public void setSearchSize(int size) {
+    searchSize = size;
+  }
+
+  public void setRaiseHashLimitStrategy(double strategy) {
+    hashLimitStrategy = strategy;
+  }
+
+  /**
+   * This is only for testing.
+   * @return the number of times the actual distance between two vectors was 
computed.
+   */
+  public int resetEvaluationCount() {
+    int result = distanceEvaluations;
+    distanceEvaluations = 0;
+    return result;
+  }
+
+  @Override
+  public Iterator<Vector> iterator() {
+    return Iterators.transform(trainingVectors.iterator(), new 
Function<HashedVector, Vector>() {
+      @Override
+      public Vector apply(org.apache.mahout.math.neighborhood.HashedVector 
input) {
+        Preconditions.checkNotNull(input);
+        //noinspection ConstantConditions
+        return input.getVector();
+      }
+    });
+  }
+
+  @Override
+  public boolean remove(Vector v, double epsilon) {
+    return trainingVectors.remove(new HashedVector(v, projection, 
HashedVector.INVALID_INDEX, BIT_MASK));
+  }
+
+  @Override
+  public void clear() {
+    trainingVectors.clear();
+  }
+}

Modified: 
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeansTest.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeansTest.java?rev=1483702&r1=1483701&r2=1483702&view=diff
==============================================================================
--- 
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeansTest.java
 (original)
+++ 
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeansTest.java
 Fri May 17 09:04:12 2013
@@ -29,6 +29,7 @@ import org.apache.mahout.math.Centroid;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.neighborhood.BruteSearch;
 import org.apache.mahout.math.neighborhood.FastProjectionSearch;
+import org.apache.mahout.math.neighborhood.LocalitySensitiveHashSearch;
 import org.apache.mahout.math.neighborhood.ProjectionSearch;
 import org.apache.mahout.math.neighborhood.Searcher;
 import org.apache.mahout.math.neighborhood.UpdatableSearcher;
@@ -66,9 +67,11 @@ public class StreamingKMeansTest {
         {new ProjectionSearch(new SquaredEuclideanDistanceMeasure(), 
NUM_PROJECTIONS, SEARCH_SIZE), true},
         {new FastProjectionSearch(new SquaredEuclideanDistanceMeasure(), 
NUM_PROJECTIONS, SEARCH_SIZE),
             true},
+        {new LocalitySensitiveHashSearch(new 
SquaredEuclideanDistanceMeasure(), SEARCH_SIZE), true},
         {new ProjectionSearch(new SquaredEuclideanDistanceMeasure(), 
NUM_PROJECTIONS, SEARCH_SIZE), false},
         {new FastProjectionSearch(new SquaredEuclideanDistanceMeasure(), 
NUM_PROJECTIONS, SEARCH_SIZE),
             false},
+        {new LocalitySensitiveHashSearch(new 
SquaredEuclideanDistanceMeasure(), SEARCH_SIZE), false},
     });
   }
 

Modified: 
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansTestMR.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansTestMR.java?rev=1483702&r1=1483701&r2=1483702&view=diff
==============================================================================
--- 
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansTestMR.java
 (original)
+++ 
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansTestMR.java
 Fri May 17 09:04:12 2013
@@ -44,6 +44,7 @@ import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
 import org.apache.mahout.math.neighborhood.BruteSearch;
 import org.apache.mahout.math.neighborhood.FastProjectionSearch;
+import org.apache.mahout.math.neighborhood.LocalitySensitiveHashSearch;
 import org.apache.mahout.math.neighborhood.ProjectionSearch;
 import org.apache.mahout.math.random.WeightedThing;
 import org.junit.Test;
@@ -90,6 +91,7 @@ public class StreamingKMeansTestMR {
     return Arrays.asList(new Object[][]{
         {ProjectionSearch.class.getName(), 
SquaredEuclideanDistanceMeasure.class.getName()},
         {FastProjectionSearch.class.getName(), 
SquaredEuclideanDistanceMeasure.class.getName()},
+        {LocalitySensitiveHashSearch.class.getName(), 
SquaredEuclideanDistanceMeasure.class.getName()},
     });
   }
 

Added: 
mahout/trunk/core/src/test/java/org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearchTest.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearchTest.java?rev=1483702&view=auto
==============================================================================
--- 
mahout/trunk/core/src/test/java/org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearchTest.java
 (added)
+++ 
mahout/trunk/core/src/test/java/org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearchTest.java
 Fri May 17 09:04:12 2013
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.math.neighborhood;
+
+import java.util.BitSet;
+import java.util.List;
+
+import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.WeightedVector;
+import org.apache.mahout.math.random.Normal;
+import org.apache.mahout.math.random.WeightedThing;
+import org.apache.mahout.math.stats.OnlineSummarizer;
+import org.junit.Test;
+
+import static org.junit.Assert.assertTrue;
+
+public class LocalitySensitiveHashSearchTest {
+
+  @Test
+  public void testNormal() {
+    Matrix testData = new DenseMatrix(100000, 10);
+    final Normal gen = new Normal();
+    testData.assign(gen);
+
+    final EuclideanDistanceMeasure distance = new EuclideanDistanceMeasure();
+    BruteSearch ref = new BruteSearch(distance);
+    ref.addAllMatrixSlicesAsWeightedVectors(testData);
+
+    LocalitySensitiveHashSearch cut = new 
LocalitySensitiveHashSearch(distance, 10);
+    cut.addAllMatrixSlicesAsWeightedVectors(testData);
+
+    cut.setSearchSize(200);
+    cut.resetEvaluationCount();
+
+    System.out.printf("speedup,q1,q2,q3\n");
+
+    for (int i = 0; i < 12; i++) {
+      double strategy = (i - 1.0) / 10.0;
+      cut.setRaiseHashLimitStrategy(strategy);
+      OnlineSummarizer t1 = evaluateStrategy(testData, ref, cut);
+      int evals = cut.resetEvaluationCount();
+      final double speedup = 10e6 / evals;
+      System.out.printf("%.1f,%.2f,%.2f,%.2f\n", speedup, t1.getQuartile(1),
+          t1.getQuartile(2), t1.getQuartile(3));
+      assertTrue(t1.getQuartile(2) > 0.45);
+      assertTrue(speedup > 4 || t1.getQuartile(2) > 0.9);
+      assertTrue(speedup > 15 || t1.getQuartile(2) > 0.8);
+    }
+  }
+
+  private OnlineSummarizer evaluateStrategy(Matrix testData, BruteSearch ref,
+                                            LocalitySensitiveHashSearch cut) {
+    OnlineSummarizer t1 = new OnlineSummarizer();
+
+    for (int i = 0; i < 100; i++) {
+      final Vector q = testData.viewRow(i);
+      List<WeightedThing<Vector>> v1 = cut.search(q, 150);
+      BitSet b1 = new BitSet();
+      for (WeightedThing<Vector> v : v1) {
+        b1.set(((WeightedVector)v.getValue()).getIndex());
+      }
+
+      List<WeightedThing<Vector>> v2 = ref.search(q, 100);
+      BitSet b2 = new BitSet();
+      for (WeightedThing<Vector> v : v2) {
+        b2.set(((WeightedVector)v.getValue()).getIndex());
+      }
+
+      b1.and(b2);
+      t1.add(b1.cardinality());
+    }
+    return t1;
+  }
+
+  @Test
+  public void testDotCorrelation() {
+    final Normal gen = new Normal();
+
+    Matrix projection = new DenseMatrix(64, 10);
+    projection.assign(gen);
+
+    Vector query = new DenseVector(10);
+    query.assign(gen);
+    long qhash = HashedVector.computeHash64(query, projection);
+
+    int count[] = new int[65];
+    Vector v = new DenseVector(10);
+    for (int i = 0; i <500000; i++) {
+      v.assign(gen);
+      long hash = HashedVector.computeHash64(v, projection);
+      final int bitDot = Long.bitCount(qhash ^ hash);
+      count[bitDot]++;
+      if (count[bitDot] < 200) {
+        System.out.printf("%d, %.3f\n", bitDot, v.dot(query) / 
Math.sqrt(v.getLengthSquared() * query.getLengthSquared()));
+      }
+    }
+    for (int i = 0; i < 65; ++i) {
+      System.out.printf("%d, %d\n", i, count[i]);
+    }
+  }
+}

Modified: 
mahout/trunk/core/src/test/java/org/apache/mahout/math/neighborhood/SearchQualityTest.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/neighborhood/SearchQualityTest.java?rev=1483702&r1=1483701&r2=1483702&view=diff
==============================================================================
--- 
mahout/trunk/core/src/test/java/org/apache/mahout/math/neighborhood/SearchQualityTest.java
 (original)
+++ 
mahout/trunk/core/src/test/java/org/apache/mahout/math/neighborhood/SearchQualityTest.java
 Fri May 17 09:04:12 2013
@@ -59,6 +59,10 @@ public class SearchQualityTest {
         // NUM_PROJECTIONS = 5
         // SEARCH_SIZE = 5
         {new ProjectionSearch(distanceMeasure, 5, 5), dataPoints, queries, 
reference, referenceSearchFirst},
+        {new FastProjectionSearch(distanceMeasure, 5, 5), dataPoints, queries, 
reference, referenceSearchFirst},
+        {new LocalitySensitiveHashSearch(distanceMeasure, 5), dataPoints, 
queries, reference, referenceSearchFirst},
+        // SEARCH_SIZE = 2
+        {new LocalitySensitiveHashSearch(distanceMeasure, 2), dataPoints, 
queries, reference, referenceSearchFirst},
     }
     );
   }

Modified: 
mahout/trunk/core/src/test/java/org/apache/mahout/math/neighborhood/SearchSanityTest.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/neighborhood/SearchSanityTest.java?rev=1483702&r1=1483701&r2=1483702&view=diff
==============================================================================
--- 
mahout/trunk/core/src/test/java/org/apache/mahout/math/neighborhood/SearchSanityTest.java
 (original)
+++ 
mahout/trunk/core/src/test/java/org/apache/mahout/math/neighborhood/SearchSanityTest.java
 Fri May 17 09:04:12 2013
@@ -61,6 +61,7 @@ public class SearchSanityTest extends Ma
         {new ProjectionSearch(new EuclideanDistanceMeasure(), NUM_PROJECTIONS, 
SEARCH_SIZE), dataPoints},
         {new FastProjectionSearch(new EuclideanDistanceMeasure(), 
NUM_PROJECTIONS, SEARCH_SIZE),
             dataPoints},
+        {new LocalitySensitiveHashSearch(new EuclideanDistanceMeasure(), 
SEARCH_SIZE), dataPoints},
     });
   }
 
@@ -185,4 +186,21 @@ public class SearchSanityTest extends Ma
       }
     }
   }
+
+  @Test
+  public void testSearchFirst() {
+    searcher.clear();
+    searcher.addAll(dataPoints);
+    for (Vector datapoint : dataPoints) {
+      WeightedThing<Vector> first = searcher.searchFirst(datapoint, false);
+      WeightedThing<Vector> second = searcher.searchFirst(datapoint, true);
+      List<WeightedThing<Vector>> firstTwo = searcher.search(datapoint, 2);
+
+      assertEquals("First isn't self", 0, first.getWeight(), 0);
+      assertEquals("First isn't self", datapoint, first.getValue());
+      assertEquals("First doesn't match", first, firstTwo.get(0));
+      assertEquals(String.format("Second doesn't match got %f expected %f", 
second.getWeight(), firstTwo.get(1).getWeight()),
+          second, firstTwo.get(1));
+    }
+  }
 }


Reply via email to