Author: srowen
Date: Thu Mar 31 19:00:25 2011
New Revision: 1087409
URL: http://svn.apache.org/viewvc?rev=1087409&view=rev
Log:
MAHOUT-645 add vector benchmarks for Elkan optimization
Modified:
mahout/trunk/utils/src/main/java/org/apache/mahout/benchmark/VectorBenchmarks.java
Modified:
mahout/trunk/utils/src/main/java/org/apache/mahout/benchmark/VectorBenchmarks.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/utils/src/main/java/org/apache/mahout/benchmark/VectorBenchmarks.java?rev=1087409&r1=1087408&r2=1087409&view=diff
==============================================================================
---
mahout/trunk/utils/src/main/java/org/apache/mahout/benchmark/VectorBenchmarks.java
(original)
+++
mahout/trunk/utils/src/main/java/org/apache/mahout/benchmark/VectorBenchmarks.java
Thu Mar 31 19:00:25 2011
@@ -59,6 +59,7 @@ import org.apache.mahout.common.iterator
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.SparseMatrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.slf4j.Logger;
@@ -70,6 +71,8 @@ public class VectorBenchmarks implements
private static final Pattern TAB_PATTERN = Pattern.compile("\t");
private final Vector[][] vectors;
+ private final Vector[] clusters;
+ private final SparseMatrix clusterDistances;
private final List<Vector> randomVectors = new ArrayList<Vector>();
private final List<int[]> randomVectorIndices = new ArrayList<int[]>();
private final List<double[]> randomVectorValues = new ArrayList<double[]>();
@@ -80,13 +83,14 @@ public class VectorBenchmarks implements
private final int opsPerUnit;
private final Map<String,Integer> implType = new HashMap<String,Integer>();
private final Map<String,List<String[]>> statsMap = new
HashMap<String,List<String[]>>();
-
+ private final int numClusters;
- public VectorBenchmarks(int cardinality, int sparsity, int numVectors, int
loop, int opsPerUnit) {
+ public VectorBenchmarks(int cardinality, int sparsity, int numVectors, int
numClusters, int loop, int opsPerUnit) {
Random r = RandomUtils.getRandom();
this.cardinality = cardinality;
this.sparsity = sparsity;
this.numVectors = numVectors;
+ this.numClusters = numClusters;
this.loop = loop;
this.opsPerUnit = opsPerUnit;
for (int i = 0; i < numVectors; i++) {
@@ -110,7 +114,8 @@ public class VectorBenchmarks implements
randomVectors.add(v);
}
vectors = new Vector[3][numVectors];
-
+ clusters = new Vector[numClusters];
+ clusterDistances = new SparseMatrix(numClusters, numClusters);
}
private void printStats(TimingStatistics stats, String benchmarkName, String
implName, String content) {
@@ -224,14 +229,23 @@ public class VectorBenchmarks implements
}
printStats(stats, "Create (incrementally)", "RandSparseVector");
+// stats = new TimingStatistics();
+// for (int l = 0; l < loop; l++) {
+// for (int i = 0; i < numVectors; i++) {
+// vectors[2][i] = new SequentialAccessSparseVector(cardinality);
+// buildVectorIncrementally(stats, i, vectors[2][i], false);
+// }
+// }
+// printStats(stats, "Create (incrementally)", "SeqSparseVector");
+
stats = new TimingStatistics();
for (int l = 0; l < loop; l++) {
- for (int i = 0; i < numVectors; i++) {
- vectors[2][i] = new SequentialAccessSparseVector(cardinality);
- buildVectorIncrementally(stats, i, vectors[2][i], false);
+ for (int i = 0; i < numClusters; i++) {
+ clusters[i] = new RandomAccessSparseVector(cardinality);
+ buildVectorIncrementally(stats, i, clusters[i], false);
}
}
- printStats(stats, "Create (incrementally)", "SeqSparseVector");
+ printStats(stats, "Create (incrementally)", "Clusters");
}
public void cloneBenchmark() {
@@ -439,7 +453,73 @@ public class VectorBenchmarks implements
}
-
+
+
+ public void closestCentroidBenchmark(DistanceMeasure measure) {
+
+ for (int i = 0; i < numClusters; i++) {
+ for (int j = 0; j < numClusters; j++) {
+ double distance = Double.POSITIVE_INFINITY;
+ if (i != j) {
+ distance = measure.distance(clusters[i], clusters[j]);
+ }
+ clusterDistances.setQuick(i, j, distance);
+ }
+ }
+
+ long distanceCalculations = 0;
+ TimingStatistics stats = new TimingStatistics();
+ for (int l = 0; l < loop; l++) {
+ TimingStatistics.Call call = stats.newCall();
+ for (int i = 0; i < numVectors; i++) {
+ Vector vector = vectors[1][i];
+ double minDistance = Double.MAX_VALUE;
+ for (int k = 0; k < numClusters; k++) {
+ double distance = measure.distance(vector, clusters[k]);
+ distanceCalculations++;
+ if (distance < minDistance) {
+ minDistance = distance;
+ }
+ }
+ }
+ call.end();
+ }
+ printStats(stats,
+ measure.getClass().getName(),
+ "Closest center without Elkan's trick",
+ "distanceCalculations = " + distanceCalculations);
+
+
+ distanceCalculations = 0;
+ stats = new TimingStatistics();
+ Random rand = RandomUtils.getRandom();
+ //rand.setSeed(System.currentTimeMillis());
+ for (int l = 0; l < loop; l++) {
+ TimingStatistics.Call call = stats.newCall();
+ for (int i = 0; i < numVectors; i++) {
+ Vector vector = vectors[1][i];
+ int closestCentroid = rand.nextInt(numClusters);
+ double dist = measure.distance(vector, clusters[closestCentroid]);
+ distanceCalculations++;
+ for (int k = 0; k < numClusters; k++) {
+ if (closestCentroid != k) {
+ double centroidDist = clusterDistances.getQuick(k,
closestCentroid);
+ if (centroidDist < 2 * dist) {
+ dist = measure.distance(vector, clusters[k]);
+ closestCentroid = k;
+ distanceCalculations++;
+ }
+ }
+ }
+ }
+ call.end();
+ }
+ printStats(stats,
+ measure.getClass().getName(),
+ "Closest center with Elkan's trick",
+ "distanceCalculations = " + distanceCalculations);
+ }
+
public void distanceMeasureBenchmark(DistanceMeasure measure) {
double result = 0;
TimingStatistics stats = new TimingStatistics();
@@ -620,12 +700,16 @@ public class VectorBenchmarks implements
Option vectorSizeOpt =
obuilder.withLongName("vectorSize").withRequired(false).withArgument(
abuilder.withName("vs").withMinimum(1).withMaximum(1).create()).withDescription(
"Cardinality of the vector. Default 1000").withShortName("vs").create();
+
Option vectorSparsityOpt =
obuilder.withLongName("sparsity").withRequired(false).withArgument(
abuilder.withName("sp").withMinimum(1).withMaximum(1).create()).withDescription(
"Sparsity of the vector. Default 1000").withShortName("sp").create();
Option numVectorsOpt =
obuilder.withLongName("numVectors").withRequired(false).withArgument(
abuilder.withName("nv").withMinimum(1).withMaximum(1).create()).withDescription(
"Number of Vectors to create. Default:
100").withShortName("nv").create();
+ Option numClustersOpt =
obuilder.withLongName("numClusters").withRequired(false).withArgument(
+
abuilder.withName("vs").withMinimum(1).withMaximum(1).create()).withDescription(
+ "Number of Vectors to create. Default:
10").withShortName("vs").create();
Option loopOpt =
obuilder.withLongName("loop").withRequired(false).withArgument(
abuilder.withName("loop").withMinimum(1).withMaximum(1).create()).withDescription(
"Number of times to loop. Default: 200").withShortName("l").create();
@@ -654,6 +738,11 @@ public class VectorBenchmarks implements
if (cmdLine.hasOption(vectorSizeOpt)) {
cardinality = Integer.parseInt((String)
cmdLine.getValue(vectorSizeOpt));
+ }
+
+ int numClusters=25;
+ if (cmdLine.hasOption(numClustersOpt)) {
+ numClusters = Integer.parseInt((String)
cmdLine.getValue(numClustersOpt));
}
int sparsity = 1000;
@@ -676,7 +765,7 @@ public class VectorBenchmarks implements
numOps = Integer.parseInt((String) cmdLine.getValue(numOpsOpt));
}
- VectorBenchmarks mark = new VectorBenchmarks(cardinality, sparsity,
numVectors, loop, numOps);
+ VectorBenchmarks mark = new VectorBenchmarks(cardinality, sparsity,
numVectors, numClusters, loop, numOps);
mark.createBenchmark();
mark.incrementalCreateBenchmark();
mark.cloneBenchmark();
@@ -689,6 +778,12 @@ public class VectorBenchmarks implements
mark.distanceMeasureBenchmark(new ManhattanDistanceMeasure());
mark.distanceMeasureBenchmark(new TanimotoDistanceMeasure());
+ mark.closestCentroidBenchmark(new CosineDistanceMeasure());
+ mark.closestCentroidBenchmark(new SquaredEuclideanDistanceMeasure());
+ mark.closestCentroidBenchmark(new EuclideanDistanceMeasure());
+ mark.closestCentroidBenchmark(new ManhattanDistanceMeasure());
+ mark.closestCentroidBenchmark(new TanimotoDistanceMeasure());
+
log.info("\n{}", mark.summarize());
} catch (OptionException e) {
CommandLineUtil.printHelp(group);
@@ -743,4 +838,4 @@ public class VectorBenchmarks implements
return sb.toString();
}
-}
+}
\ No newline at end of file