Author: tdunning
Date: Thu Aug 12 22:36:36 2010
New Revision: 985022
URL: http://svn.apache.org/viewvc?rev=985022&view=rev
Log:
MAHOUT-469 ... added QR tests, modified to standard linear algebra, fixed
exception message. getH still deprecated in QRDecomposition.
Added:
mahout/trunk/math/src/main/java/org/apache/mahout/math/QRDecomposition.java
mahout/trunk/math/src/test/java/org/apache/mahout/math/QRDecompositionTest.java
Modified:
mahout/trunk/math/src/main/java/org/apache/mahout/math/Algebra.java
mahout/trunk/math/src/main/java/org/apache/mahout/math/IndexException.java
Modified: mahout/trunk/math/src/main/java/org/apache/mahout/math/Algebra.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/Algebra.java?rev=985022&r1=985021&r2=985022&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/Algebra.java
(original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/Algebra.java Thu Aug
12 22:36:36 2010
@@ -16,8 +16,7 @@
*/
package org.apache.mahout.math;
-import org.apache.mahout.math.Matrix;
-import org.apache.mahout.math.Vector;
+
public class Algebra {
Modified:
mahout/trunk/math/src/main/java/org/apache/mahout/math/IndexException.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/IndexException.java?rev=985022&r1=985021&r2=985022&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/IndexException.java
(original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/IndexException.java
Thu Aug 12 22:36:36 2010
@@ -24,7 +24,7 @@ package org.apache.mahout.math;
public class IndexException extends IllegalArgumentException {
public IndexException(int index, int cardinality) {
- super("Index " + index + " is outside allowable range of [0," +
cardinality + ']');
+ super("Index " + index + " is outside allowable range of [0," +
cardinality + ')');
}
}
Added:
mahout/trunk/math/src/main/java/org/apache/mahout/math/QRDecomposition.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/QRDecomposition.java?rev=985022&view=auto
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/QRDecomposition.java
(added)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/QRDecomposition.java
Thu Aug 12 22:36:36 2010
@@ -0,0 +1,247 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+Copyright ? 1999 CERN - European Organization for Nuclear Research.
+Permission to use, copy, modify, distribute and sell this software and its
documentation for any purpose
+is hereby granted without fee, provided that the above copyright notice appear
in all copies and
+that both that copyright notice and this permission notice appear in
supporting documentation.
+CERN makes no representations about the suitability of this software for any
purpose.
+It is provided "as is" without expressed or implied warranty.
+*/
+package org.apache.mahout.math;
+
+import org.apache.mahout.math.function.Functions;
+
+
+/**
+ For an <tt>m x n</tt> matrix <tt>A</tt> with <tt>m >= n</tt>, the QR
decomposition is an <tt>m x n</tt>
+ orthogonal matrix <tt>Q</tt> and an <tt>n x n</tt> upper triangular matrix
<tt>R</tt> so that
+ <tt>A = Q*R</tt>.
+ <P>
+ The QR decompostion always exists, even if the matrix does not have
+ full rank, so the constructor will never fail. The primary use of the
+ QR decomposition is in the least squares solution of nonsquare systems
+ of simultaneous linear equations. This will fail if <tt>isFullRank()</tt>
+ returns <tt>false</tt>.
+ */
+
+/** partially deprecated until unit tests are in place. Until this time, this
class/interface is unsupported. */
+public class QRDecomposition {
+
+ /** Array for internal storage of decomposition. */
+ private final Matrix QR;
+
+ /** Row and column dimensions. */
+ private final int originalRows, originalColumns;
+
+ /** Array for internal storage of diagonal of R. */
+ private final Vector Rdiag;
+
+ /**
+ * Constructs and returns a new QR decomposition object; computed by
Householder reflections; The decomposed matrices
+ * can be retrieved via instance methods of the returned decomposition
object.
+ *
+ * @param A A rectangular matrix.
+ * @throws IllegalArgumentException if <tt>A.rows() < A.columns()</tt>.
+ */
+
+ public QRDecomposition(Matrix A) {
+
+ // Initialize.
+ QR = A.clone();
+ originalRows = A.numRows();
+ originalColumns = A.numCols();
+ Rdiag = new DenseVector(originalColumns);
+
+ // precompute and cache some views to avoid regenerating them time and
again
+ Vector[] QRcolumnsPart = new Vector[originalColumns];
+ for (int k = 0; k < originalColumns; k++) {
+ QRcolumnsPart[k] = QR.viewColumn(k).viewPart(k, originalRows - k);
+ }
+
+ // Main loop.
+ for (int k = 0; k < originalColumns; k++) {
+ //DoubleMatrix1D QRcolk = QR.viewColumn(k).viewPart(k,m-k);
+ // Compute 2-norm of k-th column without under/overflow.
+ double nrm = 0;
+ //if (k<m) nrm = QRcolumnsPart[k].aggregate(hypot,F.identity);
+
+ for (int i = k; i < originalRows; i++) { // fixes bug reported by
[email protected]
+ nrm = Algebra.hypot(nrm, QR.getQuick(i, k));
+ }
+
+
+ if (nrm != 0.0) {
+ // Form k-th Householder vector.
+ if (QR.getQuick(k, k) < 0) {
+ nrm = -nrm;
+ }
+ QRcolumnsPart[k].assign(Functions.div(nrm));
+ /*
+ for (int i = k; i < m; i++) {
+ QR[i][k] /= nrm;
+ }
+ */
+
+ QR.setQuick(k, k, QR.getQuick(k, k) + 1);
+
+ // Apply transformation to remaining columns.
+ for (int j = k + 1; j < originalColumns; j++) {
+ Vector QRcolj = QR.viewColumn(j).viewPart(k, originalRows - k);
+ double s = QRcolumnsPart[k].dot(QRcolj);
+ /*
+ // fixes bug reported by John Chambers
+ DoubleMatrix1D QRcolj = QR.viewColumn(j).viewPart(k,m-k);
+ double s = QRcolumnsPart[k].zDotProduct(QRcolumns[j]);
+ double s = 0.0;
+ for (int i = k; i < m; i++) {
+ s += QR[i][k]*QR[i][j];
+ }
+ */
+ s = -s / QR.getQuick(k, k);
+ //QRcolumnsPart[j].assign(QRcolumns[k], F.plusMult(s));
+
+ for (int i = k; i < originalRows; i++) {
+ QR.setQuick(i, j, QR.getQuick(i, j) + s * QR.getQuick(i, k));
+ }
+
+ }
+ }
+ Rdiag.setQuick(k, -nrm);
+ }
+ }
+
+ /**
+ * Returns the Householder vectors <tt>H</tt>. Not covered by tests yet.
+ *
+ * @return A lower trapezoidal matrix whose columns define the householder
reflections.
+ *
+ */
+ @Deprecated
+ public Matrix getH() {
+ Matrix H = QR.clone();
+ int rows = H.numRows();
+ int columns = H.numCols();
+ for (int i = 0; i < rows; i++) {
+ for (int j = i + 1; j < columns; j++) {
+ H.setQuick(i, j, 0);
+ }
+ }
+ return H;
+ }
+
+ /**
+ * Generates and returns the (economy-sized) orthogonal factor <tt>Q</tt>.
+ *
+ * @return <tt>Q</tt>
+ */
+ public Matrix getQ() {
+ int columns = Math.min(originalColumns, originalRows);
+ Matrix Q = QR.like(originalRows, columns);
+ for (int k = columns - 1; k >= 0; k--) {
+ Vector QRcolk = QR.viewColumn(k).viewPart(k, originalRows - k);
+ Q.set(k, k, 1);
+ for (int j = k; j < columns; j++) {
+ if (QR.get(k, k) != 0) {
+ Vector Qcolj = Q.viewColumn(j).viewPart(k, originalRows - k);
+ double s = -QRcolk.dot(Qcolj) / QR.get(k, k);
+ Qcolj.assign(QRcolk, Functions.plusMult(s));
+ }
+ }
+ }
+ return Q;
+ }
+
+ /**
+ * Returns the upper triangular factor, <tt>R</tt>.
+ *
+ * @return <tt>R</tt>
+ */
+ public Matrix getR() {
+ int rows = Math.min(originalRows, originalColumns);
+ Matrix R = QR.like(rows, originalColumns);
+ for (int i = 0; i < rows; i++) {
+ for (int j = 0; j < originalColumns; j++) {
+ if (i < j) {
+ R.setQuick(i, j, QR.getQuick(i, j));
+ } else if (i == j) {
+ R.setQuick(i, j, Rdiag.getQuick(i));
+ } else {
+ R.setQuick(i, j, 0);
+ }
+ }
+ }
+ return R;
+ }
+
+ /**
+ * Returns whether the matrix <tt>A</tt> has full rank.
+ *
+ * @return true if <tt>R</tt>, and hence <tt>A</tt>, has full rank.
+ */
+ public boolean hasFullRank() {
+ for (int j = 0; j < originalColumns; j++) {
+ if (Rdiag.getQuick(j) == 0) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /**
+ * Least squares solution of <tt>A*X = B</tt>; <tt>returns X</tt>.
+ *
+ * @param B A matrix with as many rows as <tt>A</tt> and any number of
columns.
+ * @return <tt>X</tt> that minimizes the two norm of <tt>Q*R*X - B</tt>.
+ * @throws IllegalArgumentException if <tt>B.rows() != A.rows()</tt>.
+ */
+ public Matrix solve(Matrix B) {
+ if (B.numRows() != originalRows) {
+ throw new IllegalArgumentException("Matrix row dimensions must agree.");
+ }
+
+ int columns = B.numCols();
+ Matrix X = B.like(originalColumns, columns);
+
+ // this can all be done a bit more efficiently if we don't actually
+ // form explicit versions of Q^T and R but this code isn't soo bad
+ // and it is much easier to understand
+ Matrix Qt = getQ().transpose();
+ Matrix Y = Qt.times(B);
+
+ Matrix R = getR();
+ for (int k = Math.min(originalColumns, originalRows) - 1; k >= 0; k--) {
+ // X[k,] = Y[k,] / R[k,k], note that X[k,] starts with 0 so += is same
as =
+ X.viewRow(k).assign(Y.viewRow(k), Functions.plusMult(1 / R.get(k, k)));
+
+ // Y[0:(k-1),] -= R[0:(k-1),k] * X[k,]
+ Vector rColumn = R.viewColumn(k).viewPart(0, k);
+ for (int c = 0; c < columns; c++) {
+ Y.viewColumn(c).viewPart(0, k).assign(rColumn,
Functions.plusMult(-X.get(k, c)));
+ }
+ }
+ return X;
+ }
+
+ /**
+ * Returns a rough string rendition of a QR.
+ */
+ public String toString() {
+ return String.format("QR(%d,%d,fullRank=%s)", originalColumns,
originalRows, hasFullRank());
+ }
+}
Added:
mahout/trunk/math/src/test/java/org/apache/mahout/math/QRDecompositionTest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/QRDecompositionTest.java?rev=985022&view=auto
==============================================================================
---
mahout/trunk/math/src/test/java/org/apache/mahout/math/QRDecompositionTest.java
(added)
+++
mahout/trunk/math/src/test/java/org/apache/mahout/math/QRDecompositionTest.java
Thu Aug 12 22:36:36 2010
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import org.apache.mahout.math.function.Functions;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class QRDecompositionTest {
+ @Test
+ public void fullRankTall() {
+ Matrix x = matrix();
+ QRDecomposition qr = new QRDecomposition(x);
+ Assert.assertTrue(qr.hasFullRank());
+ Matrix rRef = reshape(new double[]{
+ -2.99129686445138, 0, 0, 0, 0,
+ -0.0282260628674372, -2.38850244769059, 0, 0, 0,
+ 0.733739310355871, 1.48042000631646, 2.29051263117895, 0, 0,
+ -0.0394082168269326, 0.282829484207801, -0.00438521041803086,
-2.90823198084203, 0,
+ 0.923669647838536, 1.76679276072492, 0.637690104222683,
-0.225890909498753, -1.35732293800944},
+ 5, 5);
+ Matrix r = qr.getR();
+ assertEquals(rRef, r, 1e-8);
+
+ Matrix qRef = reshape(new double[]{
+ -0.165178287646573, 0.0510035857637869, 0.13985915987379,
-0.120173729496501,
+ -0.453198314345324, 0.644400679630493, -0.503117990820608,
0.24968739845381,
+ 0.323968339146224, -0.465266080134262, 0.276508948773268,
-0.687909700644343,
+ 0.0544048888907195, -0.0166677718378263, 0.171309755790717,
0.310339001630029,
+ 0.674790532821663, 0.0058166082200493, -0.381707516461884,
0.300504956413142,
+ -0.105751091334003, 0.410450870871096, 0.31113446615821,
0.179338172684956,
+ 0.361951807617901, 0.763921725548796, 0.380327892605634,
-0.287274944594054,
+ 0.0311604042556675, 0.0386096858143961, 0.0387156960650472,
-0.232975755728917,
+ 0.0358178276684149, 0.173105775703199, 0.327321867815603,
0.328671945345279,
+ -0.36015879836344, -0.444261660176044, 0.09438499563253,
0.646216148583769
+ }, 8, 5);
+
+ printMatrix("qRef", qRef);
+
+ Matrix q = qr.getQ();
+ printMatrix("q", q);
+
+ assertEquals(qRef, q, 1e-8);
+
+ Matrix x1 = qr.solve(reshape(new double[]{
+ -0.0178247686747641, 0.68631714634098, -0.335464858468858,
1.50249941751569,
+ -0.669901640772149, -0.977025038942455, -1.18857546169856,
-1.24792900492054
+ }, 8, 1));
+ Matrix xref = reshape(new double[]{
+ -0.0127440093664874, 0.655825940180799, -0.100755415991702,
-0.0349559562697406,
+ -0.190744297762028
+ }, 5, 1);
+
+ printMatrix("x1", x1);
+ printMatrix("xref", xref);
+
+ assertEquals(xref, x1, 1e-8);
+ }
+
+ @Test
+ public void fullRankWide() {
+ Matrix x = matrix().transpose();
+ QRDecomposition qr = new QRDecomposition(x);
+ Assert.assertFalse(qr.hasFullRank());
+ Matrix rActual = qr.getR();
+
+ Matrix rRef = reshape(new double[]{
+ -2.42812464965842, 0, 0, 0, 0,
+ 0.303587286111356, -2.91663643494775, 0, 0, 0,
+ -0.201812474153156, -0.765485720168378, 1.09989373598954, 0, 0,
+ 1.47980701097885, -0.637545820524326, -1.55519859337935,
0.844655127991726, 0,
+ 0.0248883129453161, 0.00115010570270549, -0.236340588891252,
-0.092924118200147, 1.42910099545547,
+ -1.1678472412429, 0.531245845248056, 0.351978196071514,
-1.03241474816555, -2.20223861735426,
+ -0.887809959067632, 0.189731251982918, -0.504321849233586,
0.490484123999836, 1.21266692336743,
+ -0.633888169775463, 1.04738559065986, 0.284041239547031,
0.578183510077156, -0.942314870832456
+ }, 5, 8);
+ printMatrix("rRef", rRef);
+ printMatrix("rActual", rActual);
+ assertEquals(rRef, rActual, 1e-8);
+
+ Matrix qRef = reshape(new double[]{
+ -0.203489262374627, 0.316761677948356, -0.784155643293468,
0.394321494579, -0.29641971170211,
+ 0.0311283614803723, -0.34755265020736, 0.137138511478328,
0.848579887681972, 0.373287266507375,
+ -0.39603700561249, -0.787812566647329, -0.377864833067864,
-0.275080943427399, 0.0636764674878229,
+ 0.0763976893309043, -0.318551137554327, 0.286407036668598,
0.206004127289883, -0.876482672226889,
+ 0.89159476695423, -0.238213616975551, -0.376141107880836,
-0.0794701657055114, 0.0227025098210165
+ }, 5, 5);
+
+ Matrix q = qr.getQ();
+
+ printMatrix("qRef", qRef);
+ printMatrix("q", q);
+
+ assertEquals(qRef, q, 1e-8);
+
+ Matrix x1 = qr.solve(b());
+ Matrix xRef = reshape(new double[]{
+ -0.182580239668147, -0.437233627652114, 0.138787653097464,
0.672934739896228, -0.131420217069083, 0, 0, 0
+ }, 8, 1);
+
+ printMatrix("xRef", xRef);
+ printMatrix("x", x1);
+ assertEquals(xRef, x1, 1e-8);
+ }
+
+ private void assertEquals(Matrix ref, Matrix actual, double epsilon) {
+ Assert.assertEquals(0, ref.minus(actual).aggregate(Functions.max,
Functions.abs), epsilon);
+ }
+
+ private void printMatrix(String name, Matrix m) {
+ int rows = m.numRows();
+ int columns = m.numCols();
+ System.out.printf("%s - %d x %d\n", name, rows, columns);
+ for (int i = 0; i < rows; i++) {
+ for (int j = 0; j < columns; j++) {
+ System.out.printf("%10.5f", m.get(i, j));
+ }
+ System.out.printf("\n");
+ }
+ System.out.printf("\n");
+ System.out.printf("\n");
+ }
+
+ private Matrix matrix() {
+ double values[] = {
+ 0.494097293912641, -0.152566866170993, -0.418360266395271,
0.359475300232312,
+ 1.35565069667582, -1.92759373242903, 1.50497526839076,
-0.746889132087904,
+ -0.769136838293565, 1.10984954080986, -0.664389974392489,
1.6464660350229,
+ -0.11715420616969, 0.0216221197371269, -0.394972730980765,
-0.748293157213142,
+ 1.90402764664962, -0.638042862848559, -0.362336344669668,
-0.418261074380526,
+ -0.494211543128429, 1.38828971158414, 0.597110366867923,
1.05341387608687,
+ -0.957461740877418, -2.35528802598249, -1.03171458944128,
0.644319090271635,
+ -0.0569108993041965, -0.14419465550881, -0.0456801828174936,
+ 0.754694392571835, 0.719744008628535, -1.17873249802301,
-0.155887528905918,
+ -1.5159868405466, 0.0918931582603128, 1.42179027361583,
-0.100495054250176,
+ 0.0687986548485584
+ };
+ return reshape(values, 8, 5);
+ }
+
+ private Matrix reshape(double[] values, int rows, int columns) {
+ Matrix m = new DenseMatrix(rows, columns);
+ int i = 0;
+ for (double v : values) {
+ m.set(i % rows, i / rows, v);
+ i++;
+ }
+ return m;
+ }
+
+ private Matrix b() {
+ return reshape(new double[]{-0.0178247686747641, 0.68631714634098,
-0.335464858468858, 1.50249941751569, -0.669901640772149}, 5, 1);
+ }
+}