http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/VectorView.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/VectorView.java b/core/src/main/java/org/apache/mahout/math/VectorView.java new file mode 100644 index 0000000..62c5490 --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/VectorView.java @@ -0,0 +1,238 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math; + +import java.util.Iterator; + +import com.google.common.collect.AbstractIterator; + +/** Implements subset view of a Vector */ +public class VectorView extends AbstractVector { + + protected Vector vector; + + // the offset into the Vector + protected int offset; + + /** For serialization purposes only */ + public VectorView() { + super(0); + } + + public VectorView(Vector vector, int offset, int cardinality) { + super(cardinality); + this.vector = vector; + this.offset = offset; + } + + @Override + protected Matrix matrixLike(int rows, int columns) { + return ((AbstractVector) vector).matrixLike(rows, columns); + } + + @Override + public Vector clone() { + VectorView r = (VectorView) super.clone(); + r.vector = vector.clone(); + r.offset = offset; + return r; + } + + @Override + public boolean isDense() { + return vector.isDense(); + } + + @Override + public boolean isSequentialAccess() { + return vector.isSequentialAccess(); + } + + @Override + public VectorView like() { + return new VectorView(vector.like(), offset, size()); + } + + @Override + public Vector like(int cardinality) { + return vector.like(cardinality); + } + + @Override + public double getQuick(int index) { + return vector.getQuick(offset + index); + } + + @Override + public void setQuick(int index, double value) { + vector.setQuick(offset + index, value); + } + + @Override + public int getNumNondefaultElements() { + return size(); + } + + @Override + public Vector viewPart(int offset, int length) { + if (offset < 0) { + throw new IndexException(offset, size()); + } + if (offset + length > size()) { + throw new IndexException(offset + length, size()); + } + return new VectorView(vector, offset + this.offset, length); + } + + /** @return true if index is a valid index in the underlying Vector */ + private boolean isInView(int index) { + return index >= offset && index < offset + size(); + } + + @Override + public Iterator<Element> iterateNonZero() { + return new NonZeroIterator(); + } + + @Override + public Iterator<Element> iterator() { + return new AllIterator(); + } + + public final class NonZeroIterator extends AbstractIterator<Element> { + + private final Iterator<Element> it; + + private NonZeroIterator() { + it = vector.nonZeroes().iterator(); + } + + @Override + protected Element computeNext() { + while (it.hasNext()) { + Element el = it.next(); + if (isInView(el.index()) && el.get() != 0) { + Element decorated = el; /* vector.getElement(el.index()); */ + return new DecoratorElement(decorated); + } + } + return endOfData(); + } + + } + + public final class AllIterator extends AbstractIterator<Element> { + + private final Iterator<Element> it; + + private AllIterator() { + it = vector.all().iterator(); + } + + @Override + protected Element computeNext() { + while (it.hasNext()) { + Element el = it.next(); + if (isInView(el.index())) { + Element decorated = vector.getElement(el.index()); + return new DecoratorElement(decorated); + } + } + return endOfData(); // No element was found + } + + } + + private final class DecoratorElement implements Element { + + private final Element decorated; + + private DecoratorElement(Element decorated) { + this.decorated = decorated; + } + + @Override + public double get() { + return decorated.get(); + } + + @Override + public int index() { + return decorated.index() - offset; + } + + @Override + public void set(double value) { + decorated.set(value); + } + } + + @Override + public double getLengthSquared() { + double result = 0.0; + int size = size(); + for (int i = 0; i < size; i++) { + double value = getQuick(i); + result += value * value; + } + return result; + } + + @Override + public double getDistanceSquared(Vector v) { + double result = 0.0; + int size = size(); + for (int i = 0; i < size; i++) { + double delta = getQuick(i) - v.getQuick(i); + result += delta * delta; + } + return result; + } + + @Override + public double getLookupCost() { + return vector.getLookupCost(); + } + + @Override + public double getIteratorAdvanceCost() { + // TODO: remove the 2x after fixing the Element iterator + return 2 * vector.getIteratorAdvanceCost(); + } + + @Override + public boolean isAddConstantTime() { + return vector.isAddConstantTime(); + } + + /** + * Used internally by assign() to update multiple indices and values at once. + * Only really useful for sparse vectors (especially SequentialAccessSparseVector). + * <p> + * If someone ever adds a new type of sparse vectors, this method must merge (index, value) pairs into the vector. + * + * @param updates a mapping of indices to values to merge in the vector. + */ + @Override + public void mergeUpdates(OrderedIntDoubleMapping updates) { + for (int i = 0; i < updates.getNumMappings(); ++i) { + updates.setIndexAt(i, updates.indexAt(i) + offset); + } + vector.mergeUpdates(updates); + } +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/WeightedVector.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/WeightedVector.java b/core/src/main/java/org/apache/mahout/math/WeightedVector.java new file mode 100644 index 0000000..c8fdfac --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/WeightedVector.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math; + +/** + * Decorates a vector with a floating point weight and an index. + */ +public class WeightedVector extends DelegatingVector { + private static final int INVALID_INDEX = -1; + private double weight; + private int index; + + protected WeightedVector(double weight, int index) { + super(); + this.weight = weight; + this.index = index; + } + + public WeightedVector(Vector v, double weight, int index) { + super(v); + this.weight = weight; + this.index = index; + } + + public WeightedVector(Vector v, Vector projection, int index) { + super(v); + this.index = index; + this.weight = v.dot(projection); + } + + public static WeightedVector project(Vector v, Vector projection) { + return project(v, projection, INVALID_INDEX); + } + + public static WeightedVector project(Vector v, Vector projection, int index) { + return new WeightedVector(v, projection, index); + } + + public double getWeight() { + return weight; + } + + public int getIndex() { + return index; + } + + public void setWeight(double newWeight) { + this.weight = newWeight; + } + + public void setIndex(int index) { + this.index = index; + } + + @Override + public Vector like() { + return new WeightedVector(getVector().like(), weight, index); + } + + @Override + public String toString() { + return String.format("index=%d, weight=%.2f, v=%s", index, weight, getVector()); + } + + @Override + public WeightedVector clone() { + WeightedVector v = (WeightedVector)super.clone(); + v.weight = weight; + v.index = index; + return v; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/WeightedVectorComparator.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/WeightedVectorComparator.java b/core/src/main/java/org/apache/mahout/math/WeightedVectorComparator.java new file mode 100644 index 0000000..9fdd621 --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/WeightedVectorComparator.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math; + +import java.io.Serializable; +import java.util.Comparator; + +/** + * Orders {@link WeightedVector} by {@link WeightedVector#getWeight()}. + */ +public final class WeightedVectorComparator implements Comparator<WeightedVector>, Serializable { + + private static final double DOUBLE_EQUALITY_ERROR = 1.0e-8; + + @Override + public int compare(WeightedVector a, WeightedVector b) { + if (a == b) { + return 0; + } + double aWeight = a.getWeight(); + double bWeight = b.getWeight(); + int r = Double.compare(aWeight, bWeight); + if (r != 0 && Math.abs(aWeight - bWeight) >= DOUBLE_EQUALITY_ERROR) { + return r; + } + double diff = a.minus(b).norm(1); + if (diff < 1.0e-12) { + return 0; + } + for (Vector.Element element : a.all()) { + r = Double.compare(element.get(), b.get(element.index())); + if (r != 0) { + return r; + } + } + return 0; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/als/AlternatingLeastSquaresSolver.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/als/AlternatingLeastSquaresSolver.java b/core/src/main/java/org/apache/mahout/math/als/AlternatingLeastSquaresSolver.java new file mode 100644 index 0000000..dbe1f8b --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/als/AlternatingLeastSquaresSolver.java @@ -0,0 +1,116 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.als; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Iterables; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.QRDecomposition; +import org.apache.mahout.math.Vector; + +/** + * See + * <a href="http://www.hpl.hp.com/personal/Robert_Schreiber/papers/2008%20AAIM%20Netflix/netflix_aaim08(submitted).pdf"> + * this paper.</a> + */ +public final class AlternatingLeastSquaresSolver { + + private AlternatingLeastSquaresSolver() {} + + //TODO make feature vectors a simple array + public static Vector solve(Iterable<Vector> featureVectors, Vector ratingVector, double lambda, int numFeatures) { + + Preconditions.checkNotNull(featureVectors, "Feature Vectors cannot be null"); + Preconditions.checkArgument(!Iterables.isEmpty(featureVectors)); + Preconditions.checkNotNull(ratingVector, "Rating Vector cannot be null"); + Preconditions.checkArgument(ratingVector.getNumNondefaultElements() > 0, "Rating Vector cannot be empty"); + Preconditions.checkArgument(Iterables.size(featureVectors) == ratingVector.getNumNondefaultElements()); + + int nui = ratingVector.getNumNondefaultElements(); + + Matrix MiIi = createMiIi(featureVectors, numFeatures); + Matrix RiIiMaybeTransposed = createRiIiMaybeTransposed(ratingVector); + + /* compute Ai = MiIi * t(MiIi) + lambda * nui * E */ + Matrix Ai = miTimesMiTransposePlusLambdaTimesNuiTimesE(MiIi, lambda, nui); + /* compute Vi = MiIi * t(R(i,Ii)) */ + Matrix Vi = MiIi.times(RiIiMaybeTransposed); + /* compute Ai * ui = Vi */ + return solve(Ai, Vi); + } + + private static Vector solve(Matrix Ai, Matrix Vi) { + return new QRDecomposition(Ai).solve(Vi).viewColumn(0); + } + + static Matrix addLambdaTimesNuiTimesE(Matrix matrix, double lambda, int nui) { + Preconditions.checkArgument(matrix.numCols() == matrix.numRows(), "Must be a Square Matrix"); + double lambdaTimesNui = lambda * nui; + int numCols = matrix.numCols(); + for (int n = 0; n < numCols; n++) { + matrix.setQuick(n, n, matrix.getQuick(n, n) + lambdaTimesNui); + } + return matrix; + } + + private static Matrix miTimesMiTransposePlusLambdaTimesNuiTimesE(Matrix MiIi, double lambda, int nui) { + + double lambdaTimesNui = lambda * nui; + int rows = MiIi.numRows(); + + double[][] result = new double[rows][rows]; + + for (int i = 0; i < rows; i++) { + for (int j = i; j < rows; j++) { + double dot = MiIi.viewRow(i).dot(MiIi.viewRow(j)); + if (i != j) { + result[i][j] = dot; + result[j][i] = dot; + } else { + result[i][i] = dot + lambdaTimesNui; + } + } + } + return new DenseMatrix(result, true); + } + + + static Matrix createMiIi(Iterable<Vector> featureVectors, int numFeatures) { + double[][] MiIi = new double[numFeatures][Iterables.size(featureVectors)]; + int n = 0; + for (Vector featureVector : featureVectors) { + for (int m = 0; m < numFeatures; m++) { + MiIi[m][n] = featureVector.getQuick(m); + } + n++; + } + return new DenseMatrix(MiIi, true); + } + + static Matrix createRiIiMaybeTransposed(Vector ratingVector) { + Preconditions.checkArgument(ratingVector.isSequentialAccess(), "Ratings should be iterable in Index or Sequential Order"); + + double[][] RiIiMaybeTransposed = new double[ratingVector.getNumNondefaultElements()][1]; + int index = 0; + for (Vector.Element elem : ratingVector.nonZeroes()) { + RiIiMaybeTransposed[index++][0] = elem.get(); + } + return new DenseMatrix(RiIiMaybeTransposed, true); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/als/ImplicitFeedbackAlternatingLeastSquaresSolver.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/als/ImplicitFeedbackAlternatingLeastSquaresSolver.java b/core/src/main/java/org/apache/mahout/math/als/ImplicitFeedbackAlternatingLeastSquaresSolver.java new file mode 100644 index 0000000..28bf4b4 --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/als/ImplicitFeedbackAlternatingLeastSquaresSolver.java @@ -0,0 +1,171 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.als; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.QRDecomposition; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.Vector.Element; +import org.apache.mahout.math.function.Functions; +import org.apache.mahout.math.list.IntArrayList; +import org.apache.mahout.math.map.OpenIntObjectHashMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.base.Preconditions; + +/** see <a href="http://research.yahoo.com/pub/2433">Collaborative Filtering for Implicit Feedback Datasets</a> */ +public class ImplicitFeedbackAlternatingLeastSquaresSolver { + + private final int numFeatures; + private final double alpha; + private final double lambda; + private final int numTrainingThreads; + + private final OpenIntObjectHashMap<Vector> Y; + private final Matrix YtransposeY; + + private static final Logger log = LoggerFactory.getLogger(ImplicitFeedbackAlternatingLeastSquaresSolver.class); + + public ImplicitFeedbackAlternatingLeastSquaresSolver(int numFeatures, double lambda, double alpha, + OpenIntObjectHashMap<Vector> Y, int numTrainingThreads) { + this.numFeatures = numFeatures; + this.lambda = lambda; + this.alpha = alpha; + this.Y = Y; + this.numTrainingThreads = numTrainingThreads; + YtransposeY = getYtransposeY(Y); + } + + public Vector solve(Vector ratings) { + return solve(YtransposeY.plus(getYtransponseCuMinusIYPlusLambdaI(ratings)), getYtransponseCuPu(ratings)); + } + + private static Vector solve(Matrix A, Matrix y) { + return new QRDecomposition(A).solve(y).viewColumn(0); + } + + double confidence(double rating) { + return 1 + alpha * rating; + } + + /* Y' Y */ + public Matrix getYtransposeY(final OpenIntObjectHashMap<Vector> Y) { + + ExecutorService queue = Executors.newFixedThreadPool(numTrainingThreads); + if (log.isInfoEnabled()) { + log.info("Starting the computation of Y'Y"); + } + long startTime = System.nanoTime(); + final IntArrayList indexes = Y.keys(); + final int numIndexes = indexes.size(); + + final double[][] YtY = new double[numFeatures][numFeatures]; + + // Compute Y'Y by dot products between the 'columns' of Y + for (int i = 0; i < numFeatures; i++) { + for (int j = i; j < numFeatures; j++) { + + final int ii = i; + final int jj = j; + queue.execute(new Runnable() { + @Override + public void run() { + double dot = 0; + for (int k = 0; k < numIndexes; k++) { + Vector row = Y.get(indexes.getQuick(k)); + dot += row.getQuick(ii) * row.getQuick(jj); + } + YtY[ii][jj] = dot; + if (ii != jj) { + YtY[jj][ii] = dot; + } + } + }); + + } + } + queue.shutdown(); + try { + queue.awaitTermination(1, TimeUnit.DAYS); + } catch (InterruptedException e) { + log.error("Error during Y'Y queue shutdown", e); + throw new RuntimeException("Error during Y'Y queue shutdown"); + } + if (log.isInfoEnabled()) { + log.info("Computed Y'Y in " + (System.nanoTime() - startTime) / 1000000.0 + " ms" ); + } + return new DenseMatrix(YtY, true); + } + + /** Y' (Cu - I) Y + λ I */ + private Matrix getYtransponseCuMinusIYPlusLambdaI(Vector userRatings) { + Preconditions.checkArgument(userRatings.isSequentialAccess(), "need sequential access to ratings!"); + + /* (Cu -I) Y */ + OpenIntObjectHashMap<Vector> CuMinusIY = new OpenIntObjectHashMap<>(userRatings.getNumNondefaultElements()); + for (Element e : userRatings.nonZeroes()) { + CuMinusIY.put(e.index(), Y.get(e.index()).times(confidence(e.get()) - 1)); + } + + Matrix YtransponseCuMinusIY = new DenseMatrix(numFeatures, numFeatures); + + /* Y' (Cu -I) Y by outer products */ + for (Element e : userRatings.nonZeroes()) { + for (Vector.Element feature : Y.get(e.index()).all()) { + Vector partial = CuMinusIY.get(e.index()).times(feature.get()); + YtransponseCuMinusIY.viewRow(feature.index()).assign(partial, Functions.PLUS); + } + } + + /* Y' (Cu - I) Y + λ I add lambda on the diagonal */ + for (int feature = 0; feature < numFeatures; feature++) { + YtransponseCuMinusIY.setQuick(feature, feature, YtransponseCuMinusIY.getQuick(feature, feature) + lambda); + } + + return YtransponseCuMinusIY; + } + + /** Y' Cu p(u) */ + private Matrix getYtransponseCuPu(Vector userRatings) { + Preconditions.checkArgument(userRatings.isSequentialAccess(), "need sequential access to ratings!"); + + Vector YtransponseCuPu = new DenseVector(numFeatures); + + for (Element e : userRatings.nonZeroes()) { + YtransponseCuPu.assign(Y.get(e.index()).times(confidence(e.get())), Functions.PLUS); + } + + return columnVectorAsMatrix(YtransponseCuPu); + } + + private Matrix columnVectorAsMatrix(Vector v) { + double[][] matrix = new double[numFeatures][1]; + for (Vector.Element e : v.all()) { + matrix[e.index()][0] = e.get(); + } + return new DenseMatrix(matrix, true); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/decomposer/AsyncEigenVerifier.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/decomposer/AsyncEigenVerifier.java b/core/src/main/java/org/apache/mahout/math/decomposer/AsyncEigenVerifier.java new file mode 100644 index 0000000..0233848 --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/decomposer/AsyncEigenVerifier.java @@ -0,0 +1,80 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.decomposer; + +import java.io.Closeable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorIterable; + +public class AsyncEigenVerifier extends SimpleEigenVerifier implements Closeable { + + private final ExecutorService threadPool; + private EigenStatus status; + private boolean finished; + private boolean started; + + public AsyncEigenVerifier() { + threadPool = Executors.newFixedThreadPool(1); + status = new EigenStatus(-1, 0); + } + + @Override + public synchronized EigenStatus verify(VectorIterable corpus, Vector vector) { + if (!finished && !started) { // not yet started or finished, so start! + status = new EigenStatus(-1, 0); + Vector vectorCopy = vector.clone(); + threadPool.execute(new VerifierRunnable(corpus, vectorCopy)); + started = true; + } + if (finished) { + finished = false; + } + return status; + } + + @Override + public void close() { + this.threadPool.shutdownNow(); + } + protected EigenStatus innerVerify(VectorIterable corpus, Vector vector) { + return super.verify(corpus, vector); + } + + private class VerifierRunnable implements Runnable { + private final VectorIterable corpus; + private final Vector vector; + + protected VerifierRunnable(VectorIterable corpus, Vector vector) { + this.corpus = corpus; + this.vector = vector; + } + + @Override + public void run() { + EigenStatus status = innerVerify(corpus, vector); + synchronized (AsyncEigenVerifier.this) { + AsyncEigenVerifier.this.status = status; + finished = true; + started = false; + } + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/decomposer/EigenStatus.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/decomposer/EigenStatus.java b/core/src/main/java/org/apache/mahout/math/decomposer/EigenStatus.java new file mode 100644 index 0000000..a284f50 --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/decomposer/EigenStatus.java @@ -0,0 +1,50 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.decomposer; + +public class EigenStatus { + private final double eigenValue; + private final double cosAngle; + private volatile Boolean inProgress; + + public EigenStatus(double eigenValue, double cosAngle) { + this(eigenValue, cosAngle, true); + } + + public EigenStatus(double eigenValue, double cosAngle, boolean inProgress) { + this.eigenValue = eigenValue; + this.cosAngle = cosAngle; + this.inProgress = inProgress; + } + + public double getCosAngle() { + return cosAngle; + } + + public double getEigenValue() { + return eigenValue; + } + + public boolean inProgress() { + return inProgress; + } + + void setInProgress(boolean status) { + inProgress = status; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/decomposer/SimpleEigenVerifier.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/decomposer/SimpleEigenVerifier.java b/core/src/main/java/org/apache/mahout/math/decomposer/SimpleEigenVerifier.java new file mode 100644 index 0000000..71aaa30 --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/decomposer/SimpleEigenVerifier.java @@ -0,0 +1,41 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.math.decomposer; + +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorIterable; + +public class SimpleEigenVerifier implements SingularVectorVerifier { + + @Override + public EigenStatus verify(VectorIterable corpus, Vector vector) { + Vector resultantVector = corpus.timesSquared(vector); + double newNorm = resultantVector.norm(2); + double oldNorm = vector.norm(2); + double eigenValue; + double cosAngle; + if (newNorm > 0 && oldNorm > 0) { + eigenValue = newNorm / oldNorm; + cosAngle = resultantVector.dot(vector) / newNorm * oldNorm; + } else { + eigenValue = 1.0; + cosAngle = 0.0; + } + return new EigenStatus(eigenValue, cosAngle, false); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/decomposer/SingularVectorVerifier.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/decomposer/SingularVectorVerifier.java b/core/src/main/java/org/apache/mahout/math/decomposer/SingularVectorVerifier.java new file mode 100644 index 0000000..a9a7af8 --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/decomposer/SingularVectorVerifier.java @@ -0,0 +1,25 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.decomposer; + +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorIterable; + +public interface SingularVectorVerifier { + EigenStatus verify(VectorIterable eigenMatrix, Vector vector); +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/decomposer/hebbian/EigenUpdater.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/decomposer/hebbian/EigenUpdater.java b/core/src/main/java/org/apache/mahout/math/decomposer/hebbian/EigenUpdater.java new file mode 100644 index 0000000..ac9cc41 --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/decomposer/hebbian/EigenUpdater.java @@ -0,0 +1,25 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.decomposer.hebbian; + +import org.apache.mahout.math.Vector; + + +public interface EigenUpdater { + void update(Vector pseudoEigen, Vector trainingVector, TrainingState currentState); +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/decomposer/hebbian/HebbianSolver.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/decomposer/hebbian/HebbianSolver.java b/core/src/main/java/org/apache/mahout/math/decomposer/hebbian/HebbianSolver.java new file mode 100644 index 0000000..5b5cc9b --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/decomposer/hebbian/HebbianSolver.java @@ -0,0 +1,342 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.decomposer.hebbian; + +import java.util.ArrayList; +import java.util.List; +import java.util.Properties; +import java.util.Random; + +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.decomposer.AsyncEigenVerifier; +import org.apache.mahout.math.decomposer.EigenStatus; +import org.apache.mahout.math.decomposer.SingularVectorVerifier; +import org.apache.mahout.math.function.PlusMult; +import org.apache.mahout.math.function.TimesFunction; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * The Hebbian solver is an iterative, sparse, singular value decomposition solver, based on the paper + * <a href="http://www.dcs.shef.ac.uk/~genevieve/gorrell_webb.pdf">Generalized Hebbian Algorithm for + * Latent Semantic Analysis</a> (2005) by Genevieve Gorrell and Brandyn Webb (a.k.a. Simon Funk). + * TODO: more description here! For now: read the inline comments, and the comments for the constructors. + */ +public class HebbianSolver { + + private static final Logger log = LoggerFactory.getLogger(HebbianSolver.class); + private static final boolean DEBUG = false; + + private final EigenUpdater updater; + private final SingularVectorVerifier verifier; + private final double convergenceTarget; + private final int maxPassesPerEigen; + private final Random rng = RandomUtils.getRandom(); + + private int numPasses = 0; + + /** + * Creates a new HebbianSolver + * + * @param updater + * {@link EigenUpdater} used to do the actual work of iteratively updating the current "best guess" + * singular vector one data-point presentation at a time. + * @param verifier + * {@link SingularVectorVerifier } an object which perpetually tries to check how close to + * convergence the current singular vector is (typically is a + * {@link org.apache.mahout.math.decomposer.AsyncEigenVerifier } which does this + * in the background in another thread, while the main thread continues to converge) + * @param convergenceTarget a small "epsilon" value which tells the solver how small you want the cosine of the + * angle between a proposed eigenvector and that same vector after being multiplied by the (square of the) input + * corpus + * @param maxPassesPerEigen a cutoff which tells the solver after how many times of checking for convergence (done + * by the verifier) should the solver stop trying, even if it has not reached the convergenceTarget. + */ + public HebbianSolver(EigenUpdater updater, + SingularVectorVerifier verifier, + double convergenceTarget, + int maxPassesPerEigen) { + this.updater = updater; + this.verifier = verifier; + this.convergenceTarget = convergenceTarget; + this.maxPassesPerEigen = maxPassesPerEigen; + } + + /** + * Creates a new HebbianSolver with maxPassesPerEigen = Integer.MAX_VALUE (i.e. keep on iterating until + * convergenceTarget is reached). <b>Not recommended</b> unless only looking for + * the first few (5, maybe 10?) singular + * vectors, as small errors which compound early on quickly put a minimum error on subsequent vectors. + * + * @param updater {@link EigenUpdater} used to do the actual work of iteratively updating the current "best guess" + * singular vector one data-point presentation at a time. + * @param verifier {@link org.apache.mahout.math.decomposer.SingularVectorVerifier } + * an object which perpetually tries to check how close to + * convergence the current singular vector is (typically is a + * {@link org.apache.mahout.math.decomposer.AsyncEigenVerifier } which does this + * in the background in another thread, while the main thread continues to converge) + * @param convergenceTarget a small "epsilon" value which tells the solver how small you want the cosine of the + * angle between a proposed eigenvector and that same vector after being multiplied by the (square of the) input + * corpus + */ + public HebbianSolver(EigenUpdater updater, + SingularVectorVerifier verifier, + double convergenceTarget) { + this(updater, + verifier, + convergenceTarget, + Integer.MAX_VALUE); + } + + /** + * <b>This is the recommended constructor to use if you're not sure</b> + * Creates a new HebbianSolver with the default {@link HebbianUpdater } to do the updating work, and the default + * {@link org.apache.mahout.math.decomposer.AsyncEigenVerifier } to check for convergence in a + * (single) background thread. + * + * @param convergenceTarget a small "epsilon" value which tells the solver how small you want the cosine of the + * angle between a proposed eigenvector and that same vector after being multiplied by the (square of the) input + * corpus + * @param maxPassesPerEigen a cutoff which tells the solver after how many times of checking for convergence (done + * by the verifier) should the solver stop trying, even if it has not reached the convergenceTarget. + */ + public HebbianSolver(double convergenceTarget, int maxPassesPerEigen) { + this(new HebbianUpdater(), + new AsyncEigenVerifier(), + convergenceTarget, + maxPassesPerEigen); + } + + /** + * Creates a new HebbianSolver with the default {@link HebbianUpdater } to do the updating work, and the default + * {@link org.apache.mahout.math.decomposer.AsyncEigenVerifier } to check for convergence in a (single) + * background thread, with + * maxPassesPerEigen set to Integer.MAX_VALUE. <b>Not recommended</b> unless only looking + * for the first few (5, maybe 10?) singular + * vectors, as small errors which compound early on quickly put a minimum error on subsequent vectors. + * + * @param convergenceTarget a small "epsilon" value which tells the solver how small you want the cosine of the + * angle between a proposed eigenvector and that same vector after being multiplied by the (square of the) input + * corpus + */ + public HebbianSolver(double convergenceTarget) { + this(convergenceTarget, Integer.MAX_VALUE); + } + + /** + * Creates a new HebbianSolver with the default {@link HebbianUpdater } to do the updating work, and the default + * {@link org.apache.mahout.math.decomposer.AsyncEigenVerifier } to check for convergence in a (single) + * background thread, with + * convergenceTarget set to 0, which means that the solver will not really care about convergence as a loop-exiting + * criterion (but will be checking for convergence anyways, so it will be logged and singular values will be + * saved). + * + * @param numPassesPerEigen the exact number of times the verifier will check convergence status in the background + * before the solver will move on to the next eigen-vector. + */ + public HebbianSolver(int numPassesPerEigen) { + this(0.0, numPassesPerEigen); + } + + /** + * Primary singular vector solving method. + * + * @param corpus input matrix to find singular vectors of. Needs not be symmetric, should probably be sparse (in + * fact the input vectors are not mutated, and accessed only via dot-products and sums, so they should be + * {@link org.apache.mahout.math.SequentialAccessSparseVector } + * @param desiredRank the number of singular vectors to find (in roughly decreasing order by singular value) + * @return the final {@link TrainingState } of the solver, after desiredRank singular vectors (and approximate + * singular values) have been found. + */ + public TrainingState solve(Matrix corpus, + int desiredRank) { + int cols = corpus.numCols(); + Matrix eigens = new DenseMatrix(desiredRank, cols); + List<Double> eigenValues = new ArrayList<>(); + log.info("Finding {} singular vectors of matrix with {} rows, via Hebbian", desiredRank, corpus.numRows()); + /* + * The corpusProjections matrix is a running cache of the residual projection of each corpus vector against all + * of the previously found singular vectors. Without this, if multiple passes over the data is made (per + * singular vector), recalculating these projections eventually dominates the computational complexity of the + * solver. + */ + Matrix corpusProjections = new DenseMatrix(corpus.numRows(), desiredRank); + TrainingState state = new TrainingState(eigens, corpusProjections); + for (int i = 0; i < desiredRank; i++) { + Vector currentEigen = new DenseVector(cols); + Vector previousEigen = null; + while (hasNotConverged(currentEigen, corpus, state)) { + int randomStartingIndex = getRandomStartingIndex(corpus, eigens); + Vector initialTrainingVector = corpus.viewRow(randomStartingIndex); + state.setTrainingIndex(randomStartingIndex); + updater.update(currentEigen, initialTrainingVector, state); + for (int corpusRow = 0; corpusRow < corpus.numRows(); corpusRow++) { + state.setTrainingIndex(corpusRow); + if (corpusRow != randomStartingIndex) { + updater.update(currentEigen, corpus.viewRow(corpusRow), state); + } + } + state.setFirstPass(false); + if (DEBUG) { + if (previousEigen == null) { + previousEigen = currentEigen.clone(); + } else { + double dot = currentEigen.dot(previousEigen); + if (dot > 0.0) { + dot /= currentEigen.norm(2) * previousEigen.norm(2); + } + // log.info("Current pass * previous pass = {}", dot); + } + } + } + // converged! + double eigenValue = state.getStatusProgress().get(state.getStatusProgress().size() - 1).getEigenValue(); + // it's actually more efficient to do this to normalize than to call currentEigen = currentEigen.normalize(), + // because the latter does a clone, which isn't necessary here. + currentEigen.assign(new TimesFunction(), 1 / currentEigen.norm(2)); + eigens.assignRow(i, currentEigen); + eigenValues.add(eigenValue); + state.setCurrentEigenValues(eigenValues); + log.info("Found eigenvector {}, eigenvalue: {}", i, eigenValue); + + /** + * TODO: Persist intermediate output! + */ + state.setFirstPass(true); + state.setNumEigensProcessed(state.getNumEigensProcessed() + 1); + state.setActivationDenominatorSquared(0); + state.setActivationNumerator(0); + state.getStatusProgress().clear(); + numPasses = 0; + } + return state; + } + + /** + * You have to start somewhere... + * TODO: start instead wherever you find a vector with maximum residual length after subtracting off the projection + * TODO: onto all previous eigenvectors. + * + * @param corpus the corpus matrix + * @param eigens not currently used, but should be (see above TODO) + * @return the index into the corpus where the "starting seed" input vector lies. + */ + private int getRandomStartingIndex(Matrix corpus, Matrix eigens) { + int index; + Vector v; + do { + double r = rng.nextDouble(); + index = (int) (r * corpus.numRows()); + v = corpus.viewRow(index); + } while (v == null || v.norm(2) == 0 || v.getNumNondefaultElements() < 5); + return index; + } + + /** + * Uses the {@link SingularVectorVerifier } to check for convergence + * + * @param currentPseudoEigen the purported singular vector whose convergence is being checked + * @param corpus the corpus to check against + * @param state contains the previous eigens, various other solving state {@link TrainingState} + * @return true if <em>either</em> we have converged, <em>or</em> maxPassesPerEigen has been exceeded. + */ + protected boolean hasNotConverged(Vector currentPseudoEigen, + Matrix corpus, + TrainingState state) { + numPasses++; + if (state.isFirstPass()) { + log.info("First pass through the corpus, no need to check convergence..."); + return true; + } + Matrix previousEigens = state.getCurrentEigens(); + log.info("Have made {} passes through the corpus, checking convergence...", numPasses); + /* + * Step 1: orthogonalize currentPseudoEigen by subtracting off eigen(i) * helper.get(i) + * Step 2: zero-out the helper vector because it has already helped. + */ + for (int i = 0; i < state.getNumEigensProcessed(); i++) { + Vector previousEigen = previousEigens.viewRow(i); + currentPseudoEigen.assign(previousEigen, new PlusMult(-state.getHelperVector().get(i))); + state.getHelperVector().set(i, 0); + } + if (currentPseudoEigen.norm(2) > 0) { + for (int i = 0; i < state.getNumEigensProcessed(); i++) { + Vector previousEigen = previousEigens.viewRow(i); + log.info("dot with previous: {}", previousEigen.dot(currentPseudoEigen) / currentPseudoEigen.norm(2)); + } + } + /* + * Step 3: verify how eigen-like the prospective eigen is. This is potentially asynchronous. + */ + EigenStatus status = verify(corpus, currentPseudoEigen); + if (status.inProgress()) { + log.info("Verifier not finished, making another pass..."); + } else { + log.info("Has 1 - cosAngle: {}, convergence target is: {}", 1.0 - status.getCosAngle(), convergenceTarget); + state.getStatusProgress().add(status); + } + return + state.getStatusProgress().size() <= maxPassesPerEigen + && 1.0 - status.getCosAngle() > convergenceTarget; + } + + protected EigenStatus verify(Matrix corpus, Vector currentPseudoEigen) { + return verifier.verify(corpus, currentPseudoEigen); + } + + public static void main(String[] args) { + Properties props = new Properties(); + String propertiesFile = args.length > 0 ? args[0] : "config/solver.properties"; + // props.load(new FileInputStream(propertiesFile)); + + String corpusDir = props.getProperty("solver.input.dir"); + String outputDir = props.getProperty("solver.output.dir"); + if (corpusDir == null || corpusDir.isEmpty() || outputDir == null || outputDir.isEmpty()) { + log.error("{} must contain values for solver.input.dir and solver.output.dir", propertiesFile); + return; + } + //int inBufferSize = Integer.parseInt(props.getProperty("solver.input.bufferSize")); + int rank = Integer.parseInt(props.getProperty("solver.output.desiredRank")); + double convergence = Double.parseDouble(props.getProperty("solver.convergence")); + int maxPasses = Integer.parseInt(props.getProperty("solver.maxPasses")); + //int numThreads = Integer.parseInt(props.getProperty("solver.verifier.numThreads")); + + HebbianUpdater updater = new HebbianUpdater(); + SingularVectorVerifier verifier = new AsyncEigenVerifier(); + HebbianSolver solver = new HebbianSolver(updater, verifier, convergence, maxPasses); + Matrix corpus = null; + /* + if (numThreads <= 1) { + // corpus = new DiskBufferedDoubleMatrix(new File(corpusDir), inBufferSize); + } else { + // corpus = new ParallelMultiplyingDiskBufferedDoubleMatrix(new File(corpusDir), inBufferSize, numThreads); + } + */ + long now = System.currentTimeMillis(); + TrainingState finalState = solver.solve(corpus, rank); + long time = (System.currentTimeMillis() - now) / 1000; + log.info("Solved {} eigenVectors in {} seconds. Persisted to {}", + finalState.getCurrentEigens().rowSize(), time, outputDir); + } + + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/decomposer/hebbian/HebbianUpdater.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/decomposer/hebbian/HebbianUpdater.java b/core/src/main/java/org/apache/mahout/math/decomposer/hebbian/HebbianUpdater.java new file mode 100644 index 0000000..2080c3a --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/decomposer/hebbian/HebbianUpdater.java @@ -0,0 +1,71 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.decomposer.hebbian; + + +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.function.PlusMult; + +public class HebbianUpdater implements EigenUpdater { + + @Override + public void update(Vector pseudoEigen, + Vector trainingVector, + TrainingState currentState) { + double trainingVectorNorm = trainingVector.norm(2); + int numPreviousEigens = currentState.getNumEigensProcessed(); + if (numPreviousEigens > 0 && currentState.isFirstPass()) { + updateTrainingProjectionsVector(currentState, trainingVector, numPreviousEigens - 1); + } + if (currentState.getActivationDenominatorSquared() == 0 || trainingVectorNorm == 0) { + if (currentState.getActivationDenominatorSquared() == 0) { + pseudoEigen.assign(trainingVector, new PlusMult(1)); + currentState.setHelperVector(currentState.currentTrainingProjection().clone()); + double helperNorm = currentState.getHelperVector().norm(2); + currentState.setActivationDenominatorSquared(trainingVectorNorm * trainingVectorNorm - helperNorm * helperNorm); + } + return; + } + currentState.setActivationNumerator(pseudoEigen.dot(trainingVector)); + currentState.setActivationNumerator( + currentState.getActivationNumerator() + - currentState.getHelperVector().dot(currentState.currentTrainingProjection())); + + double activation = currentState.getActivationNumerator() + / Math.sqrt(currentState.getActivationDenominatorSquared()); + currentState.setActivationDenominatorSquared( + currentState.getActivationDenominatorSquared() + + 2 * activation * currentState.getActivationNumerator() + + activation * activation + * (trainingVector.getLengthSquared() - currentState.currentTrainingProjection().getLengthSquared())); + if (numPreviousEigens > 0) { + currentState.getHelperVector().assign(currentState.currentTrainingProjection(), new PlusMult(activation)); + } + pseudoEigen.assign(trainingVector, new PlusMult(activation)); + } + + private static void updateTrainingProjectionsVector(TrainingState state, + Vector trainingVector, + int previousEigenIndex) { + Vector previousEigen = state.mostRecentEigen(); + Vector currentTrainingVectorProjection = state.currentTrainingProjection(); + double projection = previousEigen.dot(trainingVector); + currentTrainingVectorProjection.set(previousEigenIndex, projection); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/decomposer/hebbian/TrainingState.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/decomposer/hebbian/TrainingState.java b/core/src/main/java/org/apache/mahout/math/decomposer/hebbian/TrainingState.java new file mode 100644 index 0000000..af6c2ef --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/decomposer/hebbian/TrainingState.java @@ -0,0 +1,143 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.decomposer.hebbian; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.decomposer.EigenStatus; + +public class TrainingState { + + private Matrix currentEigens; + private int numEigensProcessed; + private List<Double> currentEigenValues; + private Matrix trainingProjections; + private int trainingIndex; + private Vector helperVector; + private boolean firstPass; + private List<EigenStatus> statusProgress; + private double activationNumerator; + private double activationDenominatorSquared; + + TrainingState(Matrix eigens, Matrix projections) { + currentEigens = eigens; + trainingProjections = projections; + trainingIndex = 0; + helperVector = new DenseVector(eigens.numRows()); + firstPass = true; + statusProgress = new ArrayList<>(); + activationNumerator = 0; + activationDenominatorSquared = 0; + numEigensProcessed = 0; + } + + public Vector mostRecentEigen() { + return currentEigens.viewRow(numEigensProcessed - 1); + } + + public Vector currentTrainingProjection() { + if (trainingProjections.viewRow(trainingIndex) == null) { + trainingProjections.assignRow(trainingIndex, new DenseVector(currentEigens.numCols())); + } + return trainingProjections.viewRow(trainingIndex); + } + + public Matrix getCurrentEigens() { + return currentEigens; + } + + public void setCurrentEigens(Matrix currentEigens) { + this.currentEigens = currentEigens; + } + + public int getNumEigensProcessed() { + return numEigensProcessed; + } + + public void setNumEigensProcessed(int numEigensProcessed) { + this.numEigensProcessed = numEigensProcessed; + } + + public List<Double> getCurrentEigenValues() { + return currentEigenValues; + } + + public void setCurrentEigenValues(List<Double> currentEigenValues) { + this.currentEigenValues = currentEigenValues; + } + + public Matrix getTrainingProjections() { + return trainingProjections; + } + + public void setTrainingProjections(Matrix trainingProjections) { + this.trainingProjections = trainingProjections; + } + + public int getTrainingIndex() { + return trainingIndex; + } + + public void setTrainingIndex(int trainingIndex) { + this.trainingIndex = trainingIndex; + } + + public Vector getHelperVector() { + return helperVector; + } + + public void setHelperVector(Vector helperVector) { + this.helperVector = helperVector; + } + + public boolean isFirstPass() { + return firstPass; + } + + public void setFirstPass(boolean firstPass) { + this.firstPass = firstPass; + } + + public List<EigenStatus> getStatusProgress() { + return statusProgress; + } + + public void setStatusProgress(List<EigenStatus> statusProgress) { + this.statusProgress = statusProgress; + } + + public double getActivationNumerator() { + return activationNumerator; + } + + public void setActivationNumerator(double activationNumerator) { + this.activationNumerator = activationNumerator; + } + + public double getActivationDenominatorSquared() { + return activationDenominatorSquared; + } + + public void setActivationDenominatorSquared(double activationDenominatorSquared) { + this.activationDenominatorSquared = activationDenominatorSquared; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/decomposer/lanczos/LanczosSolver.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/decomposer/lanczos/LanczosSolver.java b/core/src/main/java/org/apache/mahout/math/decomposer/lanczos/LanczosSolver.java new file mode 100644 index 0000000..61a77db --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/decomposer/lanczos/LanczosSolver.java @@ -0,0 +1,213 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.decomposer.lanczos; + + +import java.util.EnumMap; +import java.util.Map; + +import com.google.common.base.Preconditions; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorIterable; +import org.apache.mahout.math.function.DoubleFunction; +import org.apache.mahout.math.function.PlusMult; +import org.apache.mahout.math.solver.EigenDecomposition; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Simple implementation of the <a href="http://en.wikipedia.org/wiki/Lanczos_algorithm">Lanczos algorithm</a> for + * finding eigenvalues of a symmetric matrix, applied to non-symmetric matrices by applying Matrix.timesSquared(vector) + * as the "matrix-multiplication" method.<p> + * + * See the SSVD code for a better option + * {@link org.apache.mahout.math.ssvd.SequentialBigSvd} + * See also the docs on + * <a href=https://mahout.apache.org/users/dim-reduction/ssvd.html>stochastic + * projection SVD</a> + * <p> + * To avoid floating point overflow problems which arise in power-methods like Lanczos, an initial pass is made + * through the input matrix to + * <ul> + * <li>generate a good starting seed vector by summing all the rows of the input matrix, and</li> + * <li>compute the trace(inputMatrix<sup>t</sup>*matrix) + * </ul> + * <p> + * This latter value, being the sum of all of the singular values, is used to rescale the entire matrix, effectively + * forcing the largest singular value to be strictly less than one, and transforming floating point <em>overflow</em> + * problems into floating point <em>underflow</em> (ie, very small singular values will become invisible, as they + * will appear to be zero and the algorithm will terminate). + * <p>This implementation uses {@link EigenDecomposition} to do the + * eigenvalue extraction from the small (desiredRank x desiredRank) tridiagonal matrix. Numerical stability is + * achieved via brute-force: re-orthogonalization against all previous eigenvectors is computed after every pass. + * This can be made smarter if (when!) this proves to be a major bottleneck. Of course, this step can be parallelized + * as well. + * @see org.apache.mahout.math.ssvd.SequentialBigSvd + */ +@Deprecated +public class LanczosSolver { + + private static final Logger log = LoggerFactory.getLogger(LanczosSolver.class); + + public static final double SAFE_MAX = 1.0e150; + + public enum TimingSection { + ITERATE, ORTHOGANLIZE, TRIDIAG_DECOMP, FINAL_EIGEN_CREATE + } + + private final Map<TimingSection, Long> startTimes = new EnumMap<>(TimingSection.class); + private final Map<TimingSection, Long> times = new EnumMap<>(TimingSection.class); + + private static final class Scale extends DoubleFunction { + private final double d; + + private Scale(double d) { + this.d = d; + } + + @Override + public double apply(double arg1) { + return arg1 * d; + } + } + + public void solve(LanczosState state, + int desiredRank) { + solve(state, desiredRank, false); + } + + public void solve(LanczosState state, + int desiredRank, + boolean isSymmetric) { + VectorIterable corpus = state.getCorpus(); + log.info("Finding {} singular vectors of matrix with {} rows, via Lanczos", + desiredRank, corpus.numRows()); + int i = state.getIterationNumber(); + Vector currentVector = state.getBasisVector(i - 1); + Vector previousVector = state.getBasisVector(i - 2); + double beta = 0; + Matrix triDiag = state.getDiagonalMatrix(); + while (i < desiredRank) { + startTime(TimingSection.ITERATE); + Vector nextVector = isSymmetric ? corpus.times(currentVector) : corpus.timesSquared(currentVector); + log.info("{} passes through the corpus so far...", i); + if (state.getScaleFactor() <= 0) { + state.setScaleFactor(calculateScaleFactor(nextVector)); + } + nextVector.assign(new Scale(1.0 / state.getScaleFactor())); + if (previousVector != null) { + nextVector.assign(previousVector, new PlusMult(-beta)); + } + // now orthogonalize + double alpha = currentVector.dot(nextVector); + nextVector.assign(currentVector, new PlusMult(-alpha)); + endTime(TimingSection.ITERATE); + startTime(TimingSection.ORTHOGANLIZE); + orthoganalizeAgainstAllButLast(nextVector, state); + endTime(TimingSection.ORTHOGANLIZE); + // and normalize + beta = nextVector.norm(2); + if (outOfRange(beta) || outOfRange(alpha)) { + log.warn("Lanczos parameters out of range: alpha = {}, beta = {}. Bailing out early!", + alpha, beta); + break; + } + nextVector.assign(new Scale(1 / beta)); + state.setBasisVector(i, nextVector); + previousVector = currentVector; + currentVector = nextVector; + // save the projections and norms! + triDiag.set(i - 1, i - 1, alpha); + if (i < desiredRank - 1) { + triDiag.set(i - 1, i, beta); + triDiag.set(i, i - 1, beta); + } + state.setIterationNumber(++i); + } + startTime(TimingSection.TRIDIAG_DECOMP); + + log.info("Lanczos iteration complete - now to diagonalize the tri-diagonal auxiliary matrix."); + // at this point, have tridiag all filled out, and basis is all filled out, and orthonormalized + EigenDecomposition decomp = new EigenDecomposition(triDiag); + + Matrix eigenVects = decomp.getV(); + Vector eigenVals = decomp.getRealEigenvalues(); + endTime(TimingSection.TRIDIAG_DECOMP); + startTime(TimingSection.FINAL_EIGEN_CREATE); + for (int row = 0; row < i; row++) { + Vector realEigen = null; + + Vector ejCol = eigenVects.viewColumn(row); + int size = Math.min(ejCol.size(), state.getBasisSize()); + for (int j = 0; j < size; j++) { + double d = ejCol.get(j); + Vector rowJ = state.getBasisVector(j); + if (realEigen == null) { + realEigen = rowJ.like(); + } + realEigen.assign(rowJ, new PlusMult(d)); + } + + Preconditions.checkState(realEigen != null); + assert realEigen != null; + + realEigen = realEigen.normalize(); + state.setRightSingularVector(row, realEigen); + double e = eigenVals.get(row) * state.getScaleFactor(); + if (!isSymmetric) { + e = Math.sqrt(e); + } + log.info("Eigenvector {} found with eigenvalue {}", row, e); + state.setSingularValue(row, e); + } + log.info("LanczosSolver finished."); + endTime(TimingSection.FINAL_EIGEN_CREATE); + } + + protected static double calculateScaleFactor(Vector nextVector) { + return nextVector.norm(2); + } + + private static boolean outOfRange(double d) { + return Double.isNaN(d) || d > SAFE_MAX || -d > SAFE_MAX; + } + + protected static void orthoganalizeAgainstAllButLast(Vector nextVector, LanczosState state) { + for (int i = 0; i < state.getIterationNumber(); i++) { + Vector basisVector = state.getBasisVector(i); + double alpha; + if (basisVector == null || (alpha = nextVector.dot(basisVector)) == 0.0) { + continue; + } + nextVector.assign(basisVector, new PlusMult(-alpha)); + } + } + + private void startTime(TimingSection section) { + startTimes.put(section, System.nanoTime()); + } + + private void endTime(TimingSection section) { + if (!times.containsKey(section)) { + times.put(section, 0L); + } + times.put(section, times.get(section) + System.nanoTime() - startTimes.get(section)); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/decomposer/lanczos/LanczosState.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/decomposer/lanczos/LanczosState.java b/core/src/main/java/org/apache/mahout/math/decomposer/lanczos/LanczosState.java new file mode 100644 index 0000000..2ba34bd --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/decomposer/lanczos/LanczosState.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.decomposer.lanczos; + +import com.google.common.collect.Maps; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorIterable; + +import java.util.Map; + +@Deprecated +public class LanczosState { + + protected Matrix diagonalMatrix; + protected final VectorIterable corpus; + protected double scaleFactor; + protected int iterationNumber; + protected final int desiredRank; + protected Map<Integer, Vector> basis; + protected final Map<Integer, Double> singularValues; + protected Map<Integer, Vector> singularVectors; + + public LanczosState(VectorIterable corpus, int desiredRank, Vector initialVector) { + this.corpus = corpus; + this.desiredRank = desiredRank; + intitializeBasisAndSingularVectors(); + setBasisVector(0, initialVector); + scaleFactor = 0; + diagonalMatrix = new DenseMatrix(desiredRank, desiredRank); + singularValues = Maps.newHashMap(); + iterationNumber = 1; + } + + private void intitializeBasisAndSingularVectors() { + basis = Maps.newHashMap(); + singularVectors = Maps.newHashMap(); + } + + public Matrix getDiagonalMatrix() { + return diagonalMatrix; + } + + public int getIterationNumber() { + return iterationNumber; + } + + public double getScaleFactor() { + return scaleFactor; + } + + public VectorIterable getCorpus() { + return corpus; + } + + public Vector getRightSingularVector(int i) { + return singularVectors.get(i); + } + + public Double getSingularValue(int i) { + return singularValues.get(i); + } + + public Vector getBasisVector(int i) { + return basis.get(i); + } + + public int getBasisSize() { + return basis.size(); + } + + public void setBasisVector(int i, Vector basisVector) { + basis.put(i, basisVector); + } + + public void setScaleFactor(double scale) { + scaleFactor = scale; + } + + public void setIterationNumber(int i) { + iterationNumber = i; + } + + public void setRightSingularVector(int i, Vector vector) { + singularVectors.put(i, vector); + } + + public void setSingularValue(int i, double value) { + singularValues.put(i, value); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/flavor/BackEnum.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/flavor/BackEnum.java b/core/src/main/java/org/apache/mahout/math/flavor/BackEnum.java new file mode 100644 index 0000000..1782f04 --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/flavor/BackEnum.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.flavor; + +/** + * Matrix backends + */ +public enum BackEnum { + JVMMEM, + NETLIB_BLAS +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/flavor/MatrixFlavor.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/flavor/MatrixFlavor.java b/core/src/main/java/org/apache/mahout/math/flavor/MatrixFlavor.java new file mode 100644 index 0000000..e1d93f2 --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/flavor/MatrixFlavor.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.flavor; + +/** A set of matrix structure properties that I denote as "flavor" (by analogy to quarks) */ +public interface MatrixFlavor { + + /** + * Whether matrix is backed by a native system -- such as java memory, lapack/atlas, Magma etc. + */ + BackEnum getBacking(); + + /** + * Structure flavors + */ + TraversingStructureEnum getStructure() ; + + boolean isDense(); + + /** + * This default for {@link org.apache.mahout.math.DenseMatrix}-like structures + */ + MatrixFlavor DENSELIKE = new FlavorImpl(BackEnum.JVMMEM, TraversingStructureEnum.ROWWISE, true); + /** + * This is default flavor for {@link org.apache.mahout.math.SparseRowMatrix}-like. + */ + MatrixFlavor SPARSELIKE = new FlavorImpl(BackEnum.JVMMEM, TraversingStructureEnum.ROWWISE, false); + + /** + * This is default flavor for {@link org.apache.mahout.math.SparseMatrix}-like structures, i.e. sparse matrix blocks, + * where few, perhaps most, rows may be missing entirely. + */ + MatrixFlavor SPARSEROWLIKE = new FlavorImpl(BackEnum.JVMMEM, TraversingStructureEnum.SPARSEROWWISE, false); + + /** + * This is default flavor for {@link org.apache.mahout.math.DiagonalMatrix} and the likes. + */ + MatrixFlavor DIAGONALLIKE = new FlavorImpl(BackEnum.JVMMEM, TraversingStructureEnum.VECTORBACKED, false); + + final class FlavorImpl implements MatrixFlavor { + private BackEnum pBacking; + private TraversingStructureEnum pStructure; + private boolean pDense; + + public FlavorImpl(BackEnum backing, TraversingStructureEnum structure, boolean dense) { + pBacking = backing; + pStructure = structure; + pDense = dense; + } + + @Override + public BackEnum getBacking() { + return pBacking; + } + + @Override + public TraversingStructureEnum getStructure() { + return pStructure; + } + + @Override + public boolean isDense() { + return pDense; + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/flavor/TraversingStructureEnum.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/flavor/TraversingStructureEnum.java b/core/src/main/java/org/apache/mahout/math/flavor/TraversingStructureEnum.java new file mode 100644 index 0000000..13c2cf4 --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/flavor/TraversingStructureEnum.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.flavor; + +/** STRUCTURE HINT */ +public enum TraversingStructureEnum { + + UNKNOWN, + + /** + * Backing vectors are directly available as row views. + */ + ROWWISE, + + /** + * Column vectors are directly available as column views. + */ + COLWISE, + + /** + * Only some row-wise vectors are really present (can use iterateNonEmpty). Corresponds to + * [[org.apache.mahout.math.SparseMatrix]]. + */ + SPARSEROWWISE, + + SPARSECOLWISE, + + SPARSEHASH, + + VECTORBACKED, + + BLOCKIFIED +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/function/DoubleDoubleFunction.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/function/DoubleDoubleFunction.java b/core/src/main/java/org/apache/mahout/math/function/DoubleDoubleFunction.java new file mode 100644 index 0000000..466ddd6 --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/function/DoubleDoubleFunction.java @@ -0,0 +1,98 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* +Copyright 1999 CERN - European Organization for Nuclear Research. +Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose +is hereby granted without fee, provided that the above copyright notice appear in all copies and +that both that copyright notice and this permission notice appear in supporting documentation. +CERN makes no representations about the suitability of this software for any purpose. +It is provided "as is" without expressed or implied warranty. +*/ + +package org.apache.mahout.math.function; + +/** + * Interface that represents a function object: a function that takes two arguments and returns a single value. + **/ +public abstract class DoubleDoubleFunction { + + /** + * Apply the function to the arguments and return the result + * + * @param arg1 a double for the first argument + * @param arg2 a double for the second argument + * @return the result of applying the function + */ + public abstract double apply(double arg1, double arg2); + + /** + * @return true iff f(x, 0) = x for any x + */ + public boolean isLikeRightPlus() { + return false; + } + + /** + * @return true iff f(0, y) = 0 for any y + */ + public boolean isLikeLeftMult() { + return false; + } + + /** + * @return true iff f(x, 0) = 0 for any x + */ + public boolean isLikeRightMult() { + return false; + } + + /** + * @return true iff f(x, 0) = f(0, y) = 0 for any x, y + */ + public boolean isLikeMult() { + return isLikeLeftMult() && isLikeRightMult(); + } + + /** + * @return true iff f(x, y) = f(y, x) for any x, y + */ + public boolean isCommutative() { + return false; + } + + /** + * @return true iff f(x, f(y, z)) = f(f(x, y), z) for any x, y, z + */ + public boolean isAssociative() { + return false; + } + + /** + * @return true iff f(x, y) = f(y, x) for any x, y AND f(x, f(y, z)) = f(f(x, y), z) for any x, y, z + */ + public boolean isAssociativeAndCommutative() { + return isAssociative() && isCommutative(); + } + + /** + * @return true iff f(0, 0) != 0 + */ + public boolean isDensifying() { + return apply(0.0, 0.0) != 0.0; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/function/DoubleFunction.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/function/DoubleFunction.java b/core/src/main/java/org/apache/mahout/math/function/DoubleFunction.java new file mode 100644 index 0000000..7545154 --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/function/DoubleFunction.java @@ -0,0 +1,48 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.mahout.math.function; + +/* +Copyright 1999 CERN - European Organization for Nuclear Research. +Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose +is hereby granted without fee, provided that the above copyright notice appear in all copies and +that both that copyright notice and this permission notice appear in supporting documentation. +CERN makes no representations about the suitability of this software for any purpose. +It is provided "as is" without expressed or implied warranty. +*/ + +/** + * Interface that represents a function object: a function that takes a single argument and returns a single value. + * @see org.apache.mahout.math.map + */ +public abstract class DoubleFunction { + + /** + * Apply the function to the argument and return the result + * + * @param x double for the argument + * @return the result of applying the function + */ + public abstract double apply(double x); + + public boolean isDensifying() { + return Math.abs(apply(0.0)) != 0.0; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/545648f6/core/src/main/java/org/apache/mahout/math/function/FloatFunction.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/mahout/math/function/FloatFunction.java b/core/src/main/java/org/apache/mahout/math/function/FloatFunction.java new file mode 100644 index 0000000..94dfe32 --- /dev/null +++ b/core/src/main/java/org/apache/mahout/math/function/FloatFunction.java @@ -0,0 +1,36 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.mahout.math.function; + + +/** + * Interface that represents a function object: a function that takes a single argument and returns a single value. + * + */ +public interface FloatFunction { + + /** + * Applies a function to an argument. + * + * @param argument argument passed to the function. + * @return the result of the function. + */ + float apply(float argument); +}
