http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/solver/LSMR.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/solver/LSMR.java b/core/src/main/java/org/apache/mahout/math/solver/LSMR.java new file mode 100644 index 0000000..1f3e706 --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/solver/LSMR.java @@ -0,0 +1,565 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.solver; + +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.function.Functions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Solves sparse least-squares using the LSMR algorithm. + * <p/> + * LSMR solves the system of linear equations A * X = B. If the system is inconsistent, it solves + * the least-squares problem min ||b - Ax||_2. A is a rectangular matrix of dimension m-by-n, where + * all cases are allowed: m=n, m>n, or m<n. B is a vector of length m. The matrix A may be dense + * or sparse (usually sparse). + * <p/> + * Some additional configurable properties adjust the behavior of the algorithm. + * <p/> + * If you set lambda to a non-zero value then LSMR solves the regularized least-squares problem min + * ||(B) - ( A )X|| ||(0) (lambda*I) ||_2 where LAMBDA is a scalar. If LAMBDA is not set, + * the system is solved without regularization. + * <p/> + * You can also set aTolerance and bTolerance. These cause LSMR to iterate until a certain backward + * error estimate is smaller than some quantity depending on ATOL and BTOL. Let RES = B - A*X be + * the residual vector for the current approximate solution X. If A*X = B seems to be consistent, + * LSMR terminates when NORM(RES) <= ATOL*NORM(A)*NORM(X) + BTOL*NORM(B). Otherwise, LSMR terminates + * when NORM(A'*RES) <= ATOL*NORM(A)*NORM(RES). If both tolerances are 1.0e-6 (say), the final + * NORM(RES) should be accurate to about 6 digits. (The final X will usually have fewer correct + * digits, depending on cond(A) and the size of LAMBDA.) + * <p/> + * The default value for ATOL and BTOL is 1e-6. + * <p/> + * Ideally, they should be estimates of the relative error in the entries of A and B respectively. + * For example, if the entries of A have 7 correct digits, set ATOL = 1e-7. This prevents the + * algorithm from doing unnecessary work beyond the uncertainty of the input data. + * <p/> + * You can also set conditionLimit. In that case, LSMR terminates if an estimate of cond(A) exceeds + * conditionLimit. For compatible systems Ax = b, conditionLimit could be as large as 1.0e+12 (say). + * For least-squares problems, conditionLimit should be less than 1.0e+8. If conditionLimit is not + * set, the default value is 1e+8. Maximum precision can be obtained by setting aTolerance = + * bTolerance = conditionLimit = 0, but the number of iterations may then be excessive. + * <p/> + * Setting iterationLimit causes LSMR to terminate if the number of iterations reaches + * iterationLimit. The default is iterationLimit = min(m,n). For ill-conditioned systems, a + * larger value of ITNLIM may be needed. + * <p/> + * Setting localSize causes LSMR to run with rerorthogonalization on the last localSize v_k's. + * (v-vectors generated by Golub-Kahan bidiagonalization) If localSize is not set, LSMR runs without + * reorthogonalization. A localSize > max(n,m) performs reorthogonalization on all v_k's. + * Reorthgonalizing only u_k or both u_k and v_k are not an option here. Details are discussed in + * the SIAM paper. + * <p/> + * getTerminationReason() gives the reason for termination. ISTOP = 0 means X=0 is a solution. = 1 + * means X is an approximate solution to A*X = B, according to ATOL and BTOL. = 2 means X + * approximately solves the least-squares problem according to ATOL. = 3 means COND(A) seems to be + * greater than CONLIM. = 4 is the same as 1 with ATOL = BTOL = EPS. = 5 is the same as 2 with ATOL + * = EPS. = 6 is the same as 3 with CONLIM = 1/EPS. = 7 means ITN reached ITNLIM before the other + * stopping conditions were satisfied. + * <p/> + * getIterationCount() gives ITN = the number of LSMR iterations. + * <p/> + * getResidualNorm() gives an estimate of the residual norm: NORMR = norm(B-A*X). + * <p/> + * getNormalEquationResidual() gives an estimate of the residual for the normal equation: NORMAR = + * NORM(A'*(B-A*X)). + * <p/> + * getANorm() gives an estimate of the Frobenius norm of A. + * <p/> + * getCondition() gives an estimate of the condition number of A. + * <p/> + * getXNorm() gives an estimate of NORM(X). + * <p/> + * LSMR uses an iterative method. For further information, see D. C.-L. Fong and M. A. Saunders + * LSMR: An iterative algorithm for least-square problems Draft of 03 Apr 2010, to be submitted to + * SISC. + * <p/> + * David Chin-lung Fong [email protected] Institute for Computational and Mathematical + * Engineering Stanford University + * <p/> + * Michael Saunders [email protected] Systems Optimization Laboratory Dept of + * MS&E, Stanford University. ----------------------------------------------------------------------- + */ +public final class LSMR { + + private static final Logger log = LoggerFactory.getLogger(LSMR.class); + + private final double lambda; + private int localSize; + private int iterationLimit; + private double conditionLimit; + private double bTolerance; + private double aTolerance; + private int localPointer; + private Vector[] localV; + private double residualNorm; + private double normalEquationResidual; + private double xNorm; + private int iteration; + private double normA; + private double condA; + + public int getIterationCount() { + return iteration; + } + + public double getResidualNorm() { + return residualNorm; + } + + public double getNormalEquationResidual() { + return normalEquationResidual; + } + + public double getANorm() { + return normA; + } + + public double getCondition() { + return condA; + } + + public double getXNorm() { + return xNorm; + } + + /** + * LSMR uses an iterative method to solve a linear system. For further information, see D. C.-L. + * Fong and M. A. Saunders LSMR: An iterative algorithm for least-square problems Draft of 03 Apr + * 2010, to be submitted to SISC. + * <p/> + * 08 Dec 2009: First release version of LSMR. 09 Apr 2010: Updated documentation and default + * parameters. 14 Apr 2010: Updated documentation. 03 Jun 2010: LSMR with local + * reorthogonalization (full reorthogonalization is also implemented) + * <p/> + * David Chin-lung Fong [email protected] Institute for Computational and + * Mathematical Engineering Stanford University + * <p/> + * Michael Saunders [email protected] Systems Optimization Laboratory Dept of + * MS&E, Stanford University. ----------------------------------------------------------------------- + */ + + public LSMR() { + // Set default parameters. + lambda = 0; + aTolerance = 1.0e-6; + bTolerance = 1.0e-6; + conditionLimit = 1.0e8; + iterationLimit = -1; + localSize = 0; + } + + public Vector solve(Matrix A, Vector b) { + /* + % Initialize. + + + hdg1 = ' itn x(1) norm r norm A''r'; + hdg2 = ' compatible LS norm A cond A'; + pfreq = 20; % print frequency (for repeating the heading) + pcount = 0; % print counter + + % Determine dimensions m and n, and + % form the first vectors u and v. + % These satisfy beta*u = b, alpha*v = A'u. + */ + log.debug(" itn x(1) norm r norm A'r"); + log.debug(" compatible LS norm A cond A"); + + Matrix transposedA = A.transpose(); + Vector u = b; + + double beta = u.norm(2); + if (beta > 0) { + u = u.divide(beta); + } + + Vector v = transposedA.times(u); + int m = A.numRows(); + int n = A.numCols(); + + int minDim = Math.min(m, n); + if (iterationLimit == -1) { + iterationLimit = minDim; + } + + if (log.isDebugEnabled()) { + log.debug("LSMR - Least-squares solution of Ax = b, based on Matlab Version 1.02, 14 Apr 2010, " + + "Mahout version {}", getClass().getPackage().getImplementationVersion()); + log.debug(String.format("The matrix A has %d rows and %d cols, lambda = %.4g, atol = %g, btol = %g", + m, n, lambda, aTolerance, bTolerance)); + } + + double alpha = v.norm(2); + if (alpha > 0) { + v.assign(Functions.div(alpha)); + } + + + // Initialization for local reorthogonalization + localPointer = 0; + + // Preallocate storage for storing the last few v_k. Since with + // orthogonal v_k's, Krylov subspace method would converge in not + // more iterations than the number of singular values, more + // space is not necessary. + localV = new Vector[Math.min(localSize, minDim)]; + boolean localOrtho = false; + if (localSize > 0) { + localOrtho = true; + localV[0] = v; + } + + + // Initialize variables for 1st iteration. + + iteration = 0; + double zetabar = alpha * beta; + double alphabar = alpha; + + Vector h = v; + Vector hbar = zeros(n); + Vector x = zeros(n); + + // Initialize variables for estimation of ||r||. + + double betadd = beta; + + // Initialize variables for estimation of ||A|| and cond(A) + + double aNorm = alpha * alpha; + + // Items for use in stopping rules. + double normb = beta; + + double ctol = 0; + if (conditionLimit > 0) { + ctol = 1 / conditionLimit; + } + residualNorm = beta; + + // Exit if b=0 or A'b = 0. + + normalEquationResidual = alpha * beta; + if (normalEquationResidual == 0) { + return x; + } + + // Heading for iteration log. + + + if (log.isDebugEnabled()) { + double test2 = alpha / beta; +// log.debug('{} {}', hdg1, hdg2); + log.debug("{} {}", iteration, x.get(0)); + log.debug("{} {}", residualNorm, normalEquationResidual); + double test1 = 1; + log.debug("{} {}", test1, test2); + } + + + //------------------------------------------------------------------ + // Main iteration loop. + //------------------------------------------------------------------ + double rho = 1; + double rhobar = 1; + double cbar = 1; + double sbar = 0; + double betad = 0; + double rhodold = 1; + double tautildeold = 0; + double thetatilde = 0; + double zeta = 0; + double d = 0; + double maxrbar = 0; + double minrbar = 1.0e+100; + StopCode stop = StopCode.CONTINUE; + while (iteration <= iterationLimit && stop == StopCode.CONTINUE) { + + iteration++; + + // Perform the next step of the bidiagonalization to obtain the + // next beta, u, alpha, v. These satisfy the relations + // beta*u = A*v - alpha*u, + // alpha*v = A'*u - beta*v. + + u = A.times(v).minus(u.times(alpha)); + beta = u.norm(2); + if (beta > 0) { + u.assign(Functions.div(beta)); + + // store data for local-reorthogonalization of V + if (localOrtho) { + localVEnqueue(v); + } + v = transposedA.times(u).minus(v.times(beta)); + // local-reorthogonalization of V + if (localOrtho) { + v = localVOrtho(v); + } + alpha = v.norm(2); + if (alpha > 0) { + v.assign(Functions.div(alpha)); + } + } + + // At this point, beta = beta_{k+1}, alpha = alpha_{k+1}. + + // Construct rotation Qhat_{k,2k+1}. + + double alphahat = Math.hypot(alphabar, lambda); + double chat = alphabar / alphahat; + double shat = lambda / alphahat; + + // Use a plane rotation (Q_i) to turn B_i to R_i + + double rhoold = rho; + rho = Math.hypot(alphahat, beta); + double c = alphahat / rho; + double s = beta / rho; + double thetanew = s * alpha; + alphabar = c * alpha; + + // Use a plane rotation (Qbar_i) to turn R_i^T to R_i^bar + + double rhobarold = rhobar; + double zetaold = zeta; + double thetabar = sbar * rho; + double rhotemp = cbar * rho; + rhobar = Math.hypot(cbar * rho, thetanew); + cbar = cbar * rho / rhobar; + sbar = thetanew / rhobar; + zeta = cbar * zetabar; + zetabar = -sbar * zetabar; + + + // Update h, h_hat, x. + + hbar = h.minus(hbar.times(thetabar * rho / (rhoold * rhobarold))); + + x.assign(hbar.times(zeta / (rho * rhobar)), Functions.PLUS); + h = v.minus(h.times(thetanew / rho)); + + // Estimate of ||r||. + + // Apply rotation Qhat_{k,2k+1}. + double betaacute = chat * betadd; + double betacheck = -shat * betadd; + + // Apply rotation Q_{k,k+1}. + double betahat = c * betaacute; + betadd = -s * betaacute; + + // Apply rotation Qtilde_{k-1}. + // betad = betad_{k-1} here. + + double thetatildeold = thetatilde; + double rhotildeold = Math.hypot(rhodold, thetabar); + double ctildeold = rhodold / rhotildeold; + double stildeold = thetabar / rhotildeold; + thetatilde = stildeold * rhobar; + rhodold = ctildeold * rhobar; + betad = -stildeold * betad + ctildeold * betahat; + + // betad = betad_k here. + // rhodold = rhod_k here. + + tautildeold = (zetaold - thetatildeold * tautildeold) / rhotildeold; + double taud = (zeta - thetatilde * tautildeold) / rhodold; + d += betacheck * betacheck; + residualNorm = Math.sqrt(d + (betad - taud) * (betad - taud) + betadd * betadd); + + // Estimate ||A||. + aNorm += beta * beta; + normA = Math.sqrt(aNorm); + aNorm += alpha * alpha; + + // Estimate cond(A). + maxrbar = Math.max(maxrbar, rhobarold); + if (iteration > 1) { + minrbar = Math.min(minrbar, rhobarold); + } + condA = Math.max(maxrbar, rhotemp) / Math.min(minrbar, rhotemp); + + // Test for convergence. + + // Compute norms for convergence testing. + normalEquationResidual = Math.abs(zetabar); + xNorm = x.norm(2); + + // Now use these norms to estimate certain other quantities, + // some of which will be small near a solution. + + double test1 = residualNorm / normb; + double test2 = normalEquationResidual / (normA * residualNorm); + double test3 = 1 / condA; + double t1 = test1 / (1 + normA * xNorm / normb); + double rtol = bTolerance + aTolerance * normA * xNorm / normb; + + // The following tests guard against extremely small values of + // atol, btol or ctol. (The user may have set any or all of + // the parameters atol, btol, conlim to 0.) + // The effect is equivalent to the normAl tests using + // atol = eps, btol = eps, conlim = 1/eps. + + if (iteration > iterationLimit) { + stop = StopCode.ITERATION_LIMIT; + } + if (1 + test3 <= 1) { + stop = StopCode.CONDITION_MACHINE_TOLERANCE; + } + if (1 + test2 <= 1) { + stop = StopCode.LEAST_SQUARE_CONVERGED_MACHINE_TOLERANCE; + } + if (1 + t1 <= 1) { + stop = StopCode.CONVERGED_MACHINE_TOLERANCE; + } + + // Allow for tolerances set by the user. + + if (test3 <= ctol) { + stop = StopCode.CONDITION; + } + if (test2 <= aTolerance) { + stop = StopCode.CONVERGED; + } + if (test1 <= rtol) { + stop = StopCode.TRIVIAL; + } + + // See if it is time to print something. + if (log.isDebugEnabled()) { + if ((n <= 40) || (iteration <= 10) || (iteration >= iterationLimit - 10) || ((iteration % 10) == 0) + || (test3 <= 1.1 * ctol) || (test2 <= 1.1 * aTolerance) || (test1 <= 1.1 * rtol) + || (stop != StopCode.CONTINUE)) { + statusDump(x, normA, condA, test1, test2); + } + } + } // iteration loop + + // Print the stopping condition. + log.debug("Finished: {}", stop.getMessage()); + + return x; + /* + + + if show + fprintf('\n\nLSMR finished') + fprintf('\n%s', msg(istop+1,:)) + fprintf('\nistop =%8g normr =%8.1e' , istop, normr ) + fprintf(' normA =%8.1e normAr =%8.1e', normA, normAr) + fprintf('\nitn =%8g condA =%8.1e' , itn , condA ) + fprintf(' normx =%8.1e\n', normx) + end + */ + } + + private void statusDump(Vector x, double normA, double condA, double test1, double test2) { + log.debug("{} {}", residualNorm, normalEquationResidual); + log.debug("{} {}", iteration, x.get(0)); + log.debug("{} {}", test1, test2); + log.debug("{} {}", normA, condA); + } + + private static Vector zeros(int n) { + return new DenseVector(n); + } + + //----------------------------------------------------------------------- + // stores v into the circular buffer localV + //----------------------------------------------------------------------- + + private void localVEnqueue(Vector v) { + if (localV.length > 0) { + localV[localPointer] = v; + localPointer = (localPointer + 1) % localV.length; + } + } + + //----------------------------------------------------------------------- + // Perform local reorthogonalization of V + //----------------------------------------------------------------------- + + private Vector localVOrtho(Vector v) { + for (Vector old : localV) { + if (old != null) { + double x = v.dot(old); + v = v.minus(old.times(x)); + } + } + return v; + } + + private enum StopCode { + CONTINUE("Not done"), + TRIVIAL("The exact solution is x = 0"), + CONVERGED("Ax - b is small enough, given atol, btol"), + LEAST_SQUARE_CONVERGED("The least-squares solution is good enough, given atol"), + CONDITION("The estimate of cond(Abar) has exceeded condition limit"), + CONVERGED_MACHINE_TOLERANCE("Ax - b is small enough for this machine"), + LEAST_SQUARE_CONVERGED_MACHINE_TOLERANCE("The least-squares solution is good enough for this machine"), + CONDITION_MACHINE_TOLERANCE("Cond(Abar) seems to be too large for this machine"), + ITERATION_LIMIT("The iteration limit has been reached"); + + private final String message; + + StopCode(String message) { + this.message = message; + } + + public String getMessage() { + return message; + } + } + + public void setAtolerance(double aTolerance) { + this.aTolerance = aTolerance; + } + + public void setBtolerance(double bTolerance) { + this.bTolerance = bTolerance; + } + + public void setConditionLimit(double conditionLimit) { + this.conditionLimit = conditionLimit; + } + + public void setIterationLimit(int iterationLimit) { + this.iterationLimit = iterationLimit; + } + + public void setLocalSize(int localSize) { + this.localSize = localSize; + } + + public double getLambda() { + return lambda; + } + + public double getAtolerance() { + return aTolerance; + } + + public double getBtolerance() { + return bTolerance; + } +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/solver/Preconditioner.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/solver/Preconditioner.java b/core/src/main/java/org/apache/mahout/math/solver/Preconditioner.java new file mode 100644 index 0000000..91528fc --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/solver/Preconditioner.java @@ -0,0 +1,36 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.solver; + +import org.apache.mahout.math.Vector; + +/** + * Interface for defining preconditioners used for improving the performance and/or stability of linear + * system solvers. + */ +public interface Preconditioner { + + /** + * Preconditions the specified vector. + * + * @param v The vector to precondition. + * @return The preconditioned vector. + */ + Vector precondition(Vector v); + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/ssvd/SequentialBigSvd.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/ssvd/SequentialBigSvd.java b/core/src/main/java/org/apache/mahout/math/ssvd/SequentialBigSvd.java new file mode 100644 index 0000000..46354da --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/ssvd/SequentialBigSvd.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.ssvd; + +import org.apache.mahout.math.CholeskyDecomposition; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.RandomTrinaryMatrix; +import org.apache.mahout.math.SingularValueDecomposition; +import org.apache.mahout.math.Vector; + +/** + * Implements an in-memory version of stochastic projection based SVD. See SequentialOutOfCoreSvd + * for algorithm notes. + */ +public class SequentialBigSvd { + private final Matrix y; + private final CholeskyDecomposition cd1; + private final CholeskyDecomposition cd2; + private final SingularValueDecomposition svd; + private final Matrix b; + + + public SequentialBigSvd(Matrix A, int p) { + // Y = A * \Omega + y = A.times(new RandomTrinaryMatrix(A.columnSize(), p)); + + // R'R = Y' Y + cd1 = new CholeskyDecomposition(y.transpose().times(y)); + + // B = Q" A = (Y R^{-1} )' A + b = cd1.solveRight(y).transpose().times(A); + + // L L' = B B' + cd2 = new CholeskyDecomposition(b.times(b.transpose())); + + // U_0 D V_0' = L + svd = new SingularValueDecomposition(cd2.getL()); + } + + public Vector getSingularValues() { + return new DenseVector(svd.getSingularValues()); + } + + public Matrix getU() { + // U = (Y inv(R)) U_0 + return cd1.solveRight(y).times(svd.getU()); + } + + public Matrix getV() { + // V = (B' inv(L')) V_0 + return cd2.solveRight(b.transpose()).times(svd.getV()); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/stats/LogLikelihood.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/stats/LogLikelihood.java b/core/src/main/java/org/apache/mahout/math/stats/LogLikelihood.java new file mode 100644 index 0000000..d2c8434 --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/stats/LogLikelihood.java @@ -0,0 +1,220 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.stats; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Multiset; +import com.google.common.collect.Ordering; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.PriorityQueue; +import java.util.Queue; + +/** + * Utility methods for working with log-likelihood + */ +public final class LogLikelihood { + + private LogLikelihood() { + } + + /** + * Calculates the unnormalized Shannon entropy. This is + * + * -sum x_i log x_i / N = -N sum x_i/N log x_i/N + * + * where N = sum x_i + * + * If the x's sum to 1, then this is the same as the normal + * expression. Leaving this un-normalized makes working with + * counts and computing the LLR easier. + * + * @return The entropy value for the elements + */ + public static double entropy(long... elements) { + long sum = 0; + double result = 0.0; + for (long element : elements) { + Preconditions.checkArgument(element >= 0); + result += xLogX(element); + sum += element; + } + return xLogX(sum) - result; + } + + private static double xLogX(long x) { + return x == 0 ? 0.0 : x * Math.log(x); + } + + /** + * Merely an optimization for the common two argument case of {@link #entropy(long...)} + * @see #logLikelihoodRatio(long, long, long, long) + */ + private static double entropy(long a, long b) { + return xLogX(a + b) - xLogX(a) - xLogX(b); + } + + /** + * Merely an optimization for the common four argument case of {@link #entropy(long...)} + * @see #logLikelihoodRatio(long, long, long, long) + */ + private static double entropy(long a, long b, long c, long d) { + return xLogX(a + b + c + d) - xLogX(a) - xLogX(b) - xLogX(c) - xLogX(d); + } + + /** + * Calculates the Raw Log-likelihood ratio for two events, call them A and B. Then we have: + * <p/> + * <table border="1" cellpadding="5" cellspacing="0"> + * <tbody><tr><td> </td><td>Event A</td><td>Everything but A</td></tr> + * <tr><td>Event B</td><td>A and B together (k_11)</td><td>B, but not A (k_12)</td></tr> + * <tr><td>Everything but B</td><td>A without B (k_21)</td><td>Neither A nor B (k_22)</td></tr></tbody> + * </table> + * + * @param k11 The number of times the two events occurred together + * @param k12 The number of times the second event occurred WITHOUT the first event + * @param k21 The number of times the first event occurred WITHOUT the second event + * @param k22 The number of times something else occurred (i.e. was neither of these events + * @return The raw log-likelihood ratio + * + * <p/> + * Credit to http://tdunning.blogspot.com/2008/03/surprise-and-coincidence.html for the table and the descriptions. + */ + public static double logLikelihoodRatio(long k11, long k12, long k21, long k22) { + Preconditions.checkArgument(k11 >= 0 && k12 >= 0 && k21 >= 0 && k22 >= 0); + // note that we have counts here, not probabilities, and that the entropy is not normalized. + double rowEntropy = entropy(k11 + k12, k21 + k22); + double columnEntropy = entropy(k11 + k21, k12 + k22); + double matrixEntropy = entropy(k11, k12, k21, k22); + if (rowEntropy + columnEntropy < matrixEntropy) { + // round off error + return 0.0; + } + return 2.0 * (rowEntropy + columnEntropy - matrixEntropy); + } + + /** + * Calculates the root log-likelihood ratio for two events. + * See {@link #logLikelihoodRatio(long, long, long, long)}. + + * @param k11 The number of times the two events occurred together + * @param k12 The number of times the second event occurred WITHOUT the first event + * @param k21 The number of times the first event occurred WITHOUT the second event + * @param k22 The number of times something else occurred (i.e. was neither of these events + * @return The root log-likelihood ratio + * + * <p/> + * There is some more discussion here: http://s.apache.org/CGL + * + * And see the response to Wataru's comment here: + * http://tdunning.blogspot.com/2008/03/surprise-and-coincidence.html + */ + public static double rootLogLikelihoodRatio(long k11, long k12, long k21, long k22) { + double llr = logLikelihoodRatio(k11, k12, k21, k22); + double sqrt = Math.sqrt(llr); + if ((double) k11 / (k11 + k12) < (double) k21 / (k21 + k22)) { + sqrt = -sqrt; + } + return sqrt; + } + + /** + * Compares two sets of counts to see which items are interestingly over-represented in the first + * set. + * @param a The first counts. + * @param b The reference counts. + * @param maxReturn The maximum number of items to return. Use maxReturn >= a.elementSet.size() to return all + * scores above the threshold. + * @param threshold The minimum score for items to be returned. Use 0 to return all items more common + * in a than b. Use -Double.MAX_VALUE (not Double.MIN_VALUE !) to not use a threshold. + * @return A list of scored items with their scores. + */ + public static <T> List<ScoredItem<T>> compareFrequencies(Multiset<T> a, + Multiset<T> b, + int maxReturn, + double threshold) { + int totalA = a.size(); + int totalB = b.size(); + + Ordering<ScoredItem<T>> byScoreAscending = new Ordering<ScoredItem<T>>() { + @Override + public int compare(ScoredItem<T> tScoredItem, ScoredItem<T> tScoredItem1) { + return Double.compare(tScoredItem.score, tScoredItem1.score); + } + }; + Queue<ScoredItem<T>> best = new PriorityQueue<>(maxReturn + 1, byScoreAscending); + + for (T t : a.elementSet()) { + compareAndAdd(a, b, maxReturn, threshold, totalA, totalB, best, t); + } + + // if threshold >= 0 we only iterate through a because anything not there can't be as or more common than in b. + if (threshold < 0) { + for (T t : b.elementSet()) { + // only items missing from a need be scored + if (a.count(t) == 0) { + compareAndAdd(a, b, maxReturn, threshold, totalA, totalB, best, t); + } + } + } + + List<ScoredItem<T>> r = new ArrayList<>(best); + Collections.sort(r, byScoreAscending.reverse()); + return r; + } + + private static <T> void compareAndAdd(Multiset<T> a, + Multiset<T> b, + int maxReturn, + double threshold, + int totalA, + int totalB, + Queue<ScoredItem<T>> best, + T t) { + int kA = a.count(t); + int kB = b.count(t); + double score = rootLogLikelihoodRatio(kA, totalA - kA, kB, totalB - kB); + if (score >= threshold) { + ScoredItem<T> x = new ScoredItem<>(t, score); + best.add(x); + while (best.size() > maxReturn) { + best.poll(); + } + } + } + + public static final class ScoredItem<T> { + private final T item; + private final double score; + + public ScoredItem(T item, double score) { + this.item = item; + this.score = score; + } + + public double getScore() { + return score; + } + + public T getItem() { + return item; + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/stats/OnlineExponentialAverage.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/stats/OnlineExponentialAverage.java b/core/src/main/java/org/apache/mahout/math/stats/OnlineExponentialAverage.java new file mode 100644 index 0000000..54a0ec7 --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/stats/OnlineExponentialAverage.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.stats; + +/** + * Computes an online average that is exponentially weighted toward recent time-embedded samples. + */ +public class OnlineExponentialAverage { + + private final double alpha; + private double lastT; + private double s; + private double w; + private double t; + + /** + * Creates an averager that has a specified time constant for discounting old data. The time + * constant, alpha, is the time at which an older sample is discounted to 1/e relative to current + * data. Roughly speaking, data that is more than 3*alpha old doesn't matter any more and data + * that is more recent than alpha/3 is about as important as current data. + * + * See http://tdunning.blogspot.com/2011/03/exponential-weighted-averages-with.html for a + * derivation. See http://tdunning.blogspot.com/2011/03/exponentially-weighted-averaging-for.html + * for the rate method. + * + * @param alpha The time constant for discounting old data and state. + */ + public OnlineExponentialAverage(double alpha) { + this.alpha = alpha; + } + + public void add(double t, double x) { + double pi = Math.exp(-(t - lastT) / alpha); + s = x + pi * s; + w = 1.0 + pi * w; + this.t = t - lastT + pi * this.t; + lastT = t; + } + + public double mean() { + return s / w; + } + + public double meanRate() { + return s / t; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/stats/OnlineSummarizer.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/stats/OnlineSummarizer.java b/core/src/main/java/org/apache/mahout/math/stats/OnlineSummarizer.java new file mode 100644 index 0000000..793aa71 --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/stats/OnlineSummarizer.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.stats; + +//import com.tdunning.math.stats.TDigest; + +/** + * Computes on-line estimates of mean, variance and all five quartiles (notably including the + * median). Since this is done in a completely incremental fashion (that is what is meant by + * on-line) estimates are available at any time and the amount of memory used is constant. Somewhat + * surprisingly, the quantile estimates are about as good as you would get if you actually kept all + * of the samples. + * <p/> + * The method used for mean and variance is Welford's method. See + * <p/> + * http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#On-line_algorithm + * <p/> + * The method used for computing the quartiles is a simplified form of the stochastic approximation + * method described in the article "Incremental Quantile Estimation for Massive Tracking" by Chen, + * Lambert and Pinheiro + * <p/> + * See + * <p/> + * http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.105.1580 + */ +public class OnlineSummarizer { + +// private TDigest quantiles = TDigest.createDigest(100.0); + + // mean and variance estimates + private double mean; + private double variance; + + // number of samples seen so far + private int n; + + public void add(double sample) { + n++; + double oldMean = mean; + mean += (sample - mean) / n; + double diff = (sample - mean) * (sample - oldMean); + variance += (diff - variance) / n; + +// quantiles.add(sample); + } + + public int getCount() { + return n; + } + + public double getMean() { + return mean; + } + + public double getSD() { + return Math.sqrt(variance); + } + +// public double getMin() { +// return getQuartile(0); +// } +// +// public double getMax() { +// return getQuartile(4); +// } + +// public double getQuartile(int i) { +// return quantiles.quantile(0.25 * i); +// } +// +// public double quantile(double q) { +// return quantiles.quantile(q); +// } + +// public double getMedian() { +// return getQuartile(2); +// } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/test/java/org/apache/mahout/math/QRDecompositionTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/mahout/math/QRDecompositionTest.java b/core/src/test/java/org/apache/mahout/math/QRDecompositionTest.java new file mode 100644 index 0000000..b85458e --- /dev/null +++ b/core/src/test/java/org/apache/mahout/math/QRDecompositionTest.java @@ -0,0 +1,280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math; + +import java.util.List; + +import com.google.common.collect.Lists; +import org.apache.mahout.math.function.DoubleDoubleFunction; +import org.apache.mahout.math.function.Functions; +import org.apache.mahout.math.stats.OnlineSummarizer; +import org.junit.Ignore; +import org.junit.Test; + +public final class QRDecompositionTest extends MahoutTestCase { + @Test + public void randomMatrix() { + Matrix a = new DenseMatrix(60, 60).assign(Functions.random()); + QRDecomposition qr = new QRDecomposition(a); + + // how close is Q to actually being orthornormal? + double maxIdent = qr.getQ().transpose().times(qr.getQ()).viewDiagonal().assign(Functions.plus(-1)).norm(1); + assertEquals(0, maxIdent, 1.0e-13); + + // how close is Q R to the original value of A? + Matrix z = qr.getQ().times(qr.getR()).minus(a); + double maxError = z.aggregate(Functions.MIN, Functions.ABS); + assertEquals(0, maxError, 1.0e-13); + } + + @Test + public void rank1() { + Matrix x = new DenseMatrix(3, 3); + x.viewRow(0).assign(new double[]{1, 2, 3}); + x.viewRow(1).assign(new double[]{2, 4, 6}); + x.viewRow(2).assign(new double[]{3, 6, 9}); + + QRDecomposition qr = new QRDecomposition(x); + assertFalse(qr.hasFullRank()); + assertEquals(0, new DenseVector(new double[]{3.741657, 7.483315, 11.22497}).aggregate(qr.getR().viewRow(0), Functions.PLUS, new DoubleDoubleFunction() { + @Override + public double apply(double arg1, double arg2) { + return Math.abs(arg1) - Math.abs(arg2); + } + }), 1.0e-5); + } + + @Test + public void fullRankTall() { + Matrix x = matrix(); + QRDecomposition qr = new QRDecomposition(x); + assertTrue(qr.hasFullRank()); + Matrix rRef = reshape(new double[]{ + -2.99129686445138, 0, 0, 0, 0, + -0.0282260628674372, -2.38850244769059, 0, 0, 0, + 0.733739310355871, 1.48042000631646, 2.29051263117895, 0, 0, + -0.0394082168269326, 0.282829484207801, -0.00438521041803086, -2.90823198084203, 0, + 0.923669647838536, 1.76679276072492, 0.637690104222683, -0.225890909498753, -1.35732293800944}, + 5, 5); + Matrix r = qr.getR(); + + // check identity down to sign + assertEquals(0, r.clone().assign(Functions.ABS).minus(rRef.clone().assign(Functions.ABS)).aggregate(Functions.PLUS, Functions.IDENTITY), 1.0e-12); + + Matrix qRef = reshape(new double[]{ + -0.165178287646573, 0.0510035857637869, 0.13985915987379, -0.120173729496501, + -0.453198314345324, 0.644400679630493, -0.503117990820608, 0.24968739845381, + 0.323968339146224, -0.465266080134262, 0.276508948773268, -0.687909700644343, + 0.0544048888907195, -0.0166677718378263, 0.171309755790717, 0.310339001630029, + 0.674790532821663, 0.0058166082200493, -0.381707516461884, 0.300504956413142, + -0.105751091334003, 0.410450870871096, 0.31113446615821, 0.179338172684956, + 0.361951807617901, 0.763921725548796, 0.380327892605634, -0.287274944594054, + 0.0311604042556675, 0.0386096858143961, 0.0387156960650472, -0.232975755728917, + 0.0358178276684149, 0.173105775703199, 0.327321867815603, 0.328671945345279, + -0.36015879836344, -0.444261660176044, 0.09438499563253, 0.646216148583769 + }, 8, 5); + + printMatrix("qRef", qRef); + + Matrix q = qr.getQ(); + printMatrix("q", q); + + assertEquals(0, q.clone().assign(Functions.ABS).minus(qRef.clone().assign(Functions.ABS)).aggregate(Functions.PLUS, Functions.IDENTITY), 1.0e-12); + + Matrix x1 = qr.solve(reshape(new double[]{ + -0.0178247686747641, 0.68631714634098, -0.335464858468858, 1.50249941751569, + -0.669901640772149, -0.977025038942455, -1.18857546169856, -1.24792900492054 + }, 8, 1)); + Matrix xref = reshape(new double[]{ + -0.0127440093664874, 0.655825940180799, -0.100755415991702, -0.0349559562697406, + -0.190744297762028 + }, 5, 1); + + printMatrix("x1", x1); + printMatrix("xref", xref); + + assertEquals(xref, x1, 1.0e-8); + } + + @Test + public void fullRankWide() { + Matrix x = matrix().transpose(); + QRDecomposition qr = new QRDecomposition(x); + assertTrue(qr.hasFullRank()); + Matrix rActual = qr.getR(); + + Matrix rRef = reshape(new double[]{ + -2.42812464965842, 0, 0, 0, 0, + 0.303587286111356, -2.91663643494775, 0, 0, 0, + -0.201812474153156, -0.765485720168378, 1.09989373598954, 0, 0, + 1.47980701097885, -0.637545820524326, -1.55519859337935, 0.844655127991726, 0, + 0.0248883129453161, 0.00115010570270549, -0.236340588891252, -0.092924118200147, 1.42910099545547, + -1.1678472412429, 0.531245845248056, 0.351978196071514, -1.03241474816555, -2.20223861735426, + -0.887809959067632, 0.189731251982918, -0.504321849233586, 0.490484123999836, 1.21266692336743, + -0.633888169775463, 1.04738559065986, 0.284041239547031, 0.578183510077156, -0.942314870832456 + }, 5, 8); + printMatrix("rRef", rRef); + printMatrix("rActual", rActual); + assertEquals(0, rActual.clone().assign(Functions.ABS).minus(rRef.clone().assign(Functions.ABS)).aggregate(Functions.PLUS, Functions.IDENTITY), 1.0e-12); +// assertEquals(rRef, rActual, 1.0e-8); + + Matrix qRef = reshape(new double[]{ + -0.203489262374627, 0.316761677948356, -0.784155643293468, 0.394321494579, -0.29641971170211, + 0.0311283614803723, -0.34755265020736, 0.137138511478328, 0.848579887681972, 0.373287266507375, + -0.39603700561249, -0.787812566647329, -0.377864833067864, -0.275080943427399, 0.0636764674878229, + 0.0763976893309043, -0.318551137554327, 0.286407036668598, 0.206004127289883, -0.876482672226889, + 0.89159476695423, -0.238213616975551, -0.376141107880836, -0.0794701657055114, 0.0227025098210165 + }, 5, 5); + + Matrix q = qr.getQ(); + + printMatrix("qRef", qRef); + printMatrix("q", q); + + assertEquals(0, q.clone().assign(Functions.ABS).minus(qRef.clone().assign(Functions.ABS)).aggregate(Functions.PLUS, Functions.IDENTITY), 1.0e-12); +// assertEquals(qRef, q, 1.0e-8); + + Matrix x1 = qr.solve(b()); + Matrix xRef = reshape(new double[]{ + -0.182580239668147, -0.437233627652114, 0.138787653097464, 0.672934739896228, -0.131420217069083, 0, 0, 0 + }, 8, 1); + + printMatrix("xRef", xRef); + printMatrix("x", x1); + assertEquals(xRef, x1, 1.0e-8); + + assertEquals(x, qr.getQ().times(qr.getR()), 1.0e-15); + } + + // TODO: the speedup constant should be checked and oddly, the times don't increase as the counts increase + @Ignore + public void fasterThanBefore() { + + OnlineSummarizer s1 = new OnlineSummarizer(); + OnlineSummarizer s2 = new OnlineSummarizer(); + + Matrix a = new DenseMatrix(60, 60).assign(Functions.random()); + + decompositionSpeedCheck(new Decomposer() { + @Override + public QR decompose(Matrix a) { + return new QRDecomposition(a); + } + }, s1, a, "new"); + + decompositionSpeedCheck(new Decomposer() { + @Override + public QR decompose(Matrix a) { + return new OldQRDecomposition(a); + } + }, s2, a, "old"); + + // should be much more than twice as fast. (originally was on s2.getMedian, but we factored out com.tdunning ) + System.out.printf("Speedup is about %.1f times\n", s2.getMean() / s1.getMean()); + assertTrue(s1.getMean() < 0.5 * s2.getMean()); + } + + private interface Decomposer { + QR decompose(Matrix a); + } + + private static void decompositionSpeedCheck(Decomposer qrf, OnlineSummarizer s1, Matrix a, String label) { + int n = 0; + List<Integer> counts = Lists.newArrayList(10, 20, 50, 100, 200, 500); + for (int k : counts) { + double warmup = 0; + double other = 0; + + n += k; + for (int i = 0; i < k; i++) { + QR qr = qrf.decompose(a); + warmup = Math.max(warmup, qr.getQ().transpose().times(qr.getQ()).viewDiagonal().assign(Functions.plus(-1)).norm(1)); + Matrix z = qr.getQ().times(qr.getR()).minus(a); + other = Math.max(other, z.aggregate(Functions.MIN, Functions.ABS)); + } + + double maxIdent = 0; + double maxError = 0; + + long t0 = System.nanoTime(); + for (int i = 0; i < n; i++) { + QR qr = qrf.decompose(a); + + maxIdent = Math.max(maxIdent, qr.getQ().transpose().times(qr.getQ()).viewDiagonal().assign(Functions.plus(-1)).norm(1)); + Matrix z = qr.getQ().times(qr.getR()).minus(a); + maxError = Math.max(maxError, z.aggregate(Functions.MIN, Functions.ABS)); + } + long t1 = System.nanoTime(); + if (k > 100) { + s1.add(t1 - t0); + } + System.out.printf("%s %d\t%.1f\t%g\t%g\t%g\n", label, n, (t1 - t0) / 1.0e3 / n, maxIdent, maxError, warmup); + } + } + + private static void assertEquals(Matrix ref, Matrix actual, double epsilon) { + assertEquals(0, ref.minus(actual).aggregate(Functions.MAX, Functions.ABS), epsilon); + } + + private static void printMatrix(String name, Matrix m) { + int rows = m.numRows(); + int columns = m.numCols(); + System.out.printf("%s - %d x %d\n", name, rows, columns); + for (int i = 0; i < rows; i++) { + for (int j = 0; j < columns; j++) { + System.out.printf("%10.5f", m.get(i, j)); + } + System.out.printf("\n"); + } + System.out.printf("\n"); + System.out.printf("\n"); + } + + private static Matrix matrix() { + double[] values = { + 0.494097293912641, -0.152566866170993, -0.418360266395271, 0.359475300232312, + 1.35565069667582, -1.92759373242903, 1.50497526839076, -0.746889132087904, + -0.769136838293565, 1.10984954080986, -0.664389974392489, 1.6464660350229, + -0.11715420616969, 0.0216221197371269, -0.394972730980765, -0.748293157213142, + 1.90402764664962, -0.638042862848559, -0.362336344669668, -0.418261074380526, + -0.494211543128429, 1.38828971158414, 0.597110366867923, 1.05341387608687, + -0.957461740877418, -2.35528802598249, -1.03171458944128, 0.644319090271635, + -0.0569108993041965, -0.14419465550881, -0.0456801828174936, + 0.754694392571835, 0.719744008628535, -1.17873249802301, -0.155887528905918, + -1.5159868405466, 0.0918931582603128, 1.42179027361583, -0.100495054250176, + 0.0687986548485584 + }; + return reshape(values, 8, 5); + } + + private static Matrix reshape(double[] values, int rows, int columns) { + Matrix m = new DenseMatrix(rows, columns); + int i = 0; + for (double v : values) { + m.set(i % rows, i / rows, v); + i++; + } + return m; + } + + private static Matrix b() { + return reshape(new double[] + {-0.0178247686747641, 0.68631714634098, -0.335464858468858, 1.50249941751569, -0.669901640772149}, 5, 1); + } +} + http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/test/java/org/apache/mahout/math/TestSingularValueDecomposition.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/mahout/math/TestSingularValueDecomposition.java b/core/src/test/java/org/apache/mahout/math/TestSingularValueDecomposition.java new file mode 100644 index 0000000..c9e4026 --- /dev/null +++ b/core/src/test/java/org/apache/mahout/math/TestSingularValueDecomposition.java @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math; + +import com.google.common.base.Charsets; +import com.google.common.base.Splitter; +import com.google.common.collect.Iterables; +import com.google.common.io.Resources; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.function.Functions; +import org.junit.Test; + +import java.io.IOException; +import java.util.List; +import java.util.Random; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; + +//To launch this test only : mvn test -Dtest=org.apache.mahout.math.TestSingularValueDecomposition +public final class TestSingularValueDecomposition extends MahoutTestCase { + + private final double[][] testSquare = { + { 24.0 / 25.0, 43.0 / 25.0 }, + { 57.0 / 25.0, 24.0 / 25.0 } + }; + + private final double[][] testNonSquare = { + { -540.0 / 625.0, 963.0 / 625.0, -216.0 / 625.0 }, + { -1730.0 / 625.0, -744.0 / 625.0, 1008.0 / 625.0 }, + { -720.0 / 625.0, 1284.0 / 625.0, -288.0 / 625.0 }, + { -360.0 / 625.0, 192.0 / 625.0, 1756.0 / 625.0 }, + }; + + private static final double NORM_TOLERANCE = 10.0e-14; + + @Test + public void testMoreRows() { + double[] singularValues = { 123.456, 2.3, 1.001, 0.999 }; + int rows = singularValues.length + 2; + int columns = singularValues.length; + Random r = RandomUtils.getRandom(); + SingularValueDecomposition svd = + new SingularValueDecomposition(createTestMatrix(r, rows, columns, singularValues)); + double[] computedSV = svd.getSingularValues(); + assertEquals(singularValues.length, computedSV.length); + for (int i = 0; i < singularValues.length; ++i) { + assertEquals(singularValues[i], computedSV[i], 1.0e-10); + } + } + + @Test + public void testMoreColumns() { + double[] singularValues = { 123.456, 2.3, 1.001, 0.999 }; + int rows = singularValues.length; + int columns = singularValues.length + 2; + Random r = RandomUtils.getRandom(); + SingularValueDecomposition svd = + new SingularValueDecomposition(createTestMatrix(r, rows, columns, singularValues)); + double[] computedSV = svd.getSingularValues(); + assertEquals(singularValues.length, computedSV.length); + for (int i = 0; i < singularValues.length; ++i) { + assertEquals(singularValues[i], computedSV[i], 1.0e-10); + } + } + + /** test dimensions */ + @Test + public void testDimensions() { + Matrix matrix = new DenseMatrix(testSquare); + int m = matrix.numRows(); + int n = matrix.numCols(); + SingularValueDecomposition svd = new SingularValueDecomposition(matrix); + assertEquals(m, svd.getU().numRows()); + assertEquals(m, svd.getU().numCols()); + assertEquals(m, svd.getS().numCols()); + assertEquals(n, svd.getS().numCols()); + assertEquals(n, svd.getV().numRows()); + assertEquals(n, svd.getV().numCols()); + + } + + /** Test based on a dimension 4 Hadamard matrix. */ + // getCovariance to be implemented + @Test + public void testHadamard() { + Matrix matrix = new DenseMatrix(new double[][] { + {15.0 / 2.0, 5.0 / 2.0, 9.0 / 2.0, 3.0 / 2.0 }, + { 5.0 / 2.0, 15.0 / 2.0, 3.0 / 2.0, 9.0 / 2.0 }, + { 9.0 / 2.0, 3.0 / 2.0, 15.0 / 2.0, 5.0 / 2.0 }, + { 3.0 / 2.0, 9.0 / 2.0, 5.0 / 2.0, 15.0 / 2.0 } + }); + SingularValueDecomposition svd = new SingularValueDecomposition(matrix); + assertEquals(16.0, svd.getSingularValues()[0], 1.0e-14); + assertEquals( 8.0, svd.getSingularValues()[1], 1.0e-14); + assertEquals( 4.0, svd.getSingularValues()[2], 1.0e-14); + assertEquals( 2.0, svd.getSingularValues()[3], 1.0e-14); + + Matrix fullCovariance = new DenseMatrix(new double[][] { + { 85.0 / 1024, -51.0 / 1024, -75.0 / 1024, 45.0 / 1024 }, + { -51.0 / 1024, 85.0 / 1024, 45.0 / 1024, -75.0 / 1024 }, + { -75.0 / 1024, 45.0 / 1024, 85.0 / 1024, -51.0 / 1024 }, + { 45.0 / 1024, -75.0 / 1024, -51.0 / 1024, 85.0 / 1024 } + }); + + assertEquals(0.0,Algebra.getNorm(fullCovariance.minus(svd.getCovariance(0.0))),1.0e-14); + + + Matrix halfCovariance = new DenseMatrix(new double[][] { + { 5.0 / 1024, -3.0 / 1024, 5.0 / 1024, -3.0 / 1024 }, + { -3.0 / 1024, 5.0 / 1024, -3.0 / 1024, 5.0 / 1024 }, + { 5.0 / 1024, -3.0 / 1024, 5.0 / 1024, -3.0 / 1024 }, + { -3.0 / 1024, 5.0 / 1024, -3.0 / 1024, 5.0 / 1024 } + }); + assertEquals(0.0,Algebra.getNorm(halfCovariance.minus(svd.getCovariance(6.0))),1.0e-14); + + } + + /** test A = USVt */ + @Test + public void testAEqualUSVt() { + checkAEqualUSVt(new DenseMatrix(testSquare)); + checkAEqualUSVt(new DenseMatrix(testNonSquare)); + checkAEqualUSVt(new DenseMatrix(testNonSquare).transpose()); + } + + public static void checkAEqualUSVt(Matrix matrix) { + SingularValueDecomposition svd = new SingularValueDecomposition(matrix); + Matrix u = svd.getU(); + Matrix s = svd.getS(); + Matrix v = svd.getV(); + + //pad with 0, to be able to check some properties if some singular values are equal to 0 + if (s.numRows()<matrix.numRows()) { + + Matrix sp = new DenseMatrix(s.numRows()+1,s.numCols()); + Matrix up = new DenseMatrix(u.numRows(),u.numCols()+1); + + + for (int i = 0; i < u.numRows(); i++) { + for (int j = 0; j < u.numCols(); j++) { + up.set(i, j, u.get(i, j)); + } + } + + for (int i = 0; i < s.numRows(); i++) { + for (int j = 0; j < s.numCols(); j++) { + sp.set(i, j, s.get(i, j)); + } + } + + u = up; + s = sp; + } + + double norm = Algebra.getNorm(u.times(s).times(v.transpose()).minus(matrix)); + assertEquals(0, norm, NORM_TOLERANCE); + + } + + /** test that U is orthogonal */ + @Test + public void testUOrthogonal() { + checkOrthogonal(new SingularValueDecomposition(new DenseMatrix(testSquare)).getU()); + checkOrthogonal(new SingularValueDecomposition(new DenseMatrix(testNonSquare)).getU()); + checkOrthogonal(new SingularValueDecomposition(new DenseMatrix(testNonSquare).transpose()).getU()); + } + + /** test that V is orthogonal */ + @Test + public void testVOrthogonal() { + checkOrthogonal(new SingularValueDecomposition(new DenseMatrix(testSquare)).getV()); + checkOrthogonal(new SingularValueDecomposition(new DenseMatrix(testNonSquare)).getV()); + checkOrthogonal(new SingularValueDecomposition(new DenseMatrix(testNonSquare).transpose()).getV()); + } + + public static void checkOrthogonal(Matrix m) { + Matrix mTm = m.transpose().times(m); + Matrix id = new DenseMatrix(mTm.numRows(),mTm.numRows()); + for (int i = 0; i < mTm.numRows(); i++) { + id.set(i, i, 1); + } + assertEquals(0, Algebra.getNorm(mTm.minus(id)), NORM_TOLERANCE); + } + + /** test matrices values */ + @Test + public void testMatricesValues1() { + SingularValueDecomposition svd = + new SingularValueDecomposition(new DenseMatrix(testSquare)); + Matrix uRef = new DenseMatrix(new double[][] { + { 3.0 / 5.0, 4.0 / 5.0 }, + { 4.0 / 5.0, -3.0 / 5.0 } + }); + Matrix sRef = new DenseMatrix(new double[][] { + { 3.0, 0.0 }, + { 0.0, 1.0 } + }); + Matrix vRef = new DenseMatrix(new double[][] { + { 4.0 / 5.0, -3.0 / 5.0 }, + { 3.0 / 5.0, 4.0 / 5.0 } + }); + + // check values against known references + Matrix u = svd.getU(); + + assertEquals(0, Algebra.getNorm(u.minus(uRef)), NORM_TOLERANCE); + Matrix s = svd.getS(); + assertEquals(0, Algebra.getNorm(s.minus(sRef)), NORM_TOLERANCE); + Matrix v = svd.getV(); + assertEquals(0, Algebra.getNorm(v.minus(vRef)), NORM_TOLERANCE); + } + + + /** test condition number */ + @Test + public void testConditionNumber() { + SingularValueDecomposition svd = + new SingularValueDecomposition(new DenseMatrix(testSquare)); + // replace 1.0e-15 with 1.5e-15 + assertEquals(3.0, svd.cond(), 1.5e-15); + } + + @Test + public void testSvdHang() throws IOException, InterruptedException, ExecutionException, TimeoutException { + System.out.printf("starting hanging-svd\n"); + final Matrix m = readTsv("hanging-svd.tsv"); + SingularValueDecomposition svd = new SingularValueDecomposition(m); + assertEquals(0, m.minus(svd.getU().times(svd.getS()).times(svd.getV().transpose())).aggregate(Functions.PLUS, Functions.ABS), 1e-10); + System.out.printf("No hang\n"); + } + + Matrix readTsv(String name) throws IOException { + Splitter onTab = Splitter.on("\t"); + List<String> lines = Resources.readLines((Resources.getResource(name)), Charsets.UTF_8); + int rows = lines.size(); + int columns = Iterables.size(onTab.split(lines.get(0))); + Matrix r = new DenseMatrix(rows, columns); + int row = 0; + for (String line : lines) { + Iterable<String> values = onTab.split(line); + int column = 0; + for (String value : values) { + r.set(row, column, Double.parseDouble(value)); + column++; + } + row++; + } + return r; + } + + + private static Matrix createTestMatrix(Random r, int rows, int columns, double[] singularValues) { + Matrix u = createOrthogonalMatrix(r, rows); + Matrix d = createDiagonalMatrix(singularValues, rows, columns); + Matrix v = createOrthogonalMatrix(r, columns); + return u.times(d).times(v); + } + + + public static Matrix createOrthogonalMatrix(Random r, int size) { + + double[][] data = new double[size][size]; + + for (int i = 0; i < size; ++i) { + double[] dataI = data[i]; + double norm2; + do { + + // generate randomly row I + for (int j = 0; j < size; ++j) { + dataI[j] = 2 * r.nextDouble() - 1; + } + + // project the row in the subspace orthogonal to previous rows + for (int k = 0; k < i; ++k) { + double[] dataK = data[k]; + double dotProduct = 0; + for (int j = 0; j < size; ++j) { + dotProduct += dataI[j] * dataK[j]; + } + for (int j = 0; j < size; ++j) { + dataI[j] -= dotProduct * dataK[j]; + } + } + + // normalize the row + norm2 = 0; + for (double dataIJ : dataI) { + norm2 += dataIJ * dataIJ; + } + double inv = 1.0 / Math.sqrt(norm2); + for (int j = 0; j < size; ++j) { + dataI[j] *= inv; + } + + } while (norm2 * size < 0.01); + } + + return new DenseMatrix(data); + + } + + public static Matrix createDiagonalMatrix(double[] diagonal, int rows, int columns) { + double[][] dData = new double[rows][columns]; + for (int i = 0; i < Math.min(rows, columns); ++i) { + dData[i][i] = diagonal[i]; + } + return new DenseMatrix(dData); + } + + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/test/java/org/apache/mahout/math/als/AlternatingLeastSquaresSolverTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/mahout/math/als/AlternatingLeastSquaresSolverTest.java b/core/src/test/java/org/apache/mahout/math/als/AlternatingLeastSquaresSolverTest.java new file mode 100644 index 0000000..95b19ad --- /dev/null +++ b/core/src/test/java/org/apache/mahout/math/als/AlternatingLeastSquaresSolverTest.java @@ -0,0 +1,151 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.als; + +import java.util.Arrays; + +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.MahoutTestCase; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.RandomAccessSparseVector; +import org.apache.mahout.math.SequentialAccessSparseVector; +import org.apache.mahout.math.SparseMatrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.map.OpenIntObjectHashMap; +import org.junit.Test; + +public class AlternatingLeastSquaresSolverTest extends MahoutTestCase { + + @Test + public void testYtY() { + + double[][] testMatrix = new double[][] { + new double[] { 1, 2, 3, 4, 5 }, + new double[] { 1, 2, 3, 4, 5 }, + new double[] { 1, 2, 3, 4, 5 }, + new double[] { 1, 2, 3, 4, 5 }, + new double[] { 1, 2, 3, 4, 5 }}; + + double[][] testMatrix2 = new double[][] { + new double[] { 1, 2, 3, 4, 5, 6 }, + new double[] { 5, 4, 3, 2, 1, 7 }, + new double[] { 1, 2, 3, 4, 5, 8 }, + new double[] { 1, 2, 3, 4, 5, 8 }, + new double[] { 11, 12, 13, 20, 27, 8 }}; + + double[][][] testData = new double[][][] { + testMatrix, + testMatrix2 }; + + for (int i = 0; i < testData.length; i++) { + Matrix matrixToTest = new DenseMatrix(testData[i]); + + //test for race conditions by trying a few times + for (int j = 0; j < 100; j++) { + validateYtY(matrixToTest, 4); + } + + //one thread @ a time test + validateYtY(matrixToTest, 1); + } + + } + + private void validateYtY(Matrix matrixToTest, int numThreads) { + + OpenIntObjectHashMap<Vector> matrixToTestAsRowVectors = asRowVectors(matrixToTest); + ImplicitFeedbackAlternatingLeastSquaresSolver solver = new ImplicitFeedbackAlternatingLeastSquaresSolver( + matrixToTest.columnSize(), 1, 1, matrixToTestAsRowVectors, numThreads); + + Matrix yTy = matrixToTest.transpose().times(matrixToTest); + Matrix shouldMatchyTy = solver.getYtransposeY(matrixToTestAsRowVectors); + + for (int row = 0; row < yTy.rowSize(); row++) { + for (int column = 0; column < yTy.columnSize(); column++) { + assertEquals(yTy.getQuick(row, column), shouldMatchyTy.getQuick(row, column), 0); + } + } + } + + private OpenIntObjectHashMap<Vector> asRowVectors(Matrix matrix) { + OpenIntObjectHashMap<Vector> rows = new OpenIntObjectHashMap<>(); + for (int row = 0; row < matrix.numRows(); row++) { + rows.put(row, matrix.viewRow(row).clone()); + } + return rows; + } + + @Test + public void addLambdaTimesNuiTimesE() { + int nui = 5; + double lambda = 0.2; + Matrix matrix = new SparseMatrix(5, 5); + + AlternatingLeastSquaresSolver.addLambdaTimesNuiTimesE(matrix, lambda, nui); + + for (int n = 0; n < 5; n++) { + assertEquals(1.0, matrix.getQuick(n, n), EPSILON); + } + } + + @Test + public void createMiIi() { + Vector f1 = new DenseVector(new double[] { 1, 2, 3 }); + Vector f2 = new DenseVector(new double[] { 4, 5, 6 }); + + Matrix miIi = AlternatingLeastSquaresSolver.createMiIi(Arrays.asList(f1, f2), 3); + + assertEquals(1.0, miIi.getQuick(0, 0), EPSILON); + assertEquals(2.0, miIi.getQuick(1, 0), EPSILON); + assertEquals(3.0, miIi.getQuick(2, 0), EPSILON); + assertEquals(4.0, miIi.getQuick(0, 1), EPSILON); + assertEquals(5.0, miIi.getQuick(1, 1), EPSILON); + assertEquals(6.0, miIi.getQuick(2, 1), EPSILON); + } + + @Test + public void createRiIiMaybeTransposed() { + Vector ratings = new SequentialAccessSparseVector(3); + ratings.setQuick(1, 1.0); + ratings.setQuick(3, 3.0); + ratings.setQuick(5, 5.0); + + Matrix riIiMaybeTransposed = AlternatingLeastSquaresSolver.createRiIiMaybeTransposed(ratings); + assertEquals(1, riIiMaybeTransposed.numCols(), 1); + assertEquals(3, riIiMaybeTransposed.numRows(), 3); + + assertEquals(1.0, riIiMaybeTransposed.getQuick(0, 0), EPSILON); + assertEquals(3.0, riIiMaybeTransposed.getQuick(1, 0), EPSILON); + assertEquals(5.0, riIiMaybeTransposed.getQuick(2, 0), EPSILON); + } + + @Test + public void createRiIiMaybeTransposedExceptionOnNonSequentialVector() { + Vector ratings = new RandomAccessSparseVector(3); + ratings.setQuick(1, 1.0); + ratings.setQuick(3, 3.0); + ratings.setQuick(5, 5.0); + + try { + AlternatingLeastSquaresSolver.createRiIiMaybeTransposed(ratings); + fail(); + } catch (IllegalArgumentException e) {} + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/test/java/org/apache/mahout/math/decomposer/SolverTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/mahout/math/decomposer/SolverTest.java b/core/src/test/java/org/apache/mahout/math/decomposer/SolverTest.java new file mode 100644 index 0000000..13baad8 --- /dev/null +++ b/core/src/test/java/org/apache/mahout/math/decomposer/SolverTest.java @@ -0,0 +1,177 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.decomposer; + +import com.google.common.collect.Lists; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.MahoutTestCase; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.SequentialAccessSparseVector; +import org.apache.mahout.math.SparseRowMatrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorIterable; +import org.apache.mahout.math.decomposer.lanczos.LanczosState; +import org.apache.mahout.math.function.Functions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; +import java.util.Random; + +public abstract class SolverTest extends MahoutTestCase { + private static final Logger log = LoggerFactory.getLogger(SolverTest.class); + + public static void assertOrthonormal(Matrix eigens) { + assertOrthonormal(eigens, 1.0e-6); + } + + public static void assertOrthonormal(Matrix currentEigens, double errorMargin) { + List<String> nonOrthogonals = Lists.newArrayList(); + for (int i = 0; i < currentEigens.numRows(); i++) { + Vector ei = currentEigens.viewRow(i); + for (int j = 0; j <= i; j++) { + Vector ej = currentEigens.viewRow(j); + if (ei.norm(2) == 0 || ej.norm(2) == 0) { + continue; + } + double dot = ei.dot(ej); + if (i == j) { + assertTrue("not norm 1 : " + dot + " (eigen #" + i + ')', Math.abs(1.0 - dot) < errorMargin); + } else { + if (Math.abs(dot) > errorMargin) { + log.info("not orthogonal : {} (eigens {}, {})", dot, i, j); + nonOrthogonals.add("(" + i + ',' + j + ')'); + } + } + } + log.info("{}:{}", nonOrthogonals.size(), nonOrthogonals); + } + } + + public static void assertOrthonormal(LanczosState state) { + double errorMargin = 1.0e-5; + List<String> nonOrthogonals = Lists.newArrayList(); + for (int i = 0; i < state.getIterationNumber(); i++) { + Vector ei = state.getRightSingularVector(i); + for (int j = 0; j <= i; j++) { + Vector ej = state.getRightSingularVector(j); + if (ei.norm(2) == 0 || ej.norm(2) == 0) { + continue; + } + double dot = ei.dot(ej); + if (i == j) { + assertTrue("not norm 1 : " + dot + " (eigen #" + i + ')', Math.abs(1.0 - dot) < errorMargin); + } else { + if (Math.abs(dot) > errorMargin) { + log.info("not orthogonal : {} (eigens {}, {})", dot, i, j); + nonOrthogonals.add("(" + i + ',' + j + ')'); + } + } + } + if (!nonOrthogonals.isEmpty()) { + log.info("{}:{}", nonOrthogonals.size(), nonOrthogonals); + } + } + } + + public static void assertEigen(Matrix eigens, VectorIterable corpus, double errorMargin, boolean isSymmetric) { + assertEigen(eigens, corpus, eigens.numRows(), errorMargin, isSymmetric); + } + + public static void assertEigen(Matrix eigens, + VectorIterable corpus, + int numEigensToCheck, + double errorMargin, + boolean isSymmetric) { + for (int i = 0; i < numEigensToCheck; i++) { + Vector e = eigens.viewRow(i); + assertEigen(i, e, corpus, errorMargin, isSymmetric); + } + } + + public static void assertEigen(int i, Vector e, VectorIterable corpus, double errorMargin, + boolean isSymmetric) { + if (e.getLengthSquared() == 0) { + return; + } + Vector afterMultiply = isSymmetric ? corpus.times(e) : corpus.timesSquared(e); + double dot = afterMultiply.dot(e); + double afterNorm = afterMultiply.getLengthSquared(); + double error = 1 - Math.abs(dot / Math.sqrt(afterNorm * e.getLengthSquared())); + log.info("the eigen-error: {} for eigen {}", error, i); + assertTrue("Error: {" + error + " too high! (for eigen " + i + ')', Math.abs(error) < errorMargin); + } + + /** + * Builds up a consistently random (same seed every time) sparse matrix, with sometimes + * repeated rows. + */ + public static Matrix randomSequentialAccessSparseMatrix(int numRows, + int nonNullRows, + int numCols, + int entriesPerRow, + double entryMean) { + Matrix m = new SparseRowMatrix(numRows, numCols); + //double n = 0; + Random r = RandomUtils.getRandom(); + for (int i = 0; i < nonNullRows; i++) { + Vector v = new SequentialAccessSparseVector(numCols); + for (int j = 0; j < entriesPerRow; j++) { + int col = r.nextInt(numCols); + double val = r.nextGaussian(); + v.set(col, val * entryMean); + } + int c = r.nextInt(numRows); + if (r.nextBoolean() || numRows == nonNullRows) { + m.assignRow(numRows == nonNullRows ? i : c, v); + } else { + Vector other = m.viewRow(r.nextInt(numRows)); + if (other != null && other.getLengthSquared() > 0) { + m.assignRow(c, other.clone()); + } + } + //n += m.getRow(c).getLengthSquared(); + } + return m; + } + + public static Matrix randomHierarchicalMatrix(int numRows, int numCols, boolean symmetric) { + Matrix matrix = new DenseMatrix(numRows, numCols); + // TODO rejigger tests so that it doesn't expect this particular seed + Random r = new Random(1234L); + for (int row = 0; row < numRows; row++) { + Vector v = new DenseVector(numCols); + for (int col = 0; col < numCols; col++) { + double val = r.nextGaussian(); + v.set(col, val); + } + v.assign(Functions.MULT, 1/((row + 1) * v.norm(2))); + matrix.assignRow(row, v); + } + if (symmetric) { + return matrix.times(matrix.transpose()); + } + return matrix; + } + + public static Matrix randomHierarchicalSymmetricMatrix(int size) { + return randomHierarchicalMatrix(size, size, true); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/test/java/org/apache/mahout/math/decomposer/hebbian/TestHebbianSolver.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/mahout/math/decomposer/hebbian/TestHebbianSolver.java b/core/src/test/java/org/apache/mahout/math/decomposer/hebbian/TestHebbianSolver.java new file mode 100644 index 0000000..56ea4f6 --- /dev/null +++ b/core/src/test/java/org/apache/mahout/math/decomposer/hebbian/TestHebbianSolver.java @@ -0,0 +1,207 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.decomposer.hebbian; + +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.Matrix; + +import org.apache.mahout.math.decomposer.AsyncEigenVerifier; +import org.apache.mahout.math.decomposer.SolverTest; +import org.junit.Test; + +/** + * This test is woefully inadequate, and also requires tons of memory, because it's part + * unit test, part performance test, and part comparison test (between the Hebbian and Lanczos + * approaches). + * TODO: make better. + */ +public final class TestHebbianSolver extends SolverTest { + + public static long timeSolver(Matrix corpus, + double convergence, + int maxNumPasses, + TrainingState state) { + return timeSolver(corpus, + convergence, + maxNumPasses, + 10, + state); + } + + public static long timeSolver(Matrix corpus, + double convergence, + int maxNumPasses, + int desiredRank, + TrainingState state) { + HebbianUpdater updater = new HebbianUpdater(); + AsyncEigenVerifier verifier = new AsyncEigenVerifier(); + HebbianSolver solver = new HebbianSolver(updater, + verifier, + convergence, + maxNumPasses); + long start = System.nanoTime(); + TrainingState finalState = solver.solve(corpus, desiredRank); + assertNotNull(finalState); + state.setCurrentEigens(finalState.getCurrentEigens()); + state.setCurrentEigenValues(finalState.getCurrentEigenValues()); + long time = 0L; + time += System.nanoTime() - start; + verifier.close(); + assertEquals(state.getCurrentEigens().numRows(), desiredRank); + return time / 1000000L; + } + + + + public static long timeSolver(Matrix corpus, TrainingState state) { + return timeSolver(corpus, state, 10); + } + + public static long timeSolver(Matrix corpus, TrainingState state, int rank) { + return timeSolver(corpus, 0.01, 20, rank, state); + } + + @Test + public void testHebbianSolver() { + int numColumns = 800; + Matrix corpus = randomSequentialAccessSparseMatrix(1000, 900, numColumns, 30, 1.0); + int rank = 50; + Matrix eigens = new DenseMatrix(rank, numColumns); + TrainingState state = new TrainingState(eigens, null); + long optimizedTime = timeSolver(corpus, + 0.00001, + 5, + rank, + state); + eigens = state.getCurrentEigens(); + assertEigen(eigens, corpus, 0.05, false); + assertOrthonormal(eigens, 1.0e-6); + System.out.println("Avg solving (Hebbian) time in ms: " + optimizedTime); + } + + /* + public void testSolverWithSerialization() throws Exception + { + _corpusProjectionsVectorFactory = new DenseMapVectorFactory(); + _eigensVectorFactory = new DenseMapVectorFactory(); + + timeSolver(TMP_EIGEN_DIR, + 0.001, + 5, + new TrainingState(null, null)); + + File eigenDir = new File(TMP_EIGEN_DIR + File.separator + HebbianSolver.EIGEN_VECT_DIR); + DiskBufferedDoubleMatrix eigens = new DiskBufferedDoubleMatrix(eigenDir, 10); + + DoubleMatrix inMemoryMatrix = new HashMapDoubleMatrix(_corpusProjectionsVectorFactory, eigens); + + for (Entry<Integer, MapVector> diskEntry : eigens) + { + for (Entry<Integer, MapVector> inMemoryEntry : inMemoryMatrix) + { + if (diskEntry.getKey() - inMemoryEntry.getKey() == 0) + { + assertTrue("vector with index : " + diskEntry.getKey() + " is not the same on disk as in memory", + Math.abs(1 - diskEntry.getValue().dot(inMemoryEntry.getValue())) < 1e-6); + } + else + { + assertTrue("vector with index : " + diskEntry.getKey() + + " is not orthogonal to memory vect with index : " + inMemoryEntry.getKey(), + Math.abs(diskEntry.getValue().dot(inMemoryEntry.getValue())) < 1e-6); + } + } + } + File corpusDir = new File(TMP_EIGEN_DIR + File.separator + "corpus"); + corpusDir.mkdir(); + // TODO: persist to disk? + // DiskBufferedDoubleMatrix.persistChunk(corpusDir, corpus, true); + // eigens.delete(); + + // DiskBufferedDoubleMatrix.delete(new File(TMP_EIGEN_DIR)); + } + */ +/* + public void testHebbianVersusLanczos() throws Exception + { + _corpusProjectionsVectorFactory = new DenseMapVectorFactory(); + _eigensVectorFactory = new DenseMapVectorFactory(); + int desiredRank = 200; + long time = timeSolver(TMP_EIGEN_DIR, + 0.00001, + 5, + desiredRank, + new TrainingState()); + + System.out.println("Hebbian time: " + time + "ms"); + File eigenDir = new File(TMP_EIGEN_DIR + File.separator + HebbianSolver.EIGEN_VECT_DIR); + DiskBufferedDoubleMatrix eigens = new DiskBufferedDoubleMatrix(eigenDir, 10); + + DoubleMatrix2D srm = asSparseDoubleMatrix2D(corpus); + long timeA = System.nanoTime(); + EigenvalueDecomposition asSparseRealDecomp = new EigenvalueDecomposition(srm); + for (int i=0; i<desiredRank; i++) + asSparseRealDecomp.getEigenvector(i); + System.out.println("CommonsMath time: " + (System.nanoTime() - timeA)/TimingConstants.NANOS_IN_MILLI + "ms"); + + // System.out.println("Hebbian results:"); + // printEigenVerify(eigens, corpus); + + DoubleMatrix lanczosEigenVectors = new HashMapDoubleMatrix(new HashMapVectorFactory()); + List<Double> lanczosEigenValues = new ArrayList<Double>(); + + LanczosSolver solver = new LanczosSolver(); + solver.solve(corpus, desiredRank*5, lanczosEigenVectors, lanczosEigenValues); + + for (TimingSection section : LanczosSolver.TimingSection.values()) + { + System.out.println("Lanczos " + section.toString() + " = " + (int)(solver.getTimeMillis(section)/1000) + " seconds"); + } + + // System.out.println("\nLanczos results:"); + // printEigenVerify(lanczosEigenVectors, corpus); + } + + private DoubleMatrix2D asSparseDoubleMatrix2D(Matrix corpus) + { + DoubleMatrix2D result = new DenseDoubleMatrix2D(corpus.numRows(), corpus.numRows()); + for (int i=0; i<corpus.numRows(); i++) { + for (int j=i; j<corpus.numRows(); j++) { + double v = corpus.getRow(i).dot(corpus.getRow(j)); + result.set(i, j, v); + result.set(j, i, v); + } + } + return result; + } + + + public static void printEigenVerify(DoubleMatrix eigens, DoubleMatrix corpus) + { + for (Map.Entry<Integer, MapVector> entry : eigens) + { + MapVector eigen = entry.getValue(); + MapVector afterMultiply = corpus.timesSquared(eigen); + double norm = afterMultiply.norm(); + double error = 1 - eigen.dot(afterMultiply) / (eigen.norm() * afterMultiply.norm()); + System.out.println(entry.getKey() + ": error = " + error + ", eVal = " + (norm / eigen.norm())); + } + } + */ + +}
