Author: srowen
Date: Thu Mar 31 19:00:25 2011
New Revision: 1087409

URL: http://svn.apache.org/viewvc?rev=1087409&view=rev
Log:
MAHOUT-645 add vector benchmarks for Elkan optimization

Modified:
    
mahout/trunk/utils/src/main/java/org/apache/mahout/benchmark/VectorBenchmarks.java

Modified: 
mahout/trunk/utils/src/main/java/org/apache/mahout/benchmark/VectorBenchmarks.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/utils/src/main/java/org/apache/mahout/benchmark/VectorBenchmarks.java?rev=1087409&r1=1087408&r2=1087409&view=diff
==============================================================================
--- 
mahout/trunk/utils/src/main/java/org/apache/mahout/benchmark/VectorBenchmarks.java
 (original)
+++ 
mahout/trunk/utils/src/main/java/org/apache/mahout/benchmark/VectorBenchmarks.java
 Thu Mar 31 19:00:25 2011
@@ -59,6 +59,7 @@ import org.apache.mahout.common.iterator
 import org.apache.mahout.math.DenseVector;
 import org.apache.mahout.math.RandomAccessSparseVector;
 import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.SparseMatrix;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
 import org.slf4j.Logger;
@@ -70,6 +71,8 @@ public class VectorBenchmarks implements
   private static final Pattern TAB_PATTERN = Pattern.compile("\t");
 
   private final Vector[][] vectors;
+  private final Vector[] clusters;
+  private final SparseMatrix clusterDistances;
   private final List<Vector> randomVectors = new ArrayList<Vector>();
   private final List<int[]> randomVectorIndices = new ArrayList<int[]>();
   private final List<double[]> randomVectorValues = new ArrayList<double[]>();
@@ -80,13 +83,14 @@ public class VectorBenchmarks implements
   private final int opsPerUnit;
   private final Map<String,Integer> implType = new HashMap<String,Integer>();
   private final Map<String,List<String[]>> statsMap = new 
HashMap<String,List<String[]>>();
- 
+  private final int numClusters;
   
