http://git-wip-us.apache.org/repos/asf/mahout/blob/49ad8cb4/core/src/main/java/org/apache/mahout/math/set/OpenHashSet.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/set/OpenHashSet.java b/core/src/main/java/org/apache/mahout/math/set/OpenHashSet.java deleted file mode 100644 index 285b5a5..0000000 --- a/core/src/main/java/org/apache/mahout/math/set/OpenHashSet.java +++ /dev/null @@ -1,548 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.mahout.math.set; - -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Iterator; -import java.util.List; -import java.util.Set; - -import org.apache.mahout.math.MurmurHash; -import org.apache.mahout.math.function.ObjectProcedure; -import org.apache.mahout.math.map.PrimeFinder; - -/** - * Open hashing alternative to java.util.HashSet. - **/ -public class OpenHashSet<T> extends AbstractSet implements Set<T> { - protected static final byte FREE = 0; - protected static final byte FULL = 1; - protected static final byte REMOVED = 2; - protected static final char NO_KEY_VALUE = 0; - - /** The hash table keys. */ - private Object[] table; - - /** The state of each hash table entry (FREE, FULL, REMOVED). */ - private byte[] state; - - /** The number of table entries in state==FREE. */ - private int freeEntries; - - - /** Constructs an empty map with default capacity and default load factors. */ - public OpenHashSet() { - this(DEFAULT_CAPACITY); - } - - /** - * Constructs an empty map with the specified initial capacity and default load factors. - * - * @param initialCapacity the initial capacity of the map. - * @throws IllegalArgumentException if the initial capacity is less than zero. - */ - public OpenHashSet(int initialCapacity) { - this(initialCapacity, DEFAULT_MIN_LOAD_FACTOR, DEFAULT_MAX_LOAD_FACTOR); - } - - /** - * Constructs an empty map with the specified initial capacity and the specified minimum and maximum load factor. - * - * @param initialCapacity the initial capacity. - * @param minLoadFactor the minimum load factor. - * @param maxLoadFactor the maximum load factor. - * @throws IllegalArgumentException if <tt>initialCapacity < 0 || (minLoadFactor < 0.0 || minLoadFactor >= 1.0) || - * (maxLoadFactor <= 0.0 || maxLoadFactor >= 1.0) || (minLoadFactor >= - * maxLoadFactor)</tt>. - */ - public OpenHashSet(int initialCapacity, double minLoadFactor, double maxLoadFactor) { - setUp(initialCapacity, minLoadFactor, maxLoadFactor); - } - - /** Removes all values associations from the receiver. Implicitly calls <tt>trimToSize()</tt>. */ - @Override - public void clear() { - Arrays.fill(this.state, 0, state.length - 1, FREE); - distinct = 0; - freeEntries = table.length; // delta - trimToSize(); - } - - /** - * Returns a deep copy of the receiver. - * - * @return a deep copy of the receiver. - */ - @SuppressWarnings("unchecked") - @Override - public Object clone() { - OpenHashSet<T> copy = (OpenHashSet<T>) super.clone(); - copy.table = copy.table.clone(); - copy.state = copy.state.clone(); - return copy; - } - - /** - * Returns <tt>true</tt> if the receiver contains the specified key. - * - * @return <tt>true</tt> if the receiver contains the specified key. - */ - @Override - @SuppressWarnings("unchecked") - public boolean contains(Object key) { - return indexOfKey((T)key) >= 0; - } - - /** - * Ensures that the receiver can hold at least the specified number of associations without needing to allocate new - * internal memory. If necessary, allocates new internal memory and increases the capacity of the receiver. <p> This - * method never need be called; it is for performance tuning only. Calling this method before <tt>add()</tt>ing a - * large number of associations boosts performance, because the receiver will grow only once instead of potentially - * many times and hash collisions get less probable. - * - * @param minCapacity the desired minimum capacity. - */ - @Override - public void ensureCapacity(int minCapacity) { - if (table.length < minCapacity) { - int newCapacity = nextPrime(minCapacity); - rehash(newCapacity); - } - } - - /** - * Applies a procedure to each key of the receiver, if any. Note: Iterates over the keys in no particular order. - * Subclasses can define a particular order, for example, "sorted by key". All methods which <i>can</i> be expressed - * in terms of this method (most methods can) <i>must guarantee</i> to use the <i>same</i> order defined by this - * method, even if it is no particular order. This is necessary so that, for example, methods <tt>keys</tt> and - * <tt>values</tt> will yield association pairs, not two uncorrelated lists. - * - * @param procedure the procedure to be applied. Stops iteration if the procedure returns <tt>false</tt>, otherwise - * continues. - * @return <tt>false</tt> if the procedure stopped before all keys where iterated over, <tt>true</tt> otherwise. - */ - @SuppressWarnings("unchecked") - public boolean forEachKey(ObjectProcedure<T> procedure) { - for (int i = table.length; i-- > 0;) { - if (state[i] == FULL) { - if (!procedure.apply((T)table[i])) { - return false; - } - } - } - return true; - } - - /** - * @param key the key to be added to the receiver. - * @return the index where the key would need to be inserted, if it is not already contained. Returns -index-1 if the - * key is already contained at slot index. Therefore, if the returned index < 0, then it is already contained - * at slot -index-1. If the returned index >= 0, then it is NOT already contained and should be inserted at - * slot index. - */ - protected int indexOfInsertion(T key) { - Object[] tab = table; - byte[] stat = state; - int length = tab.length; - - int hash = key.hashCode() & 0x7FFFFFFF; - int i = hash % length; - int decrement = hash % (length - 2); // double hashing, see http://www.eece.unm.edu/faculty/heileman/hash/node4.html - //int decrement = (hash / length) % length; - if (decrement == 0) { - decrement = 1; - } - - // stop if we find a removed or free slot, or if we find the key itself - // do NOT skip over removed slots (yes, open addressing is like that...) - while (stat[i] == FULL && tab[i] != key) { - i -= decrement; - //hashCollisions++; - if (i < 0) { - i += length; - } - } - - if (stat[i] == REMOVED) { - // stop if we find a free slot, or if we find the key itself. - // do skip over removed slots (yes, open addressing is like that...) - // assertion: there is at least one FREE slot. - int j = i; - while (stat[i] != FREE && (stat[i] == REMOVED || tab[i] != key)) { - i -= decrement; - //hashCollisions++; - if (i < 0) { - i += length; - } - } - if (stat[i] == FREE) { - i = j; - } - } - - - if (stat[i] == FULL) { - // key already contained at slot i. - // return a negative number identifying the slot. - return -i - 1; - } - // not already contained, should be inserted at slot i. - // return a number >= 0 identifying the slot. - return i; - } - - /** - * @param key the key to be searched in the receiver. - * @return the index where the key is contained in the receiver, returns -1 if the key was not found. - */ - protected int indexOfKey(T key) { - Object[] tab = table; - byte[] stat = state; - int length = tab.length; - - int hash = key.hashCode() & 0x7FFFFFFF; - int i = hash % length; - int decrement = hash % (length - 2); // double hashing, see http://www.eece.unm.edu/faculty/heileman/hash/node4.html - //int decrement = (hash / length) % length; - if (decrement == 0) { - decrement = 1; - } - - // stop if we find a free slot, or if we find the key itself. - // do skip over removed slots (yes, open addressing is like that...) - while (stat[i] != FREE && (stat[i] == REMOVED || (!key.equals(tab[i])))) { - i -= decrement; - //hashCollisions++; - if (i < 0) { - i += length; - } - } - - if (stat[i] == FREE) { - return -1; - } // not found - return i; //found, return index where key is contained - } - - /** - * Fills all keys contained in the receiver into the specified list. Fills the list, starting at index 0. After this - * call returns the specified list has a new size that equals <tt>this.size()</tt>. - * This method can be used - * to iterate over the keys of the receiver. - * - * @param list the list to be filled, can have any size. - */ - @SuppressWarnings("unchecked") - public void keys(List<T> list) { - list.clear(); - - - Object [] tab = table; - byte[] stat = state; - - for (int i = tab.length; i-- > 0;) { - if (stat[i] == FULL) { - list.add((T)tab[i]); - } - } - } - - @SuppressWarnings("unchecked") - @Override - public boolean add(Object key) { - int i = indexOfInsertion((T)key); - if (i < 0) { //already contained - return false; - } - - if (this.distinct > this.highWaterMark) { - int newCapacity = chooseGrowCapacity(this.distinct + 1, this.minLoadFactor, this.maxLoadFactor); - rehash(newCapacity); - return add(key); - } - - this.table[i] = key; - if (this.state[i] == FREE) { - this.freeEntries--; - } - this.state[i] = FULL; - this.distinct++; - - if (this.freeEntries < 1) { //delta - int newCapacity = chooseGrowCapacity(this.distinct + 1, this.minLoadFactor, this.maxLoadFactor); - rehash(newCapacity); - return add(key); - } - - return true; - } - - /** - * Rehashes the contents of the receiver into a new table with a smaller or larger capacity. This method is called - * automatically when the number of keys in the receiver exceeds the high water mark or falls below the low water - * mark. - */ - @SuppressWarnings("unchecked") - protected void rehash(int newCapacity) { - int oldCapacity = table.length; - //if (oldCapacity == newCapacity) return; - - Object[] oldTable = table; - byte[] oldState = state; - - Object[] newTable = new Object[newCapacity]; - byte[] newState = new byte[newCapacity]; - - this.lowWaterMark = chooseLowWaterMark(newCapacity, this.minLoadFactor); - this.highWaterMark = chooseHighWaterMark(newCapacity, this.maxLoadFactor); - - this.table = newTable; - this.state = newState; - this.freeEntries = newCapacity - this.distinct; // delta - - for (int i = oldCapacity; i-- > 0;) { - if (oldState[i] == FULL) { - Object element = oldTable[i]; - int index = indexOfInsertion((T)element); - newTable[index] = element; - newState[index] = FULL; - } - } - } - - /** - * Removes the given key with its associated element from the receiver, if present. - * - * @param key the key to be removed from the receiver. - * @return <tt>true</tt> if the receiver contained the specified key, <tt>false</tt> otherwise. - */ - @SuppressWarnings("unchecked") - @Override - public boolean remove(Object key) { - int i = indexOfKey((T)key); - if (i < 0) { - return false; - } // key not contained - - this.state[i] = REMOVED; - this.distinct--; - - if (this.distinct < this.lowWaterMark) { - int newCapacity = chooseShrinkCapacity(this.distinct, this.minLoadFactor, this.maxLoadFactor); - rehash(newCapacity); - } - - return true; - } - - /** - * Initializes the receiver. - * - * @param initialCapacity the initial capacity of the receiver. - * @param minLoadFactor the minLoadFactor of the receiver. - * @param maxLoadFactor the maxLoadFactor of the receiver. - * @throws IllegalArgumentException if <tt>initialCapacity < 0 || (minLoadFactor < 0.0 || minLoadFactor >= 1.0) || - * (maxLoadFactor <= 0.0 || maxLoadFactor >= 1.0) || (minLoadFactor >= - * maxLoadFactor)</tt>. - */ - @Override - protected final void setUp(int initialCapacity, double minLoadFactor, double maxLoadFactor) { - int capacity = initialCapacity; - super.setUp(capacity, minLoadFactor, maxLoadFactor); - capacity = nextPrime(capacity); - if (capacity == 0) { - capacity = 1; - } // open addressing needs at least one FREE slot at any time. - - this.table = new Object[capacity]; - this.state = new byte[capacity]; - - // memory will be exhausted long before this pathological case happens, anyway. - this.minLoadFactor = minLoadFactor; - if (capacity == PrimeFinder.LARGEST_PRIME) { - this.maxLoadFactor = 1.0; - } else { - this.maxLoadFactor = maxLoadFactor; - } - - this.distinct = 0; - this.freeEntries = capacity; // delta - - // lowWaterMark will be established upon first expansion. - // establishing it now (upon instance construction) would immediately make the table shrink upon first put(...). - // After all the idea of an "initialCapacity" implies violating lowWaterMarks when an object is young. - // See ensureCapacity(...) - this.lowWaterMark = 0; - this.highWaterMark = chooseHighWaterMark(capacity, this.maxLoadFactor); - } - - /** - * Trims the capacity of the receiver to be the receiver's current size. Releases any superfluous internal memory. An - * application can use this operation to minimize the storage of the receiver. - */ - @Override - public void trimToSize() { - // * 1.2 because open addressing's performance exponentially degrades beyond that point - // so that even rehashing the table can take very long - int newCapacity = nextPrime((int) (1 + 1.2 * size())); - if (table.length > newCapacity) { - rehash(newCapacity); - } - } - - /** - * Access for unit tests. - * @param capacity - * @param minLoadFactor - * @param maxLoadFactor - */ - void getInternalFactors(int[] capacity, - double[] minLoadFactor, - double[] maxLoadFactor) { - capacity[0] = table.length; - minLoadFactor[0] = this.minLoadFactor; - maxLoadFactor[0] = this.maxLoadFactor; - } - - @Override - public boolean isEmpty() { - return size() == 0; - } - - /** - * OpenHashSet instances are only equal to other OpenHashSet instances, not to - * any other collection. Hypothetically, we should check for and permit - * equals on other Sets. - */ - @Override - @SuppressWarnings("unchecked") - public boolean equals(Object obj) { - if (obj == this) { - return true; - } - - if (!(obj instanceof OpenHashSet)) { - return false; - } - final OpenHashSet<T> other = (OpenHashSet<T>) obj; - if (other.size() != size()) { - return false; - } - - return forEachKey(new ObjectProcedure<T>() { - @Override - public boolean apply(T key) { - return other.contains(key); - } - }); - } - - @Override - public int hashCode() { - ByteBuffer buf = ByteBuffer.allocate(size()); - for (int i = 0; i < table.length; i++) { - Object v = table[i]; - if (state[i] == FULL) { - buf.putInt(v.hashCode()); - } - } - return MurmurHash.hash(buf, this.getClass().getName().hashCode()); - } - - /** - * Implement the standard Java Collections iterator. Note that 'remove' is silently - * ineffectual here. This method is provided for convenience, only. - */ - @Override - public Iterator<T> iterator() { - List<T> keyList = new ArrayList<>(); - keys(keyList); - return keyList.iterator(); - } - - @Override - public Object[] toArray() { - List<T> keyList = new ArrayList<>(); - keys(keyList); - return keyList.toArray(); - } - - @Override - public boolean addAll(Collection<? extends T> c) { - boolean anyAdded = false; - for (T o : c) { - boolean added = add(o); - anyAdded |= added; - } - return anyAdded; - } - - @Override - public boolean containsAll(Collection<?> c) { - for (Object o : c) { - if (!contains(o)) { - return false; - } - } - return true; - } - - @Override - public boolean removeAll(Collection<?> c) { - boolean anyRemoved = false; - for (Object o : c) { - boolean removed = remove(o); - anyRemoved |= removed; - } - return anyRemoved; - } - - @Override - public boolean retainAll(Collection<?> c) { - final Collection<?> finalCollection = c; - final boolean[] modified = new boolean[1]; - modified[0] = false; - forEachKey(new ObjectProcedure<T>() { - @Override - public boolean apply(T element) { - if (!finalCollection.contains(element)) { - remove(element); - modified[0] = true; - } - return true; - } - }); - return modified[0]; - } - - @Override - public <T1> T1[] toArray(T1[] a) { - return keys().toArray(a); - } - - public List<T> keys() { - List<T> keys = new ArrayList<>(); - keys(keys); - return keys; - } -}
http://git-wip-us.apache.org/repos/asf/mahout/blob/49ad8cb4/core/src/main/java/org/apache/mahout/math/solver/ConjugateGradientSolver.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/solver/ConjugateGradientSolver.java b/core/src/main/java/org/apache/mahout/math/solver/ConjugateGradientSolver.java deleted file mode 100644 index 02bde9b..0000000 --- a/core/src/main/java/org/apache/mahout/math/solver/ConjugateGradientSolver.java +++ /dev/null @@ -1,213 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.math.solver; - -import org.apache.mahout.math.CardinalityException; -import org.apache.mahout.math.DenseVector; -import org.apache.mahout.math.Vector; -import org.apache.mahout.math.VectorIterable; -import org.apache.mahout.math.function.Functions; -import org.apache.mahout.math.function.PlusMult; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * <p>Implementation of a conjugate gradient iterative solver for linear systems. Implements both - * standard conjugate gradient and pre-conditioned conjugate gradient. - * - * <p>Conjugate gradient requires the matrix A in the linear system Ax = b to be symmetric and positive - * definite. For convenience, this implementation could be extended relatively easily to handle the - * case where the input matrix to be be non-symmetric, in which case the system A'Ax = b would be solved. - * Because this requires only one pass through the matrix A, it is faster than explicitly computing A'A, - * then passing the results to the solver. - * - * <p>For inputs that may be ill conditioned (often the case for highly sparse input), this solver - * also accepts a parameter, lambda, which adds a scaled identity to the matrix A, solving the system - * (A + lambda*I)x = b. This obviously changes the solution, but it will guarantee solvability. The - * ridge regression approach to linear regression is a common use of this feature. - * - * <p>If only an approximate solution is required, the maximum number of iterations or the error threshold - * may be specified to end the algorithm early at the expense of accuracy. When the matrix A is ill conditioned, - * it may sometimes be necessary to increase the maximum number of iterations above the default of A.numCols() - * due to numerical issues. - * - * <p>By default the solver will run a.numCols() iterations or until the residual falls below 1E-9. - * - * <p>For more information on the conjugate gradient algorithm, see Golub & van Loan, "Matrix Computations", - * sections 10.2 and 10.3 or the <a href="http://en.wikipedia.org/wiki/Conjugate_gradient">conjugate gradient - * wikipedia article</a>. - */ - -public class ConjugateGradientSolver { - - public static final double DEFAULT_MAX_ERROR = 1.0e-9; - - private static final Logger log = LoggerFactory.getLogger(ConjugateGradientSolver.class); - private static final PlusMult PLUS_MULT = new PlusMult(1.0); - - private int iterations; - private double residualNormSquared; - - public ConjugateGradientSolver() { - this.iterations = 0; - this.residualNormSquared = Double.NaN; - } - - /** - * Solves the system Ax = b with default termination criteria. A must be symmetric, square, and positive definite. - * Only the squareness of a is checked, since testing for symmetry and positive definiteness are too expensive. If - * an invalid matrix is specified, then the algorithm may not yield a valid result. - * - * @param a The linear operator A. - * @param b The vector b. - * @return The result x of solving the system. - * @throws IllegalArgumentException if a is not square or if the size of b is not equal to the number of columns of a. - * - */ - public Vector solve(VectorIterable a, Vector b) { - return solve(a, b, null, b.size() + 2, DEFAULT_MAX_ERROR); - } - - /** - * Solves the system Ax = b with default termination criteria using the specified preconditioner. A must be - * symmetric, square, and positive definite. Only the squareness of a is checked, since testing for symmetry - * and positive definiteness are too expensive. If an invalid matrix is specified, then the algorithm may not - * yield a valid result. - * - * @param a The linear operator A. - * @param b The vector b. - * @param precond A preconditioner to use on A during the solution process. - * @return The result x of solving the system. - * @throws IllegalArgumentException if a is not square or if the size of b is not equal to the number of columns of a. - * - */ - public Vector solve(VectorIterable a, Vector b, Preconditioner precond) { - return solve(a, b, precond, b.size() + 2, DEFAULT_MAX_ERROR); - } - - - /** - * Solves the system Ax = b, where A is a linear operator and b is a vector. Uses the specified preconditioner - * to improve numeric stability and possibly speed convergence. This version of solve() allows control over the - * termination and iteration parameters. - * - * @param a The matrix A. - * @param b The vector b. - * @param preconditioner The preconditioner to apply. - * @param maxIterations The maximum number of iterations to run. - * @param maxError The maximum amount of residual error to tolerate. The algorithm will run until the residual falls - * below this value or until maxIterations are completed. - * @return The result x of solving the system. - * @throws IllegalArgumentException if the matrix is not square, if the size of b is not equal to the number of - * columns of A, if maxError is less than zero, or if maxIterations is not positive. - */ - - public Vector solve(VectorIterable a, - Vector b, - Preconditioner preconditioner, - int maxIterations, - double maxError) { - - if (a.numRows() != a.numCols()) { - throw new IllegalArgumentException("Matrix must be square, symmetric and positive definite."); - } - - if (a.numCols() != b.size()) { - throw new CardinalityException(a.numCols(), b.size()); - } - - if (maxIterations <= 0) { - throw new IllegalArgumentException("Max iterations must be positive."); - } - - if (maxError < 0.0) { - throw new IllegalArgumentException("Max error must be non-negative."); - } - - Vector x = new DenseVector(b.size()); - - iterations = 0; - Vector residual = b.minus(a.times(x)); - residualNormSquared = residual.dot(residual); - - log.info("Conjugate gradient initial residual norm = {}", Math.sqrt(residualNormSquared)); - double previousConditionedNormSqr = 0.0; - Vector updateDirection = null; - while (Math.sqrt(residualNormSquared) > maxError && iterations < maxIterations) { - Vector conditionedResidual; - double conditionedNormSqr; - if (preconditioner == null) { - conditionedResidual = residual; - conditionedNormSqr = residualNormSquared; - } else { - conditionedResidual = preconditioner.precondition(residual); - conditionedNormSqr = residual.dot(conditionedResidual); - } - - ++iterations; - - if (iterations == 1) { - updateDirection = new DenseVector(conditionedResidual); - } else { - double beta = conditionedNormSqr / previousConditionedNormSqr; - - // updateDirection = residual + beta * updateDirection - updateDirection.assign(Functions.MULT, beta); - updateDirection.assign(conditionedResidual, Functions.PLUS); - } - - Vector aTimesUpdate = a.times(updateDirection); - - double alpha = conditionedNormSqr / updateDirection.dot(aTimesUpdate); - - // x = x + alpha * updateDirection - PLUS_MULT.setMultiplicator(alpha); - x.assign(updateDirection, PLUS_MULT); - - // residual = residual - alpha * A * updateDirection - PLUS_MULT.setMultiplicator(-alpha); - residual.assign(aTimesUpdate, PLUS_MULT); - - previousConditionedNormSqr = conditionedNormSqr; - residualNormSquared = residual.dot(residual); - - log.info("Conjugate gradient iteration {} residual norm = {}", iterations, Math.sqrt(residualNormSquared)); - } - return x; - } - - /** - * Returns the number of iterations run once the solver is complete. - * - * @return The number of iterations run. - */ - public int getIterations() { - return iterations; - } - - /** - * Returns the norm of the residual at the completion of the solver. Usually this should be close to zero except in - * the case of a non positive definite matrix A, which results in an unsolvable system, or for ill conditioned A, in - * which case more iterations than the default may be needed. - * - * @return The norm of the residual in the solution. - */ - public double getResidualNorm() { - return Math.sqrt(residualNormSquared); - } -} http://git-wip-us.apache.org/repos/asf/mahout/blob/49ad8cb4/core/src/main/java/org/apache/mahout/math/solver/EigenDecomposition.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/solver/EigenDecomposition.java b/core/src/main/java/org/apache/mahout/math/solver/EigenDecomposition.java deleted file mode 100644 index 871ba44..0000000 --- a/core/src/main/java/org/apache/mahout/math/solver/EigenDecomposition.java +++ /dev/null @@ -1,892 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * Adapted from the public domain Jama code. - */ - -package org.apache.mahout.math.solver; - -import org.apache.mahout.math.DenseMatrix; -import org.apache.mahout.math.DenseVector; -import org.apache.mahout.math.Matrix; -import org.apache.mahout.math.Vector; -import org.apache.mahout.math.function.Functions; - -/** - * Eigenvalues and eigenvectors of a real matrix. - * <p/> - * If A is symmetric, then A = V*D*V' where the eigenvalue matrix D is diagonal and the eigenvector - * matrix V is orthogonal. I.e. A = V.times(D.times(V.transpose())) and V.times(V.transpose()) - * equals the identity matrix. - * <p/> - * If A is not symmetric, then the eigenvalue matrix D is block diagonal with the real eigenvalues - * in 1-by-1 blocks and any complex eigenvalues, lambda + i*mu, in 2-by-2 blocks, [lambda, mu; -mu, - * lambda]. The columns of V represent the eigenvectors in the sense that A*V = V*D, i.e. - * A.times(V) equals V.times(D). The matrix V may be badly conditioned, or even singular, so the - * validity of the equation A = V*D*inverse(V) depends upon V.cond(). - */ -public class EigenDecomposition { - - /** Row and column dimension (square matrix). */ - private final int n; - /** Arrays for internal storage of eigenvalues. */ - private final Vector d; - private final Vector e; - /** Array for internal storage of eigenvectors. */ - private final Matrix v; - - public EigenDecomposition(Matrix x) { - this(x, isSymmetric(x)); - } - - public EigenDecomposition(Matrix x, boolean isSymmetric) { - n = x.columnSize(); - d = new DenseVector(n); - e = new DenseVector(n); - v = new DenseMatrix(n, n); - - if (isSymmetric) { - v.assign(x); - - // Tridiagonalize. - tred2(); - - // Diagonalize. - tql2(); - - } else { - // Reduce to Hessenberg form. - // Reduce Hessenberg to real Schur form. - hqr2(orthes(x)); - } - } - - /** - * Return the eigenvector matrix - * - * @return V - */ - public Matrix getV() { - return v.like().assign(v); - } - - /** - * Return the real parts of the eigenvalues - */ - public Vector getRealEigenvalues() { - return d; - } - - /** - * Return the imaginary parts of the eigenvalues - */ - public Vector getImagEigenvalues() { - return e; - } - - /** - * Return the block diagonal eigenvalue matrix - * - * @return D - */ - public Matrix getD() { - Matrix x = new DenseMatrix(n, n); - x.assign(0); - x.viewDiagonal().assign(d); - for (int i = 0; i < n; i++) { - double v = e.getQuick(i); - if (v > 0) { - x.setQuick(i, i + 1, v); - } else if (v < 0) { - x.setQuick(i, i - 1, v); - } - } - return x; - } - - // Symmetric Householder reduction to tridiagonal form. - private void tred2() { - // This is derived from the Algol procedures tred2 by - // Bowdler, Martin, Reinsch, and Wilkinson, Handbook for - // Auto. Comp., Vol.ii-Linear Algebra, and the corresponding - // Fortran subroutine in EISPACK. - - d.assign(v.viewColumn(n - 1)); - - // Householder reduction to tridiagonal form. - - for (int i = n - 1; i > 0; i--) { - - // Scale to avoid under/overflow. - - double scale = d.viewPart(0, i).norm(1); - double h = 0.0; - - - if (scale == 0.0) { - e.setQuick(i, d.getQuick(i - 1)); - for (int j = 0; j < i; j++) { - d.setQuick(j, v.getQuick(i - 1, j)); - v.setQuick(i, j, 0.0); - v.setQuick(j, i, 0.0); - } - } else { - - // Generate Householder vector. - - for (int k = 0; k < i; k++) { - d.setQuick(k, d.getQuick(k) / scale); - h += d.getQuick(k) * d.getQuick(k); - } - double f = d.getQuick(i - 1); - double g = Math.sqrt(h); - if (f > 0) { - g = -g; - } - e.setQuick(i, scale * g); - h -= f * g; - d.setQuick(i - 1, f - g); - for (int j = 0; j < i; j++) { - e.setQuick(j, 0.0); - } - - // Apply similarity transformation to remaining columns. - - for (int j = 0; j < i; j++) { - f = d.getQuick(j); - v.setQuick(j, i, f); - g = e.getQuick(j) + v.getQuick(j, j) * f; - for (int k = j + 1; k <= i - 1; k++) { - g += v.getQuick(k, j) * d.getQuick(k); - e.setQuick(k, e.getQuick(k) + v.getQuick(k, j) * f); - } - e.setQuick(j, g); - } - f = 0.0; - for (int j = 0; j < i; j++) { - e.setQuick(j, e.getQuick(j) / h); - f += e.getQuick(j) * d.getQuick(j); - } - double hh = f / (h + h); - for (int j = 0; j < i; j++) { - e.setQuick(j, e.getQuick(j) - hh * d.getQuick(j)); - } - for (int j = 0; j < i; j++) { - f = d.getQuick(j); - g = e.getQuick(j); - for (int k = j; k <= i - 1; k++) { - v.setQuick(k, j, v.getQuick(k, j) - (f * e.getQuick(k) + g * d.getQuick(k))); - } - d.setQuick(j, v.getQuick(i - 1, j)); - v.setQuick(i, j, 0.0); - } - } - d.setQuick(i, h); - } - - // Accumulate transformations. - - for (int i = 0; i < n - 1; i++) { - v.setQuick(n - 1, i, v.getQuick(i, i)); - v.setQuick(i, i, 1.0); - double h = d.getQuick(i + 1); - if (h != 0.0) { - for (int k = 0; k <= i; k++) { - d.setQuick(k, v.getQuick(k, i + 1) / h); - } - for (int j = 0; j <= i; j++) { - double g = 0.0; - for (int k = 0; k <= i; k++) { - g += v.getQuick(k, i + 1) * v.getQuick(k, j); - } - for (int k = 0; k <= i; k++) { - v.setQuick(k, j, v.getQuick(k, j) - g * d.getQuick(k)); - } - } - } - for (int k = 0; k <= i; k++) { - v.setQuick(k, i + 1, 0.0); - } - } - d.assign(v.viewRow(n - 1)); - v.viewRow(n - 1).assign(0); - v.setQuick(n - 1, n - 1, 1.0); - e.setQuick(0, 0.0); - } - - // Symmetric tridiagonal QL algorithm. - private void tql2() { - - // This is derived from the Algol procedures tql2, by - // Bowdler, Martin, Reinsch, and Wilkinson, Handbook for - // Auto. Comp., Vol.ii-Linear Algebra, and the corresponding - // Fortran subroutine in EISPACK. - - e.viewPart(0, n - 1).assign(e.viewPart(1, n - 1)); - e.setQuick(n - 1, 0.0); - - double f = 0.0; - double tst1 = 0.0; - double eps = Math.pow(2.0, -52.0); - for (int l = 0; l < n; l++) { - - // Find small subdiagonal element - - tst1 = Math.max(tst1, Math.abs(d.getQuick(l)) + Math.abs(e.getQuick(l))); - int m = l; - while (m < n) { - if (Math.abs(e.getQuick(m)) <= eps * tst1) { - break; - } - m++; - } - - // If m == l, d.getQuick(l) is an eigenvalue, - // otherwise, iterate. - - if (m > l) { - do { - // Compute implicit shift - - double g = d.getQuick(l); - double p = (d.getQuick(l + 1) - g) / (2.0 * e.getQuick(l)); - double r = Math.hypot(p, 1.0); - if (p < 0) { - r = -r; - } - d.setQuick(l, e.getQuick(l) / (p + r)); - d.setQuick(l + 1, e.getQuick(l) * (p + r)); - double dl1 = d.getQuick(l + 1); - double h = g - d.getQuick(l); - for (int i = l + 2; i < n; i++) { - d.setQuick(i, d.getQuick(i) - h); - } - f += h; - - // Implicit QL transformation. - - p = d.getQuick(m); - double c = 1.0; - double c2 = c; - double c3 = c; - double el1 = e.getQuick(l + 1); - double s = 0.0; - double s2 = 0.0; - for (int i = m - 1; i >= l; i--) { - c3 = c2; - c2 = c; - s2 = s; - g = c * e.getQuick(i); - h = c * p; - r = Math.hypot(p, e.getQuick(i)); - e.setQuick(i + 1, s * r); - s = e.getQuick(i) / r; - c = p / r; - p = c * d.getQuick(i) - s * g; - d.setQuick(i + 1, h + s * (c * g + s * d.getQuick(i))); - - // Accumulate transformation. - - for (int k = 0; k < n; k++) { - h = v.getQuick(k, i + 1); - v.setQuick(k, i + 1, s * v.getQuick(k, i) + c * h); - v.setQuick(k, i, c * v.getQuick(k, i) - s * h); - } - } - p = -s * s2 * c3 * el1 * e.getQuick(l) / dl1; - e.setQuick(l, s * p); - d.setQuick(l, c * p); - - // Check for convergence. - - } while (Math.abs(e.getQuick(l)) > eps * tst1); - } - d.setQuick(l, d.getQuick(l) + f); - e.setQuick(l, 0.0); - } - - // Sort eigenvalues and corresponding vectors. - - for (int i = 0; i < n - 1; i++) { - int k = i; - double p = d.getQuick(i); - for (int j = i + 1; j < n; j++) { - if (d.getQuick(j) > p) { - k = j; - p = d.getQuick(j); - } - } - if (k != i) { - d.setQuick(k, d.getQuick(i)); - d.setQuick(i, p); - for (int j = 0; j < n; j++) { - p = v.getQuick(j, i); - v.setQuick(j, i, v.getQuick(j, k)); - v.setQuick(j, k, p); - } - } - } - } - - // Nonsymmetric reduction to Hessenberg form. - private Matrix orthes(Matrix x) { - // Working storage for nonsymmetric algorithm. - Vector ort = new DenseVector(n); - Matrix hessenBerg = new DenseMatrix(n, n).assign(x); - - // This is derived from the Algol procedures orthes and ortran, - // by Martin and Wilkinson, Handbook for Auto. Comp., - // Vol.ii-Linear Algebra, and the corresponding - // Fortran subroutines in EISPACK. - - int low = 0; - int high = n - 1; - - for (int m = low + 1; m <= high - 1; m++) { - - // Scale column. - - Vector hColumn = hessenBerg.viewColumn(m - 1).viewPart(m, high - m + 1); - double scale = hColumn.norm(1); - - if (scale != 0.0) { - // Compute Householder transformation. - - ort.viewPart(m, high - m + 1).assign(hColumn, Functions.plusMult(1 / scale)); - double h = ort.viewPart(m, high - m + 1).getLengthSquared(); - - double g = Math.sqrt(h); - if (ort.getQuick(m) > 0) { - g = -g; - } - h -= ort.getQuick(m) * g; - ort.setQuick(m, ort.getQuick(m) - g); - - // Apply Householder similarity transformation - // H = (I-u*u'/h)*H*(I-u*u')/h) - - Vector ortPiece = ort.viewPart(m, high - m + 1); - for (int j = m; j < n; j++) { - double f = ortPiece.dot(hessenBerg.viewColumn(j).viewPart(m, high - m + 1)) / h; - hessenBerg.viewColumn(j).viewPart(m, high - m + 1).assign(ortPiece, Functions.plusMult(-f)); - } - - for (int i = 0; i <= high; i++) { - double f = ortPiece.dot(hessenBerg.viewRow(i).viewPart(m, high - m + 1)) / h; - hessenBerg.viewRow(i).viewPart(m, high - m + 1).assign(ortPiece, Functions.plusMult(-f)); - } - ort.setQuick(m, scale * ort.getQuick(m)); - hessenBerg.setQuick(m, m - 1, scale * g); - } - } - - // Accumulate transformations (Algol's ortran). - - v.assign(0); - v.viewDiagonal().assign(1); - - for (int m = high - 1; m >= low + 1; m--) { - if (hessenBerg.getQuick(m, m - 1) != 0.0) { - ort.viewPart(m + 1, high - m).assign(hessenBerg.viewColumn(m - 1).viewPart(m + 1, high - m)); - for (int j = m; j <= high; j++) { - double g = ort.viewPart(m, high - m + 1).dot(v.viewColumn(j).viewPart(m, high - m + 1)); - // Double division avoids possible underflow - g = g / ort.getQuick(m) / hessenBerg.getQuick(m, m - 1); - v.viewColumn(j).viewPart(m, high - m + 1).assign(ort.viewPart(m, high - m + 1), Functions.plusMult(g)); - } - } - } - return hessenBerg; - } - - - // Complex scalar division. - private double cdivr; - private double cdivi; - - private void cdiv(double xr, double xi, double yr, double yi) { - double r; - double d; - if (Math.abs(yr) > Math.abs(yi)) { - r = yi / yr; - d = yr + r * yi; - cdivr = (xr + r * xi) / d; - cdivi = (xi - r * xr) / d; - } else { - r = yr / yi; - d = yi + r * yr; - cdivr = (r * xr + xi) / d; - cdivi = (r * xi - xr) / d; - } - } - - - // Nonsymmetric reduction from Hessenberg to real Schur form. - - private void hqr2(Matrix h) { - - // This is derived from the Algol procedure hqr2, - // by Martin and Wilkinson, Handbook for Auto. Comp., - // Vol.ii-Linear Algebra, and the corresponding - // Fortran subroutine in EISPACK. - - // Initialize - - int nn = this.n; - int n = nn - 1; - int low = 0; - int high = nn - 1; - double eps = Math.pow(2.0, -52.0); - double exshift = 0.0; - double p = 0; - double q = 0; - double r = 0; - double s = 0; - double z = 0; - double w; - double x; - double y; - - // Store roots isolated by balanc and compute matrix norm - - double norm = h.aggregate(Functions.PLUS, Functions.ABS); - - // Outer loop over eigenvalue index - - int iter = 0; - while (n >= low) { - - // Look for single small sub-diagonal element - - int l = n; - while (l > low) { - s = Math.abs(h.getQuick(l - 1, l - 1)) + Math.abs(h.getQuick(l, l)); - if (s == 0.0) { - s = norm; - } - if (Math.abs(h.getQuick(l, l - 1)) < eps * s) { - break; - } - l--; - } - - // Check for convergence - - if (l == n) { - // One root found - h.setQuick(n, n, h.getQuick(n, n) + exshift); - d.setQuick(n, h.getQuick(n, n)); - e.setQuick(n, 0.0); - n--; - iter = 0; - - - } else if (l == n - 1) { - // Two roots found - w = h.getQuick(n, n - 1) * h.getQuick(n - 1, n); - p = (h.getQuick(n - 1, n - 1) - h.getQuick(n, n)) / 2.0; - q = p * p + w; - z = Math.sqrt(Math.abs(q)); - h.setQuick(n, n, h.getQuick(n, n) + exshift); - h.setQuick(n - 1, n - 1, h.getQuick(n - 1, n - 1) + exshift); - x = h.getQuick(n, n); - - // Real pair - if (q >= 0) { - if (p >= 0) { - z = p + z; - } else { - z = p - z; - } - d.setQuick(n - 1, x + z); - d.setQuick(n, d.getQuick(n - 1)); - if (z != 0.0) { - d.setQuick(n, x - w / z); - } - e.setQuick(n - 1, 0.0); - e.setQuick(n, 0.0); - x = h.getQuick(n, n - 1); - s = Math.abs(x) + Math.abs(z); - p = x / s; - q = z / s; - r = Math.sqrt(p * p + q * q); - p /= r; - q /= r; - - // Row modification - - for (int j = n - 1; j < nn; j++) { - z = h.getQuick(n - 1, j); - h.setQuick(n - 1, j, q * z + p * h.getQuick(n, j)); - h.setQuick(n, j, q * h.getQuick(n, j) - p * z); - } - - // Column modification - - for (int i = 0; i <= n; i++) { - z = h.getQuick(i, n - 1); - h.setQuick(i, n - 1, q * z + p * h.getQuick(i, n)); - h.setQuick(i, n, q * h.getQuick(i, n) - p * z); - } - - // Accumulate transformations - - for (int i = low; i <= high; i++) { - z = v.getQuick(i, n - 1); - v.setQuick(i, n - 1, q * z + p * v.getQuick(i, n)); - v.setQuick(i, n, q * v.getQuick(i, n) - p * z); - } - - // Complex pair - - } else { - d.setQuick(n - 1, x + p); - d.setQuick(n, x + p); - e.setQuick(n - 1, z); - e.setQuick(n, -z); - } - n -= 2; - iter = 0; - - // No convergence yet - - } else { - - // Form shift - - x = h.getQuick(n, n); - y = 0.0; - w = 0.0; - if (l < n) { - y = h.getQuick(n - 1, n - 1); - w = h.getQuick(n, n - 1) * h.getQuick(n - 1, n); - } - - // Wilkinson's original ad hoc shift - - if (iter == 10) { - exshift += x; - for (int i = low; i <= n; i++) { - h.setQuick(i, i, x); - } - s = Math.abs(h.getQuick(n, n - 1)) + Math.abs(h.getQuick(n - 1, n - 2)); - x = y = 0.75 * s; - w = -0.4375 * s * s; - } - - // MATLAB's new ad hoc shift - - if (iter == 30) { - s = (y - x) / 2.0; - s = s * s + w; - if (s > 0) { - s = Math.sqrt(s); - if (y < x) { - s = -s; - } - s = x - w / ((y - x) / 2.0 + s); - for (int i = low; i <= n; i++) { - h.setQuick(i, i, h.getQuick(i, i) - s); - } - exshift += s; - x = y = w = 0.964; - } - } - - iter++; // (Could check iteration count here.) - - // Look for two consecutive small sub-diagonal elements - - int m = n - 2; - while (m >= l) { - z = h.getQuick(m, m); - r = x - z; - s = y - z; - p = (r * s - w) / h.getQuick(m + 1, m) + h.getQuick(m, m + 1); - q = h.getQuick(m + 1, m + 1) - z - r - s; - r = h.getQuick(m + 2, m + 1); - s = Math.abs(p) + Math.abs(q) + Math.abs(r); - p /= s; - q /= s; - r /= s; - if (m == l) { - break; - } - double hmag = Math.abs(h.getQuick(m - 1, m - 1)) + Math.abs(h.getQuick(m + 1, m + 1)); - double threshold = eps * Math.abs(p) * (Math.abs(z) + hmag); - if (Math.abs(h.getQuick(m, m - 1)) * (Math.abs(q) + Math.abs(r)) < threshold) { - break; - } - m--; - } - - for (int i = m + 2; i <= n; i++) { - h.setQuick(i, i - 2, 0.0); - if (i > m + 2) { - h.setQuick(i, i - 3, 0.0); - } - } - - // Double QR step involving rows l:n and columns m:n - - for (int k = m; k <= n - 1; k++) { - boolean notlast = k != n - 1; - if (k != m) { - p = h.getQuick(k, k - 1); - q = h.getQuick(k + 1, k - 1); - r = notlast ? h.getQuick(k + 2, k - 1) : 0.0; - x = Math.abs(p) + Math.abs(q) + Math.abs(r); - if (x != 0.0) { - p /= x; - q /= x; - r /= x; - } - } - if (x == 0.0) { - break; - } - s = Math.sqrt(p * p + q * q + r * r); - if (p < 0) { - s = -s; - } - if (s != 0) { - if (k != m) { - h.setQuick(k, k - 1, -s * x); - } else if (l != m) { - h.setQuick(k, k - 1, -h.getQuick(k, k - 1)); - } - p += s; - x = p / s; - y = q / s; - z = r / s; - q /= p; - r /= p; - - // Row modification - - for (int j = k; j < nn; j++) { - p = h.getQuick(k, j) + q * h.getQuick(k + 1, j); - if (notlast) { - p += r * h.getQuick(k + 2, j); - h.setQuick(k + 2, j, h.getQuick(k + 2, j) - p * z); - } - h.setQuick(k, j, h.getQuick(k, j) - p * x); - h.setQuick(k + 1, j, h.getQuick(k + 1, j) - p * y); - } - - // Column modification - - for (int i = 0; i <= Math.min(n, k + 3); i++) { - p = x * h.getQuick(i, k) + y * h.getQuick(i, k + 1); - if (notlast) { - p += z * h.getQuick(i, k + 2); - h.setQuick(i, k + 2, h.getQuick(i, k + 2) - p * r); - } - h.setQuick(i, k, h.getQuick(i, k) - p); - h.setQuick(i, k + 1, h.getQuick(i, k + 1) - p * q); - } - - // Accumulate transformations - - for (int i = low; i <= high; i++) { - p = x * v.getQuick(i, k) + y * v.getQuick(i, k + 1); - if (notlast) { - p += z * v.getQuick(i, k + 2); - v.setQuick(i, k + 2, v.getQuick(i, k + 2) - p * r); - } - v.setQuick(i, k, v.getQuick(i, k) - p); - v.setQuick(i, k + 1, v.getQuick(i, k + 1) - p * q); - } - } // (s != 0) - } // k loop - } // check convergence - } // while (n >= low) - - // Backsubstitute to find vectors of upper triangular form - - if (norm == 0.0) { - return; - } - - for (n = nn - 1; n >= 0; n--) { - p = d.getQuick(n); - q = e.getQuick(n); - - // Real vector - - double t; - if (q == 0) { - int l = n; - h.setQuick(n, n, 1.0); - for (int i = n - 1; i >= 0; i--) { - w = h.getQuick(i, i) - p; - r = 0.0; - for (int j = l; j <= n; j++) { - r += h.getQuick(i, j) * h.getQuick(j, n); - } - if (e.getQuick(i) < 0.0) { - z = w; - s = r; - } else { - l = i; - if (e.getQuick(i) == 0.0) { - if (w == 0.0) { - h.setQuick(i, n, -r / (eps * norm)); - } else { - h.setQuick(i, n, -r / w); - } - - // Solve real equations - - } else { - x = h.getQuick(i, i + 1); - y = h.getQuick(i + 1, i); - q = (d.getQuick(i) - p) * (d.getQuick(i) - p) + e.getQuick(i) * e.getQuick(i); - t = (x * s - z * r) / q; - h.setQuick(i, n, t); - if (Math.abs(x) > Math.abs(z)) { - h.setQuick(i + 1, n, (-r - w * t) / x); - } else { - h.setQuick(i + 1, n, (-s - y * t) / z); - } - } - - // Overflow control - - t = Math.abs(h.getQuick(i, n)); - if (eps * t * t > 1) { - for (int j = i; j <= n; j++) { - h.setQuick(j, n, h.getQuick(j, n) / t); - } - } - } - } - - // Complex vector - - } else if (q < 0) { - int l = n - 1; - - // Last vector component imaginary so matrix is triangular - - if (Math.abs(h.getQuick(n, n - 1)) > Math.abs(h.getQuick(n - 1, n))) { - h.setQuick(n - 1, n - 1, q / h.getQuick(n, n - 1)); - h.setQuick(n - 1, n, -(h.getQuick(n, n) - p) / h.getQuick(n, n - 1)); - } else { - cdiv(0.0, -h.getQuick(n - 1, n), h.getQuick(n - 1, n - 1) - p, q); - h.setQuick(n - 1, n - 1, cdivr); - h.setQuick(n - 1, n, cdivi); - } - h.setQuick(n, n - 1, 0.0); - h.setQuick(n, n, 1.0); - for (int i = n - 2; i >= 0; i--) { - double ra = 0.0; - double sa = 0.0; - for (int j = l; j <= n; j++) { - ra += h.getQuick(i, j) * h.getQuick(j, n - 1); - sa += h.getQuick(i, j) * h.getQuick(j, n); - } - w = h.getQuick(i, i) - p; - - if (e.getQuick(i) < 0.0) { - z = w; - r = ra; - s = sa; - } else { - l = i; - if (e.getQuick(i) == 0) { - cdiv(-ra, -sa, w, q); - h.setQuick(i, n - 1, cdivr); - h.setQuick(i, n, cdivi); - } else { - - // Solve complex equations - - x = h.getQuick(i, i + 1); - y = h.getQuick(i + 1, i); - double vr = (d.getQuick(i) - p) * (d.getQuick(i) - p) + e.getQuick(i) * e.getQuick(i) - q * q; - double vi = (d.getQuick(i) - p) * 2.0 * q; - if (vr == 0.0 && vi == 0.0) { - double hmag = Math.abs(x) + Math.abs(y); - vr = eps * norm * (Math.abs(w) + Math.abs(q) + hmag + Math.abs(z)); - } - cdiv(x * r - z * ra + q * sa, x * s - z * sa - q * ra, vr, vi); - h.setQuick(i, n - 1, cdivr); - h.setQuick(i, n, cdivi); - if (Math.abs(x) > (Math.abs(z) + Math.abs(q))) { - h.setQuick(i + 1, n - 1, (-ra - w * h.getQuick(i, n - 1) + q * h.getQuick(i, n)) / x); - h.setQuick(i + 1, n, (-sa - w * h.getQuick(i, n) - q * h.getQuick(i, n - 1)) / x); - } else { - cdiv(-r - y * h.getQuick(i, n - 1), -s - y * h.getQuick(i, n), z, q); - h.setQuick(i + 1, n - 1, cdivr); - h.setQuick(i + 1, n, cdivi); - } - } - - // Overflow control - - t = Math.max(Math.abs(h.getQuick(i, n - 1)), Math.abs(h.getQuick(i, n))); - if (eps * t * t > 1) { - for (int j = i; j <= n; j++) { - h.setQuick(j, n - 1, h.getQuick(j, n - 1) / t); - h.setQuick(j, n, h.getQuick(j, n) / t); - } - } - } - } - } - } - - // Vectors of isolated roots - - for (int i = 0; i < nn; i++) { - if (i < low || i > high) { - for (int j = i; j < nn; j++) { - v.setQuick(i, j, h.getQuick(i, j)); - } - } - } - - // Back transformation to get eigenvectors of original matrix - - for (int j = nn - 1; j >= low; j--) { - for (int i = low; i <= high; i++) { - z = 0.0; - for (int k = low; k <= Math.min(j, high); k++) { - z += v.getQuick(i, k) * h.getQuick(k, j); - } - v.setQuick(i, j, z); - } - } - } - - private static boolean isSymmetric(Matrix a) { - /* - Symmetry flag. - */ - int n = a.columnSize(); - - boolean isSymmetric = true; - for (int j = 0; (j < n) && isSymmetric; j++) { - for (int i = 0; (i < n) && isSymmetric; i++) { - isSymmetric = a.getQuick(i, j) == a.getQuick(j, i); - } - } - return isSymmetric; - } -} http://git-wip-us.apache.org/repos/asf/mahout/blob/49ad8cb4/core/src/main/java/org/apache/mahout/math/solver/JacobiConditioner.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/solver/JacobiConditioner.java b/core/src/main/java/org/apache/mahout/math/solver/JacobiConditioner.java deleted file mode 100644 index 7524564..0000000 --- a/core/src/main/java/org/apache/mahout/math/solver/JacobiConditioner.java +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.math.solver; - -import org.apache.mahout.math.DenseVector; -import org.apache.mahout.math.Matrix; -import org.apache.mahout.math.Vector; - -/** - * Implements the Jacobi preconditioner for a matrix A. This is defined as inv(diag(A)). - */ -public final class JacobiConditioner implements Preconditioner { - - private final DenseVector inverseDiagonal; - - public JacobiConditioner(Matrix a) { - if (a.numCols() != a.numRows()) { - throw new IllegalArgumentException("Matrix must be square."); - } - - inverseDiagonal = new DenseVector(a.numCols()); - for (int i = 0; i < a.numCols(); ++i) { - inverseDiagonal.setQuick(i, 1.0 / a.getQuick(i, i)); - } - } - - @Override - public Vector precondition(Vector v) { - return v.times(inverseDiagonal); - } - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/49ad8cb4/core/src/main/java/org/apache/mahout/math/solver/LSMR.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/solver/LSMR.java b/core/src/main/java/org/apache/mahout/math/solver/LSMR.java deleted file mode 100644 index 1f3e706..0000000 --- a/core/src/main/java/org/apache/mahout/math/solver/LSMR.java +++ /dev/null @@ -1,565 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.math.solver; - -import org.apache.mahout.math.DenseVector; -import org.apache.mahout.math.Matrix; -import org.apache.mahout.math.Vector; -import org.apache.mahout.math.function.Functions; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Solves sparse least-squares using the LSMR algorithm. - * <p/> - * LSMR solves the system of linear equations A * X = B. If the system is inconsistent, it solves - * the least-squares problem min ||b - Ax||_2. A is a rectangular matrix of dimension m-by-n, where - * all cases are allowed: m=n, m>n, or m<n. B is a vector of length m. The matrix A may be dense - * or sparse (usually sparse). - * <p/> - * Some additional configurable properties adjust the behavior of the algorithm. - * <p/> - * If you set lambda to a non-zero value then LSMR solves the regularized least-squares problem min - * ||(B) - ( A )X|| ||(0) (lambda*I) ||_2 where LAMBDA is a scalar. If LAMBDA is not set, - * the system is solved without regularization. - * <p/> - * You can also set aTolerance and bTolerance. These cause LSMR to iterate until a certain backward - * error estimate is smaller than some quantity depending on ATOL and BTOL. Let RES = B - A*X be - * the residual vector for the current approximate solution X. If A*X = B seems to be consistent, - * LSMR terminates when NORM(RES) <= ATOL*NORM(A)*NORM(X) + BTOL*NORM(B). Otherwise, LSMR terminates - * when NORM(A'*RES) <= ATOL*NORM(A)*NORM(RES). If both tolerances are 1.0e-6 (say), the final - * NORM(RES) should be accurate to about 6 digits. (The final X will usually have fewer correct - * digits, depending on cond(A) and the size of LAMBDA.) - * <p/> - * The default value for ATOL and BTOL is 1e-6. - * <p/> - * Ideally, they should be estimates of the relative error in the entries of A and B respectively. - * For example, if the entries of A have 7 correct digits, set ATOL = 1e-7. This prevents the - * algorithm from doing unnecessary work beyond the uncertainty of the input data. - * <p/> - * You can also set conditionLimit. In that case, LSMR terminates if an estimate of cond(A) exceeds - * conditionLimit. For compatible systems Ax = b, conditionLimit could be as large as 1.0e+12 (say). - * For least-squares problems, conditionLimit should be less than 1.0e+8. If conditionLimit is not - * set, the default value is 1e+8. Maximum precision can be obtained by setting aTolerance = - * bTolerance = conditionLimit = 0, but the number of iterations may then be excessive. - * <p/> - * Setting iterationLimit causes LSMR to terminate if the number of iterations reaches - * iterationLimit. The default is iterationLimit = min(m,n). For ill-conditioned systems, a - * larger value of ITNLIM may be needed. - * <p/> - * Setting localSize causes LSMR to run with rerorthogonalization on the last localSize v_k's. - * (v-vectors generated by Golub-Kahan bidiagonalization) If localSize is not set, LSMR runs without - * reorthogonalization. A localSize > max(n,m) performs reorthogonalization on all v_k's. - * Reorthgonalizing only u_k or both u_k and v_k are not an option here. Details are discussed in - * the SIAM paper. - * <p/> - * getTerminationReason() gives the reason for termination. ISTOP = 0 means X=0 is a solution. = 1 - * means X is an approximate solution to A*X = B, according to ATOL and BTOL. = 2 means X - * approximately solves the least-squares problem according to ATOL. = 3 means COND(A) seems to be - * greater than CONLIM. = 4 is the same as 1 with ATOL = BTOL = EPS. = 5 is the same as 2 with ATOL - * = EPS. = 6 is the same as 3 with CONLIM = 1/EPS. = 7 means ITN reached ITNLIM before the other - * stopping conditions were satisfied. - * <p/> - * getIterationCount() gives ITN = the number of LSMR iterations. - * <p/> - * getResidualNorm() gives an estimate of the residual norm: NORMR = norm(B-A*X). - * <p/> - * getNormalEquationResidual() gives an estimate of the residual for the normal equation: NORMAR = - * NORM(A'*(B-A*X)). - * <p/> - * getANorm() gives an estimate of the Frobenius norm of A. - * <p/> - * getCondition() gives an estimate of the condition number of A. - * <p/> - * getXNorm() gives an estimate of NORM(X). - * <p/> - * LSMR uses an iterative method. For further information, see D. C.-L. Fong and M. A. Saunders - * LSMR: An iterative algorithm for least-square problems Draft of 03 Apr 2010, to be submitted to - * SISC. - * <p/> - * David Chin-lung Fong [email protected] Institute for Computational and Mathematical - * Engineering Stanford University - * <p/> - * Michael Saunders [email protected] Systems Optimization Laboratory Dept of - * MS&E, Stanford University. ----------------------------------------------------------------------- - */ -public final class LSMR { - - private static final Logger log = LoggerFactory.getLogger(LSMR.class); - - private final double lambda; - private int localSize; - private int iterationLimit; - private double conditionLimit; - private double bTolerance; - private double aTolerance; - private int localPointer; - private Vector[] localV; - private double residualNorm; - private double normalEquationResidual; - private double xNorm; - private int iteration; - private double normA; - private double condA; - - public int getIterationCount() { - return iteration; - } - - public double getResidualNorm() { - return residualNorm; - } - - public double getNormalEquationResidual() { - return normalEquationResidual; - } - - public double getANorm() { - return normA; - } - - public double getCondition() { - return condA; - } - - public double getXNorm() { - return xNorm; - } - - /** - * LSMR uses an iterative method to solve a linear system. For further information, see D. C.-L. - * Fong and M. A. Saunders LSMR: An iterative algorithm for least-square problems Draft of 03 Apr - * 2010, to be submitted to SISC. - * <p/> - * 08 Dec 2009: First release version of LSMR. 09 Apr 2010: Updated documentation and default - * parameters. 14 Apr 2010: Updated documentation. 03 Jun 2010: LSMR with local - * reorthogonalization (full reorthogonalization is also implemented) - * <p/> - * David Chin-lung Fong [email protected] Institute for Computational and - * Mathematical Engineering Stanford University - * <p/> - * Michael Saunders [email protected] Systems Optimization Laboratory Dept of - * MS&E, Stanford University. ----------------------------------------------------------------------- - */ - - public LSMR() { - // Set default parameters. - lambda = 0; - aTolerance = 1.0e-6; - bTolerance = 1.0e-6; - conditionLimit = 1.0e8; - iterationLimit = -1; - localSize = 0; - } - - public Vector solve(Matrix A, Vector b) { - /* - % Initialize. - - - hdg1 = ' itn x(1) norm r norm A''r'; - hdg2 = ' compatible LS norm A cond A'; - pfreq = 20; % print frequency (for repeating the heading) - pcount = 0; % print counter - - % Determine dimensions m and n, and - % form the first vectors u and v. - % These satisfy beta*u = b, alpha*v = A'u. - */ - log.debug(" itn x(1) norm r norm A'r"); - log.debug(" compatible LS norm A cond A"); - - Matrix transposedA = A.transpose(); - Vector u = b; - - double beta = u.norm(2); - if (beta > 0) { - u = u.divide(beta); - } - - Vector v = transposedA.times(u); - int m = A.numRows(); - int n = A.numCols(); - - int minDim = Math.min(m, n); - if (iterationLimit == -1) { - iterationLimit = minDim; - } - - if (log.isDebugEnabled()) { - log.debug("LSMR - Least-squares solution of Ax = b, based on Matlab Version 1.02, 14 Apr 2010, " - + "Mahout version {}", getClass().getPackage().getImplementationVersion()); - log.debug(String.format("The matrix A has %d rows and %d cols, lambda = %.4g, atol = %g, btol = %g", - m, n, lambda, aTolerance, bTolerance)); - } - - double alpha = v.norm(2); - if (alpha > 0) { - v.assign(Functions.div(alpha)); - } - - - // Initialization for local reorthogonalization - localPointer = 0; - - // Preallocate storage for storing the last few v_k. Since with - // orthogonal v_k's, Krylov subspace method would converge in not - // more iterations than the number of singular values, more - // space is not necessary. - localV = new Vector[Math.min(localSize, minDim)]; - boolean localOrtho = false; - if (localSize > 0) { - localOrtho = true; - localV[0] = v; - } - - - // Initialize variables for 1st iteration. - - iteration = 0; - double zetabar = alpha * beta; - double alphabar = alpha; - - Vector h = v; - Vector hbar = zeros(n); - Vector x = zeros(n); - - // Initialize variables for estimation of ||r||. - - double betadd = beta; - - // Initialize variables for estimation of ||A|| and cond(A) - - double aNorm = alpha * alpha; - - // Items for use in stopping rules. - double normb = beta; - - double ctol = 0; - if (conditionLimit > 0) { - ctol = 1 / conditionLimit; - } - residualNorm = beta; - - // Exit if b=0 or A'b = 0. - - normalEquationResidual = alpha * beta; - if (normalEquationResidual == 0) { - return x; - } - - // Heading for iteration log. - - - if (log.isDebugEnabled()) { - double test2 = alpha / beta; -// log.debug('{} {}', hdg1, hdg2); - log.debug("{} {}", iteration, x.get(0)); - log.debug("{} {}", residualNorm, normalEquationResidual); - double test1 = 1; - log.debug("{} {}", test1, test2); - } - - - //------------------------------------------------------------------ - // Main iteration loop. - //------------------------------------------------------------------ - double rho = 1; - double rhobar = 1; - double cbar = 1; - double sbar = 0; - double betad = 0; - double rhodold = 1; - double tautildeold = 0; - double thetatilde = 0; - double zeta = 0; - double d = 0; - double maxrbar = 0; - double minrbar = 1.0e+100; - StopCode stop = StopCode.CONTINUE; - while (iteration <= iterationLimit && stop == StopCode.CONTINUE) { - - iteration++; - - // Perform the next step of the bidiagonalization to obtain the - // next beta, u, alpha, v. These satisfy the relations - // beta*u = A*v - alpha*u, - // alpha*v = A'*u - beta*v. - - u = A.times(v).minus(u.times(alpha)); - beta = u.norm(2); - if (beta > 0) { - u.assign(Functions.div(beta)); - - // store data for local-reorthogonalization of V - if (localOrtho) { - localVEnqueue(v); - } - v = transposedA.times(u).minus(v.times(beta)); - // local-reorthogonalization of V - if (localOrtho) { - v = localVOrtho(v); - } - alpha = v.norm(2); - if (alpha > 0) { - v.assign(Functions.div(alpha)); - } - } - - // At this point, beta = beta_{k+1}, alpha = alpha_{k+1}. - - // Construct rotation Qhat_{k,2k+1}. - - double alphahat = Math.hypot(alphabar, lambda); - double chat = alphabar / alphahat; - double shat = lambda / alphahat; - - // Use a plane rotation (Q_i) to turn B_i to R_i - - double rhoold = rho; - rho = Math.hypot(alphahat, beta); - double c = alphahat / rho; - double s = beta / rho; - double thetanew = s * alpha; - alphabar = c * alpha; - - // Use a plane rotation (Qbar_i) to turn R_i^T to R_i^bar - - double rhobarold = rhobar; - double zetaold = zeta; - double thetabar = sbar * rho; - double rhotemp = cbar * rho; - rhobar = Math.hypot(cbar * rho, thetanew); - cbar = cbar * rho / rhobar; - sbar = thetanew / rhobar; - zeta = cbar * zetabar; - zetabar = -sbar * zetabar; - - - // Update h, h_hat, x. - - hbar = h.minus(hbar.times(thetabar * rho / (rhoold * rhobarold))); - - x.assign(hbar.times(zeta / (rho * rhobar)), Functions.PLUS); - h = v.minus(h.times(thetanew / rho)); - - // Estimate of ||r||. - - // Apply rotation Qhat_{k,2k+1}. - double betaacute = chat * betadd; - double betacheck = -shat * betadd; - - // Apply rotation Q_{k,k+1}. - double betahat = c * betaacute; - betadd = -s * betaacute; - - // Apply rotation Qtilde_{k-1}. - // betad = betad_{k-1} here. - - double thetatildeold = thetatilde; - double rhotildeold = Math.hypot(rhodold, thetabar); - double ctildeold = rhodold / rhotildeold; - double stildeold = thetabar / rhotildeold; - thetatilde = stildeold * rhobar; - rhodold = ctildeold * rhobar; - betad = -stildeold * betad + ctildeold * betahat; - - // betad = betad_k here. - // rhodold = rhod_k here. - - tautildeold = (zetaold - thetatildeold * tautildeold) / rhotildeold; - double taud = (zeta - thetatilde * tautildeold) / rhodold; - d += betacheck * betacheck; - residualNorm = Math.sqrt(d + (betad - taud) * (betad - taud) + betadd * betadd); - - // Estimate ||A||. - aNorm += beta * beta; - normA = Math.sqrt(aNorm); - aNorm += alpha * alpha; - - // Estimate cond(A). - maxrbar = Math.max(maxrbar, rhobarold); - if (iteration > 1) { - minrbar = Math.min(minrbar, rhobarold); - } - condA = Math.max(maxrbar, rhotemp) / Math.min(minrbar, rhotemp); - - // Test for convergence. - - // Compute norms for convergence testing. - normalEquationResidual = Math.abs(zetabar); - xNorm = x.norm(2); - - // Now use these norms to estimate certain other quantities, - // some of which will be small near a solution. - - double test1 = residualNorm / normb; - double test2 = normalEquationResidual / (normA * residualNorm); - double test3 = 1 / condA; - double t1 = test1 / (1 + normA * xNorm / normb); - double rtol = bTolerance + aTolerance * normA * xNorm / normb; - - // The following tests guard against extremely small values of - // atol, btol or ctol. (The user may have set any or all of - // the parameters atol, btol, conlim to 0.) - // The effect is equivalent to the normAl tests using - // atol = eps, btol = eps, conlim = 1/eps. - - if (iteration > iterationLimit) { - stop = StopCode.ITERATION_LIMIT; - } - if (1 + test3 <= 1) { - stop = StopCode.CONDITION_MACHINE_TOLERANCE; - } - if (1 + test2 <= 1) { - stop = StopCode.LEAST_SQUARE_CONVERGED_MACHINE_TOLERANCE; - } - if (1 + t1 <= 1) { - stop = StopCode.CONVERGED_MACHINE_TOLERANCE; - } - - // Allow for tolerances set by the user. - - if (test3 <= ctol) { - stop = StopCode.CONDITION; - } - if (test2 <= aTolerance) { - stop = StopCode.CONVERGED; - } - if (test1 <= rtol) { - stop = StopCode.TRIVIAL; - } - - // See if it is time to print something. - if (log.isDebugEnabled()) { - if ((n <= 40) || (iteration <= 10) || (iteration >= iterationLimit - 10) || ((iteration % 10) == 0) - || (test3 <= 1.1 * ctol) || (test2 <= 1.1 * aTolerance) || (test1 <= 1.1 * rtol) - || (stop != StopCode.CONTINUE)) { - statusDump(x, normA, condA, test1, test2); - } - } - } // iteration loop - - // Print the stopping condition. - log.debug("Finished: {}", stop.getMessage()); - - return x; - /* - - - if show - fprintf('\n\nLSMR finished') - fprintf('\n%s', msg(istop+1,:)) - fprintf('\nistop =%8g normr =%8.1e' , istop, normr ) - fprintf(' normA =%8.1e normAr =%8.1e', normA, normAr) - fprintf('\nitn =%8g condA =%8.1e' , itn , condA ) - fprintf(' normx =%8.1e\n', normx) - end - */ - } - - private void statusDump(Vector x, double normA, double condA, double test1, double test2) { - log.debug("{} {}", residualNorm, normalEquationResidual); - log.debug("{} {}", iteration, x.get(0)); - log.debug("{} {}", test1, test2); - log.debug("{} {}", normA, condA); - } - - private static Vector zeros(int n) { - return new DenseVector(n); - } - - //----------------------------------------------------------------------- - // stores v into the circular buffer localV - //----------------------------------------------------------------------- - - private void localVEnqueue(Vector v) { - if (localV.length > 0) { - localV[localPointer] = v; - localPointer = (localPointer + 1) % localV.length; - } - } - - //----------------------------------------------------------------------- - // Perform local reorthogonalization of V - //----------------------------------------------------------------------- - - private Vector localVOrtho(Vector v) { - for (Vector old : localV) { - if (old != null) { - double x = v.dot(old); - v = v.minus(old.times(x)); - } - } - return v; - } - - private enum StopCode { - CONTINUE("Not done"), - TRIVIAL("The exact solution is x = 0"), - CONVERGED("Ax - b is small enough, given atol, btol"), - LEAST_SQUARE_CONVERGED("The least-squares solution is good enough, given atol"), - CONDITION("The estimate of cond(Abar) has exceeded condition limit"), - CONVERGED_MACHINE_TOLERANCE("Ax - b is small enough for this machine"), - LEAST_SQUARE_CONVERGED_MACHINE_TOLERANCE("The least-squares solution is good enough for this machine"), - CONDITION_MACHINE_TOLERANCE("Cond(Abar) seems to be too large for this machine"), - ITERATION_LIMIT("The iteration limit has been reached"); - - private final String message; - - StopCode(String message) { - this.message = message; - } - - public String getMessage() { - return message; - } - } - - public void setAtolerance(double aTolerance) { - this.aTolerance = aTolerance; - } - - public void setBtolerance(double bTolerance) { - this.bTolerance = bTolerance; - } - - public void setConditionLimit(double conditionLimit) { - this.conditionLimit = conditionLimit; - } - - public void setIterationLimit(int iterationLimit) { - this.iterationLimit = iterationLimit; - } - - public void setLocalSize(int localSize) { - this.localSize = localSize; - } - - public double getLambda() { - return lambda; - } - - public double getAtolerance() { - return aTolerance; - } - - public double getBtolerance() { - return bTolerance; - } -} http://git-wip-us.apache.org/repos/asf/mahout/blob/49ad8cb4/core/src/main/java/org/apache/mahout/math/solver/Preconditioner.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/solver/Preconditioner.java b/core/src/main/java/org/apache/mahout/math/solver/Preconditioner.java deleted file mode 100644 index 91528fc..0000000 --- a/core/src/main/java/org/apache/mahout/math/solver/Preconditioner.java +++ /dev/null @@ -1,36 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.math.solver; - -import org.apache.mahout.math.Vector; - -/** - * Interface for defining preconditioners used for improving the performance and/or stability of linear - * system solvers. - */ -public interface Preconditioner { - - /** - * Preconditions the specified vector. - * - * @param v The vector to precondition. - * @return The preconditioned vector. - */ - Vector precondition(Vector v); - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/49ad8cb4/core/src/main/java/org/apache/mahout/math/stats/LogLikelihood.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/stats/LogLikelihood.java b/core/src/main/java/org/apache/mahout/math/stats/LogLikelihood.java deleted file mode 100644 index d2c8434..0000000 --- a/core/src/main/java/org/apache/mahout/math/stats/LogLikelihood.java +++ /dev/null @@ -1,220 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.math.stats; - -import com.google.common.base.Preconditions; -import com.google.common.collect.Multiset; -import com.google.common.collect.Ordering; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.PriorityQueue; -import java.util.Queue; - -/** - * Utility methods for working with log-likelihood - */ -public final class LogLikelihood { - - private LogLikelihood() { - } - - /** - * Calculates the unnormalized Shannon entropy. This is - * - * -sum x_i log x_i / N = -N sum x_i/N log x_i/N - * - * where N = sum x_i - * - * If the x's sum to 1, then this is the same as the normal - * expression. Leaving this un-normalized makes working with - * counts and computing the LLR easier. - * - * @return The entropy value for the elements - */ - public static double entropy(long... elements) { - long sum = 0; - double result = 0.0; - for (long element : elements) { - Preconditions.checkArgument(element >= 0); - result += xLogX(element); - sum += element; - } - return xLogX(sum) - result; - } - - private static double xLogX(long x) { - return x == 0 ? 0.0 : x * Math.log(x); - } - - /** - * Merely an optimization for the common two argument case of {@link #entropy(long...)} - * @see #logLikelihoodRatio(long, long, long, long) - */ - private static double entropy(long a, long b) { - return xLogX(a + b) - xLogX(a) - xLogX(b); - } - - /** - * Merely an optimization for the common four argument case of {@link #entropy(long...)} - * @see #logLikelihoodRatio(long, long, long, long) - */ - private static double entropy(long a, long b, long c, long d) { - return xLogX(a + b + c + d) - xLogX(a) - xLogX(b) - xLogX(c) - xLogX(d); - } - - /** - * Calculates the Raw Log-likelihood ratio for two events, call them A and B. Then we have: - * <p/> - * <table border="1" cellpadding="5" cellspacing="0"> - * <tbody><tr><td> </td><td>Event A</td><td>Everything but A</td></tr> - * <tr><td>Event B</td><td>A and B together (k_11)</td><td>B, but not A (k_12)</td></tr> - * <tr><td>Everything but B</td><td>A without B (k_21)</td><td>Neither A nor B (k_22)</td></tr></tbody> - * </table> - * - * @param k11 The number of times the two events occurred together - * @param k12 The number of times the second event occurred WITHOUT the first event - * @param k21 The number of times the first event occurred WITHOUT the second event - * @param k22 The number of times something else occurred (i.e. was neither of these events - * @return The raw log-likelihood ratio - * - * <p/> - * Credit to http://tdunning.blogspot.com/2008/03/surprise-and-coincidence.html for the table and the descriptions. - */ - public static double logLikelihoodRatio(long k11, long k12, long k21, long k22) { - Preconditions.checkArgument(k11 >= 0 && k12 >= 0 && k21 >= 0 && k22 >= 0); - // note that we have counts here, not probabilities, and that the entropy is not normalized. - double rowEntropy = entropy(k11 + k12, k21 + k22); - double columnEntropy = entropy(k11 + k21, k12 + k22); - double matrixEntropy = entropy(k11, k12, k21, k22); - if (rowEntropy + columnEntropy < matrixEntropy) { - // round off error - return 0.0; - } - return 2.0 * (rowEntropy + columnEntropy - matrixEntropy); - } - - /** - * Calculates the root log-likelihood ratio for two events. - * See {@link #logLikelihoodRatio(long, long, long, long)}. - - * @param k11 The number of times the two events occurred together - * @param k12 The number of times the second event occurred WITHOUT the first event - * @param k21 The number of times the first event occurred WITHOUT the second event - * @param k22 The number of times something else occurred (i.e. was neither of these events - * @return The root log-likelihood ratio - * - * <p/> - * There is some more discussion here: http://s.apache.org/CGL - * - * And see the response to Wataru's comment here: - * http://tdunning.blogspot.com/2008/03/surprise-and-coincidence.html - */ - public static double rootLogLikelihoodRatio(long k11, long k12, long k21, long k22) { - double llr = logLikelihoodRatio(k11, k12, k21, k22); - double sqrt = Math.sqrt(llr); - if ((double) k11 / (k11 + k12) < (double) k21 / (k21 + k22)) { - sqrt = -sqrt; - } - return sqrt; - } - - /** - * Compares two sets of counts to see which items are interestingly over-represented in the first - * set. - * @param a The first counts. - * @param b The reference counts. - * @param maxReturn The maximum number of items to return. Use maxReturn >= a.elementSet.size() to return all - * scores above the threshold. - * @param threshold The minimum score for items to be returned. Use 0 to return all items more common - * in a than b. Use -Double.MAX_VALUE (not Double.MIN_VALUE !) to not use a threshold. - * @return A list of scored items with their scores. - */ - public static <T> List<ScoredItem<T>> compareFrequencies(Multiset<T> a, - Multiset<T> b, - int maxReturn, - double threshold) { - int totalA = a.size(); - int totalB = b.size(); - - Ordering<ScoredItem<T>> byScoreAscending = new Ordering<ScoredItem<T>>() { - @Override - public int compare(ScoredItem<T> tScoredItem, ScoredItem<T> tScoredItem1) { - return Double.compare(tScoredItem.score, tScoredItem1.score); - } - }; - Queue<ScoredItem<T>> best = new PriorityQueue<>(maxReturn + 1, byScoreAscending); - - for (T t : a.elementSet()) { - compareAndAdd(a, b, maxReturn, threshold, totalA, totalB, best, t); - } - - // if threshold >= 0 we only iterate through a because anything not there can't be as or more common than in b. - if (threshold < 0) { - for (T t : b.elementSet()) { - // only items missing from a need be scored - if (a.count(t) == 0) { - compareAndAdd(a, b, maxReturn, threshold, totalA, totalB, best, t); - } - } - } - - List<ScoredItem<T>> r = new ArrayList<>(best); - Collections.sort(r, byScoreAscending.reverse()); - return r; - } - - private static <T> void compareAndAdd(Multiset<T> a, - Multiset<T> b, - int maxReturn, - double threshold, - int totalA, - int totalB, - Queue<ScoredItem<T>> best, - T t) { - int kA = a.count(t); - int kB = b.count(t); - double score = rootLogLikelihoodRatio(kA, totalA - kA, kB, totalB - kB); - if (score >= threshold) { - ScoredItem<T> x = new ScoredItem<>(t, score); - best.add(x); - while (best.size() > maxReturn) { - best.poll(); - } - } - } - - public static final class ScoredItem<T> { - private final T item; - private final double score; - - public ScoredItem(T item, double score) { - this.item = item; - this.score = score; - } - - public double getScore() { - return score; - } - - public T getItem() { - return item; - } - } -}
