Author: robinanil Date: Mon Feb 22 17:01:56 2010 New Revision: 912655 URL: http://svn.apache.org/viewvc?rev=912655&view=rev Log: MAHOUT-300 First wave of perf improvements
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/TimingStatistics.java lucene/mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java lucene/mahout/trunk/math/src/main/java/org/apache/mahout/math/DenseVector.java lucene/mahout/trunk/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java lucene/mahout/trunk/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java lucene/mahout/trunk/math/src/test/java/org/apache/mahout/math/TestSparseVector.java lucene/mahout/trunk/math/src/test/java/org/apache/mahout/math/VectorTest.java lucene/mahout/trunk/utils/src/main/java/org/apache/mahout/benchmark/VectorBenchmarks.java Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/TimingStatistics.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/TimingStatistics.java?rev=912655&r1=912654&r2=912655&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/TimingStatistics.java (original) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/TimingStatistics.java Mon Feb 22 17:01:56 2010 @@ -78,9 +78,9 @@ @Override public synchronized String toString() { - return '\n' + "nCalls = " + nCalls + ";\n" + "sumTime = " + sumTime / 1000000000.0d + "s;\n" - + "minTime = " + minTime / 1000000.0d + "ms;\n" + "maxTime = " + maxTime / 1000000.0d + "ms;\n" - + "meanTime = " + getMeanTime() / 1000000.0d + "ms;\n" + "stdDevTime = " + getStdDevTime() + return '\n' + "nCalls = " + nCalls + ";\n" + "sum = " + sumTime / 1000000000.0d + "s;\n" + + "min = " + minTime / 1000000.0d + "ms;\n" + "max = " + maxTime / 1000000.0d + "ms;\n" + + "mean = " + getMeanTime() / 1000000.0d + "ms;\n" + "stdDev = " + getStdDevTime() / 1000000.0d + "ms;"; } Modified: lucene/mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java?rev=912655&r1=912654&r2=912655&view=diff ============================================================================== --- lucene/mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java (original) +++ lucene/mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java Mon Feb 22 17:01:56 2010 @@ -17,17 +17,18 @@ package org.apache.mahout.math; -import com.google.gson.Gson; -import com.google.gson.GsonBuilder; -import com.google.gson.reflect.TypeToken; -import org.apache.mahout.math.function.BinaryFunction; -import org.apache.mahout.math.function.UnaryFunction; - import java.lang.reflect.Type; import java.util.HashMap; import java.util.Iterator; import java.util.Map; +import org.apache.mahout.math.function.BinaryFunction; +import org.apache.mahout.math.function.UnaryFunction; + +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import com.google.gson.reflect.TypeToken; + /** Implementations of generic capabilities like sum of elements and dot products */ public abstract class AbstractVector implements Vector { @@ -111,10 +112,8 @@ Iterator<Element> iter = result.iterateNonZero(); while (iter.hasNext()) { Element element = iter.next(); - int index = element.index(); - result.setQuick(index, element.get() / x); + element.set(element.get() / x); } - return result; } @@ -122,15 +121,34 @@ if (size() != x.size()) { throw new CardinalityException(size(), x.size()); } - double result = 0; + if(this == x) return dotSelf(); + double result = 0.0; Iterator<Element> iter = iterateNonZero(); while (iter.hasNext()) { Element element = iter.next(); result += element.get() * x.getQuick(element.index()); } - return result; } + + public double dotSelf() { + double result = 0; + if (this instanceof DenseVector) { + for (int i = 0; i < size(); i++) { + double value = this.getQuick(i); + result += value * value; + } + return result; + } else { + Iterator<Element> iter = iterateNonZero(); + while (iter.hasNext()) { + Element element = iter.next(); + double value = element.get(); + result += value * value; + } + return result; + } + } public double get(int index) { if (index >= 0 && index < size()) { @@ -144,17 +162,28 @@ if (size() != x.size()) { throw new CardinalityException(); } - Vector result = clone(); - Iterator<Element> iter = x.iterateNonZero(); - while (iter.hasNext()) { - Element e = iter.next(); - result.setQuick(e.index(), getQuick(e.index()) - e.get()); + if (x instanceof RandomAccessSparseVector || x instanceof DenseVector) { + // TODO: if both are RandomAccess check the numNonDefault to determine which to iterate + Vector result = x.clone(); + Iterator<Element> iter = iterateNonZero(); + while (iter.hasNext()) { + Element e = iter.next(); + result.setQuick(e.index(), result.getQuick(e.index()) - e.get()); + } + return result; + } else { // TODO: check the numNonDefault elements to further optimize + Vector result = clone(); + Iterator<Element> iter = x.iterateNonZero(); + while (iter.hasNext()) { + Element e = iter.next(); + result.setQuick(e.index(), getQuick(e.index()) - e.get()); + } + return result; } - return result; } public Vector normalize() { - return divide(Math.sqrt(dot(this))); + return divide(Math.sqrt(dotSelf())); } public Vector normalize(double power) { @@ -174,7 +203,7 @@ } return val; } else if (power == 2.0) { - return Math.sqrt(dot(this)); + return Math.sqrt(dotSelf()); } else if (power == 1.0) { double val = 0.0; Iterator<Element> iter = this.iterateNonZero(); @@ -205,7 +234,7 @@ if (lengthSquared >= 0.0) { return lengthSquared; } - return lengthSquared = dot(this); + return lengthSquared = dotSelf(); } public double getDistanceSquared(Vector v) { @@ -238,20 +267,41 @@ public double maxValue() { double result = Double.NEGATIVE_INFINITY; - for (int i = 0; i < size(); i++) { - result = Math.max(result, getQuick(i)); + int nonZeroElements = 0; + Iterator<Element> iter = this.iterateNonZero(); + while (iter.hasNext()) { + nonZeroElements++; + Element element = iter.next(); + result = Math.max(result, element.get()); } + if (nonZeroElements < size()) return Math.max(result, 0.0); return result; } - + public int maxValueIndex() { int result = -1; double max = Double.NEGATIVE_INFINITY; - for (int i = 0; i < size(); i++) { - double tmp = getQuick(i); + int nonZeroElements = 0; + Iterator<Element> iter = this.iterateNonZero(); + while (iter.hasNext()) { + nonZeroElements++; + Element element = iter.next(); + double tmp = element.get(); if (tmp > max) { max = tmp; - result = i; + result = element.index(); + } + } + // if the maxElement is negative and the vector is sparse then any + // unfilled element(0.0) could be the maxValue hence return -1; + if (nonZeroElements < size() && max < 0.0) { + iter = this.iterateAll(); + while (iter.hasNext()) { + Element element = iter.next(); + double tmp = element.get(); + if (tmp == 0d) { + return element.index(); + } } } return result; @@ -301,11 +351,10 @@ public Vector times(double x) { Vector result = clone(); - Iterator<Element> iter = iterateNonZero(); + Iterator<Element> iter = result.iterateNonZero(); while (iter.hasNext()) { Element element = iter.next(); - int index = element.index(); - result.setQuick(index, element.get() * x); + element.set(element.get() * x); } return result; @@ -320,7 +369,7 @@ while (iter.hasNext()) { Element element = iter.next(); int index = element.index(); - result.setQuick(index, element.get() * x.getQuick(index)); + element.set(element.get() * x.getQuick(index)); } return result; @@ -567,5 +616,4 @@ bindings.put(label, index); set(index, value); } - } Modified: lucene/mahout/trunk/math/src/main/java/org/apache/mahout/math/DenseVector.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/math/src/main/java/org/apache/mahout/math/DenseVector.java?rev=912655&r1=912654&r2=912655&view=diff ============================================================================== --- lucene/mahout/trunk/math/src/main/java/org/apache/mahout/math/DenseVector.java (original) +++ lucene/mahout/trunk/math/src/main/java/org/apache/mahout/math/DenseVector.java Mon Feb 22 17:01:56 2010 @@ -17,13 +17,13 @@ package org.apache.mahout.math; -import org.apache.mahout.math.function.BinaryFunction; -import org.apache.mahout.math.function.PlusMult; - import java.util.Arrays; import java.util.Iterator; import java.util.NoSuchElementException; +import org.apache.mahout.math.function.BinaryFunction; +import org.apache.mahout.math.function.PlusMult; + /** Implements vector as an array of doubles */ public class DenseVector extends AbstractVector { @@ -314,4 +314,28 @@ v.setQuick(i, values[i] + v.getQuick(i)); } } + + @Override + public double dot(Vector x) { + if (size() != x.size()) { + throw new CardinalityException(size(), x.size()); + } + if(this == x) return dotSelf(); + + double result = 0; + if (x instanceof DenseVector) { + for (int i = 0; i < x.size(); i++) { + result += this.getQuick(i) * x.getQuick(i); + } + return result; + } else { + // Try to get the speed boost associated fast/normal seq access on x and quick lookup on this + Iterator<org.apache.mahout.math.Vector.Element> iter = x.iterateNonZero(); + while (iter.hasNext()) { + org.apache.mahout.math.Vector.Element element = iter.next(); + result += element.get() * getQuick(element.index()); + } + return result; + } + } } Modified: lucene/mahout/trunk/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java?rev=912655&r1=912654&r2=912655&view=diff ============================================================================== --- lucene/mahout/trunk/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java (original) +++ lucene/mahout/trunk/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java Mon Feb 22 17:01:56 2010 @@ -23,7 +23,6 @@ import org.apache.mahout.math.function.IntDoubleProcedure; import org.apache.mahout.math.list.IntArrayList; import org.apache.mahout.math.map.OpenIntDoubleHashMap; -import org.apache.mahout.math.set.OpenIntHashSet; /** Implements vector that only stores non-zero doubles */ @@ -248,52 +247,6 @@ values.put(ind, value); } } - - private static class DistanceSquared implements IntDoubleProcedure { - final Vector v; - final OpenIntHashSet skipSet; - public double result = 0.0; - - DistanceSquared(Vector v, OpenIntHashSet skipSet) { - this.v = v; - this.skipSet = skipSet; - } - - public boolean apply(int key, double value) { - if(skipSet.contains(key)) { - // returning true is ok, yes? It is ignored. - return true; - } - skipSet.add(key); - double centroidValue = v.get(key); - double delta = value - centroidValue; - result += (delta * delta);// - (centroidValue * centroidValue); - return true; - } - } - - // TODO is this more optimal than the version in AbstractVector? Should be checked. - @Override - public double getDistanceSquared(Vector v) { - if(v instanceof DenseVector) { - // quicker to just use the DenseVector version - return v.getDistanceSquared(this); - } - if (v.size() != size()) { - throw new CardinalityException(); - } - OpenIntHashSet used = new OpenIntHashSet(); - DistanceSquared distanceSquared = new DistanceSquared(v, used); - values.forEachPair(distanceSquared); - Iterator<Vector.Element> it = v.iterateNonZero(); - Vector.Element e; - DistanceSquared otherHalf = new DistanceSquared(this, used); - while(it.hasNext() && (e = it.next()) != null) { - otherHalf.apply(e.index(), e.get()); - } - return distanceSquared.result + otherHalf.result; - } - private static class AddToVector implements IntDoubleProcedure { final Vector v; @@ -315,5 +268,29 @@ } values.forEachPair(new AddToVector(v)); } - + @Override + public double dot(Vector x) { + if (size() != x.size()) { + throw new CardinalityException(size(), x.size()); + } + if(this == x) return dotSelf(); + + double result = 0; + if (x instanceof SequentialAccessSparseVector || x instanceof RandomAccessSparseVector) { + Iterator<org.apache.mahout.math.Vector.Element> iter = x.iterateNonZero(); + while (iter.hasNext()) { + org.apache.mahout.math.Vector.Element element = iter.next(); + result += element.get() * getQuick(element.index()); + } + return result; + } else { + Iterator<org.apache.mahout.math.Vector.Element> iter = iterateNonZero(); + while (iter.hasNext()) { + org.apache.mahout.math.Vector.Element element = iter.next(); + result += element.get() * x.getQuick(element.index()); + } + return result; + } + } + } Modified: lucene/mahout/trunk/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java?rev=912655&r1=912654&r2=912655&view=diff ============================================================================== --- lucene/mahout/trunk/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java (original) +++ lucene/mahout/trunk/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java Mon Feb 22 17:01:56 2010 @@ -20,6 +20,8 @@ import java.util.Iterator; import java.util.NoSuchElementException; +import org.apache.mahout.math.Vector.Element; + /** * <p> * Implements vector that only stores non-zero doubles as a pair of parallel arrays (OrderedIntDoubleMapping), @@ -279,4 +281,48 @@ values[offset] = value; } } + + @Override + public double dot(Vector x) { + if (size() != x.size()) { + throw new CardinalityException(size(), x.size()); + } + if(this == x) return dotSelf(); + + double result = 0; + if (x instanceof SequentialAccessSparseVector) { + // For sparse SeqAccVectors. do dot product without lookup in a linear fashion + Iterator<Element> myIter = iterateNonZero(); + Iterator<Element> otherIter = x.iterateNonZero(); + Element myCurrent = null; + Element otherCurrent = null; + while (myIter.hasNext() && otherIter.hasNext()) { + if (myCurrent == null) myCurrent = myIter.next(); + if (otherCurrent == null) otherCurrent = otherIter.next(); + + int myIndex = myCurrent.index(); + int otherIndex = otherCurrent.index(); + + if (myIndex < otherIndex) { + // due to the sparseness skipping occurs more hence checked before equality + myCurrent = null; + } else if (myIndex > otherIndex){ + otherCurrent = null; + } else { // both are equal + result += myCurrent.get() * otherCurrent.get(); + myCurrent = null; + otherCurrent = null; + } + } + return result; + } else { // seq.rand. seq.dense + Iterator<Element> iter = iterateNonZero(); + while (iter.hasNext()) { + Element element = iter.next(); + result += element.get() * x.getQuick(element.index()); + } + return result; + } + } + } Modified: lucene/mahout/trunk/math/src/test/java/org/apache/mahout/math/TestSparseVector.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/math/src/test/java/org/apache/mahout/math/TestSparseVector.java?rev=912655&r1=912654&r2=912655&view=diff ============================================================================== --- lucene/mahout/trunk/math/src/test/java/org/apache/mahout/math/TestSparseVector.java (original) +++ lucene/mahout/trunk/math/src/test/java/org/apache/mahout/math/TestSparseVector.java Mon Feb 22 17:01:56 2010 @@ -327,7 +327,7 @@ other.set(3, -9); other.set(4, 1); double expected = test.minus(other).getLengthSquared(); - assertEquals("a.getDistanceSquared(b) != a.minus(b).getLengthSquared", expected, test.getDistanceSquared(other)); + assertEquals("a.getDistanceSquared(b) != a.minus(b).getLengthSquared", Math.abs(expected - test.getDistanceSquared(other)) < 10E-7, true); } public void testAssignDouble() { Modified: lucene/mahout/trunk/math/src/test/java/org/apache/mahout/math/VectorTest.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/math/src/test/java/org/apache/mahout/math/VectorTest.java?rev=912655&r1=912654&r2=912655&view=diff ============================================================================== --- lucene/mahout/trunk/math/src/test/java/org/apache/mahout/math/VectorTest.java (original) +++ lucene/mahout/trunk/math/src/test/java/org/apache/mahout/math/VectorTest.java Mon Feb 22 17:01:56 2010 @@ -388,6 +388,39 @@ assertEquals(idx + " does not equal: " + 0, 0, idx); vec1 = new RandomAccessSparseVector(3); + + vec1.setQuick(0, -1); + vec1.setQuick(2, -2); + + max = vec1.maxValue(); + assertEquals(max + " does not equal: " + 0, 0d, max, 0.0); + + idx = vec1.maxValueIndex(); + assertEquals(idx + " does not equal: " + 1, 1, idx); + + vec1 = new SequentialAccessSparseVector(3); + + vec1.setQuick(0, -1); + vec1.setQuick(2, -2); + + max = vec1.maxValue(); + assertEquals(max + " does not equal: " + 0, 0d, max, 0.0); + + idx = vec1.maxValueIndex(); + assertEquals(idx + " does not equal: " + 1, 1, idx); + + vec1 = new DenseVector(3); + + vec1.setQuick(0, -1); + vec1.setQuick(2, -2); + + max = vec1.maxValue(); + assertEquals(max + " does not equal: " + 0, 0d, max, 0.0); + + idx = vec1.maxValueIndex(); + assertEquals(idx + " does not equal: " + 1, 1, idx); + + vec1 = new RandomAccessSparseVector(3); max = vec1.maxValue(); assertEquals(max + " does not equal 0", 0d, max); Modified: lucene/mahout/trunk/utils/src/main/java/org/apache/mahout/benchmark/VectorBenchmarks.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/utils/src/main/java/org/apache/mahout/benchmark/VectorBenchmarks.java?rev=912655&r1=912654&r2=912655&view=diff ============================================================================== --- lucene/mahout/trunk/utils/src/main/java/org/apache/mahout/benchmark/VectorBenchmarks.java (original) +++ lucene/mahout/trunk/utils/src/main/java/org/apache/mahout/benchmark/VectorBenchmarks.java Mon Feb 22 17:01:56 2010 @@ -18,6 +18,7 @@ package org.apache.mahout.benchmark; import java.util.ArrayList; +import java.util.BitSet; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -58,25 +59,41 @@ private final Vector[][] vectors; private final List<Vector> randomVectors = new ArrayList<Vector>(); + private final List<int[]> randomVectorIndices = new ArrayList<int[]>(); + private final List<double[]> randomVectorValues = new ArrayList<double[]>(); private final int cardinality; + private final int sparsity; private final int numVectors; private final int loop; private final int opsPerUnit; private final Map<String,Integer> implType = new HashMap<String,Integer>(); private final Map<String,List<String[]>> statsMap = new HashMap<String,List<String[]>>(); - public VectorBenchmarks(int cardinality, int numVectors, int loop, int opsPerUnit) { + public VectorBenchmarks(int cardinality, int sparsity, int numVectors, int loop, int opsPerUnit) { Random r = RandomUtils.getRandom(); this.cardinality = cardinality; + this.sparsity = sparsity; this.numVectors = numVectors; this.loop = loop; this.opsPerUnit = opsPerUnit; for (int i = 0; i < numVectors; i++) { - Vector v = new DenseVector(cardinality); - for (int j = 0; j < cardinality; j++) { + Vector v = new SequentialAccessSparseVector(cardinality, sparsity); // sparsity! + BitSet featureSpace = new BitSet(cardinality); + int[] indexes = new int[sparsity]; + double[] values = new double[sparsity]; + int j = 0; + while (j < sparsity) { double value = r.nextGaussian(); - v.set(j, value); + int index = r.nextInt(cardinality); + if (featureSpace.get(index) == false) { + featureSpace.set(index); + indexes[j] = index; + values[j++] = value; + v.set(index, value); + } } + randomVectorIndices.add(indexes); + randomVectorValues.add(values); randomVectors.add(v); } vectors = new Vector[3][numVectors]; @@ -96,7 +113,7 @@ String implName, String content, int multiplier) { - float speed = multiplier * loop * numVectors * cardinality * 1000.0f * 12 / stats.getSumTime(); + float speed = multiplier * loop * numVectors * sparsity * 1000.0f * 12 / stats.getSumTime(); float opsPerSec = loop * numVectors * 1000000000.0f / stats.getSumTime(); log.info("{} {} \n{} {} \nSpeed: {} UnitsProcessed/sec {} MBytes/sec ", new Object[] {benchmarkName, implName, content, stats.toString(), opsPerSec, speed}); @@ -125,7 +142,7 @@ call.end(); } } - printStats(stats, "Create", "DenseVector"); + printStats(stats, "Create (copy)", "DenseVector"); stats = new TimingStatistics(); for (int l = 0; l < loop; l++) { @@ -135,7 +152,7 @@ call.end(); } } - printStats(stats, "Create", "RandomAccessSparseVector"); + printStats(stats, "Create (copy)", "RandSparseVector"); stats = new TimingStatistics(); for (int l = 0; l < loop; l++) { @@ -145,9 +162,64 @@ call.end(); } } - printStats(stats, "Create", "SequentialAccessSparseVector"); + printStats(stats, "Create (copy)", "SeqSparseVector"); } + + private void buildVectorIncrementally(TimingStatistics stats, int randomIndex, Vector v, boolean useSetQuick) { + int[] indexes = randomVectorIndices.get(randomIndex); + double[] values = randomVectorValues.get(randomIndex); + List<Integer> randomOrder = new ArrayList<Integer>(); + for(int i=0; i<indexes.length; i++) { + randomOrder.add(i); + } + Collections.shuffle(randomOrder); + int[] permutation = new int[randomOrder.size()]; + for(int i=0; i<randomOrder.size(); i++) { + permutation[i] = randomOrder.get(i); + } + + TimingStatistics.Call call = stats.newCall(); + if(useSetQuick) { + for(int i : permutation) { + v.setQuick(indexes[i], values[i]); + } + } else { + for(int i : permutation) { + v.set(indexes[i], values[i]); + } + } + call.end(); + } + + public void incrementalCreateBenchmark() { + TimingStatistics stats = new TimingStatistics(); + for (int l = 0; l < loop; l++) { + for (int i = 0; i < numVectors; i++) { + vectors[0][i] = new DenseVector(cardinality); + buildVectorIncrementally(stats, i, vectors[0][i], false); + } + } + printStats(stats, "Create (incrementally)", "DenseVector"); + + stats = new TimingStatistics(); + for (int l = 0; l < loop; l++) { + for (int i = 0; i < numVectors; i++) { + vectors[1][i] = new RandomAccessSparseVector(cardinality); + buildVectorIncrementally(stats, i, vectors[1][i], false); + } + } + printStats(stats, "Create (incrementally)", "RandSparseVector"); + + stats = new TimingStatistics(); + for (int l = 0; l < loop; l++) { + for (int i = 0; i < numVectors; i++) { + vectors[2][i] = new SequentialAccessSparseVector(cardinality); + buildVectorIncrementally(stats, i, vectors[2][i], false); + } + } + printStats(stats, "Create (incrementally)", "SeqSparseVector"); + } public void cloneBenchmark() { TimingStatistics stats = new TimingStatistics(); @@ -168,7 +240,7 @@ call.end(); } } - printStats(stats, "Clone", "RandomAccessSparseVector"); + printStats(stats, "Clone", "RandSparseVector"); stats = new TimingStatistics(); for (int l = 0; l < loop; l++) { @@ -178,7 +250,7 @@ call.end(); } } - printStats(stats, "Clone", "SequentialAccessSparseVector"); + printStats(stats, "Clone", "SeqSparseVector"); } @@ -194,7 +266,7 @@ } // print result to prevent hotspot from eliminating deadcode printStats(stats, "DotProduct", "DenseVector", "sum = " + result + ' '); - + result = 0; stats = new TimingStatistics(); for (int l = 0; l < loop; l++) { for (int i = 0; i < numVectors; i++) { @@ -204,8 +276,8 @@ } } // print result to prevent hotspot from eliminating deadcode - printStats(stats, "DotProduct", "RandomAccessSparseVector", "sum = " + result + ' '); - + printStats(stats, "DotProduct", "RandSparseVector", "sum = " + result + ' '); + result = 0; stats = new TimingStatistics(); for (int l = 0; l < loop; l++) { for (int i = 0; i < numVectors; i++) { @@ -215,11 +287,78 @@ } } // print result to prevent hotspot from eliminating deadcode - printStats(stats, "DotProduct", "SequentialAccessSparseVector", "sum = " + result + ' '); - + printStats(stats, "DotProduct", "SeqSparseVector", "sum = " + result + ' '); + result = 0; + stats = new TimingStatistics(); + for (int l = 0; l < loop; l++) { + for (int i = 0; i < numVectors; i++) { + TimingStatistics.Call call = stats.newCall(); + result += vectors[0][i].dot(vectors[1][(i + 1) % numVectors]); + call.end(); + } + } + // print result to prevent hotspot from eliminating deadcode + printStats(stats, "DotProduct", "Dense.dot(Rand)", "sum = " + result + ' '); + result = 0; + stats = new TimingStatistics(); + for (int l = 0; l < loop; l++) { + for (int i = 0; i < numVectors; i++) { + TimingStatistics.Call call = stats.newCall(); + result += vectors[0][i].dot(vectors[2][(i + 1) % numVectors]); + call.end(); + } + } + // print result to prevent hotspot from eliminating deadcode + printStats(stats, "DotProduct", "Dense.dot(Seq)", "sum = " + result + ' '); + result = 0; + stats = new TimingStatistics(); + for (int l = 0; l < loop; l++) { + for (int i = 0; i < numVectors; i++) { + TimingStatistics.Call call = stats.newCall(); + result += vectors[1][i].dot(vectors[0][(i + 1) % numVectors]); + call.end(); + } + } + // print result to prevent hotspot from eliminating deadcode + printStats(stats, "DotProduct", "Rand.dot(Dense)", "sum = " + result + ' '); + result = 0; + stats = new TimingStatistics(); + for (int l = 0; l < loop; l++) { + for (int i = 0; i < numVectors; i++) { + TimingStatistics.Call call = stats.newCall(); + result += vectors[1][i].dot(vectors[2][(i + 1) % numVectors]); + call.end(); + } + } + // print result to prevent hotspot from eliminating deadcode + printStats(stats, "DotProduct", "Rand.dot(Seq)", "sum = " + result + ' '); + result = 0; + stats = new TimingStatistics(); + for (int l = 0; l < loop; l++) { + for (int i = 0; i < numVectors; i++) { + TimingStatistics.Call call = stats.newCall(); + result += vectors[2][i].dot(vectors[0][(i + 1) % numVectors]); + call.end(); + } + } + // print result to prevent hotspot from eliminating deadcode + printStats(stats, "DotProduct", "Seq.dot(Dense)", "sum = " + result + ' '); + result = 0; + stats = new TimingStatistics(); + for (int l = 0; l < loop; l++) { + for (int i = 0; i < numVectors; i++) { + TimingStatistics.Call call = stats.newCall(); + result += vectors[2][i].dot(vectors[1][(i + 1) % numVectors]); + call.end(); + } + } + // print result to prevent hotspot from eliminating deadcode + printStats(stats, "DotProduct", "Seq.dot(Rand)", "sum = " + result + ' '); + + } - public void distanceMeasureBenchark(DistanceMeasure measure) { + public void distanceMeasureBenchmark(DistanceMeasure measure) { double result = 0; TimingStatistics stats = new TimingStatistics(); for (int l = 0; l < loop; l++) { @@ -238,7 +377,7 @@ } // print result to prevent hotspot from eliminating deadcode printStats(stats, measure.getClass().getName(), "DenseVector", "minDistance = " + result + ' '); - + result = 0; stats = new TimingStatistics(); for (int l = 0; l < loop; l++) { for (int i = 0; i < numVectors; i++) { @@ -255,8 +394,9 @@ } } // print result to prevent hotspot from eliminating deadcode - printStats(stats, measure.getClass().getName(), "RandomAccessSparseVector", "minDistance = " + result + printStats(stats, measure.getClass().getName(), "RandSparseVector", "minDistance = " + result + ' '); + result = 0; stats = new TimingStatistics(); for (int l = 0; l < loop; l++) { for (int i = 0; i < numVectors; i++) { @@ -273,7 +413,7 @@ } } // print result to prevent hotspot from eliminating deadcode - printStats(stats, measure.getClass().getName(), "SequentialAccessSparseVector", "minDistance = " + result + printStats(stats, measure.getClass().getName(), "SeqSparseVector", "minDistance = " + result + ' '); } @@ -287,6 +427,9 @@ Option vectorSizeOpt = obuilder.withLongName("vectorSize").withRequired(false).withArgument( abuilder.withName("vs").withMinimum(1).withMaximum(1).create()).withDescription( "Cardinality of the vector. Default 1000").withShortName("vs").create(); + Option vectorSparsityOpt = obuilder.withLongName("sparsity").withRequired(false).withArgument( + abuilder.withName("sp").withMinimum(1).withMaximum(1).create()).withDescription( + "Sparsity of the vector. Default 1000").withShortName("sp").create(); Option numVectorsOpt = obuilder.withLongName("numVectors").withRequired(false).withArgument( abuilder.withName("nv").withMinimum(1).withMaximum(1).create()).withDescription( "Number of Vectors to create. Default: 100").withShortName("nv").create(); @@ -301,8 +444,8 @@ Option helpOpt = DefaultOptionCreator.helpOption(); - Group group = gbuilder.withName("Options").withOption(vectorSizeOpt).withOption(numVectorsOpt) - .withOption(loopOpt).withOption(numOpsOpt).withOption(helpOpt).create(); + Group group = gbuilder.withName("Options").withOption(vectorSizeOpt).withOption(vectorSparsityOpt) + .withOption(numVectorsOpt).withOption(loopOpt).withOption(numOpsOpt).withOption(helpOpt).create(); try { Parser parser = new Parser(); @@ -319,12 +462,18 @@ cardinality = Integer.parseInt((String) cmdLine.getValue(vectorSizeOpt)); } + + int sparsity = 1000; + if (cmdLine.hasOption(vectorSparsityOpt)) { + sparsity = Integer.parseInt((String) cmdLine.getValue(vectorSparsityOpt)); + } + int numVectors = 100; if (cmdLine.hasOption(numVectorsOpt)) { numVectors = Integer.parseInt((String) cmdLine.getValue(numVectorsOpt)); } - int loop = 200; + int loop = 600; if (cmdLine.hasOption(loopOpt)) { loop = Integer.parseInt((String) cmdLine.getValue(loopOpt)); @@ -334,15 +483,16 @@ numOps = Integer.parseInt((String) cmdLine.getValue(numOpsOpt)); } - VectorBenchmarks mark = new VectorBenchmarks(cardinality, numVectors, loop, numOps); + VectorBenchmarks mark = new VectorBenchmarks(cardinality, sparsity, numVectors, loop, numOps); mark.createBenchmark(); + mark.incrementalCreateBenchmark(); mark.cloneBenchmark(); mark.dotBenchmark(); - mark.distanceMeasureBenchark(new CosineDistanceMeasure()); - mark.distanceMeasureBenchark(new SquaredEuclideanDistanceMeasure()); - mark.distanceMeasureBenchark(new EuclideanDistanceMeasure()); - mark.distanceMeasureBenchark(new ManhattanDistanceMeasure()); - mark.distanceMeasureBenchark(new TanimotoDistanceMeasure()); + mark.distanceMeasureBenchmark(new CosineDistanceMeasure()); + mark.distanceMeasureBenchmark(new SquaredEuclideanDistanceMeasure()); + mark.distanceMeasureBenchmark(new EuclideanDistanceMeasure()); + mark.distanceMeasureBenchmark(new ManhattanDistanceMeasure()); + mark.distanceMeasureBenchmark(new TanimotoDistanceMeasure()); log.info("\n{}", mark.summarize()); } catch (OptionException e) { @@ -353,13 +503,13 @@ @Override public String summarize() { - int pad = 30; + int pad = 24; StringBuilder sb = new StringBuilder(1000); sb.append(StringUtils.rightPad("BenchMarks", pad)); for (int i = 0; i < implType.size(); i++) { for (Entry<String,Integer> e : implType.entrySet()) { if (e.getValue() == i) { - sb.append(StringUtils.rightPad(e.getKey(), pad)); + sb.append(StringUtils.rightPad(e.getKey(), pad).substring(0, pad)); break; } }