Author: jeastman
Date: Wed May 13 22:30:23 2009
New Revision: 774566
URL: http://svn.apache.org/viewvc?rev=774566&view=rev
Log:
- implemented SparseVector.times optimizations suggested by MAHOUT-66
- implemented unit test thereof which demonstrates 5-10ms improvement when used
with
50,000 cardinality, 1000 random element vectors typical of Text clustering
- SparseVector.optimesTimes = true is the default; consider removing it all
later
Modified:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SparseVector.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SquareRootFunction.java
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/VectorTest.java
Modified:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SparseVector.java
URL:
http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SparseVector.java?rev=774566&r1=774565&r2=774566&view=diff
==============================================================================
---
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SparseVector.java
(original)
+++
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SparseVector.java
Wed May 13 22:30:23 2009
@@ -39,9 +39,10 @@
private Map<Integer, Double> values;
-
private int cardinality;
+ public static boolean optimizeTimes = true;
+
/**
* Decode a new instance from the argument
*
@@ -96,19 +97,22 @@
}
@Override
- @SuppressWarnings("unchecked")
+ @SuppressWarnings("unchecked")
public String asFormatString() {
StringBuilder out = new StringBuilder();
out.append("[s").append(cardinality).append(", ");
- Map.Entry<Integer, Double>[] entries = (Map.Entry<Integer, Double>[])
values.entrySet().toArray(new Map.Entry[values.size()]);
- Arrays.sort(entries, new Comparator<Map.Entry<Integer, Double>>(){
+ Map.Entry<Integer, Double>[] entries = (Map.Entry<Integer, Double>[])
values
+ .entrySet().toArray(new Map.Entry[values.size()]);
+ Arrays.sort(entries, new Comparator<Map.Entry<Integer, Double>>() {
@Override
- public int compare(Map.Entry<Integer, Double> e1, Map.Entry<Integer,
Double> e2) {
+ public int compare(Map.Entry<Integer, Double> e1,
+ Map.Entry<Integer, Double> e2) {
return e1.getKey().compareTo(e2.getKey());
}
});
for (Map.Entry<Integer, Double> entry : entries) {
-
out.append(entry.getKey()).append(':').append(entry.getValue()).append(", ");
+ out.append(entry.getKey()).append(':').append(entry.getValue()).append(
+ ", ");
}
out.append("] ");
return out.toString();
@@ -188,15 +192,17 @@
return new Iterator();
}
-
@Override
public boolean equals(Object o) {
- if (this == o) return true;
- if (o == null || getClass() != o.getClass()) return false;
+ if (this == o)
+ return true;
+ if (o == null || getClass() != o.getClass())
+ return false;
SparseVector that = (SparseVector) o;
- return cardinality == that.cardinality && (values == null ? that.values ==
null : values.equals(that.values));
+ return cardinality == that.cardinality
+ && (values == null ? that.values == null : values.equals(that.values));
}
@Override
@@ -273,4 +279,42 @@
this.values = values;
}
+ @Override
+ public Vector times(double x) {
+ Vector result;
+ if (optimizeTimes) {
+ result = like();
+ for (Vector.Element element : this) {
+ double value = element.get();
+ int index = element.index();
+ result.setQuick(index, value * x);
+ }
+ } else {
+ result = copy();
+ for (int i = 0; i < result.cardinality(); i++)
+ result.setQuick(i, getQuick(i) * x);
+ }
+ return result;
+ }
+
+ @Override
+ public Vector times(Vector x) {
+ if (cardinality() != x.cardinality())
+ throw new CardinalityException();
+ Vector result;
+ if (optimizeTimes) {
+ result = like();
+ for (Vector.Element element : this) {
+ double value = element.get();
+ int index = element.index();
+ result.setQuick(index, value * x.getQuick(index));
+ }
+ } else {
+ result = copy();
+ for (int i = 0; i < result.cardinality(); i++)
+ result.setQuick(i, getQuick(i) * x.getQuick(i));
+ }
+ return result;
+ }
+
}
Modified:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SquareRootFunction.java
URL:
http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SquareRootFunction.java?rev=774566&r1=774565&r2=774566&view=diff
==============================================================================
---
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SquareRootFunction.java
(original)
+++
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SquareRootFunction.java
Wed May 13 22:30:23 2009
@@ -4,7 +4,7 @@
@Override
public double apply(double arg1) {
- return Math.sqrt(arg1);
+ return Math.abs(arg1);
}
}
Modified:
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/VectorTest.java
URL:
http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/VectorTest.java?rev=774566&r1=774565&r2=774566&view=diff
==============================================================================
---
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/VectorTest.java
(original)
+++
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/VectorTest.java
Wed May 13 22:30:23 2009
@@ -17,6 +17,9 @@
package org.apache.mahout.matrix;
+import java.util.Date;
+import java.util.Random;
+
import junit.framework.TestCase;
public class VectorTest extends TestCase {
@@ -42,7 +45,6 @@
assertEquals(result + " does not equal: " + 32, 32.0, result);
}
-
public void testDenseVector() throws Exception {
DenseVector vec1 = new DenseVector(3);
DenseVector vec2 = new DenseVector(3);
@@ -67,17 +69,18 @@
test[e.index()] = e.get();
}
- for (int i = 0; i<test.length; i++) {
+ for (int i = 0; i < test.length; i++) {
assertEquals(apriori[i], test[i]);
}
}
public void testEnumeration() throws Exception {
- double[] apriori = {0, 1, 2, 3, 4};
+ double[] apriori = { 0, 1, 2, 3, 4 };
+
+ doTestEnumeration(apriori, new VectorView(new DenseVector(new double[] {
+ -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }), 2, 5));
- doTestEnumeration(apriori, new VectorView(new DenseVector(new double[]{-2,
-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}), 2, 5));
-
- doTestEnumeration(apriori, new DenseVector(new double[]{0, 1, 2, 3, 4}));
+ doTestEnumeration(apriori, new DenseVector(new double[] { 0, 1, 2, 3, 4
}));
SparseVector sparse = new SparseVector(5);
sparse.set(0, 0);
@@ -88,4 +91,59 @@
doTestEnumeration(apriori, sparse);
}
-}
\ No newline at end of file
+ public void testSparseVectorTimesX() {
+ Random rnd = new Random();
+ Vector v1 = randomSparseVector(rnd);
+ double x = rnd.nextDouble();
+ long t0 = new Date().getTime();
+ SparseVector.optimizeTimes = false;
+ Vector rRef = null;
+ for (int i = 0; i < 10; i++)
+ rRef = v1.times(x);
+ long t1 = new Date().getTime();
+ SparseVector.optimizeTimes = true;
+ Vector rOpt = null;
+ for (int i = 0; i < 10; i++)
+ rOpt = v1.times(x);
+ long t2 = new Date().getTime();
+ long tOpt = t2 - t1;
+ long tRef = t1 - t0;
+ assertTrue(tOpt < tRef);
+ System.out.println("testSparseVectorTimesX tRef=tOpt=" + (tRef - tOpt)
+ + " ms for 10 iterations");
+ for (int i = 0; i < 50000; i++)
+ assertEquals("i=" + i, rRef.getQuick(i), rOpt.getQuick(i));
+ }
+
+ public void testSparseVectorTimesV() {
+ Random rnd = new Random();
+ Vector v1 = randomSparseVector(rnd);
+ Vector v2 = randomSparseVector(rnd);
+ long t0 = new Date().getTime();
+ SparseVector.optimizeTimes = false;
+ Vector rRef = null;
+ for (int i = 0; i < 10; i++)
+ rRef = v1.times(v2);
+ long t1 = new Date().getTime();
+ SparseVector.optimizeTimes = true;
+ Vector rOpt = null;
+ for (int i = 0; i < 10; i++)
+ rOpt = v1.times(v2);
+ long t2 = new Date().getTime();
+ long tOpt = t2 - t1;
+ long tRef = t1 - t0;
+ assertTrue(tOpt < tRef);
+ System.out.println("testSparseVectorTimesV tRef=tOpt=" + (tRef - tOpt)
+ + " ms for 10 iterations");
+ for (int i = 0; i < 50000; i++)
+ assertEquals("i=" + i, rRef.getQuick(i), rOpt.getQuick(i));
+ }
+
+ private Vector randomSparseVector(Random rnd) {
+ SparseVector v1 = new SparseVector(50000);
+ for (int i = 0; i < 1000; i++)
+ v1.setQuick((int) (rnd.nextDouble() * 50000), rnd.nextDouble());
+ return v1;
+ }
+
+}