http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/YtYJob.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/YtYJob.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/YtYJob.java new file mode 100644 index 0000000..378a885 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/YtYJob.java @@ -0,0 +1,220 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.math.hadoop.stochasticsvd; + +import org.apache.commons.lang3.Validate; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.SequenceFile.CompressionType; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.UpperTriangular; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; + +import java.io.IOException; + +/** + * Job that accumulates Y'Y output + */ +public final class YtYJob { + + public static final String PROP_OMEGA_SEED = "ssvd.omegaseed"; + public static final String PROP_K = "ssvd.k"; + public static final String PROP_P = "ssvd.p"; + + // we have single output, so we use standard output + public static final String OUTPUT_YT_Y = "part-"; + + private YtYJob() { + } + + public static class YtYMapper extends + Mapper<Writable, VectorWritable, IntWritable, VectorWritable> { + + private int kp; + private Omega omega; + private UpperTriangular mYtY; + + /* + * we keep yRow in a dense form here but keep an eye not to dense up while + * doing YtY products. I am not sure that sparse vector would create much + * performance benefits since we must to assume that y would be more often + * dense than sparse, so for bulk dense operations that would perform + * somewhat better than a RandomAccessSparse vector frequent updates. + */ + private Vector yRow; + + @Override + protected void setup(Context context) throws IOException, + InterruptedException { + int k = context.getConfiguration().getInt(PROP_K, -1); + int p = context.getConfiguration().getInt(PROP_P, -1); + + Validate.isTrue(k > 0, "invalid k parameter"); + Validate.isTrue(p > 0, "invalid p parameter"); + + kp = k + p; + long omegaSeed = + Long.parseLong(context.getConfiguration().get(PROP_OMEGA_SEED)); + + omega = new Omega(omegaSeed, k + p); + + mYtY = new UpperTriangular(kp); + + // see which one works better! + // yRow = new RandomAccessSparseVector(kp); + yRow = new DenseVector(kp); + } + + @Override + protected void map(Writable key, VectorWritable value, Context context) + throws IOException, InterruptedException { + omega.computeYRow(value.get(), yRow); + // compute outer product update for YtY + + if (yRow.isDense()) { + for (int i = 0; i < kp; i++) { + double yi; + if ((yi = yRow.getQuick(i)) == 0.0) { + continue; // avoid densing up here unnecessarily + } + for (int j = i; j < kp; j++) { + double yj; + if ((yj = yRow.getQuick(j)) != 0.0) { + mYtY.setQuick(i, j, mYtY.getQuick(i, j) + yi * yj); + } + } + } + } else { + /* + * the disadvantage of using sparse vector (aside from the fact that we + * are creating some short-lived references) here is that we obviously + * do two times more iterations then necessary if y row is pretty dense. + */ + for (Vector.Element eli : yRow.nonZeroes()) { + int i = eli.index(); + for (Vector.Element elj : yRow.nonZeroes()) { + int j = elj.index(); + if (j < i) { + continue; + } + mYtY.setQuick(i, j, mYtY.getQuick(i, j) + eli.get() * elj.get()); + } + } + } + } + + @Override + protected void cleanup(Context context) throws IOException, + InterruptedException { + context.write(new IntWritable(context.getTaskAttemptID().getTaskID() + .getId()), + new VectorWritable(new DenseVector(mYtY.getData()))); + } + } + + public static class YtYReducer extends + Reducer<IntWritable, VectorWritable, IntWritable, VectorWritable> { + private final VectorWritable accum = new VectorWritable(); + private DenseVector acc; + + @Override + protected void setup(Context context) throws IOException, + InterruptedException { + int k = context.getConfiguration().getInt(PROP_K, -1); + int p = context.getConfiguration().getInt(PROP_P, -1); + + Validate.isTrue(k > 0, "invalid k parameter"); + Validate.isTrue(p > 0, "invalid p parameter"); + accum.set(acc = new DenseVector(k + p)); + } + + @Override + protected void cleanup(Context context) throws IOException, + InterruptedException { + context.write(new IntWritable(), accum); + } + + @Override + protected void reduce(IntWritable key, + Iterable<VectorWritable> values, + Context arg2) throws IOException, + InterruptedException { + for (VectorWritable vw : values) { + acc.addAll(vw.get()); + } + } + } + + public static void run(Configuration conf, + Path[] inputPaths, + Path outputPath, + int k, + int p, + long seed) throws ClassNotFoundException, + InterruptedException, IOException { + + Job job = new Job(conf); + job.setJobName("YtY-job"); + job.setJarByClass(YtYJob.class); + + job.setInputFormatClass(SequenceFileInputFormat.class); + FileInputFormat.setInputPaths(job, inputPaths); + FileOutputFormat.setOutputPath(job, outputPath); + + SequenceFileOutputFormat.setOutputCompressionType(job, + CompressionType.BLOCK); + + job.setMapOutputKeyClass(IntWritable.class); + job.setMapOutputValueClass(VectorWritable.class); + + job.setOutputKeyClass(IntWritable.class); + job.setOutputValueClass(VectorWritable.class); + + job.setMapperClass(YtYMapper.class); + + job.getConfiguration().setLong(PROP_OMEGA_SEED, seed); + job.getConfiguration().setInt(PROP_K, k); + job.getConfiguration().setInt(PROP_P, p); + + /* + * we must reduce to just one matrix which means we need only one reducer. + * But it's ok since each mapper outputs only one vector (a packed + * UpperTriangular) so even if there're thousands of mappers, one reducer + * should cope just fine. + */ + job.setNumReduceTasks(1); + + job.submit(); + job.waitForCompletion(false); + + if (!job.isSuccessful()) { + throw new IOException("YtY job unsuccessful."); + } + + } + +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/GivensThinSolver.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/GivensThinSolver.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/GivensThinSolver.java new file mode 100644 index 0000000..7033efe --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/GivensThinSolver.java @@ -0,0 +1,638 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.hadoop.stochasticsvd.qr; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; + +import com.google.common.collect.Lists; +import org.apache.mahout.math.AbstractVector; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.OrderedIntDoubleMapping; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.UpperTriangular; + +/** + * Givens Thin solver. Standard Givens operations are reordered in a way that + * helps us to push them thru MapReduce operations in a block fashion. + */ +public class GivensThinSolver { + + private double[] vARow; + private double[] vQtRow; + private final double[][] mQt; + private final double[][] mR; + private int qtStartRow; + private int rStartRow; + private int m; + private final int n; // m-row cnt, n- column count, m>=n + private int cnt; + private final double[] cs = new double[2]; + + public GivensThinSolver(int m, int n) { + if (!(m >= n)) { + throw new IllegalArgumentException("Givens thin QR: must be true: m>=n"); + } + + this.m = m; + this.n = n; + + mQt = new double[n][]; + mR = new double[n][]; + vARow = new double[n]; + vQtRow = new double[m]; + + for (int i = 0; i < n; i++) { + mQt[i] = new double[this.m]; + mR[i] = new double[this.n]; + } + cnt = 0; + } + + public void reset() { + cnt = 0; + } + + public void solve(Matrix a) { + + assert a.rowSize() == m; + assert a.columnSize() == n; + + double[] aRow = new double[n]; + for (int i = 0; i < m; i++) { + Vector aRowV = a.viewRow(i); + for (int j = 0; j < n; j++) { + aRow[j] = aRowV.getQuick(j); + } + appendRow(aRow); + } + } + + public boolean isFull() { + return cnt == m; + } + + public int getM() { + return m; + } + + public int getN() { + return n; + } + + public int getCnt() { + return cnt; + } + + public void adjust(int newM) { + if (newM == m) { + // no adjustment is required. + return; + } + if (newM < n) { + throw new IllegalArgumentException("new m can't be less than n"); + } + if (newM < cnt) { + throw new IllegalArgumentException( + "new m can't be less than rows accumulated"); + } + vQtRow = new double[newM]; + + // grow or shrink qt rows + if (newM > m) { + // grow qt rows + for (int i = 0; i < n; i++) { + mQt[i] = Arrays.copyOf(mQt[i], newM); + System.arraycopy(mQt[i], 0, mQt[i], newM - m, m); + Arrays.fill(mQt[i], 0, newM - m, 0); + } + } else { + // shrink qt rows + for (int i = 0; i < n; i++) { + mQt[i] = Arrays.copyOfRange(mQt[i], m - newM, m); + } + } + + m = newM; + + } + + public void trim() { + adjust(cnt); + } + + /** + * api for row-by-row addition + * + * @param aRow + */ + public void appendRow(double[] aRow) { + if (cnt >= m) { + throw new IllegalStateException("thin QR solver fed more rows than initialized for"); + } + try { + /* + * moving pointers around is inefficient but for the sanity's sake i am + * keeping it this way so i don't have to guess how R-tilde index maps to + * actual block index + */ + Arrays.fill(vQtRow, 0); + vQtRow[m - cnt - 1] = 1; + int height = cnt > n ? n : cnt; + System.arraycopy(aRow, 0, vARow, 0, n); + + if (height > 0) { + givens(vARow[0], getRRow(0)[0], cs); + applyGivensInPlace(cs[0], cs[1], vARow, getRRow(0), 0, n); + applyGivensInPlace(cs[0], cs[1], vQtRow, getQtRow(0), 0, m); + } + + for (int i = 1; i < height; i++) { + givens(getRRow(i - 1)[i], getRRow(i)[i], cs); + applyGivensInPlace(cs[0], cs[1], getRRow(i - 1), getRRow(i), i, + n - i); + applyGivensInPlace(cs[0], cs[1], getQtRow(i - 1), getQtRow(i), 0, + m); + } + /* + * push qt and r-tilde 1 row down + * + * just swap the references to reduce GC churning + */ + pushQtDown(); + double[] swap = getQtRow(0); + setQtRow(0, vQtRow); + vQtRow = swap; + + pushRDown(); + swap = getRRow(0); + setRRow(0, vARow); + vARow = swap; + + } finally { + cnt++; + } + } + + private double[] getQtRow(int row) { + + return mQt[(row += qtStartRow) >= n ? row - n : row]; + } + + private void setQtRow(int row, double[] qtRow) { + mQt[(row += qtStartRow) >= n ? row - n : row] = qtRow; + } + + private void pushQtDown() { + qtStartRow = qtStartRow == 0 ? n - 1 : qtStartRow - 1; + } + + private double[] getRRow(int row) { + row += rStartRow; + return mR[row >= n ? row - n : row]; + } + + private void setRRow(int row, double[] rrow) { + mR[(row += rStartRow) >= n ? row - n : row] = rrow; + } + + private void pushRDown() { + rStartRow = rStartRow == 0 ? n - 1 : rStartRow - 1; + } + + /* + * warning: both of these return actually n+1 rows with the last one being // + * not interesting. + */ + public UpperTriangular getRTilde() { + UpperTriangular packedR = new UpperTriangular(n); + for (int i = 0; i < n; i++) { + packedR.assignNonZeroElementsInRow(i, getRRow(i)); + } + return packedR; + } + + public double[][] getThinQtTilde() { + if (qtStartRow != 0) { + /* + * rotate qt rows into place + * + * double[~500][], once per block, not a big deal. + */ + double[][] qt = new double[n][]; + System.arraycopy(mQt, qtStartRow, qt, 0, n - qtStartRow); + System.arraycopy(mQt, 0, qt, n - qtStartRow, qtStartRow); + return qt; + } + return mQt; + } + + public static void applyGivensInPlace(double c, double s, double[] row1, + double[] row2, int offset, int len) { + + int n = offset + len; + for (int j = offset; j < n; j++) { + double tau1 = row1[j]; + double tau2 = row2[j]; + row1[j] = c * tau1 - s * tau2; + row2[j] = s * tau1 + c * tau2; + } + } + + public static void applyGivensInPlace(double c, double s, Vector row1, + Vector row2, int offset, int len) { + + int n = offset + len; + for (int j = offset; j < n; j++) { + double tau1 = row1.getQuick(j); + double tau2 = row2.getQuick(j); + row1.setQuick(j, c * tau1 - s * tau2); + row2.setQuick(j, s * tau1 + c * tau2); + } + } + + public static void applyGivensInPlace(double c, double s, int i, int k, + Matrix mx) { + int n = mx.columnSize(); + + for (int j = 0; j < n; j++) { + double tau1 = mx.get(i, j); + double tau2 = mx.get(k, j); + mx.set(i, j, c * tau1 - s * tau2); + mx.set(k, j, s * tau1 + c * tau2); + } + } + + public static void fromRho(double rho, double[] csOut) { + if (rho == 1) { + csOut[0] = 0; + csOut[1] = 1; + return; + } + if (Math.abs(rho) < 1) { + csOut[1] = 2 * rho; + csOut[0] = Math.sqrt(1 - csOut[1] * csOut[1]); + return; + } + csOut[0] = 2 / rho; + csOut[1] = Math.sqrt(1 - csOut[0] * csOut[0]); + } + + public static void givens(double a, double b, double[] csOut) { + if (b == 0) { + csOut[0] = 1; + csOut[1] = 0; + return; + } + if (Math.abs(b) > Math.abs(a)) { + double tau = -a / b; + csOut[1] = 1 / Math.sqrt(1 + tau * tau); + csOut[0] = csOut[1] * tau; + } else { + double tau = -b / a; + csOut[0] = 1 / Math.sqrt(1 + tau * tau); + csOut[1] = csOut[0] * tau; + } + } + + public static double toRho(double c, double s) { + if (c == 0) { + return 1; + } + if (Math.abs(s) < Math.abs(c)) { + return Math.signum(c) * s / 2; + } else { + return Math.signum(s) * 2 / c; + } + } + + public static void mergeR(UpperTriangular r1, UpperTriangular r2) { + TriangularRowView r1Row = new TriangularRowView(r1); + TriangularRowView r2Row = new TriangularRowView(r2); + + int kp = r1Row.size(); + assert kp == r2Row.size(); + + double[] cs = new double[2]; + + for (int v = 0; v < kp; v++) { + for (int u = v; u < kp; u++) { + givens(r1Row.setViewedRow(u).get(u), r2Row.setViewedRow(u - v).get(u), + cs); + applyGivensInPlace(cs[0], cs[1], r1Row, r2Row, u, kp - u); + } + } + } + + public static void mergeR(double[][] r1, double[][] r2) { + int kp = r1[0].length; + assert kp == r2[0].length; + + double[] cs = new double[2]; + + for (int v = 0; v < kp; v++) { + for (int u = v; u < kp; u++) { + givens(r1[u][u], r2[u - v][u], cs); + applyGivensInPlace(cs[0], cs[1], r1[u], r2[u - v], u, kp - u); + } + } + + } + + public static void mergeRonQ(UpperTriangular r1, UpperTriangular r2, + double[][] qt1, double[][] qt2) { + TriangularRowView r1Row = new TriangularRowView(r1); + TriangularRowView r2Row = new TriangularRowView(r2); + int kp = r1Row.size(); + assert kp == r2Row.size(); + assert kp == qt1.length; + assert kp == qt2.length; + + int r = qt1[0].length; + assert qt2[0].length == r; + + double[] cs = new double[2]; + + for (int v = 0; v < kp; v++) { + for (int u = v; u < kp; u++) { + givens(r1Row.setViewedRow(u).get(u), r2Row.setViewedRow(u - v).get(u), + cs); + applyGivensInPlace(cs[0], cs[1], r1Row, r2Row, u, kp - u); + applyGivensInPlace(cs[0], cs[1], qt1[u], qt2[u - v], 0, r); + } + } + } + + public static void mergeRonQ(double[][] r1, double[][] r2, double[][] qt1, + double[][] qt2) { + + int kp = r1[0].length; + assert kp == r2[0].length; + assert kp == qt1.length; + assert kp == qt2.length; + + int r = qt1[0].length; + assert qt2[0].length == r; + double[] cs = new double[2]; + + /* + * pairwise givens(a,b) so that a come off main diagonal in r1 and bs come + * off u-th upper subdiagonal in r2. + */ + for (int v = 0; v < kp; v++) { + for (int u = v; u < kp; u++) { + givens(r1[u][u], r2[u - v][u], cs); + applyGivensInPlace(cs[0], cs[1], r1[u], r2[u - v], u, kp - u); + applyGivensInPlace(cs[0], cs[1], qt1[u], qt2[u - v], 0, r); + } + } + } + + // returns merged Q (which in this case is the qt1) + public static double[][] mergeQrUp(double[][] qt1, double[][] r1, + double[][] r2) { + int kp = qt1.length; + int r = qt1[0].length; + + double[][] qTilde = new double[kp][]; + for (int i = 0; i < kp; i++) { + qTilde[i] = new double[r]; + } + mergeRonQ(r1, r2, qt1, qTilde); + return qt1; + } + + // returns merged Q (which in this case is the qt1) + public static double[][] mergeQrUp(double[][] qt1, UpperTriangular r1, UpperTriangular r2) { + int kp = qt1.length; + int r = qt1[0].length; + + double[][] qTilde = new double[kp][]; + for (int i = 0; i < kp; i++) { + qTilde[i] = new double[r]; + } + mergeRonQ(r1, r2, qt1, qTilde); + return qt1; + } + + public static double[][] mergeQrDown(double[][] r1, double[][] qt2, double[][] r2) { + int kp = qt2.length; + int r = qt2[0].length; + + double[][] qTilde = new double[kp][]; + for (int i = 0; i < kp; i++) { + qTilde[i] = new double[r]; + } + mergeRonQ(r1, r2, qTilde, qt2); + return qTilde; + + } + + public static double[][] mergeQrDown(UpperTriangular r1, double[][] qt2, UpperTriangular r2) { + int kp = qt2.length; + int r = qt2[0].length; + + double[][] qTilde = new double[kp][]; + for (int i = 0; i < kp; i++) { + qTilde[i] = new double[r]; + } + mergeRonQ(r1, r2, qTilde, qt2); + return qTilde; + + } + + public static double[][] computeQtHat(double[][] qt, int i, + Iterator<UpperTriangular> rIter) { + UpperTriangular rTilde = rIter.next(); + for (int j = 1; j < i; j++) { + mergeR(rTilde, rIter.next()); + } + if (i > 0) { + qt = mergeQrDown(rTilde, qt, rIter.next()); + } + while (rIter.hasNext()) { + qt = mergeQrUp(qt, rTilde, rIter.next()); + } + return qt; + } + + // test helpers + public static boolean isOrthonormal(double[][] qt, boolean insufficientRank, double epsilon) { + int n = qt.length; + int rank = 0; + for (int i = 0; i < n; i++) { + Vector ei = new DenseVector(qt[i], true); + + double norm = ei.norm(2); + + if (Math.abs(1.0 - norm) < epsilon) { + rank++; + } else if (Math.abs(norm) > epsilon) { + return false; // not a rank deficiency, either + } + + for (int j = 0; j <= i; j++) { + Vector ej = new DenseVector(qt[j], true); + double dot = ei.dot(ej); + if (!(Math.abs((i == j && rank > j ? 1.0 : 0.0) - dot) < epsilon)) { + return false; + } + } + } + return insufficientRank ? rank < n : rank == n; + } + + public static boolean isOrthonormalBlocked(Iterable<double[][]> qtHats, + boolean insufficientRank, double epsilon) { + int n = qtHats.iterator().next().length; + int rank = 0; + for (int i = 0; i < n; i++) { + List<Vector> ei = Lists.newArrayList(); + // Vector e_i=new DenseVector (qt[i],true); + for (double[][] qtHat : qtHats) { + ei.add(new DenseVector(qtHat[i], true)); + } + + double norm = 0; + for (Vector v : ei) { + norm += v.dot(v); + } + norm = Math.sqrt(norm); + if (Math.abs(1 - norm) < epsilon) { + rank++; + } else if (Math.abs(norm) > epsilon) { + return false; // not a rank deficiency, either + } + + for (int j = 0; j <= i; j++) { + List<Vector> ej = Lists.newArrayList(); + for (double[][] qtHat : qtHats) { + ej.add(new DenseVector(qtHat[j], true)); + } + + // Vector e_j = new DenseVector ( qt[j], true); + double dot = 0; + for (int k = 0; k < ei.size(); k++) { + dot += ei.get(k).dot(ej.get(k)); + } + if (!(Math.abs((i == j && rank > j ? 1 : 0) - dot) < epsilon)) { + return false; + } + } + } + return insufficientRank ? rank < n : rank == n; + } + + private static final class TriangularRowView extends AbstractVector { + private final UpperTriangular viewed; + private int rowNum; + + private TriangularRowView(UpperTriangular viewed) { + super(viewed.columnSize()); + this.viewed = viewed; + + } + + TriangularRowView setViewedRow(int row) { + rowNum = row; + return this; + } + + @Override + public boolean isDense() { + return true; + } + + @Override + public boolean isSequentialAccess() { + return false; + } + + @Override + public Iterator<Element> iterator() { + throw new UnsupportedOperationException(); + } + + @Override + public Iterator<Element> iterateNonZero() { + throw new UnsupportedOperationException(); + } + + @Override + public double getQuick(int index) { + return viewed.getQuick(rowNum, index); + } + + @Override + public Vector like() { + throw new UnsupportedOperationException(); + } + + @Override + public void setQuick(int index, double value) { + viewed.setQuick(rowNum, index, value); + + } + + @Override + public int getNumNondefaultElements() { + throw new UnsupportedOperationException(); + } + + @Override + public double getLookupCost() { + return 1; + } + + @Override + public double getIteratorAdvanceCost() { + return 1; + } + + @Override + public boolean isAddConstantTime() { + return true; + } + + @Override + public Matrix matrixLike(int rows, int columns) { + throw new UnsupportedOperationException(); + } + + /** + * Used internally by assign() to update multiple indices and values at once. + * Only really useful for sparse vectors (especially SequentialAccessSparseVector). + * <p/> + * If someone ever adds a new type of sparse vectors, this method must merge (index, value) pairs into the vector. + * + * @param updates a mapping of indices to values to merge in the vector. + */ + @Override + public void mergeUpdates(OrderedIntDoubleMapping updates) { + int[] indices = updates.getIndices(); + double[] values = updates.getValues(); + for (int i = 0; i < updates.getNumMappings(); ++i) { + viewed.setQuick(rowNum, indices[i], values[i]); + } + } + + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/GramSchmidt.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/GramSchmidt.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/GramSchmidt.java new file mode 100644 index 0000000..09be91f --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/GramSchmidt.java @@ -0,0 +1,52 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.math.hadoop.stochasticsvd.qr; + +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.function.DoubleFunction; + +/** + * Gram Schmidt quick helper. + */ +public final class GramSchmidt { + + private GramSchmidt() { + } + + public static void orthonormalizeColumns(Matrix mx) { + + int n = mx.numCols(); + + for (int c = 0; c < n; c++) { + Vector col = mx.viewColumn(c); + for (int c1 = 0; c1 < c; c1++) { + Vector viewC1 = mx.viewColumn(c1); + col.assign(col.minus(viewC1.times(viewC1.dot(col)))); + + } + final double norm2 = col.norm(2); + col.assign(new DoubleFunction() { + @Override + public double apply(double x) { + return x / norm2; + } + }); + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/QRFirstStep.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/QRFirstStep.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/QRFirstStep.java new file mode 100644 index 0000000..8509e0a --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/QRFirstStep.java @@ -0,0 +1,284 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.math.hadoop.stochasticsvd.qr; + +import java.io.Closeable; +import java.io.File; +import java.io.IOException; +import java.util.Arrays; +import java.util.Deque; +import java.util.List; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.NullWritable; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.SequenceFile.CompressionType; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.mapred.JobConf; +import org.apache.hadoop.mapred.OutputCollector; +import org.apache.hadoop.mapred.lib.MultipleOutputs; +import org.apache.mahout.common.IOUtils; +import org.apache.mahout.common.iterator.CopyConstructorIterator; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.Vector.Element; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.hadoop.stochasticsvd.DenseBlockWritable; +import org.apache.mahout.math.UpperTriangular; + +import com.google.common.collect.Lists; +import com.google.common.io.Closeables; + +/** + * QR first step without MR abstractions and doing it just in terms of iterators + * and collectors. (although Collector is probably an outdated api). + * + * + */ +@SuppressWarnings("deprecation") +public class QRFirstStep implements Closeable, OutputCollector<Writable, Vector> { + + public static final String PROP_K = "ssvd.k"; + public static final String PROP_P = "ssvd.p"; + public static final String PROP_AROWBLOCK_SIZE = "ssvd.arowblock.size"; + + private int kp; + private List<double[]> yLookahead; + private GivensThinSolver qSolver; + private int blockCnt; + private final DenseBlockWritable value = new DenseBlockWritable(); + private final Writable tempKey = new IntWritable(); + private MultipleOutputs outputs; + private final Deque<Closeable> closeables = Lists.newLinkedList(); + private SequenceFile.Writer tempQw; + private Path tempQPath; + private final List<UpperTriangular> rSubseq = Lists.newArrayList(); + private final Configuration jobConf; + + private final OutputCollector<? super Writable, ? super DenseBlockWritable> qtHatOut; + private final OutputCollector<? super Writable, ? super VectorWritable> rHatOut; + + public QRFirstStep(Configuration jobConf, + OutputCollector<? super Writable, ? super DenseBlockWritable> qtHatOut, + OutputCollector<? super Writable, ? super VectorWritable> rHatOut) { + this.jobConf = jobConf; + this.qtHatOut = qtHatOut; + this.rHatOut = rHatOut; + setup(); + } + + @Override + public void close() throws IOException { + cleanup(); + } + + public int getKP() { + return kp; + } + + private void flushSolver() throws IOException { + UpperTriangular r = qSolver.getRTilde(); + double[][] qt = qSolver.getThinQtTilde(); + + rSubseq.add(r); + + value.setBlock(qt); + getTempQw().append(tempKey, value); + + /* + * this probably should be a sparse row matrix, but compressor should get it + * for disk and in memory we want it dense anyway, sparse random + * implementations would be a mostly a memory management disaster consisting + * of rehashes and GC // thrashing. (IMHO) + */ + value.setBlock(null); + qSolver.reset(); + } + + // second pass to run a modified version of computeQHatSequence. + private void flushQBlocks() throws IOException { + if (blockCnt == 1) { + /* + * only one block, no temp file, no second pass. should be the default + * mode for efficiency in most cases. Sure mapper should be able to load + * the entire split in memory -- and we don't require even that. + */ + value.setBlock(qSolver.getThinQtTilde()); + outputQHat(value); + outputR(new VectorWritable(new DenseVector(qSolver.getRTilde().getData(), + true))); + + } else { + secondPass(); + } + } + + private void outputQHat(DenseBlockWritable value) throws IOException { + qtHatOut.collect(NullWritable.get(), value); + } + + private void outputR(VectorWritable value) throws IOException { + rHatOut.collect(NullWritable.get(), value); + } + + private void secondPass() throws IOException { + qSolver = null; // release mem + FileSystem localFs = FileSystem.getLocal(jobConf); + SequenceFile.Reader tempQr = + new SequenceFile.Reader(localFs, tempQPath, jobConf); + closeables.addFirst(tempQr); + int qCnt = 0; + while (tempQr.next(tempKey, value)) { + value + .setBlock(GivensThinSolver.computeQtHat(value.getBlock(), + qCnt, + new CopyConstructorIterator<>(rSubseq.iterator()))); + if (qCnt == 1) { + /* + * just merge r[0] <- r[1] so it doesn't have to repeat in subsequent + * computeQHat iterators + */ + GivensThinSolver.mergeR(rSubseq.get(0), rSubseq.remove(1)); + } else { + qCnt++; + } + outputQHat(value); + } + + assert rSubseq.size() == 1; + + outputR(new VectorWritable(new DenseVector(rSubseq.get(0).getData(), true))); + + } + + protected void map(Vector incomingYRow) throws IOException { + double[] yRow; + if (yLookahead.size() == kp) { + if (qSolver.isFull()) { + + flushSolver(); + blockCnt++; + + } + yRow = yLookahead.remove(0); + + qSolver.appendRow(yRow); + } else { + yRow = new double[kp]; + } + + if (incomingYRow.isDense()) { + for (int i = 0; i < kp; i++) { + yRow[i] = incomingYRow.get(i); + } + } else { + Arrays.fill(yRow, 0); + for (Element yEl : incomingYRow.nonZeroes()) { + yRow[yEl.index()] = yEl.get(); + } + } + + yLookahead.add(yRow); + } + + protected void setup() { + + int r = Integer.parseInt(jobConf.get(PROP_AROWBLOCK_SIZE)); + int k = Integer.parseInt(jobConf.get(PROP_K)); + int p = Integer.parseInt(jobConf.get(PROP_P)); + kp = k + p; + + yLookahead = Lists.newArrayListWithCapacity(kp); + qSolver = new GivensThinSolver(r, kp); + outputs = new MultipleOutputs(new JobConf(jobConf)); + closeables.addFirst(new Closeable() { + @Override + public void close() throws IOException { + outputs.close(); + } + }); + + } + + protected void cleanup() throws IOException { + try { + if (qSolver == null && yLookahead.isEmpty()) { + return; + } + if (qSolver == null) { + qSolver = new GivensThinSolver(yLookahead.size(), kp); + } + // grow q solver up if necessary + + qSolver.adjust(qSolver.getCnt() + yLookahead.size()); + while (!yLookahead.isEmpty()) { + + qSolver.appendRow(yLookahead.remove(0)); + + } + assert qSolver.isFull(); + if (++blockCnt > 1) { + flushSolver(); + assert tempQw != null; + closeables.remove(tempQw); + Closeables.close(tempQw, false); + } + flushQBlocks(); + + } finally { + IOUtils.close(closeables); + } + + } + + private SequenceFile.Writer getTempQw() throws IOException { + if (tempQw == null) { + /* + * temporary Q output hopefully will not exceed size of IO cache in which + * case it is only good since it is going to be managed by kernel, not + * java GC. And if IO cache is not good enough, then at least it is always + * sequential. + */ + String taskTmpDir = System.getProperty("java.io.tmpdir"); + + FileSystem localFs = FileSystem.getLocal(jobConf); + Path parent = new Path(taskTmpDir); + Path sub = new Path(parent, "qw_" + System.currentTimeMillis()); + tempQPath = new Path(sub, "q-temp.seq"); + tempQw = + SequenceFile.createWriter(localFs, + jobConf, + tempQPath, + IntWritable.class, + DenseBlockWritable.class, + CompressionType.BLOCK); + closeables.addFirst(tempQw); + closeables.addFirst(new IOUtils.DeleteFileOnClose(new File(tempQPath + .toString()))); + } + return tempQw; + } + + @Override + public void collect(Writable key, Vector vw) throws IOException { + map(vw); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/QRLastStep.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/QRLastStep.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/QRLastStep.java new file mode 100644 index 0000000..545f1f9 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/QRLastStep.java @@ -0,0 +1,144 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.hadoop.stochasticsvd.qr; + +import java.io.Closeable; +import java.io.IOException; +import java.util.Iterator; +import java.util.List; +import java.util.NoSuchElementException; + +import org.apache.commons.lang3.Validate; +import org.apache.mahout.common.iterator.CopyConstructorIterator; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.hadoop.stochasticsvd.DenseBlockWritable; +import org.apache.mahout.math.UpperTriangular; + +import com.google.common.collect.Lists; + +/** + * Second/last step of QR iterations. Takes input of qtHats and rHats and + * provides iterator to pull ready rows of final Q. + * + */ +public class QRLastStep implements Closeable, Iterator<Vector> { + + private final Iterator<DenseBlockWritable> qHatInput; + + private final List<UpperTriangular> mRs = Lists.newArrayList(); + private final int blockNum; + private double[][] mQt; + private int cnt; + private int r; + private int kp; + private Vector qRow; + + /** + * + * @param qHatInput + * the Q-Hat input that was output in the first step + * @param rHatInput + * all RHat outputs int the group in order of groups + * @param blockNum + * our RHat number in the group + */ + public QRLastStep(Iterator<DenseBlockWritable> qHatInput, + Iterator<VectorWritable> rHatInput, + int blockNum) { + this.blockNum = blockNum; + this.qHatInput = qHatInput; + /* + * in this implementation we actually preload all Rs into memory to make R + * sequence modifications more efficient. + */ + int block = 0; + while (rHatInput.hasNext()) { + Vector value = rHatInput.next().get(); + if (block < blockNum && block > 0) { + GivensThinSolver.mergeR(mRs.get(0), new UpperTriangular(value)); + } else { + mRs.add(new UpperTriangular(value)); + } + block++; + } + + } + + private boolean loadNextQt() { + boolean more = qHatInput.hasNext(); + if (!more) { + return false; + } + DenseBlockWritable v = qHatInput.next(); + mQt = + GivensThinSolver + .computeQtHat(v.getBlock(), + blockNum == 0 ? 0 : 1, + new CopyConstructorIterator<>(mRs.iterator())); + r = mQt[0].length; + kp = mQt.length; + if (qRow == null) { + qRow = new DenseVector(kp); + } + return true; + } + + @Override + public boolean hasNext() { + if (mQt != null && cnt == r) { + mQt = null; + } + boolean result = true; + if (mQt == null) { + result = loadNextQt(); + cnt = 0; + } + return result; + } + + @Override + public Vector next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + Validate.isTrue(hasNext(), "Q input overrun"); + /* + * because Q blocks are initially stored in inverse order + */ + int qRowIndex = r - cnt - 1; + for (int j = 0; j < kp; j++) { + qRow.setQuick(j, mQt[j][qRowIndex]); + } + cnt++; + return qRow; + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + + @Override + public void close() throws IOException { + mQt = null; + mRs.clear(); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/neighborhood/BruteSearch.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/math/neighborhood/BruteSearch.java b/mr/src/main/java/org/apache/mahout/math/neighborhood/BruteSearch.java new file mode 100644 index 0000000..51484c7 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/math/neighborhood/BruteSearch.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.neighborhood; + +import java.util.Iterator; +import java.util.List; +import java.util.PriorityQueue; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; +import com.google.common.collect.Ordering; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.WeightedVector; +import org.apache.mahout.math.random.WeightedThing; + +/** + * Search for nearest neighbors using a complete search (i.e. looping through + * the references and comparing each vector to the query). + */ +public class BruteSearch extends UpdatableSearcher { + /** + * The list of reference vectors. + */ + private final List<Vector> referenceVectors; + + public BruteSearch(DistanceMeasure distanceMeasure) { + super(distanceMeasure); + referenceVectors = Lists.newArrayList(); + } + + @Override + public void add(Vector vector) { + referenceVectors.add(vector); + } + + @Override + public int size() { + return referenceVectors.size(); + } + + /** + * Scans the list of reference vectors one at a time for @limit neighbors of + * the query vector. + * The weights of the WeightedVectors are not taken into account. + * + * @param query The query vector. + * @param limit The number of results to returned; must be at least 1. + * @return A list of the closest @limit neighbors for the given query. + */ + @Override + public List<WeightedThing<Vector>> search(Vector query, int limit) { + Preconditions.checkArgument(limit > 0, "limit must be greater then 0!"); + limit = Math.min(limit, referenceVectors.size()); + // A priority queue of the best @limit elements, ordered from worst to best so that the worst + // element is always on top and can easily be removed. + PriorityQueue<WeightedThing<Integer>> bestNeighbors = + new PriorityQueue<>(limit, Ordering.natural().reverse()); + // The resulting list of weighted WeightedVectors (the weight is the distance from the query). + List<WeightedThing<Vector>> results = + Lists.newArrayListWithCapacity(limit); + int rowNumber = 0; + for (Vector row : referenceVectors) { + double distance = distanceMeasure.distance(query, row); + // Only add a new neighbor if the result is better than the worst element + // in the queue or the queue isn't full. + if (bestNeighbors.size() < limit || bestNeighbors.peek().getWeight() > distance) { + bestNeighbors.add(new WeightedThing<>(rowNumber, distance)); + if (bestNeighbors.size() > limit) { + bestNeighbors.poll(); + } else { + // Increase the size of the results list by 1 so we can add elements in the reverse + // order from the queue. + results.add(null); + } + } + ++rowNumber; + } + for (int i = limit - 1; i >= 0; --i) { + WeightedThing<Integer> neighbor = bestNeighbors.poll(); + results.set(i, new WeightedThing<>( + referenceVectors.get(neighbor.getValue()), neighbor.getWeight())); + } + return results; + } + + /** + * Returns the closest vector to the query. + * When only one the nearest vector is needed, use this method, NOT search(query, limit) because + * it's faster (less overhead). + * + * @param query the vector to search for + * @param differentThanQuery if true, returns the closest vector different than the query (this + * only matters if the query is among the searched vectors), otherwise, + * returns the closest vector to the query (even the same vector). + * @return the weighted vector closest to the query + */ + @Override + public WeightedThing<Vector> searchFirst(Vector query, boolean differentThanQuery) { + double bestDistance = Double.POSITIVE_INFINITY; + Vector bestVector = null; + for (Vector row : referenceVectors) { + double distance = distanceMeasure.distance(query, row); + if (distance < bestDistance && (!differentThanQuery || !row.equals(query))) { + bestDistance = distance; + bestVector = row; + } + } + return new WeightedThing<>(bestVector, bestDistance); + } + + /** + * Searches with a list full of queries in a threaded fashion. + * + * @param queries The queries to search for. + * @param limit The number of results to return. + * @param numThreads Number of threads to use in searching. + * @return A list of result lists. + */ + public List<List<WeightedThing<Vector>>> search(Iterable<WeightedVector> queries, + final int limit, int numThreads) throws InterruptedException { + ExecutorService executor = Executors.newFixedThreadPool(numThreads); + List<Callable<Object>> tasks = Lists.newArrayList(); + + final List<List<WeightedThing<Vector>>> results = Lists.newArrayList(); + int i = 0; + for (final Vector query : queries) { + results.add(null); + final int index = i++; + tasks.add(new Callable<Object>() { + @Override + public Object call() throws Exception { + results.set(index, BruteSearch.this.search(query, limit)); + return null; + } + }); + } + + executor.invokeAll(tasks); + executor.shutdown(); + + return results; + } + + @Override + public Iterator<Vector> iterator() { + return referenceVectors.iterator(); + } + + @Override + public boolean remove(Vector query, double epsilon) { + int rowNumber = 0; + for (Vector row : referenceVectors) { + double distance = distanceMeasure.distance(query, row); + if (distance < epsilon) { + referenceVectors.remove(rowNumber); + return true; + } + rowNumber++; + } + return false; + } + + @Override + public void clear() { + referenceVectors.clear(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/neighborhood/FastProjectionSearch.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/math/neighborhood/FastProjectionSearch.java b/mr/src/main/java/org/apache/mahout/math/neighborhood/FastProjectionSearch.java new file mode 100644 index 0000000..006f4b6 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/math/neighborhood/FastProjectionSearch.java @@ -0,0 +1,326 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.neighborhood; + +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Set; + +import com.google.common.base.Preconditions; +import com.google.common.collect.AbstractIterator; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.random.RandomProjector; +import org.apache.mahout.math.random.WeightedThing; + +/** + * Does approximate nearest neighbor search by projecting the vectors similar to ProjectionSearch. + * The main difference between this class and the ProjectionSearch is the use of sorted arrays + * instead of binary search trees to implement the sets of scalar projections. + * + * Instead of taking log n time to add a vector to each of the vectors, * the pending additions are + * kept separate and are searched using a brute search. When there are "enough" pending additions, + * they're committed into the main pool of vectors. + */ +public class FastProjectionSearch extends UpdatableSearcher { + // The list of vectors that have not yet been projected (that are pending). + private final List<Vector> pendingAdditions = Lists.newArrayList(); + + // The list of basis vectors. Populated when the first vector's dimension is know by calling + // initialize once. + private Matrix basisMatrix = null; + + // The list of sorted lists of scalar projections. The outer list has one entry for each basis + // vector that all the other vectors will be projected on. + // For each basis vector, the inner list has an entry for each vector that has been projected. + // These entries are WeightedThing<Vector> where the weight is the value of the scalar + // projection and the value is the vector begin referred to. + private List<List<WeightedThing<Vector>>> scalarProjections; + + // The number of projection used for approximating the distance. + private final int numProjections; + + // The number of elements to keep on both sides of the closest estimated distance as possible + // candidates for the best actual distance. + private final int searchSize; + + // Initially, the dimension of the vectors searched by this searcher is unknown. After adding + // the first vector, the basis will be initialized. This marks whether initialization has + // happened or not so we only do it once. + private boolean initialized = false; + + // Removing vectors from the searcher is done lazily to avoid the linear time cost of removing + // elements from an array. This member keeps track of the number of removed vectors (marked as + // "impossible" values in the array) so they can be removed when updating the structure. + private int numPendingRemovals = 0; + + private static final double ADDITION_THRESHOLD = 0.05; + private static final double REMOVAL_THRESHOLD = 0.02; + + public FastProjectionSearch(DistanceMeasure distanceMeasure, int numProjections, int searchSize) { + super(distanceMeasure); + Preconditions.checkArgument(numProjections > 0 && numProjections < 100, + "Unreasonable value for number of projections. Must be: 0 < numProjections < 100"); + this.numProjections = numProjections; + this.searchSize = searchSize; + scalarProjections = Lists.newArrayListWithCapacity(numProjections); + for (int i = 0; i < numProjections; ++i) { + scalarProjections.add(Lists.<WeightedThing<Vector>>newArrayList()); + } + } + + private void initialize(int numDimensions) { + if (initialized) { + return; + } + basisMatrix = RandomProjector.generateBasisNormal(numProjections, numDimensions); + initialized = true; + } + + /** + * Add a new Vector to the Searcher that will be checked when getting + * the nearest neighbors. + * <p/> + * The vector IS NOT CLONED. Do not modify the vector externally otherwise the internal + * Searcher data structures could be invalidated. + */ + @Override + public void add(Vector vector) { + initialize(vector.size()); + pendingAdditions.add(vector); + } + + /** + * Returns the number of WeightedVectors being searched for nearest neighbors. + */ + @Override + public int size() { + return pendingAdditions.size() + scalarProjections.get(0).size() - numPendingRemovals; + } + + /** + * When querying the Searcher for the closest vectors, a list of WeightedThing<Vector>s is + * returned. The value of the WeightedThing is the neighbor and the weight is the + * the distance (calculated by some metric - see a concrete implementation) between the query + * and neighbor. + * The actual type of vector in the pair is the same as the vector added to the Searcher. + */ + @Override + public List<WeightedThing<Vector>> search(Vector query, int limit) { + reindex(false); + + Set<Vector> candidates = Sets.newHashSet(); + Vector projection = basisMatrix.times(query); + for (int i = 0; i < basisMatrix.numRows(); ++i) { + List<WeightedThing<Vector>> currProjections = scalarProjections.get(i); + int middle = Collections.binarySearch(currProjections, + new WeightedThing<Vector>(projection.get(i))); + if (middle < 0) { + middle = -(middle + 1); + } + for (int j = Math.max(0, middle - searchSize); + j < Math.min(currProjections.size(), middle + searchSize + 1); ++j) { + if (currProjections.get(j).getValue() == null) { + continue; + } + candidates.add(currProjections.get(j).getValue()); + } + } + + List<WeightedThing<Vector>> top = + Lists.newArrayListWithCapacity(candidates.size() + pendingAdditions.size()); + for (Vector candidate : Iterables.concat(candidates, pendingAdditions)) { + top.add(new WeightedThing<>(candidate, distanceMeasure.distance(candidate, query))); + } + Collections.sort(top); + + return top.subList(0, Math.min(top.size(), limit)); + } + + /** + * Returns the closest vector to the query. + * When only one the nearest vector is needed, use this method, NOT search(query, limit) because + * it's faster (less overhead). + * + * @param query the vector to search for + * @param differentThanQuery if true, returns the closest vector different than the query (this + * only matters if the query is among the searched vectors), otherwise, + * returns the closest vector to the query (even the same vector). + * @return the weighted vector closest to the query + */ + @Override + public WeightedThing<Vector> searchFirst(Vector query, boolean differentThanQuery) { + reindex(false); + + double bestDistance = Double.POSITIVE_INFINITY; + Vector bestVector = null; + + Vector projection = basisMatrix.times(query); + for (int i = 0; i < basisMatrix.numRows(); ++i) { + List<WeightedThing<Vector>> currProjections = scalarProjections.get(i); + int middle = Collections.binarySearch(currProjections, + new WeightedThing<Vector>(projection.get(i))); + if (middle < 0) { + middle = -(middle + 1); + } + for (int j = Math.max(0, middle - searchSize); + j < Math.min(currProjections.size(), middle + searchSize + 1); ++j) { + if (currProjections.get(j).getValue() == null) { + continue; + } + Vector vector = currProjections.get(j).getValue(); + double distance = distanceMeasure.distance(vector, query); + if (distance < bestDistance && (!differentThanQuery || !vector.equals(query))) { + bestDistance = distance; + bestVector = vector; + } + } + } + + for (Vector vector : pendingAdditions) { + double distance = distanceMeasure.distance(vector, query); + if (distance < bestDistance && (!differentThanQuery || !vector.equals(query))) { + bestDistance = distance; + bestVector = vector; + } + } + + return new WeightedThing<>(bestVector, bestDistance); + } + + @Override + public boolean remove(Vector vector, double epsilon) { + WeightedThing<Vector> closestPair = searchFirst(vector, false); + if (distanceMeasure.distance(closestPair.getValue(), vector) > epsilon) { + return false; + } + + boolean isProjected = true; + Vector projection = basisMatrix.times(vector); + for (int i = 0; i < basisMatrix.numRows(); ++i) { + List<WeightedThing<Vector>> currProjections = scalarProjections.get(i); + WeightedThing<Vector> searchedThing = new WeightedThing<>(projection.get(i)); + int middle = Collections.binarySearch(currProjections, searchedThing); + if (middle < 0) { + isProjected = false; + break; + } + // Elements to be removed are kept in the sorted array until the next reindex, but their inner vector + // is set to null. + scalarProjections.get(i).set(middle, searchedThing); + } + if (isProjected) { + ++numPendingRemovals; + return true; + } + + for (int i = 0; i < pendingAdditions.size(); ++i) { + if (pendingAdditions.get(i).equals(vector)) { + pendingAdditions.remove(i); + break; + } + } + return true; + } + + private void reindex(boolean force) { + int numProjected = scalarProjections.get(0).size(); + if (force || pendingAdditions.size() > ADDITION_THRESHOLD * numProjected + || numPendingRemovals > REMOVAL_THRESHOLD * numProjected) { + + // We only need to copy the first list because when iterating we use only that list for the Vector + // references. + // see public Iterator<Vector> iterator() + List<List<WeightedThing<Vector>>> scalarProjections = Lists.newArrayListWithCapacity(numProjections); + for (int i = 0; i < numProjections; ++i) { + if (i == 0) { + scalarProjections.add(Lists.newArrayList(this.scalarProjections.get(i))); + } else { + scalarProjections.add(this.scalarProjections.get(i)); + } + } + + // Project every pending vector onto every basis vector. + for (Vector pending : pendingAdditions) { + Vector projection = basisMatrix.times(pending); + for (int i = 0; i < numProjections; ++i) { + scalarProjections.get(i).add(new WeightedThing<>(pending, projection.get(i))); + } + } + pendingAdditions.clear(); + // For each basis vector, sort the resulting list (for binary search) and remove the number + // of pending removals (it's the same for every basis vector) at the end (the weights are + // set to Double.POSITIVE_INFINITY when removing). + for (int i = 0; i < numProjections; ++i) { + List<WeightedThing<Vector>> currProjections = scalarProjections.get(i); + for (WeightedThing<Vector> v : currProjections) { + if (v.getValue() == null) { + v.setWeight(Double.POSITIVE_INFINITY); + } + } + Collections.sort(currProjections); + for (int j = 0; j < numPendingRemovals; ++j) { + currProjections.remove(currProjections.size() - 1); + } + } + numPendingRemovals = 0; + + this.scalarProjections = scalarProjections; + } + } + + @Override + public void clear() { + pendingAdditions.clear(); + for (int i = 0; i < numProjections; ++i) { + scalarProjections.get(i).clear(); + } + numPendingRemovals = 0; + } + + /** + * This iterates on the snapshot of the contents first instantiated regardless of any future modifications. + * Changes done after the iterator is created will not be visible to the iterator but will be visible + * when searching. + * @return iterator through the vectors in this searcher. + */ + @Override + public Iterator<Vector> iterator() { + reindex(true); + return new AbstractIterator<Vector>() { + private final Iterator<WeightedThing<Vector>> data = scalarProjections.get(0).iterator(); + @Override + protected Vector computeNext() { + do { + if (!data.hasNext()) { + return endOfData(); + } + WeightedThing<Vector> next = data.next(); + if (next.getValue() != null) { + return next.getValue(); + } + } while (true); + } + }; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/neighborhood/HashedVector.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/math/neighborhood/HashedVector.java b/mr/src/main/java/org/apache/mahout/math/neighborhood/HashedVector.java new file mode 100644 index 0000000..eb91813 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/math/neighborhood/HashedVector.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.neighborhood; + +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.WeightedVector; + +/** + * Decorates a weighted vector with a locality sensitive hash. + * + * The LSH function implemented is the random hyperplane based hash function. + * See "Similarity Estimation Techniques from Rounding Algorithms" by Moses S. Charikar, section 3. + * http://www.cs.princeton.edu/courses/archive/spring04/cos598B/bib/CharikarEstim.pdf + */ +public class HashedVector extends WeightedVector { + protected static final int INVALID_INDEX = -1; + + /** + * Value of the locality sensitive hash. It is 64 bit. + */ + private final long hash; + + public HashedVector(Vector vector, long hash, int index) { + super(vector, 1, index); + this.hash = hash; + } + + public HashedVector(Vector vector, Matrix projection, int index, long mask) { + super(vector, 1, index); + this.hash = mask & computeHash64(vector, projection); + } + + public HashedVector(WeightedVector weightedVector, Matrix projection, long mask) { + super(weightedVector.getVector(), weightedVector.getWeight(), weightedVector.getIndex()); + this.hash = mask & computeHash64(weightedVector, projection); + } + + public static long computeHash64(Vector vector, Matrix projection) { + long hash = 0; + for (Element element : projection.times(vector).nonZeroes()) { + if (element.get() > 0) { + hash += 1L << element.index(); + } + } + return hash; + } + + public static HashedVector hash(WeightedVector v, Matrix projection) { + return hash(v, projection, 0); + } + + public static HashedVector hash(WeightedVector v, Matrix projection, long mask) { + return new HashedVector(v, projection, mask); + } + + public int hammingDistance(long otherHash) { + return Long.bitCount(hash ^ otherHash); + } + + public long getHash() { + return hash; + } + + @Override + public String toString() { + return String.format("index=%d, hash=%08x, v=%s", getIndex(), hash, getVector()); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof HashedVector)) { + return o instanceof Vector && this.minus((Vector) o).norm(1) == 0; + } + HashedVector v = (HashedVector) o; + return v.hash == this.hash && this.minus(v).norm(1) == 0; + } + + @Override + public int hashCode() { + int result = super.hashCode(); + result = 31 * result + (int) (hash ^ (hash >>> 32)); + return result; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearch.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearch.java b/mr/src/main/java/org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearch.java new file mode 100644 index 0000000..aa1f103 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearch.java @@ -0,0 +1,295 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.neighborhood; + +import java.util.Collections; +import java.util.Iterator; +import java.util.List; + +import com.google.common.base.Function; +import com.google.common.base.Preconditions; +import com.google.common.collect.HashMultiset; +import com.google.common.collect.Iterators; +import com.google.common.collect.Lists; +import com.google.common.collect.Multiset; +import org.apache.lucene.util.PriorityQueue; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.random.RandomProjector; +import org.apache.mahout.math.random.WeightedThing; +import org.apache.mahout.math.stats.OnlineSummarizer; + +/** + * Implements a Searcher that uses locality sensitivity hash as a first pass approximation + * to estimate distance without floating point math. The clever bit about this implementation + * is that it does an adaptive cutoff for the cutoff on the bitwise distance. Making this + * cutoff adaptive means that we only needs to make a single pass through the data. + */ +public class LocalitySensitiveHashSearch extends UpdatableSearcher { + /** + * Number of bits in the locality sensitive hash. 64 bits fix neatly into a long. + */ + private static final int BITS = 64; + + /** + * Bit mask for the computed hash. Currently, it's 0xffffffffffff. + */ + private static final long BIT_MASK = -1L; + + /** + * The maximum Hamming distance between two hashes that the hash limit can grow back to. + * It starts at BITS and decreases as more points than are needed are added to the candidate priority queue. + * But, after the observed distribution of distances becomes too good (we're seeing less than some percentage of the + * total number of points; using the hash strategy somewhere less than 25%) the limit is increased to compute + * more distances. + * This is because + */ + private static final int MAX_HASH_LIMIT = 32; + + /** + * Minimum number of points with a given Hamming from the query that must be observed to consider raising the minimum + * distance for a candidate. + */ + private static final int MIN_DISTRIBUTION_COUNT = 10; + + private final Multiset<HashedVector> trainingVectors = HashMultiset.create(); + + /** + * This matrix of BITS random vectors is used to compute the Locality Sensitive Hash + * we compute the dot product with these vectors using a matrix multiplication and then use just + * sign of each result as one bit in the hash + */ + private Matrix projection; + + /** + * The search size determines how many top results we retain. We do this because the hash distance + * isn't guaranteed to be entirely monotonic with respect to the real distance. To the extent that + * actual distance is well approximated by hash distance, then the searchSize can be decreased to + * roughly the number of results that you want. + */ + private int searchSize; + + /** + * Controls how the hash limit is raised. 0 means use minimum of distribution, 1 means use first quartile. + * Intermediate values indicate an interpolation should be used. Negative values mean to never increase. + */ + private double hashLimitStrategy = 0.9; + + /** + * Number of evaluations of the full distance between two points that was required. + */ + private int distanceEvaluations = 0; + + /** + * Whether the projection matrix was initialized. This has to be deferred until the size of the vectors is known, + * effectively until the first vector is added. + */ + private boolean initialized = false; + + public LocalitySensitiveHashSearch(DistanceMeasure distanceMeasure, int searchSize) { + super(distanceMeasure); + this.searchSize = searchSize; + this.projection = null; + } + + private void initialize(int numDimensions) { + if (initialized) { + return; + } + initialized = true; + projection = RandomProjector.generateBasisNormal(BITS, numDimensions); + } + + private PriorityQueue<WeightedThing<Vector>> searchInternal(Vector query) { + long queryHash = HashedVector.computeHash64(query, projection); + + // We keep an approximation of the closest vectors here. + PriorityQueue<WeightedThing<Vector>> top = Searcher.getCandidateQueue(getSearchSize()); + + // We scan the vectors using bit counts as an approximation of the dot product so we can do as few + // full distance computations as possible. Our goal is to only do full distance computations for + // vectors with hash distance at most as large as the searchSize biggest hash distance seen so far. + + OnlineSummarizer[] distribution = new OnlineSummarizer[BITS + 1]; + for (int i = 0; i < BITS + 1; i++) { + distribution[i] = new OnlineSummarizer(); + } + + distanceEvaluations = 0; + + // We keep the counts of the hash distances here. This lets us accurately + // judge what hash distance cutoff we should use. + int[] hashCounts = new int[BITS + 1]; + + // Maximum number of different bits to still consider a vector a candidate for nearest neighbor. + // Starts at the maximum number of bits, but decreases and can increase. + int hashLimit = BITS; + int limitCount = 0; + double distanceLimit = Double.POSITIVE_INFINITY; + + // In this loop, we have the invariants that: + // + // limitCount = sum_{i<hashLimit} hashCount[i] + // and + // limitCount >= searchSize && limitCount - hashCount[hashLimit-1] < searchSize + for (HashedVector vector : trainingVectors) { + // This computes the Hamming Distance between the vector's hash and the query's hash. + // The result is correlated with the angle between the vectors. + int bitDot = vector.hammingDistance(queryHash); + if (bitDot <= hashLimit) { + distanceEvaluations++; + + double distance = distanceMeasure.distance(query, vector); + distribution[bitDot].add(distance); + + if (distance < distanceLimit) { + top.insertWithOverflow(new WeightedThing<Vector>(vector, distance)); + if (top.size() == searchSize) { + distanceLimit = top.top().getWeight(); + } + + hashCounts[bitDot]++; + limitCount++; + while (hashLimit > 0 && limitCount - hashCounts[hashLimit - 1] > searchSize) { + hashLimit--; + limitCount -= hashCounts[hashLimit]; + } + + if (hashLimitStrategy >= 0) { + while (hashLimit < MAX_HASH_LIMIT && distribution[hashLimit].getCount() > MIN_DISTRIBUTION_COUNT + && ((1 - hashLimitStrategy) * distribution[hashLimit].getQuartile(0) + + hashLimitStrategy * distribution[hashLimit].getQuartile(1)) < distanceLimit) { + limitCount += hashCounts[hashLimit]; + hashLimit++; + } + } + } + } + } + return top; + } + + @Override + public List<WeightedThing<Vector>> search(Vector query, int limit) { + PriorityQueue<WeightedThing<Vector>> top = searchInternal(query); + List<WeightedThing<Vector>> results = Lists.newArrayListWithExpectedSize(top.size()); + while (top.size() != 0) { + WeightedThing<Vector> wv = top.pop(); + results.add(new WeightedThing<>(((HashedVector) wv.getValue()).getVector(), wv.getWeight())); + } + Collections.reverse(results); + if (limit < results.size()) { + results = results.subList(0, limit); + } + return results; + } + + /** + * Returns the closest vector to the query. + * When only one the nearest vector is needed, use this method, NOT search(query, limit) because + * it's faster (less overhead). + * This is nearly the same as search(). + * + * @param query the vector to search for + * @param differentThanQuery if true, returns the closest vector different than the query (this + * only matters if the query is among the searched vectors), otherwise, + * returns the closest vector to the query (even the same vector). + * @return the weighted vector closest to the query + */ + @Override + public WeightedThing<Vector> searchFirst(Vector query, boolean differentThanQuery) { + // We get the top searchSize neighbors. + PriorityQueue<WeightedThing<Vector>> top = searchInternal(query); + // We then cut the number down to just the best 2. + while (top.size() > 2) { + top.pop(); + } + // If there are fewer than 2 results, we just return the one we have. + if (top.size() < 2) { + return removeHash(top.pop()); + } + // There are exactly 2 results. + WeightedThing<Vector> secondBest = top.pop(); + WeightedThing<Vector> best = top.pop(); + // If the best result is the same as the query, but we don't want to return the query. + if (differentThanQuery && best.getValue().equals(query)) { + best = secondBest; + } + return removeHash(best); + } + + protected static WeightedThing<Vector> removeHash(WeightedThing<Vector> input) { + return new WeightedThing<>(((HashedVector) input.getValue()).getVector(), input.getWeight()); + } + + @Override + public void add(Vector vector) { + initialize(vector.size()); + trainingVectors.add(new HashedVector(vector, projection, HashedVector.INVALID_INDEX, BIT_MASK)); + } + + @Override + public int size() { + return trainingVectors.size(); + } + + public int getSearchSize() { + return searchSize; + } + + public void setSearchSize(int size) { + searchSize = size; + } + + public void setRaiseHashLimitStrategy(double strategy) { + hashLimitStrategy = strategy; + } + + /** + * This is only for testing. + * @return the number of times the actual distance between two vectors was computed. + */ + public int resetEvaluationCount() { + int result = distanceEvaluations; + distanceEvaluations = 0; + return result; + } + + @Override + public Iterator<Vector> iterator() { + return Iterators.transform(trainingVectors.iterator(), new Function<HashedVector, Vector>() { + @Override + public Vector apply(org.apache.mahout.math.neighborhood.HashedVector input) { + Preconditions.checkNotNull(input); + //noinspection ConstantConditions + return input.getVector(); + } + }); + } + + @Override + public boolean remove(Vector v, double epsilon) { + return trainingVectors.remove(new HashedVector(v, projection, HashedVector.INVALID_INDEX, BIT_MASK)); + } + + @Override + public void clear() { + trainingVectors.clear(); + } +}
