Author: tdunning
Date: Mon Feb 4 23:51:58 2013
New Revision: 1442436
URL: http://svn.apache.org/viewvc?rev=1442436&view=rev
Log:
MAHOUT-1148 - Demonstrate speedup
Added:
mahout/trunk/math/src/main/java/org/apache/mahout/math/QR.java
Modified:
mahout/trunk/math/src/main/java/org/apache/mahout/math/OldQRDecomposition.java
mahout/trunk/math/src/main/java/org/apache/mahout/math/QRDecomposition.java
mahout/trunk/math/src/test/java/org/apache/mahout/math/QRDecompositionTest.java
Modified:
mahout/trunk/math/src/main/java/org/apache/mahout/math/OldQRDecomposition.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/OldQRDecomposition.java?rev=1442436&r1=1442435&r2=1442436&view=diff
==============================================================================
---
mahout/trunk/math/src/main/java/org/apache/mahout/math/OldQRDecomposition.java
(original)
+++
mahout/trunk/math/src/main/java/org/apache/mahout/math/OldQRDecomposition.java
Mon Feb 4 23:51:58 2013
@@ -41,7 +41,7 @@ import java.util.Locale;
*/
/** partially deprecated until unit tests are in place. Until this time, this
class/interface is unsupported. */
-public class OldQRDecomposition {
+public class OldQRDecomposition implements QR {
/** Array for internal storage of decomposition. */
private final Matrix qr;
Added: mahout/trunk/math/src/main/java/org/apache/mahout/math/QR.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/QR.java?rev=1442436&view=auto
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/QR.java (added)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/QR.java Mon Feb 4
23:51:58 2013
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.package org.apache.mahout.math;
+ */
+package org.apache.mahout.math;
+
+public interface QR {
+ Matrix getQ();
+
+ Matrix getR();
+
+ boolean hasFullRank();
+
+ Matrix solve(Matrix B);
+}
Modified:
mahout/trunk/math/src/main/java/org/apache/mahout/math/QRDecomposition.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/QRDecomposition.java?rev=1442436&r1=1442435&r2=1442436&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/QRDecomposition.java
(original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/QRDecomposition.java
Mon Feb 4 23:51:58 2013
@@ -23,10 +23,8 @@
*/
package org.apache.mahout.math;
-import com.google.common.collect.Lists;
import org.apache.mahout.math.function.Functions;
-import java.util.List;
import java.util.Locale;
@@ -42,8 +40,7 @@ import java.util.Locale;
returns <tt>false</tt>.
*/
-public class QRDecomposition {
- private static final int N = 10;
+public class QRDecomposition implements QR {
private final Matrix q, r;
private final boolean fullRank;
private final int rows;
@@ -111,6 +108,7 @@ public class QRDecomposition {
*
* @return <tt>Q</tt>
*/
+ @Override
public Matrix getQ() {
return q;
}
@@ -120,6 +118,7 @@ public class QRDecomposition {
*
* @return <tt>R</tt>
*/
+ @Override
public Matrix getR() {
return r;
}
@@ -129,6 +128,7 @@ public class QRDecomposition {
*
* @return true if <tt>R</tt>, and hence <tt>A</tt>, has full rank.
*/
+ @Override
public boolean hasFullRank() {
return fullRank;
}
@@ -140,6 +140,7 @@ public class QRDecomposition {
* @return <tt>X</tt> that minimizes the two norm of <tt>Q*R*X - B</tt>.
* @throws IllegalArgumentException if <tt>B.rows() != A.rows()</tt>.
*/
+ @Override
public Matrix solve(Matrix B) {
if (B.numRows() != rows) {
throw new IllegalArgumentException("Matrix row dimensions must agree.");
@@ -175,37 +176,4 @@ public class QRDecomposition {
public String toString() {
return String.format(Locale.ENGLISH, "QR(%d x %d,fullRank=%s)", rows,
columns, hasFullRank());
}
-
- public static void main(String[] args) {
- Matrix a = new DenseMatrix(60, 60).assign(Functions.random());
-
- int n = 0;
- List<Integer> counts = Lists.newArrayList(10, 20, 50, 100, 200, 500, 1000,
2000, 5000);
- for (int k : counts) {
- double warmup = 0;
- double other = 0;
-
- n += k;
- for (int i = 0; i < k; i++) {
- QRDecomposition qr = new QRDecomposition(a);
- warmup = Math.max(warmup,
qr.getQ().transpose().times(qr.getQ()).viewDiagonal().assign(Functions.plus(-1)).norm(1));
- Matrix z = qr.getQ().times(qr.getR()).minus(a);
- other = Math.max(other, z.aggregate(Functions.MIN, Functions.ABS));
- }
-
- double maxIdent = 0;
- double maxError = 0;
-
- long t0 = System.nanoTime();
- for (int i = 0; i < N; i++) {
- QRDecomposition qr = new QRDecomposition(a);
-
- maxIdent = Math.max(maxIdent,
qr.getQ().transpose().times(qr.getQ()).viewDiagonal().assign(Functions.plus(-1)).norm(1));
- Matrix z = qr.getQ().times(qr.getR()).minus(a);
- maxError = Math.max(maxError, z.aggregate(Functions.MIN,
Functions.ABS));
- }
- System.out.printf("%d\t%.1f\t%g\t%g\t%g\n", n, (System.nanoTime() - t0)
/ 1e3 / N, maxIdent, maxError, warmup);
-// System.out.printf("%g, %g\n", maxIdent, maxError);
- }
- }
}
Modified:
mahout/trunk/math/src/test/java/org/apache/mahout/math/QRDecompositionTest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/QRDecompositionTest.java?rev=1442436&r1=1442435&r2=1442436&view=diff
==============================================================================
---
mahout/trunk/math/src/test/java/org/apache/mahout/math/QRDecompositionTest.java
(original)
+++
mahout/trunk/math/src/test/java/org/apache/mahout/math/QRDecompositionTest.java
Mon Feb 4 23:51:58 2013
@@ -17,10 +17,14 @@
package org.apache.mahout.math;
+import com.google.common.collect.Lists;
import org.apache.mahout.math.function.DoubleDoubleFunction;
import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.stats.OnlineSummarizer;
import org.junit.Test;
+import java.util.List;
+
public final class QRDecompositionTest extends MahoutTestCase {
@Test
public void randomMatrix() {
@@ -156,6 +160,71 @@ public final class QRDecompositionTest e
assertEquals(x, qr.getQ().times(qr.getR()), 1e-15);
}
+ @Test
+ public void fasterThanBefore() {
+
+ OnlineSummarizer s1 = new OnlineSummarizer();
+ OnlineSummarizer s2 = new OnlineSummarizer();
+
+ Matrix a = new DenseMatrix(60, 60).assign(Functions.random());
+
+ decompositionSpeedCheck(new Decomposer() {
+ @Override
+ public QR decompose(Matrix a) {
+ return new QRDecomposition(a);
+ }
+ }, s1, a, "new");
+
+ decompositionSpeedCheck(new Decomposer() {
+ @Override
+ public QR decompose(Matrix a) {
+ return new OldQRDecomposition(a);
+ }
+ }, s2, a, "old");
+
+ // should be much more than twice as fast.
+ System.out.printf("Speedup is about %.1f times\n", s2.getMedian() /
s1.getMedian());
+ assertTrue(s1.getMedian() < 0.5 * s2.getMedian());
+ }
+
+ private interface Decomposer {
+ public QR decompose(Matrix a);
+ }
+
+ private void decompositionSpeedCheck(Decomposer qrf, OnlineSummarizer s1,
Matrix a, String label) {
+ int n = 0;
+ List<Integer> counts = Lists.newArrayList(10, 20, 50, 100, 200, 500);
+ for (int k : counts) {
+ double warmup = 0;
+ double other = 0;
+
+ n += k;
+ for (int i = 0; i < k; i++) {
+ QR qr = qrf.decompose(a);
+ warmup = Math.max(warmup,
qr.getQ().transpose().times(qr.getQ()).viewDiagonal().assign(Functions.plus(-1)).norm(1));
+ Matrix z = qr.getQ().times(qr.getR()).minus(a);
+ other = Math.max(other, z.aggregate(Functions.MIN, Functions.ABS));
+ }
+
+ double maxIdent = 0;
+ double maxError = 0;
+
+ long t0 = System.nanoTime();
+ for (int i = 0; i < n; i++) {
+ QR qr = qrf.decompose(a);
+
+ maxIdent = Math.max(maxIdent,
qr.getQ().transpose().times(qr.getQ()).viewDiagonal().assign(Functions.plus(-1)).norm(1));
+ Matrix z = qr.getQ().times(qr.getR()).minus(a);
+ maxError = Math.max(maxError, z.aggregate(Functions.MIN,
Functions.ABS));
+ }
+ long t1 = System.nanoTime();
+ if (k > 100) {
+ s1.add(t1 - t0);
+ }
+ System.out.printf("%s %d\t%.1f\t%g\t%g\t%g\n", label, n, (t1 - t0) / 1e3
/ n, maxIdent, maxError, warmup);
+ }
+ }
+
private static void assertEquals(Matrix ref, Matrix actual, double epsilon) {
assertEquals(0, ref.minus(actual).aggregate(Functions.MAX, Functions.ABS),
epsilon);
}