-  public VectorBenchmarks(int cardinality, int sparsity, int numVectors, int 
loop, int opsPerUnit) {
+  public VectorBenchmarks(int cardinality, int sparsity, int numVectors, int 
numClusters, int loop, int opsPerUnit) {
     Random r = RandomUtils.getRandom();
     this.cardinality = cardinality;
     this.sparsity = sparsity;
     this.numVectors = numVectors;
+    this.numClusters = numClusters;
     this.loop = loop;
     this.opsPerUnit = opsPerUnit;
     for (int i = 0; i < numVectors; i++) {
@@ -110,7 +114,8 @@ public class VectorBenchmarks implements
       randomVectors.add(v);
     }
     vectors = new Vector[3][numVectors];
-    
+    clusters = new Vector[numClusters];
+    clusterDistances = new SparseMatrix(numClusters, numClusters);
   }
   
   private void printStats(TimingStatistics stats, String benchmarkName, String 
implName, String content) {
@@ -224,14 +229,23 @@ public class VectorBenchmarks implements
     }
     printStats(stats, "Create (incrementally)", "RandSparseVector");
 
+//    stats = new TimingStatistics();
+//    for (int l = 0; l < loop; l++) {
+//      for (int i = 0; i < numVectors; i++) {
+//        vectors[2][i] = new SequentialAccessSparseVector(cardinality);
+//        buildVectorIncrementally(stats, i, vectors[2][i], false);
+//      }
+//    }
+//    printStats(stats, "Create (incrementally)", "SeqSparseVector");
+    
     stats = new TimingStatistics();
     for (int l = 0; l < loop; l++) {
-      for (int i = 0; i < numVectors; i++) {
-        vectors[2][i] = new SequentialAccessSparseVector(cardinality);
-        buildVectorIncrementally(stats, i, vectors[2][i], false);
+      for (int i = 0; i < numClusters; i++) {
+        clusters[i] = new RandomAccessSparseVector(cardinality);
+        buildVectorIncrementally(stats, i, clusters[i], false);
       }
     }
-    printStats(stats, "Create (incrementally)", "SeqSparseVector");
+    printStats(stats, "Create (incrementally)", "Clusters");
   }
   
   public void cloneBenchmark() {
@@ -439,7 +453,73 @@ public class VectorBenchmarks implements
 
 
   }
-  
+
+
+  public void closestCentroidBenchmark(DistanceMeasure measure) {
+
+    for (int i = 0; i < numClusters; i++) {
+      for (int j = 0; j < numClusters; j++) {
+        double distance = Double.POSITIVE_INFINITY;
+        if (i != j) {
+          distance = measure.distance(clusters[i], clusters[j]);
+        }
+        clusterDistances.setQuick(i, j, distance);
+      }
+    }
+
+    long distanceCalculations = 0;
+    TimingStatistics stats = new TimingStatistics();
+    for (int l = 0; l < loop; l++) {
+      TimingStatistics.Call call = stats.newCall();
+      for (int i = 0; i < numVectors; i++) {
+        Vector vector = vectors[1][i];
+        double minDistance = Double.MAX_VALUE;
+        for (int k = 0; k < numClusters; k++) {
+          double distance = measure.distance(vector, clusters[k]);
+          distanceCalculations++;
+          if (distance < minDistance) {
+            minDistance = distance;
+          }
+        }
+      }
+      call.end();
+    }
+    printStats(stats,
+               measure.getClass().getName(),
+               "Closest center without Elkan's trick",
+               "distanceCalculations = " + distanceCalculations);
+
+
+    distanceCalculations = 0;
+    stats = new TimingStatistics();
+    Random rand = RandomUtils.getRandom();
+    //rand.setSeed(System.currentTimeMillis());
+    for (int l = 0; l < loop; l++) {
+      TimingStatistics.Call call = stats.newCall();
+      for (int i = 0; i < numVectors; i++) {
+        Vector vector = vectors[1][i];
+        int closestCentroid = rand.nextInt(numClusters);
+        double dist = measure.distance(vector, clusters[closestCentroid]);
+        distanceCalculations++;
+        for (int k = 0; k < numClusters; k++) {
+          if (closestCentroid != k) {
+            double centroidDist = clusterDistances.getQuick(k, 
closestCentroid);
+            if (centroidDist < 2 * dist) {
+              dist = measure.distance(vector, clusters[k]);
+              closestCentroid = k;
+              distanceCalculations++;
+            }
+          }
+        }
+      }
+      call.end();
+    }
+    printStats(stats,
+               measure.getClass().getName(),
+               "Closest center with Elkan's trick",
+               "distanceCalculations = " + distanceCalculations);
+  }
+
   public void distanceMeasureBenchmark(DistanceMeasure measure) {
     double result = 0;
     TimingStatistics stats = new TimingStatistics();
@@ -620,12 +700,16 @@ public class VectorBenchmarks implements
     Option vectorSizeOpt = 
obuilder.withLongName("vectorSize").withRequired(false).withArgument(
       
abuilder.withName("vs").withMinimum(1).withMaximum(1).create()).withDescription(
       "Cardinality of the vector. Default 1000").withShortName("vs").create();
+    
     Option vectorSparsityOpt = 
obuilder.withLongName("sparsity").withRequired(false).withArgument(
       
abuilder.withName("sp").withMinimum(1).withMaximum(1).create()).withDescription(
       "Sparsity of the vector. Default 1000").withShortName("sp").create();
     Option numVectorsOpt = 
obuilder.withLongName("numVectors").withRequired(false).withArgument(
       
abuilder.withName("nv").withMinimum(1).withMaximum(1).create()).withDescription(
       "Number of Vectors to create. Default: 
100").withShortName("nv").create();
+    Option numClustersOpt = 
obuilder.withLongName("numClusters").withRequired(false).withArgument(
+             
abuilder.withName("vs").withMinimum(1).withMaximum(1).create()).withDescription(
+             "Number of Vectors to create. Default: 
10").withShortName("vs").create();
     Option loopOpt = 
obuilder.withLongName("loop").withRequired(false).withArgument(
       
abuilder.withName("loop").withMinimum(1).withMaximum(1).create()).withDescription(
       "Number of times to loop. Default: 200").withShortName("l").create();
@@ -654,6 +738,11 @@ public class VectorBenchmarks implements
       if (cmdLine.hasOption(vectorSizeOpt)) {
         cardinality = Integer.parseInt((String) 
cmdLine.getValue(vectorSizeOpt));
         
+      }    
+      
+      int numClusters=25;
+      if (cmdLine.hasOption(numClustersOpt)) {
+         numClusters = Integer.parseInt((String) 
cmdLine.getValue(numClustersOpt));          
       }
 
       int sparsity = 1000;
@@ -676,7 +765,7 @@ public class VectorBenchmarks implements
         numOps = Integer.parseInt((String) cmdLine.getValue(numOpsOpt));
         
       }
-      VectorBenchmarks mark = new VectorBenchmarks(cardinality, sparsity, 
numVectors, loop, numOps);
+      VectorBenchmarks mark = new VectorBenchmarks(cardinality, sparsity, 
numVectors, numClusters, loop, numOps);
       mark.createBenchmark();
       mark.incrementalCreateBenchmark();
       mark.cloneBenchmark();
@@ -689,6 +778,12 @@ public class VectorBenchmarks implements
       mark.distanceMeasureBenchmark(new ManhattanDistanceMeasure());
       mark.distanceMeasureBenchmark(new TanimotoDistanceMeasure());
       
+      mark.closestCentroidBenchmark(new CosineDistanceMeasure());
+      mark.closestCentroidBenchmark(new SquaredEuclideanDistanceMeasure());
+      mark.closestCentroidBenchmark(new EuclideanDistanceMeasure());
+      mark.closestCentroidBenchmark(new ManhattanDistanceMeasure());
+      mark.closestCentroidBenchmark(new TanimotoDistanceMeasure());
+      
       log.info("\n{}", mark.summarize());
     } catch (OptionException e) {
       CommandLineUtil.printHelp(group);
@@ -743,4 +838,4 @@ public class VectorBenchmarks implements
     return sb.toString();
   }
   
-}
+}
\ No newline at end of file


Reply via email to