http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/random/Multinomial.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/random/Multinomial.java b/core/src/main/java/org/apache/mahout/math/random/Multinomial.java new file mode 100644 index 0000000..d79c32c --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/random/Multinomial.java @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.random; + +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Random; + +import com.google.common.base.Preconditions; +import com.google.common.collect.AbstractIterator; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Multiset; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.list.DoubleArrayList; + +/** + * Multinomial sampler that allows updates to element probabilities. The basic idea is that sampling is + * done by using a simple balanced tree. Probabilities are kept in the tree so that we can navigate to + * any leaf in log N time. Updates are simple because we can just propagate them upwards. + * <p/> + * In order to facilitate access by value, we maintain an additional map from value to tree node. + */ +public final class Multinomial<T> implements Sampler<T>, Iterable<T> { + // these lists use heap ordering. Thus, the root is at location 1, first level children at 2 and 3, second level + // at 4, 5 and 6, 7. + private final DoubleArrayList weight = new DoubleArrayList(); + private final List<T> values = Lists.newArrayList(); + private final Map<T, Integer> items = Maps.newHashMap(); + private Random rand = RandomUtils.getRandom(); + + public Multinomial() { + weight.add(0); + values.add(null); + } + + public Multinomial(Multiset<T> counts) { + this(); + Preconditions.checkArgument(!counts.isEmpty(), "Need some data to build sampler"); + rand = RandomUtils.getRandom(); + for (T t : counts.elementSet()) { + add(t, counts.count(t)); + } + } + + public Multinomial(Iterable<WeightedThing<T>> things) { + this(); + for (WeightedThing<T> thing : things) { + add(thing.getValue(), thing.getWeight()); + } + } + + public void add(T value, double w) { + Preconditions.checkNotNull(value); + Preconditions.checkArgument(!items.containsKey(value)); + + int n = this.weight.size(); + if (n == 1) { + weight.add(w); + values.add(value); + items.put(value, 1); + } else { + // parent comes down + weight.add(weight.get(n / 2)); + values.add(values.get(n / 2)); + items.put(values.get(n / 2), n); + n++; + + // new item goes in + items.put(value, n); + this.weight.add(w); + values.add(value); + + // parents get incremented all the way to the root + while (n > 1) { + n /= 2; + this.weight.set(n, this.weight.get(n) + w); + } + } + } + + public double getWeight(T value) { + if (items.containsKey(value)) { + return weight.get(items.get(value)); + } else { + return 0; + } + } + + public double getProbability(T value) { + if (items.containsKey(value)) { + return weight.get(items.get(value)) / weight.get(1); + } else { + return 0; + } + } + + public double getWeight() { + if (weight.size() > 1) { + return weight.get(1); + } else { + return 0; + } + } + + public void delete(T value) { + set(value, 0); + } + + public void set(T value, double newP) { + Preconditions.checkArgument(items.containsKey(value)); + int n = items.get(value); + if (newP <= 0) { + // this makes the iterator not see such an element even though we leave a phantom in the tree + // Leaving the phantom behind simplifies tree maintenance and testing, but isn't really necessary. + items.remove(value); + } + double oldP = weight.get(n); + while (n > 0) { + weight.set(n, weight.get(n) - oldP + newP); + n /= 2; + } + } + + @Override + public T sample() { + Preconditions.checkArgument(!weight.isEmpty()); + return sample(rand.nextDouble()); + } + + public T sample(double u) { + u *= weight.get(1); + + int n = 1; + while (2 * n < weight.size()) { + // children are at 2n and 2n+1 + double left = weight.get(2 * n); + if (u <= left) { + n = 2 * n; + } else { + u -= left; + n = 2 * n + 1; + } + } + return values.get(n); + } + + /** + * Exposed for testing only. Returns a list of the leaf weights. These are in an + * order such that probing just before and after the cumulative sum of these weights + * will touch every element of the tree twice and thus will make it possible to test + * every possible left/right decision in navigating the tree. + */ + List<Double> getWeights() { + List<Double> r = Lists.newArrayList(); + int i = Integer.highestOneBit(weight.size()); + while (i < weight.size()) { + r.add(weight.get(i)); + i++; + } + i /= 2; + while (i < Integer.highestOneBit(weight.size())) { + r.add(weight.get(i)); + i++; + } + return r; + } + + @Override + public Iterator<T> iterator() { + return new AbstractIterator<T>() { + Iterator<T> valuesIterator = Iterables.skip(values, 1).iterator(); + @Override + protected T computeNext() { + while (valuesIterator.hasNext()) { + T next = valuesIterator.next(); + if (items.containsKey(next)) { + return next; + } + } + return endOfData(); + } + }; + } +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/random/Normal.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/random/Normal.java b/core/src/main/java/org/apache/mahout/math/random/Normal.java new file mode 100644 index 0000000..c162f26 --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/random/Normal.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.random; + +import org.apache.mahout.common.RandomUtils; + +import java.util.Random; + +public final class Normal extends AbstractSamplerFunction { + private final Random rand = RandomUtils.getRandom(); + private double mean = 0; + private double sd = 1; + + public Normal() {} + + public Normal(double mean, double sd) { + this.mean = mean; + this.sd = sd; + } + + @Override + public Double sample() { + return rand.nextGaussian() * sd + mean; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/random/PoissonSampler.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/random/PoissonSampler.java b/core/src/main/java/org/apache/mahout/math/random/PoissonSampler.java new file mode 100644 index 0000000..e4e49f8 --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/random/PoissonSampler.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.random; + +import com.google.common.collect.Lists; +import org.apache.commons.math3.distribution.PoissonDistribution; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.common.RandomWrapper; + +import java.util.List; + +/** + * Samples from a Poisson distribution. Should probably not be used for lambda > 1000 or so. + */ +public final class PoissonSampler extends AbstractSamplerFunction { + + private double limit; + private Multinomial<Integer> partial; + private final RandomWrapper gen; + private final PoissonDistribution pd; + + public PoissonSampler(double lambda) { + limit = 1; + gen = RandomUtils.getRandom(); + pd = new PoissonDistribution(gen.getRandomGenerator(), + lambda, + PoissonDistribution.DEFAULT_EPSILON, + PoissonDistribution.DEFAULT_MAX_ITERATIONS); + } + + @Override + public Double sample() { + return sample(gen.nextDouble()); + } + + double sample(double u) { + if (u < limit) { + List<WeightedThing<Integer>> steps = Lists.newArrayList(); + limit = 1; + int i = 0; + while (u / 20 < limit) { + double pdf = pd.probability(i); + limit -= pdf; + steps.add(new WeightedThing<>(i, pdf)); + i++; + } + steps.add(new WeightedThing<>(steps.size(), limit)); + partial = new Multinomial<>(steps); + } + return partial.sample(u); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/random/Sampler.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/random/Sampler.java b/core/src/main/java/org/apache/mahout/math/random/Sampler.java new file mode 100644 index 0000000..51460fa --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/random/Sampler.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.random; + +/** + * Samples from a generic type. + */ +public interface Sampler<T> { + T sample(); +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/random/WeightedThing.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/random/WeightedThing.java b/core/src/main/java/org/apache/mahout/math/random/WeightedThing.java new file mode 100644 index 0000000..20f6df3 --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/random/WeightedThing.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.random; + +import com.google.common.base.Preconditions; +import org.apache.mahout.common.RandomUtils; + +/** + * Handy for creating multinomial distributions of things. + */ +public final class WeightedThing<T> implements Comparable<WeightedThing<T>> { + private double weight; + private final T value; + + public WeightedThing(T thing, double weight) { + this.value = Preconditions.checkNotNull(thing); + this.weight = weight; + } + + public WeightedThing(double weight) { + this.value = null; + this.weight = weight; + } + + public T getValue() { + return value; + } + + public double getWeight() { + return weight; + } + + public void setWeight(double weight) { + this.weight = weight; + } + + @Override + public int compareTo(WeightedThing<T> other) { + return Double.compare(this.weight, other.weight); + } + + @Override + public boolean equals(Object o) { + if (o instanceof WeightedThing) { + @SuppressWarnings("unchecked") + WeightedThing<T> other = (WeightedThing<T>) o; + return weight == other.weight && value.equals(other.value); + } + return false; + } + + @Override + public int hashCode() { + return 31 * RandomUtils.hashDouble(weight) + value.hashCode(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/set/AbstractSet.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/set/AbstractSet.java b/core/src/main/java/org/apache/mahout/math/set/AbstractSet.java new file mode 100644 index 0000000..7691420 --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/set/AbstractSet.java @@ -0,0 +1,188 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/* +Copyright 1999 CERN - European Organization for Nuclear Research. +Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose +is hereby granted without fee, provided that the above copyright notice appear in all copies and +that both that copyright notice and this permission notice appear in supporting documentation. +CERN makes no representations about the suitability of this software for any purpose. +It is provided "as is" without expressed or implied warranty. +*/ +package org.apache.mahout.math.set; + +import org.apache.mahout.math.PersistentObject; +import org.apache.mahout.math.map.PrimeFinder; + +public abstract class AbstractSet extends PersistentObject { + //public static boolean debug = false; // debug only + + /** The number of distinct associations in the map; its "size()". */ + protected int distinct; + + /** + * The table capacity c=table.length always satisfies the invariant <tt>c * minLoadFactor <= s <= c * + * maxLoadFactor</tt>, where s=size() is the number of associations currently contained. The term "c * minLoadFactor" + * is called the "lowWaterMark", "c * maxLoadFactor" is called the "highWaterMark". In other words, the table capacity + * (and proportionally the memory used by this class) oscillates within these constraints. The terms are precomputed + * and cached to avoid recalculating them each time put(..) or removeKey(...) is called. + */ + protected int lowWaterMark; + protected int highWaterMark; + + /** The minimum load factor for the hashtable. */ + protected double minLoadFactor; + + /** The maximum load factor for the hashtable. */ + protected double maxLoadFactor; + + // these are public access for unit tests. + public static final int DEFAULT_CAPACITY = 277; + public static final double DEFAULT_MIN_LOAD_FACTOR = 0.2; + public static final double DEFAULT_MAX_LOAD_FACTOR = 0.5; + + /** + * Chooses a new prime table capacity optimized for growing that (approximately) satisfies the invariant <tt>c * + * minLoadFactor <= size <= c * maxLoadFactor</tt> and has at least one FREE slot for the given size. + */ + protected int chooseGrowCapacity(int size, double minLoad, double maxLoad) { + return nextPrime(Math.max(size + 1, (int) ((4 * size / (3 * minLoad + maxLoad))))); + } + + /** + * Returns new high water mark threshold based on current capacity and maxLoadFactor. + * + * @return int the new threshold. + */ + protected int chooseHighWaterMark(int capacity, double maxLoad) { + return Math.min(capacity - 2, (int) (capacity * maxLoad)); //makes sure there is always at least one FREE slot + } + + /** + * Returns new low water mark threshold based on current capacity and minLoadFactor. + * + * @return int the new threshold. + */ + protected int chooseLowWaterMark(int capacity, double minLoad) { + return (int) (capacity * minLoad); + } + + /** + * Chooses a new prime table capacity neither favoring shrinking nor growing, that (approximately) satisfies the + * invariant <tt>c * minLoadFactor <= size <= c * maxLoadFactor</tt> and has at least one FREE slot for the given + * size. + */ + protected int chooseMeanCapacity(int size, double minLoad, double maxLoad) { + return nextPrime(Math.max(size + 1, (int) ((2 * size / (minLoad + maxLoad))))); + } + + /** + * Chooses a new prime table capacity optimized for shrinking that (approximately) satisfies the invariant <tt>c * + * minLoadFactor <= size <= c * maxLoadFactor</tt> and has at least one FREE slot for the given size. + */ + protected int chooseShrinkCapacity(int size, double minLoad, double maxLoad) { + return nextPrime(Math.max(size + 1, (int) ((4 * size / (minLoad + 3 * maxLoad))))); + } + + /** Removes all (key,value) associations from the receiver. */ + public abstract void clear(); + + /** + * Ensures that the receiver can hold at least the specified number of elements without needing to allocate new + * internal memory. If necessary, allocates new internal memory and increases the capacity of the receiver. <p> This + * method never need be called; it is for performance tuning only. Calling this method before <tt>put()</tt>ing a + * large number of associations boosts performance, because the receiver will grow only once instead of potentially + * many times. <p> <b>This default implementation does nothing.</b> Override this method if necessary. + * + * @param minCapacity the desired minimum capacity. + */ + public void ensureCapacity(int minCapacity) { + } + + /** + * Returns <tt>true</tt> if the receiver contains no (key,value) associations. + * + * @return <tt>true</tt> if the receiver contains no (key,value) associations. + */ + public boolean isEmpty() { + return distinct == 0; + } + + /** + * Returns a prime number which is <code>>= desiredCapacity</code> and very close to <code>desiredCapacity</code> + * (within 11% if <code>desiredCapacity >= 1000</code>). + * + * @param desiredCapacity the capacity desired by the user. + * @return the capacity which should be used for a hashtable. + */ + protected int nextPrime(int desiredCapacity) { + return PrimeFinder.nextPrime(desiredCapacity); + } + + /** + * Initializes the receiver. You will almost certainly need to override this method in subclasses to initialize the + * hash table. + * + * @param initialCapacity the initial capacity of the receiver. + * @param minLoadFactor the minLoadFactor of the receiver. + * @param maxLoadFactor the maxLoadFactor of the receiver. + * @throws IllegalArgumentException if <tt>initialCapacity < 0 || (minLoadFactor < 0.0 || minLoadFactor >= 1.0) || + * (maxLoadFactor <= 0.0 || maxLoadFactor >= 1.0) || (minLoadFactor >= + * maxLoadFactor)</tt>. + */ + protected void setUp(int initialCapacity, double minLoadFactor, double maxLoadFactor) { + if (initialCapacity < 0) { + throw new IllegalArgumentException("Initial Capacity must not be less than zero: " + initialCapacity); + } + if (minLoadFactor < 0.0 || minLoadFactor >= 1.0) { + throw new IllegalArgumentException("Illegal minLoadFactor: " + minLoadFactor); + } + if (maxLoadFactor <= 0.0 || maxLoadFactor >= 1.0) { + throw new IllegalArgumentException("Illegal maxLoadFactor: " + maxLoadFactor); + } + if (minLoadFactor >= maxLoadFactor) { + throw new IllegalArgumentException( + "Illegal minLoadFactor: " + minLoadFactor + " and maxLoadFactor: " + maxLoadFactor); + } + } + + /** + * Returns the number of (key,value) associations currently contained. + * + * @return the number of (key,value) associations currently contained. + */ + public int size() { + return distinct; + } + + /** + * Trims the capacity of the receiver to be the receiver's current size. Releases any superfluous internal memory. An + * application can use this operation to minimize the storage of the receiver. <p> This default implementation does + * nothing. Override this method if necessary. + */ + public void trimToSize() { + } + + protected static boolean equalsMindTheNull(Object a, Object b) { + if (a == null && b == null) { + return true; + } + if (a == null || b == null) { + return false; + } + return a.equals(b); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/set/HashUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/set/HashUtils.java b/core/src/main/java/org/apache/mahout/math/set/HashUtils.java new file mode 100644 index 0000000..f5dfeb0 --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/set/HashUtils.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.set; + +/** + * Computes hashes of primitive values. Providing these as statics allows the templated code + * to compute hashes of sets. + */ +public final class HashUtils { + + private HashUtils() { + } + + public static int hash(byte x) { + return x; + } + + public static int hash(short x) { + return x; + } + + public static int hash(char x) { + return x; + } + + public static int hash(int x) { + return x; + } + + public static int hash(float x) { + return Float.floatToIntBits(x) >>> 3 + Float.floatToIntBits((float) (Math.PI * x)); + } + + public static int hash(double x) { + return hash(17 * Double.doubleToLongBits(x)); + } + + public static int hash(long x) { + return (int) ((x * 11) >>> 32 ^ x); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/set/OpenHashSet.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/set/OpenHashSet.java b/core/src/main/java/org/apache/mahout/math/set/OpenHashSet.java new file mode 100644 index 0000000..285b5a5 --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/set/OpenHashSet.java @@ -0,0 +1,548 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.mahout.math.set; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import java.util.Set; + +import org.apache.mahout.math.MurmurHash; +import org.apache.mahout.math.function.ObjectProcedure; +import org.apache.mahout.math.map.PrimeFinder; + +/** + * Open hashing alternative to java.util.HashSet. + **/ +public class OpenHashSet<T> extends AbstractSet implements Set<T> { + protected static final byte FREE = 0; + protected static final byte FULL = 1; + protected static final byte REMOVED = 2; + protected static final char NO_KEY_VALUE = 0; + + /** The hash table keys. */ + private Object[] table; + + /** The state of each hash table entry (FREE, FULL, REMOVED). */ + private byte[] state; + + /** The number of table entries in state==FREE. */ + private int freeEntries; + + + /** Constructs an empty map with default capacity and default load factors. */ + public OpenHashSet() { + this(DEFAULT_CAPACITY); + } + + /** + * Constructs an empty map with the specified initial capacity and default load factors. + * + * @param initialCapacity the initial capacity of the map. + * @throws IllegalArgumentException if the initial capacity is less than zero. + */ + public OpenHashSet(int initialCapacity) { + this(initialCapacity, DEFAULT_MIN_LOAD_FACTOR, DEFAULT_MAX_LOAD_FACTOR); + } + + /** + * Constructs an empty map with the specified initial capacity and the specified minimum and maximum load factor. + * + * @param initialCapacity the initial capacity. + * @param minLoadFactor the minimum load factor. + * @param maxLoadFactor the maximum load factor. + * @throws IllegalArgumentException if <tt>initialCapacity < 0 || (minLoadFactor < 0.0 || minLoadFactor >= 1.0) || + * (maxLoadFactor <= 0.0 || maxLoadFactor >= 1.0) || (minLoadFactor >= + * maxLoadFactor)</tt>. + */ + public OpenHashSet(int initialCapacity, double minLoadFactor, double maxLoadFactor) { + setUp(initialCapacity, minLoadFactor, maxLoadFactor); + } + + /** Removes all values associations from the receiver. Implicitly calls <tt>trimToSize()</tt>. */ + @Override + public void clear() { + Arrays.fill(this.state, 0, state.length - 1, FREE); + distinct = 0; + freeEntries = table.length; // delta + trimToSize(); + } + + /** + * Returns a deep copy of the receiver. + * + * @return a deep copy of the receiver. + */ + @SuppressWarnings("unchecked") + @Override + public Object clone() { + OpenHashSet<T> copy = (OpenHashSet<T>) super.clone(); + copy.table = copy.table.clone(); + copy.state = copy.state.clone(); + return copy; + } + + /** + * Returns <tt>true</tt> if the receiver contains the specified key. + * + * @return <tt>true</tt> if the receiver contains the specified key. + */ + @Override + @SuppressWarnings("unchecked") + public boolean contains(Object key) { + return indexOfKey((T)key) >= 0; + } + + /** + * Ensures that the receiver can hold at least the specified number of associations without needing to allocate new + * internal memory. If necessary, allocates new internal memory and increases the capacity of the receiver. <p> This + * method never need be called; it is for performance tuning only. Calling this method before <tt>add()</tt>ing a + * large number of associations boosts performance, because the receiver will grow only once instead of potentially + * many times and hash collisions get less probable. + * + * @param minCapacity the desired minimum capacity. + */ + @Override + public void ensureCapacity(int minCapacity) { + if (table.length < minCapacity) { + int newCapacity = nextPrime(minCapacity); + rehash(newCapacity); + } + } + + /** + * Applies a procedure to each key of the receiver, if any. Note: Iterates over the keys in no particular order. + * Subclasses can define a particular order, for example, "sorted by key". All methods which <i>can</i> be expressed + * in terms of this method (most methods can) <i>must guarantee</i> to use the <i>same</i> order defined by this + * method, even if it is no particular order. This is necessary so that, for example, methods <tt>keys</tt> and + * <tt>values</tt> will yield association pairs, not two uncorrelated lists. + * + * @param procedure the procedure to be applied. Stops iteration if the procedure returns <tt>false</tt>, otherwise + * continues. + * @return <tt>false</tt> if the procedure stopped before all keys where iterated over, <tt>true</tt> otherwise. + */ + @SuppressWarnings("unchecked") + public boolean forEachKey(ObjectProcedure<T> procedure) { + for (int i = table.length; i-- > 0;) { + if (state[i] == FULL) { + if (!procedure.apply((T)table[i])) { + return false; + } + } + } + return true; + } + + /** + * @param key the key to be added to the receiver. + * @return the index where the key would need to be inserted, if it is not already contained. Returns -index-1 if the + * key is already contained at slot index. Therefore, if the returned index < 0, then it is already contained + * at slot -index-1. If the returned index >= 0, then it is NOT already contained and should be inserted at + * slot index. + */ + protected int indexOfInsertion(T key) { + Object[] tab = table; + byte[] stat = state; + int length = tab.length; + + int hash = key.hashCode() & 0x7FFFFFFF; + int i = hash % length; + int decrement = hash % (length - 2); // double hashing, see http://www.eece.unm.edu/faculty/heileman/hash/node4.html + //int decrement = (hash / length) % length; + if (decrement == 0) { + decrement = 1; + } + + // stop if we find a removed or free slot, or if we find the key itself + // do NOT skip over removed slots (yes, open addressing is like that...) + while (stat[i] == FULL && tab[i] != key) { + i -= decrement; + //hashCollisions++; + if (i < 0) { + i += length; + } + } + + if (stat[i] == REMOVED) { + // stop if we find a free slot, or if we find the key itself. + // do skip over removed slots (yes, open addressing is like that...) + // assertion: there is at least one FREE slot. + int j = i; + while (stat[i] != FREE && (stat[i] == REMOVED || tab[i] != key)) { + i -= decrement; + //hashCollisions++; + if (i < 0) { + i += length; + } + } + if (stat[i] == FREE) { + i = j; + } + } + + + if (stat[i] == FULL) { + // key already contained at slot i. + // return a negative number identifying the slot. + return -i - 1; + } + // not already contained, should be inserted at slot i. + // return a number >= 0 identifying the slot. + return i; + } + + /** + * @param key the key to be searched in the receiver. + * @return the index where the key is contained in the receiver, returns -1 if the key was not found. + */ + protected int indexOfKey(T key) { + Object[] tab = table; + byte[] stat = state; + int length = tab.length; + + int hash = key.hashCode() & 0x7FFFFFFF; + int i = hash % length; + int decrement = hash % (length - 2); // double hashing, see http://www.eece.unm.edu/faculty/heileman/hash/node4.html + //int decrement = (hash / length) % length; + if (decrement == 0) { + decrement = 1; + } + + // stop if we find a free slot, or if we find the key itself. + // do skip over removed slots (yes, open addressing is like that...) + while (stat[i] != FREE && (stat[i] == REMOVED || (!key.equals(tab[i])))) { + i -= decrement; + //hashCollisions++; + if (i < 0) { + i += length; + } + } + + if (stat[i] == FREE) { + return -1; + } // not found + return i; //found, return index where key is contained + } + + /** + * Fills all keys contained in the receiver into the specified list. Fills the list, starting at index 0. After this + * call returns the specified list has a new size that equals <tt>this.size()</tt>. + * This method can be used + * to iterate over the keys of the receiver. + * + * @param list the list to be filled, can have any size. + */ + @SuppressWarnings("unchecked") + public void keys(List<T> list) { + list.clear(); + + + Object [] tab = table; + byte[] stat = state; + + for (int i = tab.length; i-- > 0;) { + if (stat[i] == FULL) { + list.add((T)tab[i]); + } + } + } + + @SuppressWarnings("unchecked") + @Override + public boolean add(Object key) { + int i = indexOfInsertion((T)key); + if (i < 0) { //already contained + return false; + } + + if (this.distinct > this.highWaterMark) { + int newCapacity = chooseGrowCapacity(this.distinct + 1, this.minLoadFactor, this.maxLoadFactor); + rehash(newCapacity); + return add(key); + } + + this.table[i] = key; + if (this.state[i] == FREE) { + this.freeEntries--; + } + this.state[i] = FULL; + this.distinct++; + + if (this.freeEntries < 1) { //delta + int newCapacity = chooseGrowCapacity(this.distinct + 1, this.minLoadFactor, this.maxLoadFactor); + rehash(newCapacity); + return add(key); + } + + return true; + } + + /** + * Rehashes the contents of the receiver into a new table with a smaller or larger capacity. This method is called + * automatically when the number of keys in the receiver exceeds the high water mark or falls below the low water + * mark. + */ + @SuppressWarnings("unchecked") + protected void rehash(int newCapacity) { + int oldCapacity = table.length; + //if (oldCapacity == newCapacity) return; + + Object[] oldTable = table; + byte[] oldState = state; + + Object[] newTable = new Object[newCapacity]; + byte[] newState = new byte[newCapacity]; + + this.lowWaterMark = chooseLowWaterMark(newCapacity, this.minLoadFactor); + this.highWaterMark = chooseHighWaterMark(newCapacity, this.maxLoadFactor); + + this.table = newTable; + this.state = newState; + this.freeEntries = newCapacity - this.distinct; // delta + + for (int i = oldCapacity; i-- > 0;) { + if (oldState[i] == FULL) { + Object element = oldTable[i]; + int index = indexOfInsertion((T)element); + newTable[index] = element; + newState[index] = FULL; + } + } + } + + /** + * Removes the given key with its associated element from the receiver, if present. + * + * @param key the key to be removed from the receiver. + * @return <tt>true</tt> if the receiver contained the specified key, <tt>false</tt> otherwise. + */ + @SuppressWarnings("unchecked") + @Override + public boolean remove(Object key) { + int i = indexOfKey((T)key); + if (i < 0) { + return false; + } // key not contained + + this.state[i] = REMOVED; + this.distinct--; + + if (this.distinct < this.lowWaterMark) { + int newCapacity = chooseShrinkCapacity(this.distinct, this.minLoadFactor, this.maxLoadFactor); + rehash(newCapacity); + } + + return true; + } + + /** + * Initializes the receiver. + * + * @param initialCapacity the initial capacity of the receiver. + * @param minLoadFactor the minLoadFactor of the receiver. + * @param maxLoadFactor the maxLoadFactor of the receiver. + * @throws IllegalArgumentException if <tt>initialCapacity < 0 || (minLoadFactor < 0.0 || minLoadFactor >= 1.0) || + * (maxLoadFactor <= 0.0 || maxLoadFactor >= 1.0) || (minLoadFactor >= + * maxLoadFactor)</tt>. + */ + @Override + protected final void setUp(int initialCapacity, double minLoadFactor, double maxLoadFactor) { + int capacity = initialCapacity; + super.setUp(capacity, minLoadFactor, maxLoadFactor); + capacity = nextPrime(capacity); + if (capacity == 0) { + capacity = 1; + } // open addressing needs at least one FREE slot at any time. + + this.table = new Object[capacity]; + this.state = new byte[capacity]; + + // memory will be exhausted long before this pathological case happens, anyway. + this.minLoadFactor = minLoadFactor; + if (capacity == PrimeFinder.LARGEST_PRIME) { + this.maxLoadFactor = 1.0; + } else { + this.maxLoadFactor = maxLoadFactor; + } + + this.distinct = 0; + this.freeEntries = capacity; // delta + + // lowWaterMark will be established upon first expansion. + // establishing it now (upon instance construction) would immediately make the table shrink upon first put(...). + // After all the idea of an "initialCapacity" implies violating lowWaterMarks when an object is young. + // See ensureCapacity(...) + this.lowWaterMark = 0; + this.highWaterMark = chooseHighWaterMark(capacity, this.maxLoadFactor); + } + + /** + * Trims the capacity of the receiver to be the receiver's current size. Releases any superfluous internal memory. An + * application can use this operation to minimize the storage of the receiver. + */ + @Override + public void trimToSize() { + // * 1.2 because open addressing's performance exponentially degrades beyond that point + // so that even rehashing the table can take very long + int newCapacity = nextPrime((int) (1 + 1.2 * size())); + if (table.length > newCapacity) { + rehash(newCapacity); + } + } + + /** + * Access for unit tests. + * @param capacity + * @param minLoadFactor + * @param maxLoadFactor + */ + void getInternalFactors(int[] capacity, + double[] minLoadFactor, + double[] maxLoadFactor) { + capacity[0] = table.length; + minLoadFactor[0] = this.minLoadFactor; + maxLoadFactor[0] = this.maxLoadFactor; + } + + @Override + public boolean isEmpty() { + return size() == 0; + } + + /** + * OpenHashSet instances are only equal to other OpenHashSet instances, not to + * any other collection. Hypothetically, we should check for and permit + * equals on other Sets. + */ + @Override + @SuppressWarnings("unchecked") + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + + if (!(obj instanceof OpenHashSet)) { + return false; + } + final OpenHashSet<T> other = (OpenHashSet<T>) obj; + if (other.size() != size()) { + return false; + } + + return forEachKey(new ObjectProcedure<T>() { + @Override + public boolean apply(T key) { + return other.contains(key); + } + }); + } + + @Override + public int hashCode() { + ByteBuffer buf = ByteBuffer.allocate(size()); + for (int i = 0; i < table.length; i++) { + Object v = table[i]; + if (state[i] == FULL) { + buf.putInt(v.hashCode()); + } + } + return MurmurHash.hash(buf, this.getClass().getName().hashCode()); + } + + /** + * Implement the standard Java Collections iterator. Note that 'remove' is silently + * ineffectual here. This method is provided for convenience, only. + */ + @Override + public Iterator<T> iterator() { + List<T> keyList = new ArrayList<>(); + keys(keyList); + return keyList.iterator(); + } + + @Override + public Object[] toArray() { + List<T> keyList = new ArrayList<>(); + keys(keyList); + return keyList.toArray(); + } + + @Override + public boolean addAll(Collection<? extends T> c) { + boolean anyAdded = false; + for (T o : c) { + boolean added = add(o); + anyAdded |= added; + } + return anyAdded; + } + + @Override + public boolean containsAll(Collection<?> c) { + for (Object o : c) { + if (!contains(o)) { + return false; + } + } + return true; + } + + @Override + public boolean removeAll(Collection<?> c) { + boolean anyRemoved = false; + for (Object o : c) { + boolean removed = remove(o); + anyRemoved |= removed; + } + return anyRemoved; + } + + @Override + public boolean retainAll(Collection<?> c) { + final Collection<?> finalCollection = c; + final boolean[] modified = new boolean[1]; + modified[0] = false; + forEachKey(new ObjectProcedure<T>() { + @Override + public boolean apply(T element) { + if (!finalCollection.contains(element)) { + remove(element); + modified[0] = true; + } + return true; + } + }); + return modified[0]; + } + + @Override + public <T1> T1[] toArray(T1[] a) { + return keys().toArray(a); + } + + public List<T> keys() { + List<T> keys = new ArrayList<>(); + keys(keys); + return keys; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/solver/ConjugateGradientSolver.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/solver/ConjugateGradientSolver.java b/core/src/main/java/org/apache/mahout/math/solver/ConjugateGradientSolver.java new file mode 100644 index 0000000..02bde9b --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/solver/ConjugateGradientSolver.java @@ -0,0 +1,213 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.solver; + +import org.apache.mahout.math.CardinalityException; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorIterable; +import org.apache.mahout.math.function.Functions; +import org.apache.mahout.math.function.PlusMult; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * <p>Implementation of a conjugate gradient iterative solver for linear systems. Implements both + * standard conjugate gradient and pre-conditioned conjugate gradient. + * + * <p>Conjugate gradient requires the matrix A in the linear system Ax = b to be symmetric and positive + * definite. For convenience, this implementation could be extended relatively easily to handle the + * case where the input matrix to be be non-symmetric, in which case the system A'Ax = b would be solved. + * Because this requires only one pass through the matrix A, it is faster than explicitly computing A'A, + * then passing the results to the solver. + * + * <p>For inputs that may be ill conditioned (often the case for highly sparse input), this solver + * also accepts a parameter, lambda, which adds a scaled identity to the matrix A, solving the system + * (A + lambda*I)x = b. This obviously changes the solution, but it will guarantee solvability. The + * ridge regression approach to linear regression is a common use of this feature. + * + * <p>If only an approximate solution is required, the maximum number of iterations or the error threshold + * may be specified to end the algorithm early at the expense of accuracy. When the matrix A is ill conditioned, + * it may sometimes be necessary to increase the maximum number of iterations above the default of A.numCols() + * due to numerical issues. + * + * <p>By default the solver will run a.numCols() iterations or until the residual falls below 1E-9. + * + * <p>For more information on the conjugate gradient algorithm, see Golub & van Loan, "Matrix Computations", + * sections 10.2 and 10.3 or the <a href="http://en.wikipedia.org/wiki/Conjugate_gradient">conjugate gradient + * wikipedia article</a>. + */ + +public class ConjugateGradientSolver { + + public static final double DEFAULT_MAX_ERROR = 1.0e-9; + + private static final Logger log = LoggerFactory.getLogger(ConjugateGradientSolver.class); + private static final PlusMult PLUS_MULT = new PlusMult(1.0); + + private int iterations; + private double residualNormSquared; + + public ConjugateGradientSolver() { + this.iterations = 0; + this.residualNormSquared = Double.NaN; + } + + /** + * Solves the system Ax = b with default termination criteria. A must be symmetric, square, and positive definite. + * Only the squareness of a is checked, since testing for symmetry and positive definiteness are too expensive. If + * an invalid matrix is specified, then the algorithm may not yield a valid result. + * + * @param a The linear operator A. + * @param b The vector b. + * @return The result x of solving the system. + * @throws IllegalArgumentException if a is not square or if the size of b is not equal to the number of columns of a. + * + */ + public Vector solve(VectorIterable a, Vector b) { + return solve(a, b, null, b.size() + 2, DEFAULT_MAX_ERROR); + } + + /** + * Solves the system Ax = b with default termination criteria using the specified preconditioner. A must be + * symmetric, square, and positive definite. Only the squareness of a is checked, since testing for symmetry + * and positive definiteness are too expensive. If an invalid matrix is specified, then the algorithm may not + * yield a valid result. + * + * @param a The linear operator A. + * @param b The vector b. + * @param precond A preconditioner to use on A during the solution process. + * @return The result x of solving the system. + * @throws IllegalArgumentException if a is not square or if the size of b is not equal to the number of columns of a. + * + */ + public Vector solve(VectorIterable a, Vector b, Preconditioner precond) { + return solve(a, b, precond, b.size() + 2, DEFAULT_MAX_ERROR); + } + + + /** + * Solves the system Ax = b, where A is a linear operator and b is a vector. Uses the specified preconditioner + * to improve numeric stability and possibly speed convergence. This version of solve() allows control over the + * termination and iteration parameters. + * + * @param a The matrix A. + * @param b The vector b. + * @param preconditioner The preconditioner to apply. + * @param maxIterations The maximum number of iterations to run. + * @param maxError The maximum amount of residual error to tolerate. The algorithm will run until the residual falls + * below this value or until maxIterations are completed. + * @return The result x of solving the system. + * @throws IllegalArgumentException if the matrix is not square, if the size of b is not equal to the number of + * columns of A, if maxError is less than zero, or if maxIterations is not positive. + */ + + public Vector solve(VectorIterable a, + Vector b, + Preconditioner preconditioner, + int maxIterations, + double maxError) { + + if (a.numRows() != a.numCols()) { + throw new IllegalArgumentException("Matrix must be square, symmetric and positive definite."); + } + + if (a.numCols() != b.size()) { + throw new CardinalityException(a.numCols(), b.size()); + } + + if (maxIterations <= 0) { + throw new IllegalArgumentException("Max iterations must be positive."); + } + + if (maxError < 0.0) { + throw new IllegalArgumentException("Max error must be non-negative."); + } + + Vector x = new DenseVector(b.size()); + + iterations = 0; + Vector residual = b.minus(a.times(x)); + residualNormSquared = residual.dot(residual); + + log.info("Conjugate gradient initial residual norm = {}", Math.sqrt(residualNormSquared)); + double previousConditionedNormSqr = 0.0; + Vector updateDirection = null; + while (Math.sqrt(residualNormSquared) > maxError && iterations < maxIterations) { + Vector conditionedResidual; + double conditionedNormSqr; + if (preconditioner == null) { + conditionedResidual = residual; + conditionedNormSqr = residualNormSquared; + } else { + conditionedResidual = preconditioner.precondition(residual); + conditionedNormSqr = residual.dot(conditionedResidual); + } + + ++iterations; + + if (iterations == 1) { + updateDirection = new DenseVector(conditionedResidual); + } else { + double beta = conditionedNormSqr / previousConditionedNormSqr; + + // updateDirection = residual + beta * updateDirection + updateDirection.assign(Functions.MULT, beta); + updateDirection.assign(conditionedResidual, Functions.PLUS); + } + + Vector aTimesUpdate = a.times(updateDirection); + + double alpha = conditionedNormSqr / updateDirection.dot(aTimesUpdate); + + // x = x + alpha * updateDirection + PLUS_MULT.setMultiplicator(alpha); + x.assign(updateDirection, PLUS_MULT); + + // residual = residual - alpha * A * updateDirection + PLUS_MULT.setMultiplicator(-alpha); + residual.assign(aTimesUpdate, PLUS_MULT); + + previousConditionedNormSqr = conditionedNormSqr; + residualNormSquared = residual.dot(residual); + + log.info("Conjugate gradient iteration {} residual norm = {}", iterations, Math.sqrt(residualNormSquared)); + } + return x; + } + + /** + * Returns the number of iterations run once the solver is complete. + * + * @return The number of iterations run. + */ + public int getIterations() { + return iterations; + } + + /** + * Returns the norm of the residual at the completion of the solver. Usually this should be close to zero except in + * the case of a non positive definite matrix A, which results in an unsolvable system, or for ill conditioned A, in + * which case more iterations than the default may be needed. + * + * @return The norm of the residual in the solution. + */ + public double getResidualNorm() { + return Math.sqrt(residualNormSquared); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/solver/EigenDecomposition.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/solver/EigenDecomposition.java b/core/src/main/java/org/apache/mahout/math/solver/EigenDecomposition.java new file mode 100644 index 0000000..871ba44 --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/solver/EigenDecomposition.java @@ -0,0 +1,892 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Adapted from the public domain Jama code. + */ + +package org.apache.mahout.math.solver; + +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.function.Functions; + +/** + * Eigenvalues and eigenvectors of a real matrix. + * <p/> + * If A is symmetric, then A = V*D*V' where the eigenvalue matrix D is diagonal and the eigenvector + * matrix V is orthogonal. I.e. A = V.times(D.times(V.transpose())) and V.times(V.transpose()) + * equals the identity matrix. + * <p/> + * If A is not symmetric, then the eigenvalue matrix D is block diagonal with the real eigenvalues + * in 1-by-1 blocks and any complex eigenvalues, lambda + i*mu, in 2-by-2 blocks, [lambda, mu; -mu, + * lambda]. The columns of V represent the eigenvectors in the sense that A*V = V*D, i.e. + * A.times(V) equals V.times(D). The matrix V may be badly conditioned, or even singular, so the + * validity of the equation A = V*D*inverse(V) depends upon V.cond(). + */ +public class EigenDecomposition { + + /** Row and column dimension (square matrix). */ + private final int n; + /** Arrays for internal storage of eigenvalues. */ + private final Vector d; + private final Vector e; + /** Array for internal storage of eigenvectors. */ + private final Matrix v; + + public EigenDecomposition(Matrix x) { + this(x, isSymmetric(x)); + } + + public EigenDecomposition(Matrix x, boolean isSymmetric) { + n = x.columnSize(); + d = new DenseVector(n); + e = new DenseVector(n); + v = new DenseMatrix(n, n); + + if (isSymmetric) { + v.assign(x); + + // Tridiagonalize. + tred2(); + + // Diagonalize. + tql2(); + + } else { + // Reduce to Hessenberg form. + // Reduce Hessenberg to real Schur form. + hqr2(orthes(x)); + } + } + + /** + * Return the eigenvector matrix + * + * @return V + */ + public Matrix getV() { + return v.like().assign(v); + } + + /** + * Return the real parts of the eigenvalues + */ + public Vector getRealEigenvalues() { + return d; + } + + /** + * Return the imaginary parts of the eigenvalues + */ + public Vector getImagEigenvalues() { + return e; + } + + /** + * Return the block diagonal eigenvalue matrix + * + * @return D + */ + public Matrix getD() { + Matrix x = new DenseMatrix(n, n); + x.assign(0); + x.viewDiagonal().assign(d); + for (int i = 0; i < n; i++) { + double v = e.getQuick(i); + if (v > 0) { + x.setQuick(i, i + 1, v); + } else if (v < 0) { + x.setQuick(i, i - 1, v); + } + } + return x; + } + + // Symmetric Householder reduction to tridiagonal form. + private void tred2() { + // This is derived from the Algol procedures tred2 by + // Bowdler, Martin, Reinsch, and Wilkinson, Handbook for + // Auto. Comp., Vol.ii-Linear Algebra, and the corresponding + // Fortran subroutine in EISPACK. + + d.assign(v.viewColumn(n - 1)); + + // Householder reduction to tridiagonal form. + + for (int i = n - 1; i > 0; i--) { + + // Scale to avoid under/overflow. + + double scale = d.viewPart(0, i).norm(1); + double h = 0.0; + + + if (scale == 0.0) { + e.setQuick(i, d.getQuick(i - 1)); + for (int j = 0; j < i; j++) { + d.setQuick(j, v.getQuick(i - 1, j)); + v.setQuick(i, j, 0.0); + v.setQuick(j, i, 0.0); + } + } else { + + // Generate Householder vector. + + for (int k = 0; k < i; k++) { + d.setQuick(k, d.getQuick(k) / scale); + h += d.getQuick(k) * d.getQuick(k); + } + double f = d.getQuick(i - 1); + double g = Math.sqrt(h); + if (f > 0) { + g = -g; + } + e.setQuick(i, scale * g); + h -= f * g; + d.setQuick(i - 1, f - g); + for (int j = 0; j < i; j++) { + e.setQuick(j, 0.0); + } + + // Apply similarity transformation to remaining columns. + + for (int j = 0; j < i; j++) { + f = d.getQuick(j); + v.setQuick(j, i, f); + g = e.getQuick(j) + v.getQuick(j, j) * f; + for (int k = j + 1; k <= i - 1; k++) { + g += v.getQuick(k, j) * d.getQuick(k); + e.setQuick(k, e.getQuick(k) + v.getQuick(k, j) * f); + } + e.setQuick(j, g); + } + f = 0.0; + for (int j = 0; j < i; j++) { + e.setQuick(j, e.getQuick(j) / h); + f += e.getQuick(j) * d.getQuick(j); + } + double hh = f / (h + h); + for (int j = 0; j < i; j++) { + e.setQuick(j, e.getQuick(j) - hh * d.getQuick(j)); + } + for (int j = 0; j < i; j++) { + f = d.getQuick(j); + g = e.getQuick(j); + for (int k = j; k <= i - 1; k++) { + v.setQuick(k, j, v.getQuick(k, j) - (f * e.getQuick(k) + g * d.getQuick(k))); + } + d.setQuick(j, v.getQuick(i - 1, j)); + v.setQuick(i, j, 0.0); + } + } + d.setQuick(i, h); + } + + // Accumulate transformations. + + for (int i = 0; i < n - 1; i++) { + v.setQuick(n - 1, i, v.getQuick(i, i)); + v.setQuick(i, i, 1.0); + double h = d.getQuick(i + 1); + if (h != 0.0) { + for (int k = 0; k <= i; k++) { + d.setQuick(k, v.getQuick(k, i + 1) / h); + } + for (int j = 0; j <= i; j++) { + double g = 0.0; + for (int k = 0; k <= i; k++) { + g += v.getQuick(k, i + 1) * v.getQuick(k, j); + } + for (int k = 0; k <= i; k++) { + v.setQuick(k, j, v.getQuick(k, j) - g * d.getQuick(k)); + } + } + } + for (int k = 0; k <= i; k++) { + v.setQuick(k, i + 1, 0.0); + } + } + d.assign(v.viewRow(n - 1)); + v.viewRow(n - 1).assign(0); + v.setQuick(n - 1, n - 1, 1.0); + e.setQuick(0, 0.0); + } + + // Symmetric tridiagonal QL algorithm. + private void tql2() { + + // This is derived from the Algol procedures tql2, by + // Bowdler, Martin, Reinsch, and Wilkinson, Handbook for + // Auto. Comp., Vol.ii-Linear Algebra, and the corresponding + // Fortran subroutine in EISPACK. + + e.viewPart(0, n - 1).assign(e.viewPart(1, n - 1)); + e.setQuick(n - 1, 0.0); + + double f = 0.0; + double tst1 = 0.0; + double eps = Math.pow(2.0, -52.0); + for (int l = 0; l < n; l++) { + + // Find small subdiagonal element + + tst1 = Math.max(tst1, Math.abs(d.getQuick(l)) + Math.abs(e.getQuick(l))); + int m = l; + while (m < n) { + if (Math.abs(e.getQuick(m)) <= eps * tst1) { + break; + } + m++; + } + + // If m == l, d.getQuick(l) is an eigenvalue, + // otherwise, iterate. + + if (m > l) { + do { + // Compute implicit shift + + double g = d.getQuick(l); + double p = (d.getQuick(l + 1) - g) / (2.0 * e.getQuick(l)); + double r = Math.hypot(p, 1.0); + if (p < 0) { + r = -r; + } + d.setQuick(l, e.getQuick(l) / (p + r)); + d.setQuick(l + 1, e.getQuick(l) * (p + r)); + double dl1 = d.getQuick(l + 1); + double h = g - d.getQuick(l); + for (int i = l + 2; i < n; i++) { + d.setQuick(i, d.getQuick(i) - h); + } + f += h; + + // Implicit QL transformation. + + p = d.getQuick(m); + double c = 1.0; + double c2 = c; + double c3 = c; + double el1 = e.getQuick(l + 1); + double s = 0.0; + double s2 = 0.0; + for (int i = m - 1; i >= l; i--) { + c3 = c2; + c2 = c; + s2 = s; + g = c * e.getQuick(i); + h = c * p; + r = Math.hypot(p, e.getQuick(i)); + e.setQuick(i + 1, s * r); + s = e.getQuick(i) / r; + c = p / r; + p = c * d.getQuick(i) - s * g; + d.setQuick(i + 1, h + s * (c * g + s * d.getQuick(i))); + + // Accumulate transformation. + + for (int k = 0; k < n; k++) { + h = v.getQuick(k, i + 1); + v.setQuick(k, i + 1, s * v.getQuick(k, i) + c * h); + v.setQuick(k, i, c * v.getQuick(k, i) - s * h); + } + } + p = -s * s2 * c3 * el1 * e.getQuick(l) / dl1; + e.setQuick(l, s * p); + d.setQuick(l, c * p); + + // Check for convergence. + + } while (Math.abs(e.getQuick(l)) > eps * tst1); + } + d.setQuick(l, d.getQuick(l) + f); + e.setQuick(l, 0.0); + } + + // Sort eigenvalues and corresponding vectors. + + for (int i = 0; i < n - 1; i++) { + int k = i; + double p = d.getQuick(i); + for (int j = i + 1; j < n; j++) { + if (d.getQuick(j) > p) { + k = j; + p = d.getQuick(j); + } + } + if (k != i) { + d.setQuick(k, d.getQuick(i)); + d.setQuick(i, p); + for (int j = 0; j < n; j++) { + p = v.getQuick(j, i); + v.setQuick(j, i, v.getQuick(j, k)); + v.setQuick(j, k, p); + } + } + } + } + + // Nonsymmetric reduction to Hessenberg form. + private Matrix orthes(Matrix x) { + // Working storage for nonsymmetric algorithm. + Vector ort = new DenseVector(n); + Matrix hessenBerg = new DenseMatrix(n, n).assign(x); + + // This is derived from the Algol procedures orthes and ortran, + // by Martin and Wilkinson, Handbook for Auto. Comp., + // Vol.ii-Linear Algebra, and the corresponding + // Fortran subroutines in EISPACK. + + int low = 0; + int high = n - 1; + + for (int m = low + 1; m <= high - 1; m++) { + + // Scale column. + + Vector hColumn = hessenBerg.viewColumn(m - 1).viewPart(m, high - m + 1); + double scale = hColumn.norm(1); + + if (scale != 0.0) { + // Compute Householder transformation. + + ort.viewPart(m, high - m + 1).assign(hColumn, Functions.plusMult(1 / scale)); + double h = ort.viewPart(m, high - m + 1).getLengthSquared(); + + double g = Math.sqrt(h); + if (ort.getQuick(m) > 0) { + g = -g; + } + h -= ort.getQuick(m) * g; + ort.setQuick(m, ort.getQuick(m) - g); + + // Apply Householder similarity transformation + // H = (I-u*u'/h)*H*(I-u*u')/h) + + Vector ortPiece = ort.viewPart(m, high - m + 1); + for (int j = m; j < n; j++) { + double f = ortPiece.dot(hessenBerg.viewColumn(j).viewPart(m, high - m + 1)) / h; + hessenBerg.viewColumn(j).viewPart(m, high - m + 1).assign(ortPiece, Functions.plusMult(-f)); + } + + for (int i = 0; i <= high; i++) { + double f = ortPiece.dot(hessenBerg.viewRow(i).viewPart(m, high - m + 1)) / h; + hessenBerg.viewRow(i).viewPart(m, high - m + 1).assign(ortPiece, Functions.plusMult(-f)); + } + ort.setQuick(m, scale * ort.getQuick(m)); + hessenBerg.setQuick(m, m - 1, scale * g); + } + } + + // Accumulate transformations (Algol's ortran). + + v.assign(0); + v.viewDiagonal().assign(1); + + for (int m = high - 1; m >= low + 1; m--) { + if (hessenBerg.getQuick(m, m - 1) != 0.0) { + ort.viewPart(m + 1, high - m).assign(hessenBerg.viewColumn(m - 1).viewPart(m + 1, high - m)); + for (int j = m; j <= high; j++) { + double g = ort.viewPart(m, high - m + 1).dot(v.viewColumn(j).viewPart(m, high - m + 1)); + // Double division avoids possible underflow + g = g / ort.getQuick(m) / hessenBerg.getQuick(m, m - 1); + v.viewColumn(j).viewPart(m, high - m + 1).assign(ort.viewPart(m, high - m + 1), Functions.plusMult(g)); + } + } + } + return hessenBerg; + } + + + // Complex scalar division. + private double cdivr; + private double cdivi; + + private void cdiv(double xr, double xi, double yr, double yi) { + double r; + double d; + if (Math.abs(yr) > Math.abs(yi)) { + r = yi / yr; + d = yr + r * yi; + cdivr = (xr + r * xi) / d; + cdivi = (xi - r * xr) / d; + } else { + r = yr / yi; + d = yi + r * yr; + cdivr = (r * xr + xi) / d; + cdivi = (r * xi - xr) / d; + } + } + + + // Nonsymmetric reduction from Hessenberg to real Schur form. + + private void hqr2(Matrix h) { + + // This is derived from the Algol procedure hqr2, + // by Martin and Wilkinson, Handbook for Auto. Comp., + // Vol.ii-Linear Algebra, and the corresponding + // Fortran subroutine in EISPACK. + + // Initialize + + int nn = this.n; + int n = nn - 1; + int low = 0; + int high = nn - 1; + double eps = Math.pow(2.0, -52.0); + double exshift = 0.0; + double p = 0; + double q = 0; + double r = 0; + double s = 0; + double z = 0; + double w; + double x; + double y; + + // Store roots isolated by balanc and compute matrix norm + + double norm = h.aggregate(Functions.PLUS, Functions.ABS); + + // Outer loop over eigenvalue index + + int iter = 0; + while (n >= low) { + + // Look for single small sub-diagonal element + + int l = n; + while (l > low) { + s = Math.abs(h.getQuick(l - 1, l - 1)) + Math.abs(h.getQuick(l, l)); + if (s == 0.0) { + s = norm; + } + if (Math.abs(h.getQuick(l, l - 1)) < eps * s) { + break; + } + l--; + } + + // Check for convergence + + if (l == n) { + // One root found + h.setQuick(n, n, h.getQuick(n, n) + exshift); + d.setQuick(n, h.getQuick(n, n)); + e.setQuick(n, 0.0); + n--; + iter = 0; + + + } else if (l == n - 1) { + // Two roots found + w = h.getQuick(n, n - 1) * h.getQuick(n - 1, n); + p = (h.getQuick(n - 1, n - 1) - h.getQuick(n, n)) / 2.0; + q = p * p + w; + z = Math.sqrt(Math.abs(q)); + h.setQuick(n, n, h.getQuick(n, n) + exshift); + h.setQuick(n - 1, n - 1, h.getQuick(n - 1, n - 1) + exshift); + x = h.getQuick(n, n); + + // Real pair + if (q >= 0) { + if (p >= 0) { + z = p + z; + } else { + z = p - z; + } + d.setQuick(n - 1, x + z); + d.setQuick(n, d.getQuick(n - 1)); + if (z != 0.0) { + d.setQuick(n, x - w / z); + } + e.setQuick(n - 1, 0.0); + e.setQuick(n, 0.0); + x = h.getQuick(n, n - 1); + s = Math.abs(x) + Math.abs(z); + p = x / s; + q = z / s; + r = Math.sqrt(p * p + q * q); + p /= r; + q /= r; + + // Row modification + + for (int j = n - 1; j < nn; j++) { + z = h.getQuick(n - 1, j); + h.setQuick(n - 1, j, q * z + p * h.getQuick(n, j)); + h.setQuick(n, j, q * h.getQuick(n, j) - p * z); + } + + // Column modification + + for (int i = 0; i <= n; i++) { + z = h.getQuick(i, n - 1); + h.setQuick(i, n - 1, q * z + p * h.getQuick(i, n)); + h.setQuick(i, n, q * h.getQuick(i, n) - p * z); + } + + // Accumulate transformations + + for (int i = low; i <= high; i++) { + z = v.getQuick(i, n - 1); + v.setQuick(i, n - 1, q * z + p * v.getQuick(i, n)); + v.setQuick(i, n, q * v.getQuick(i, n) - p * z); + } + + // Complex pair + + } else { + d.setQuick(n - 1, x + p); + d.setQuick(n, x + p); + e.setQuick(n - 1, z); + e.setQuick(n, -z); + } + n -= 2; + iter = 0; + + // No convergence yet + + } else { + + // Form shift + + x = h.getQuick(n, n); + y = 0.0; + w = 0.0; + if (l < n) { + y = h.getQuick(n - 1, n - 1); + w = h.getQuick(n, n - 1) * h.getQuick(n - 1, n); + } + + // Wilkinson's original ad hoc shift + + if (iter == 10) { + exshift += x; + for (int i = low; i <= n; i++) { + h.setQuick(i, i, x); + } + s = Math.abs(h.getQuick(n, n - 1)) + Math.abs(h.getQuick(n - 1, n - 2)); + x = y = 0.75 * s; + w = -0.4375 * s * s; + } + + // MATLAB's new ad hoc shift + + if (iter == 30) { + s = (y - x) / 2.0; + s = s * s + w; + if (s > 0) { + s = Math.sqrt(s); + if (y < x) { + s = -s; + } + s = x - w / ((y - x) / 2.0 + s); + for (int i = low; i <= n; i++) { + h.setQuick(i, i, h.getQuick(i, i) - s); + } + exshift += s; + x = y = w = 0.964; + } + } + + iter++; // (Could check iteration count here.) + + // Look for two consecutive small sub-diagonal elements + + int m = n - 2; + while (m >= l) { + z = h.getQuick(m, m); + r = x - z; + s = y - z; + p = (r * s - w) / h.getQuick(m + 1, m) + h.getQuick(m, m + 1); + q = h.getQuick(m + 1, m + 1) - z - r - s; + r = h.getQuick(m + 2, m + 1); + s = Math.abs(p) + Math.abs(q) + Math.abs(r); + p /= s; + q /= s; + r /= s; + if (m == l) { + break; + } + double hmag = Math.abs(h.getQuick(m - 1, m - 1)) + Math.abs(h.getQuick(m + 1, m + 1)); + double threshold = eps * Math.abs(p) * (Math.abs(z) + hmag); + if (Math.abs(h.getQuick(m, m - 1)) * (Math.abs(q) + Math.abs(r)) < threshold) { + break; + } + m--; + } + + for (int i = m + 2; i <= n; i++) { + h.setQuick(i, i - 2, 0.0); + if (i > m + 2) { + h.setQuick(i, i - 3, 0.0); + } + } + + // Double QR step involving rows l:n and columns m:n + + for (int k = m; k <= n - 1; k++) { + boolean notlast = k != n - 1; + if (k != m) { + p = h.getQuick(k, k - 1); + q = h.getQuick(k + 1, k - 1); + r = notlast ? h.getQuick(k + 2, k - 1) : 0.0; + x = Math.abs(p) + Math.abs(q) + Math.abs(r); + if (x != 0.0) { + p /= x; + q /= x; + r /= x; + } + } + if (x == 0.0) { + break; + } + s = Math.sqrt(p * p + q * q + r * r); + if (p < 0) { + s = -s; + } + if (s != 0) { + if (k != m) { + h.setQuick(k, k - 1, -s * x); + } else if (l != m) { + h.setQuick(k, k - 1, -h.getQuick(k, k - 1)); + } + p += s; + x = p / s; + y = q / s; + z = r / s; + q /= p; + r /= p; + + // Row modification + + for (int j = k; j < nn; j++) { + p = h.getQuick(k, j) + q * h.getQuick(k + 1, j); + if (notlast) { + p += r * h.getQuick(k + 2, j); + h.setQuick(k + 2, j, h.getQuick(k + 2, j) - p * z); + } + h.setQuick(k, j, h.getQuick(k, j) - p * x); + h.setQuick(k + 1, j, h.getQuick(k + 1, j) - p * y); + } + + // Column modification + + for (int i = 0; i <= Math.min(n, k + 3); i++) { + p = x * h.getQuick(i, k) + y * h.getQuick(i, k + 1); + if (notlast) { + p += z * h.getQuick(i, k + 2); + h.setQuick(i, k + 2, h.getQuick(i, k + 2) - p * r); + } + h.setQuick(i, k, h.getQuick(i, k) - p); + h.setQuick(i, k + 1, h.getQuick(i, k + 1) - p * q); + } + + // Accumulate transformations + + for (int i = low; i <= high; i++) { + p = x * v.getQuick(i, k) + y * v.getQuick(i, k + 1); + if (notlast) { + p += z * v.getQuick(i, k + 2); + v.setQuick(i, k + 2, v.getQuick(i, k + 2) - p * r); + } + v.setQuick(i, k, v.getQuick(i, k) - p); + v.setQuick(i, k + 1, v.getQuick(i, k + 1) - p * q); + } + } // (s != 0) + } // k loop + } // check convergence + } // while (n >= low) + + // Backsubstitute to find vectors of upper triangular form + + if (norm == 0.0) { + return; + } + + for (n = nn - 1; n >= 0; n--) { + p = d.getQuick(n); + q = e.getQuick(n); + + // Real vector + + double t; + if (q == 0) { + int l = n; + h.setQuick(n, n, 1.0); + for (int i = n - 1; i >= 0; i--) { + w = h.getQuick(i, i) - p; + r = 0.0; + for (int j = l; j <= n; j++) { + r += h.getQuick(i, j) * h.getQuick(j, n); + } + if (e.getQuick(i) < 0.0) { + z = w; + s = r; + } else { + l = i; + if (e.getQuick(i) == 0.0) { + if (w == 0.0) { + h.setQuick(i, n, -r / (eps * norm)); + } else { + h.setQuick(i, n, -r / w); + } + + // Solve real equations + + } else { + x = h.getQuick(i, i + 1); + y = h.getQuick(i + 1, i); + q = (d.getQuick(i) - p) * (d.getQuick(i) - p) + e.getQuick(i) * e.getQuick(i); + t = (x * s - z * r) / q; + h.setQuick(i, n, t); + if (Math.abs(x) > Math.abs(z)) { + h.setQuick(i + 1, n, (-r - w * t) / x); + } else { + h.setQuick(i + 1, n, (-s - y * t) / z); + } + } + + // Overflow control + + t = Math.abs(h.getQuick(i, n)); + if (eps * t * t > 1) { + for (int j = i; j <= n; j++) { + h.setQuick(j, n, h.getQuick(j, n) / t); + } + } + } + } + + // Complex vector + + } else if (q < 0) { + int l = n - 1; + + // Last vector component imaginary so matrix is triangular + + if (Math.abs(h.getQuick(n, n - 1)) > Math.abs(h.getQuick(n - 1, n))) { + h.setQuick(n - 1, n - 1, q / h.getQuick(n, n - 1)); + h.setQuick(n - 1, n, -(h.getQuick(n, n) - p) / h.getQuick(n, n - 1)); + } else { + cdiv(0.0, -h.getQuick(n - 1, n), h.getQuick(n - 1, n - 1) - p, q); + h.setQuick(n - 1, n - 1, cdivr); + h.setQuick(n - 1, n, cdivi); + } + h.setQuick(n, n - 1, 0.0); + h.setQuick(n, n, 1.0); + for (int i = n - 2; i >= 0; i--) { + double ra = 0.0; + double sa = 0.0; + for (int j = l; j <= n; j++) { + ra += h.getQuick(i, j) * h.getQuick(j, n - 1); + sa += h.getQuick(i, j) * h.getQuick(j, n); + } + w = h.getQuick(i, i) - p; + + if (e.getQuick(i) < 0.0) { + z = w; + r = ra; + s = sa; + } else { + l = i; + if (e.getQuick(i) == 0) { + cdiv(-ra, -sa, w, q); + h.setQuick(i, n - 1, cdivr); + h.setQuick(i, n, cdivi); + } else { + + // Solve complex equations + + x = h.getQuick(i, i + 1); + y = h.getQuick(i + 1, i); + double vr = (d.getQuick(i) - p) * (d.getQuick(i) - p) + e.getQuick(i) * e.getQuick(i) - q * q; + double vi = (d.getQuick(i) - p) * 2.0 * q; + if (vr == 0.0 && vi == 0.0) { + double hmag = Math.abs(x) + Math.abs(y); + vr = eps * norm * (Math.abs(w) + Math.abs(q) + hmag + Math.abs(z)); + } + cdiv(x * r - z * ra + q * sa, x * s - z * sa - q * ra, vr, vi); + h.setQuick(i, n - 1, cdivr); + h.setQuick(i, n, cdivi); + if (Math.abs(x) > (Math.abs(z) + Math.abs(q))) { + h.setQuick(i + 1, n - 1, (-ra - w * h.getQuick(i, n - 1) + q * h.getQuick(i, n)) / x); + h.setQuick(i + 1, n, (-sa - w * h.getQuick(i, n) - q * h.getQuick(i, n - 1)) / x); + } else { + cdiv(-r - y * h.getQuick(i, n - 1), -s - y * h.getQuick(i, n), z, q); + h.setQuick(i + 1, n - 1, cdivr); + h.setQuick(i + 1, n, cdivi); + } + } + + // Overflow control + + t = Math.max(Math.abs(h.getQuick(i, n - 1)), Math.abs(h.getQuick(i, n))); + if (eps * t * t > 1) { + for (int j = i; j <= n; j++) { + h.setQuick(j, n - 1, h.getQuick(j, n - 1) / t); + h.setQuick(j, n, h.getQuick(j, n) / t); + } + } + } + } + } + } + + // Vectors of isolated roots + + for (int i = 0; i < nn; i++) { + if (i < low || i > high) { + for (int j = i; j < nn; j++) { + v.setQuick(i, j, h.getQuick(i, j)); + } + } + } + + // Back transformation to get eigenvectors of original matrix + + for (int j = nn - 1; j >= low; j--) { + for (int i = low; i <= high; i++) { + z = 0.0; + for (int k = low; k <= Math.min(j, high); k++) { + z += v.getQuick(i, k) * h.getQuick(k, j); + } + v.setQuick(i, j, z); + } + } + } + + private static boolean isSymmetric(Matrix a) { + /* + Symmetry flag. + */ + int n = a.columnSize(); + + boolean isSymmetric = true; + for (int j = 0; (j < n) && isSymmetric; j++) { + for (int i = 0; (i < n) && isSymmetric; i++) { + isSymmetric = a.getQuick(i, j) == a.getQuick(j, i); + } + } + return isSymmetric; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/solver/JacobiConditioner.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/solver/JacobiConditioner.java b/core/src/main/java/org/apache/mahout/math/solver/JacobiConditioner.java new file mode 100644 index 0000000..7524564 --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/solver/JacobiConditioner.java @@ -0,0 +1,47 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.solver; + +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; + +/** + * Implements the Jacobi preconditioner for a matrix A. This is defined as inv(diag(A)). + */ +public final class JacobiConditioner implements Preconditioner { + + private final DenseVector inverseDiagonal; + + public JacobiConditioner(Matrix a) { + if (a.numCols() != a.numRows()) { + throw new IllegalArgumentException("Matrix must be square."); + } + + inverseDiagonal = new DenseVector(a.numCols()); + for (int i = 0; i < a.numCols(); ++i) { + inverseDiagonal.setQuick(i, 1.0 / a.getQuick(i, i)); + } + } + + @Override + public Vector precondition(Vector v) { + return v.times(inverseDiagonal); + } + +}
