http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/solver/DistributedConjugateGradientSolver.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/solver/DistributedConjugateGradientSolver.java b/mr/src/main/java/org/apache/mahout/math/hadoop/solver/DistributedConjugateGradientSolver.java new file mode 100644 index 0000000..dd38971 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/math/hadoop/solver/DistributedConjugateGradientSolver.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.hadoop.solver; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.util.Tool; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.common.AbstractJob; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.hadoop.DistributedRowMatrix; +import org.apache.mahout.math.solver.ConjugateGradientSolver; +import org.apache.mahout.math.solver.Preconditioner; + +/** + * Distributed implementation of the conjugate gradient solver. More or less, this is just the standard solver + * but wrapped with some methods that make it easy to run it on a DistributedRowMatrix. + */ +public class DistributedConjugateGradientSolver extends ConjugateGradientSolver implements Tool { + + private Configuration conf; + private Map<String, List<String>> parsedArgs; + + /** + * + * Runs the distributed conjugate gradient solver programmatically to solve the system (A + lambda*I)x = b. + * + * @param inputPath Path to the matrix A + * @param tempPath Path to scratch output path, deleted after the solver completes + * @param numRows Number of rows in A + * @param numCols Number of columns in A + * @param b Vector b + * @param preconditioner Optional preconditioner for the system + * @param maxIterations Maximum number of iterations to run, defaults to numCols + * @param maxError Maximum error tolerated in the result. If the norm of the residual falls below this, + * then the algorithm stops and returns. + * @return The vector that solves the system. + */ + public Vector runJob(Path inputPath, + Path tempPath, + int numRows, + int numCols, + Vector b, + Preconditioner preconditioner, + int maxIterations, + double maxError) { + DistributedRowMatrix matrix = new DistributedRowMatrix(inputPath, tempPath, numRows, numCols); + matrix.setConf(conf); + + return solve(matrix, b, preconditioner, maxIterations, maxError); + } + + @Override + public Configuration getConf() { + return conf; + } + + @Override + public void setConf(Configuration conf) { + this.conf = conf; + } + + @Override + public int run(String[] strings) throws Exception { + Path inputPath = new Path(AbstractJob.getOption(parsedArgs, "--input")); + Path outputPath = new Path(AbstractJob.getOption(parsedArgs, "--output")); + Path tempPath = new Path(AbstractJob.getOption(parsedArgs, "--tempDir")); + Path vectorPath = new Path(AbstractJob.getOption(parsedArgs, "--vector")); + int numRows = Integer.parseInt(AbstractJob.getOption(parsedArgs, "--numRows")); + int numCols = Integer.parseInt(AbstractJob.getOption(parsedArgs, "--numCols")); + int maxIterations = parsedArgs.containsKey("--maxIter") + ? Integer.parseInt(AbstractJob.getOption(parsedArgs, "--maxIter")) + : numCols + 2; + double maxError = parsedArgs.containsKey("--maxError") + ? Double.parseDouble(AbstractJob.getOption(parsedArgs, "--maxError")) + : ConjugateGradientSolver.DEFAULT_MAX_ERROR; + + Vector b = loadInputVector(vectorPath); + Vector x = runJob(inputPath, tempPath, numRows, numCols, b, null, maxIterations, maxError); + saveOutputVector(outputPath, x); + tempPath.getFileSystem(conf).delete(tempPath, true); + + return 0; + } + + public DistributedConjugateGradientSolverJob job() { + return new DistributedConjugateGradientSolverJob(); + } + + private Vector loadInputVector(Path path) throws IOException { + FileSystem fs = path.getFileSystem(conf); + try (SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, conf)) { + VectorWritable value = new VectorWritable(); + if (!reader.next(new IntWritable(), value)) { + throw new IOException("Input vector file is empty."); + } + return value.get(); + } + } + + private void saveOutputVector(Path path, Vector v) throws IOException { + FileSystem fs = path.getFileSystem(conf); + try (SequenceFile.Writer writer = + new SequenceFile.Writer(fs, conf, path, IntWritable.class, VectorWritable.class)) { + writer.append(new IntWritable(0), new VectorWritable(v)); + } + } + + public class DistributedConjugateGradientSolverJob extends AbstractJob { + @Override + public void setConf(Configuration conf) { + DistributedConjugateGradientSolver.this.setConf(conf); + } + + @Override + public Configuration getConf() { + return DistributedConjugateGradientSolver.this.getConf(); + } + + @Override + public int run(String[] args) throws Exception { + addInputOption(); + addOutputOption(); + addOption("numRows", "nr", "Number of rows in the input matrix", true); + addOption("numCols", "nc", "Number of columns in the input matrix", true); + addOption("vector", "b", "Vector to solve against", true); + addOption("lambda", "l", "Scalar in A + lambda * I [default = 0]", "0.0"); + addOption("symmetric", "sym", "Is the input matrix square and symmetric?", "true"); + addOption("maxIter", "x", "Maximum number of iterations to run"); + addOption("maxError", "err", "Maximum residual error to allow before stopping"); + + DistributedConjugateGradientSolver.this.parsedArgs = parseArguments(args); + if (DistributedConjugateGradientSolver.this.parsedArgs == null) { + return -1; + } else { + Configuration conf = getConf(); + if (conf == null) { + conf = new Configuration(); + } + DistributedConjugateGradientSolver.this.setConf(conf); + return DistributedConjugateGradientSolver.this.run(args); + } + } + } + + public static void main(String[] args) throws Exception { + ToolRunner.run(new DistributedConjugateGradientSolver().job(), args); + } +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stats/BasicStats.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stats/BasicStats.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stats/BasicStats.java new file mode 100644 index 0000000..ad0baf3 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stats/BasicStats.java @@ -0,0 +1,148 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.hadoop.stats; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.DoubleWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable; + +import java.io.IOException; + +/** + * Methods for calculating basic stats (mean, variance, stdDev, etc.) in map/reduce + */ +public final class BasicStats { + + private BasicStats() { + } + + /** + * Calculate the variance of values stored as + * + * @param input The input file containing the key and the count + * @param output The output to store the intermediate values + * @param baseConf + * @return The variance (based on sample estimation) + */ + public static double variance(Path input, Path output, + Configuration baseConf) + throws IOException, InterruptedException, ClassNotFoundException { + VarianceTotals varianceTotals = computeVarianceTotals(input, output, baseConf); + return varianceTotals.computeVariance(); + } + + /** + * Calculate the variance by a predefined mean of values stored as + * + * @param input The input file containing the key and the count + * @param output The output to store the intermediate values + * @param mean The mean based on which to compute the variance + * @param baseConf + * @return The variance (based on sample estimation) + */ + public static double varianceForGivenMean(Path input, Path output, double mean, + Configuration baseConf) + throws IOException, InterruptedException, ClassNotFoundException { + VarianceTotals varianceTotals = computeVarianceTotals(input, output, baseConf); + return varianceTotals.computeVarianceForGivenMean(mean); + } + + private static VarianceTotals computeVarianceTotals(Path input, Path output, + Configuration baseConf) throws IOException, InterruptedException, + ClassNotFoundException { + Configuration conf = new Configuration(baseConf); + conf.set("io.serializations", + "org.apache.hadoop.io.serializer.JavaSerialization," + + "org.apache.hadoop.io.serializer.WritableSerialization"); + Job job = HadoopUtil.prepareJob(input, output, SequenceFileInputFormat.class, + StandardDeviationCalculatorMapper.class, IntWritable.class, DoubleWritable.class, + StandardDeviationCalculatorReducer.class, IntWritable.class, DoubleWritable.class, + SequenceFileOutputFormat.class, conf); + HadoopUtil.delete(conf, output); + job.setCombinerClass(StandardDeviationCalculatorReducer.class); + boolean succeeded = job.waitForCompletion(true); + if (!succeeded) { + throw new IllegalStateException("Job failed!"); + } + + // Now extract the computed sum + Path filesPattern = new Path(output, "part-*"); + double sumOfSquares = 0; + double sum = 0; + double totalCount = 0; + for (Pair<Writable, Writable> record : new SequenceFileDirIterable<>( + filesPattern, PathType.GLOB, null, null, true, conf)) { + + int key = ((IntWritable) record.getFirst()).get(); + if (key == StandardDeviationCalculatorMapper.SUM_OF_SQUARES.get()) { + sumOfSquares += ((DoubleWritable) record.getSecond()).get(); + } else if (key == StandardDeviationCalculatorMapper.TOTAL_COUNT + .get()) { + totalCount += ((DoubleWritable) record.getSecond()).get(); + } else if (key == StandardDeviationCalculatorMapper.SUM + .get()) { + sum += ((DoubleWritable) record.getSecond()).get(); + } + } + + VarianceTotals varianceTotals = new VarianceTotals(); + varianceTotals.setSum(sum); + varianceTotals.setSumOfSquares(sumOfSquares); + varianceTotals.setTotalCount(totalCount); + + return varianceTotals; + } + + /** + * Calculate the standard deviation + * + * @param input The input file containing the key and the count + * @param output The output file to write the counting results to + * @param baseConf The base configuration + * @return The standard deviation + */ + public static double stdDev(Path input, Path output, + Configuration baseConf) throws IOException, InterruptedException, + ClassNotFoundException { + return Math.sqrt(variance(input, output, baseConf)); + } + + /** + * Calculate the standard deviation given a predefined mean + * + * @param input The input file containing the key and the count + * @param output The output file to write the counting results to + * @param mean The mean based on which to compute the standard deviation + * @param baseConf The base configuration + * @return The standard deviation + */ + public static double stdDevForGivenMean(Path input, Path output, double mean, + Configuration baseConf) throws IOException, InterruptedException, + ClassNotFoundException { + return Math.sqrt(varianceForGivenMean(input, output, mean, baseConf)); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stats/StandardDeviationCalculatorMapper.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stats/StandardDeviationCalculatorMapper.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stats/StandardDeviationCalculatorMapper.java new file mode 100644 index 0000000..03271da --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stats/StandardDeviationCalculatorMapper.java @@ -0,0 +1,55 @@ +package org.apache.mahout.math.hadoop.stats; +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.hadoop.io.DoubleWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.mapreduce.Mapper; + +import java.io.IOException; + +public class StandardDeviationCalculatorMapper extends + Mapper<IntWritable, Writable, IntWritable, DoubleWritable> { + + public static final IntWritable SUM_OF_SQUARES = new IntWritable(1); + public static final IntWritable SUM = new IntWritable(2); + public static final IntWritable TOTAL_COUNT = new IntWritable(-1); + + @Override + protected void map(IntWritable key, Writable value, Context context) + throws IOException, InterruptedException { + if (key.get() == -1) { + return; + } + //Kind of ugly, but such is life + double df = Double.NaN; + if (value instanceof LongWritable) { + df = ((LongWritable)value).get(); + } else if (value instanceof DoubleWritable) { + df = ((DoubleWritable)value).get(); + } + if (!Double.isNaN(df)) { + // For calculating the sum of squares + context.write(SUM_OF_SQUARES, new DoubleWritable(df * df)); + context.write(SUM, new DoubleWritable(df)); + // For calculating the total number of entries + context.write(TOTAL_COUNT, new DoubleWritable(1)); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stats/StandardDeviationCalculatorReducer.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stats/StandardDeviationCalculatorReducer.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stats/StandardDeviationCalculatorReducer.java new file mode 100644 index 0000000..0a27eec --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stats/StandardDeviationCalculatorReducer.java @@ -0,0 +1,37 @@ +package org.apache.mahout.math.hadoop.stats; +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.hadoop.io.DoubleWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.mapreduce.Reducer; + +import java.io.IOException; + +public class StandardDeviationCalculatorReducer extends + Reducer<IntWritable, DoubleWritable, IntWritable, DoubleWritable> { + + @Override + protected void reduce(IntWritable key, Iterable<DoubleWritable> values, + Context context) throws IOException, InterruptedException { + double sum = 0.0; + for (DoubleWritable value : values) { + sum += value.get(); + } + context.write(key, new DoubleWritable(sum)); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stats/VarianceTotals.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stats/VarianceTotals.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stats/VarianceTotals.java new file mode 100644 index 0000000..87448bc --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stats/VarianceTotals.java @@ -0,0 +1,68 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.hadoop.stats; + +/** + * Holds the total values needed to compute mean and standard deviation + * Provides methods for their computation + */ +public final class VarianceTotals { + + private double sumOfSquares; + private double sum; + private double totalCount; + + public double getSumOfSquares() { + return sumOfSquares; + } + + public void setSumOfSquares(double sumOfSquares) { + this.sumOfSquares = sumOfSquares; + } + + public double getSum() { + return sum; + } + + public void setSum(double sum) { + this.sum = sum; + } + + public double getTotalCount() { + return totalCount; + } + + public void setTotalCount(double totalCount) { + this.totalCount = totalCount; + } + + public double computeMean() { + return sum / totalCount; + } + + public double computeVariance() { + return ((totalCount * sumOfSquares) - (sum * sum)) + / (totalCount * (totalCount - 1.0)); + } + + public double computeVarianceForGivenMean(double mean) { + return (sumOfSquares - totalCount * mean * mean) + / (totalCount - 1.0); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/ABtDenseOutJob.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/ABtDenseOutJob.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/ABtDenseOutJob.java new file mode 100644 index 0000000..359b281 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/ABtDenseOutJob.java @@ -0,0 +1,585 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.hadoop.stochasticsvd; + +import java.io.Closeable; +import java.io.IOException; +import java.text.NumberFormat; +import java.util.ArrayDeque; +import java.util.Arrays; +import java.util.Deque; +import java.util.Iterator; +import java.util.regex.Matcher; + +import com.google.common.collect.Lists; +import org.apache.commons.lang3.Validate; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.filecache.DistributedCache; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.SequenceFile.CompressionType; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.mapred.JobConf; +import org.apache.hadoop.mapred.OutputCollector; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.IOUtils; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterator; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.SequentialAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.function.Functions; +import org.apache.mahout.math.hadoop.stochasticsvd.qr.QRFirstStep; + +/** + * Computes ABt products, then first step of QR which is pushed down to the + * reducer. + */ +@SuppressWarnings("deprecation") +public final class ABtDenseOutJob { + + public static final String PROP_BT_PATH = "ssvd.Bt.path"; + public static final String PROP_BT_BROADCAST = "ssvd.Bt.broadcast"; + public static final String PROP_SB_PATH = "ssvdpca.sb.path"; + public static final String PROP_SQ_PATH = "ssvdpca.sq.path"; + public static final String PROP_XI_PATH = "ssvdpca.xi.path"; + + private ABtDenseOutJob() { + } + + /** + * So, here, i preload A block into memory. + * <P> + * + * A sparse matrix seems to be ideal for that but there are two reasons why i + * am not using it: + * <UL> + * <LI>1) I don't know the full block height. so i may need to reallocate it + * from time to time. Although this probably not a showstopper. + * <LI>2) I found that RandomAccessSparseVectors seem to take much more memory + * than the SequentialAccessSparseVectors. + * </UL> + * <P> + * + */ + public static class ABtMapper + extends + Mapper<Writable, VectorWritable, SplitPartitionedWritable, DenseBlockWritable> { + + private SplitPartitionedWritable outKey; + private final Deque<Closeable> closeables = new ArrayDeque<>(); + private SequenceFileDirIterator<IntWritable, VectorWritable> btInput; + private Vector[] aCols; + private double[][] yiCols; + private int aRowCount; + private int kp; + private int blockHeight; + private boolean distributedBt; + private Path[] btLocalPath; + private Configuration localFsConfig; + /* + * xi and s_q are PCA-related corrections, per MAHOUT-817 + */ + protected Vector xi; + protected Vector sq; + + @Override + protected void map(Writable key, VectorWritable value, Context context) + throws IOException, InterruptedException { + + Vector vec = value.get(); + + int vecSize = vec.size(); + if (aCols == null) { + aCols = new Vector[vecSize]; + } else if (aCols.length < vecSize) { + aCols = Arrays.copyOf(aCols, vecSize); + } + + if (vec.isDense()) { + for (int i = 0; i < vecSize; i++) { + extendAColIfNeeded(i, aRowCount + 1); + aCols[i].setQuick(aRowCount, vec.getQuick(i)); + } + } else if (vec.size() > 0) { + for (Vector.Element vecEl : vec.nonZeroes()) { + int i = vecEl.index(); + extendAColIfNeeded(i, aRowCount + 1); + aCols[i].setQuick(aRowCount, vecEl.get()); + } + } + aRowCount++; + } + + private void extendAColIfNeeded(int col, int rowCount) { + if (aCols[col] == null) { + aCols[col] = + new SequentialAccessSparseVector(rowCount < blockHeight ? blockHeight + : rowCount, 1); + } else if (aCols[col].size() < rowCount) { + Vector newVec = + new SequentialAccessSparseVector(rowCount + blockHeight, + aCols[col].getNumNondefaultElements() << 1); + newVec.viewPart(0, aCols[col].size()).assign(aCols[col]); + aCols[col] = newVec; + } + } + + @Override + protected void cleanup(Context context) throws IOException, + InterruptedException { + try { + + yiCols = new double[kp][]; + + for (int i = 0; i < kp; i++) { + yiCols[i] = new double[Math.min(aRowCount, blockHeight)]; + } + + int numPasses = (aRowCount - 1) / blockHeight + 1; + + String propBtPathStr = context.getConfiguration().get(PROP_BT_PATH); + Validate.notNull(propBtPathStr, "Bt input is not set"); + Path btPath = new Path(propBtPathStr); + DenseBlockWritable dbw = new DenseBlockWritable(); + + /* + * so it turns out that it may be much more efficient to do a few + * independent passes over Bt accumulating the entire block in memory + * than pass huge amount of blocks out to combiner. so we aim of course + * to fit entire s x (k+p) dense block in memory where s is the number + * of A rows in this split. If A is much sparser than (k+p) avg # of + * elements per row then the block may exceed the split size. if this + * happens, and if the given blockHeight is not high enough to + * accomodate this (because of memory constraints), then we start + * splitting s into several passes. since computation is cpu-bound + * anyway, it should be o.k. for supersparse inputs. (as ok it can be + * that projection is thicker than the original anyway, why would one + * use that many k+p then). + */ + int lastRowIndex = -1; + for (int pass = 0; pass < numPasses; pass++) { + + if (distributedBt) { + + btInput = + new SequenceFileDirIterator<>(btLocalPath, true, localFsConfig); + + } else { + + btInput = + new SequenceFileDirIterator<>(btPath, PathType.GLOB, null, null, true, context.getConfiguration()); + } + closeables.addFirst(btInput); + Validate.isTrue(btInput.hasNext(), "Empty B' input!"); + + int aRowBegin = pass * blockHeight; + int bh = Math.min(blockHeight, aRowCount - aRowBegin); + + /* + * check if we need to trim block allocation + */ + if (pass > 0) { + if (bh == blockHeight) { + for (int i = 0; i < kp; i++) { + Arrays.fill(yiCols[i], 0.0); + } + } else { + + for (int i = 0; i < kp; i++) { + yiCols[i] = null; + } + for (int i = 0; i < kp; i++) { + yiCols[i] = new double[bh]; + } + } + } + + while (btInput.hasNext()) { + Pair<IntWritable, VectorWritable> btRec = btInput.next(); + int btIndex = btRec.getFirst().get(); + Vector btVec = btRec.getSecond().get(); + Vector aCol; + if (btIndex > aCols.length || (aCol = aCols[btIndex]) == null + || aCol.size() == 0) { + + /* 100% zero A column in the block, skip it as sparse */ + continue; + } + int j = -1; + for (Vector.Element aEl : aCol.nonZeroes()) { + j = aEl.index(); + + /* + * now we compute only swathes between aRowBegin..aRowBegin+bh + * exclusive. it seems like a deficiency but in fact i think it + * will balance itself out: either A is dense and then we + * shouldn't have more than one pass and therefore filter + * conditions will never kick in. Or, the only situation where we + * can't fit Y_i block in memory is when A input is much sparser + * than k+p per row. But if this is the case, then we'd be looking + * at very few elements without engaging them in any operations so + * even then it should be ok. + */ + if (j < aRowBegin) { + continue; + } + if (j >= aRowBegin + bh) { + break; + } + + /* + * assume btVec is dense + */ + if (xi != null) { + /* + * MAHOUT-817: PCA correction for B'. I rewrite the whole + * computation loop so i don't have to check if PCA correction + * is needed at individual element level. It looks bulkier this + * way but perhaps less wasteful on cpu. + */ + for (int s = 0; s < kp; s++) { + // code defensively against shortened xi + double xii = xi.size() > btIndex ? xi.get(btIndex) : 0.0; + yiCols[s][j - aRowBegin] += + aEl.get() * (btVec.getQuick(s) - xii * sq.get(s)); + } + } else { + /* + * no PCA correction + */ + for (int s = 0; s < kp; s++) { + yiCols[s][j - aRowBegin] += aEl.get() * btVec.getQuick(s); + } + } + + } + if (lastRowIndex < j) { + lastRowIndex = j; + } + } + + /* + * so now we have stuff in yi + */ + dbw.setBlock(yiCols); + outKey.setTaskItemOrdinal(pass); + context.write(outKey, dbw); + + closeables.remove(btInput); + btInput.close(); + } + + } finally { + IOUtils.close(closeables); + } + } + + @Override + protected void setup(Context context) throws IOException, + InterruptedException { + + Configuration conf = context.getConfiguration(); + int k = Integer.parseInt(conf.get(QRFirstStep.PROP_K)); + int p = Integer.parseInt(conf.get(QRFirstStep.PROP_P)); + kp = k + p; + + outKey = new SplitPartitionedWritable(context); + + blockHeight = conf.getInt(BtJob.PROP_OUTER_PROD_BLOCK_HEIGHT, -1); + distributedBt = conf.get(PROP_BT_BROADCAST) != null; + if (distributedBt) { + btLocalPath = HadoopUtil.getCachedFiles(conf); + localFsConfig = new Configuration(); + localFsConfig.set("fs.default.name", "file:///"); + } + + /* + * PCA -related corrections (MAHOUT-817) + */ + String xiPathStr = conf.get(PROP_XI_PATH); + if (xiPathStr != null) { + xi = SSVDHelper.loadAndSumUpVectors(new Path(xiPathStr), conf); + sq = + SSVDHelper.loadAndSumUpVectors(new Path(conf.get(PROP_SQ_PATH)), conf); + } + + } + } + + /** + * QR first step pushed down to reducer. + * + */ + public static class QRReducer + extends Reducer<SplitPartitionedWritable, DenseBlockWritable, SplitPartitionedWritable, VectorWritable> { + + /* + * HACK: partition number formats in hadoop, copied. this may stop working + * if it gets out of sync with newer hadoop version. But unfortunately rules + * of forming output file names are not sufficiently exposed so we need to + * hack it if we write the same split output from either mapper or reducer. + * alternatively, we probably can replace it by our own output file naming + * management completely and bypass MultipleOutputs entirely. + */ + + private static final NumberFormat NUMBER_FORMAT = + NumberFormat.getInstance(); + static { + NUMBER_FORMAT.setMinimumIntegerDigits(5); + NUMBER_FORMAT.setGroupingUsed(false); + } + + private final Deque<Closeable> closeables = Lists.newLinkedList(); + + protected int blockHeight; + + protected int lastTaskId = -1; + + protected OutputCollector<Writable, DenseBlockWritable> qhatCollector; + protected OutputCollector<Writable, VectorWritable> rhatCollector; + protected QRFirstStep qr; + protected Vector yiRow; + protected Vector sb; + + @Override + protected void setup(Context context) throws IOException, + InterruptedException { + Configuration conf = context.getConfiguration(); + blockHeight = conf.getInt(BtJob.PROP_OUTER_PROD_BLOCK_HEIGHT, -1); + String sbPathStr = conf.get(PROP_SB_PATH); + + /* + * PCA -related corrections (MAHOUT-817) + */ + if (sbPathStr != null) { + sb = SSVDHelper.loadAndSumUpVectors(new Path(sbPathStr), conf); + } + } + + protected void setupBlock(Context context, SplitPartitionedWritable spw) + throws InterruptedException, IOException { + IOUtils.close(closeables); + qhatCollector = + createOutputCollector(QJob.OUTPUT_QHAT, + spw, + context, + DenseBlockWritable.class); + rhatCollector = + createOutputCollector(QJob.OUTPUT_RHAT, + spw, + context, + VectorWritable.class); + qr = + new QRFirstStep(context.getConfiguration(), + qhatCollector, + rhatCollector); + closeables.addFirst(qr); + lastTaskId = spw.getTaskId(); + + } + + @Override + protected void reduce(SplitPartitionedWritable key, + Iterable<DenseBlockWritable> values, + Context context) throws IOException, + InterruptedException { + + if (key.getTaskId() != lastTaskId) { + setupBlock(context, key); + } + + Iterator<DenseBlockWritable> iter = values.iterator(); + DenseBlockWritable dbw = iter.next(); + double[][] yiCols = dbw.getBlock(); + if (iter.hasNext()) { + throw new IOException("Unexpected extra Y_i block in reducer input."); + } + + long blockBase = key.getTaskItemOrdinal() * blockHeight; + int bh = yiCols[0].length; + if (yiRow == null) { + yiRow = new DenseVector(yiCols.length); + } + + for (int k = 0; k < bh; k++) { + for (int j = 0; j < yiCols.length; j++) { + yiRow.setQuick(j, yiCols[j][k]); + } + + key.setTaskItemOrdinal(blockBase + k); + + // pca offset correction if any + if (sb != null) { + yiRow.assign(sb, Functions.MINUS); + } + + qr.collect(key, yiRow); + } + + } + + private Path getSplitFilePath(String name, + SplitPartitionedWritable spw, + Context context) throws InterruptedException, + IOException { + String uniqueFileName = FileOutputFormat.getUniqueFile(context, name, ""); + uniqueFileName = uniqueFileName.replaceFirst("-r-", "-m-"); + uniqueFileName = + uniqueFileName.replaceFirst("\\d+$", + Matcher.quoteReplacement(NUMBER_FORMAT.format(spw.getTaskId()))); + return new Path(FileOutputFormat.getWorkOutputPath(context), + uniqueFileName); + } + + /** + * key doesn't matter here, only value does. key always gets substituted by + * SPW. + * + * @param <K> + * bogus + */ + private <K, V> OutputCollector<K, V> createOutputCollector(String name, + final SplitPartitionedWritable spw, + Context ctx, + Class<V> valueClass) throws IOException, InterruptedException { + Path outputPath = getSplitFilePath(name, spw, ctx); + final SequenceFile.Writer w = + SequenceFile.createWriter(FileSystem.get(outputPath.toUri(), ctx.getConfiguration()), + ctx.getConfiguration(), + outputPath, + SplitPartitionedWritable.class, + valueClass); + closeables.addFirst(w); + return new OutputCollector<K, V>() { + @Override + public void collect(K key, V val) throws IOException { + w.append(spw, val); + } + }; + } + + @Override + protected void cleanup(Context context) throws IOException, + InterruptedException { + + IOUtils.close(closeables); + } + + } + + public static void run(Configuration conf, + Path[] inputAPaths, + Path inputBtGlob, + Path xiPath, + Path sqPath, + Path sbPath, + Path outputPath, + int aBlockRows, + int minSplitSize, + int k, + int p, + int outerProdBlockHeight, + int numReduceTasks, + boolean broadcastBInput) + throws ClassNotFoundException, InterruptedException, IOException { + + JobConf oldApiJob = new JobConf(conf); + + Job job = new Job(oldApiJob); + job.setJobName("ABt-job"); + job.setJarByClass(ABtDenseOutJob.class); + + job.setInputFormatClass(SequenceFileInputFormat.class); + FileInputFormat.setInputPaths(job, inputAPaths); + if (minSplitSize > 0) { + FileInputFormat.setMinInputSplitSize(job, minSplitSize); + } + + FileOutputFormat.setOutputPath(job, outputPath); + + SequenceFileOutputFormat.setOutputCompressionType(job, + CompressionType.BLOCK); + + job.setMapOutputKeyClass(SplitPartitionedWritable.class); + job.setMapOutputValueClass(DenseBlockWritable.class); + + job.setOutputKeyClass(SplitPartitionedWritable.class); + job.setOutputValueClass(VectorWritable.class); + + job.setMapperClass(ABtMapper.class); + job.setReducerClass(QRReducer.class); + + job.getConfiguration().setInt(QJob.PROP_AROWBLOCK_SIZE, aBlockRows); + job.getConfiguration().setInt(BtJob.PROP_OUTER_PROD_BLOCK_HEIGHT, + outerProdBlockHeight); + job.getConfiguration().setInt(QRFirstStep.PROP_K, k); + job.getConfiguration().setInt(QRFirstStep.PROP_P, p); + job.getConfiguration().set(PROP_BT_PATH, inputBtGlob.toString()); + + /* + * PCA-related options, MAHOUT-817 + */ + if (xiPath != null) { + job.getConfiguration().set(PROP_XI_PATH, xiPath.toString()); + job.getConfiguration().set(PROP_SB_PATH, sbPath.toString()); + job.getConfiguration().set(PROP_SQ_PATH, sqPath.toString()); + } + + job.setNumReduceTasks(numReduceTasks); + + // broadcast Bt files if required. + if (broadcastBInput) { + job.getConfiguration().set(PROP_BT_BROADCAST, "y"); + + FileSystem fs = FileSystem.get(inputBtGlob.toUri(), conf); + FileStatus[] fstats = fs.globStatus(inputBtGlob); + if (fstats != null) { + for (FileStatus fstat : fstats) { + /* + * new api is not enabled yet in our dependencies at this time, still + * using deprecated one + */ + DistributedCache.addCacheFile(fstat.getPath().toUri(), + job.getConfiguration()); + } + } + } + + job.submit(); + job.waitForCompletion(false); + + if (!job.isSuccessful()) { + throw new IOException("ABt job unsuccessful."); + } + + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/ABtJob.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/ABtJob.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/ABtJob.java new file mode 100644 index 0000000..afa1463 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/ABtJob.java @@ -0,0 +1,494 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.hadoop.stochasticsvd; + +import java.io.Closeable; +import java.io.IOException; +import java.text.NumberFormat; +import java.util.ArrayDeque; +import java.util.Arrays; +import java.util.Deque; +import java.util.regex.Matcher; + +import com.google.common.collect.Lists; +import org.apache.commons.lang3.Validate; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.filecache.DistributedCache; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.SequenceFile.CompressionType; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.mapred.JobConf; +import org.apache.hadoop.mapred.OutputCollector; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.IOUtils; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterator; +import org.apache.mahout.math.SequentialAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.hadoop.stochasticsvd.qr.QRFirstStep; + +/** + * Computes ABt products, then first step of QR which is pushed down to the + * reducer. + * + */ +@SuppressWarnings("deprecation") +public final class ABtJob { + + public static final String PROP_BT_PATH = "ssvd.Bt.path"; + public static final String PROP_BT_BROADCAST = "ssvd.Bt.broadcast"; + + private ABtJob() { + } + + /** + * So, here, i preload A block into memory. + * <P> + * + * A sparse matrix seems to be ideal for that but there are two reasons why i + * am not using it: + * <UL> + * <LI>1) I don't know the full block height. so i may need to reallocate it + * from time to time. Although this probably not a showstopper. + * <LI>2) I found that RandomAccessSparseVectors seem to take much more memory + * than the SequentialAccessSparseVectors. + * </UL> + * <P> + * + */ + public static class ABtMapper + extends + Mapper<Writable, VectorWritable, SplitPartitionedWritable, SparseRowBlockWritable> { + + private SplitPartitionedWritable outKey; + private final Deque<Closeable> closeables = new ArrayDeque<>(); + private SequenceFileDirIterator<IntWritable, VectorWritable> btInput; + private Vector[] aCols; + // private Vector[] yiRows; + // private VectorWritable outValue = new VectorWritable(); + private int aRowCount; + private int kp; + private int blockHeight; + private SparseRowBlockAccumulator yiCollector; + + @Override + protected void map(Writable key, VectorWritable value, Context context) + throws IOException, InterruptedException { + + Vector vec = value.get(); + + int vecSize = vec.size(); + if (aCols == null) { + aCols = new Vector[vecSize]; + } else if (aCols.length < vecSize) { + aCols = Arrays.copyOf(aCols, vecSize); + } + + if (vec.isDense()) { + for (int i = 0; i < vecSize; i++) { + extendAColIfNeeded(i, aRowCount + 1); + aCols[i].setQuick(aRowCount, vec.getQuick(i)); + } + } else { + for (Vector.Element vecEl : vec.nonZeroes()) { + int i = vecEl.index(); + extendAColIfNeeded(i, aRowCount + 1); + aCols[i].setQuick(aRowCount, vecEl.get()); + } + } + aRowCount++; + } + + private void extendAColIfNeeded(int col, int rowCount) { + if (aCols[col] == null) { + aCols[col] = + new SequentialAccessSparseVector(rowCount < 10000 ? 10000 : rowCount, + 1); + } else if (aCols[col].size() < rowCount) { + Vector newVec = + new SequentialAccessSparseVector(rowCount << 1, + aCols[col].getNumNondefaultElements() << 1); + newVec.viewPart(0, aCols[col].size()).assign(aCols[col]); + aCols[col] = newVec; + } + } + + @Override + protected void cleanup(Context context) throws IOException, + InterruptedException { + try { + // yiRows= new Vector[aRowCount]; + + int lastRowIndex = -1; + + while (btInput.hasNext()) { + Pair<IntWritable, VectorWritable> btRec = btInput.next(); + int btIndex = btRec.getFirst().get(); + Vector btVec = btRec.getSecond().get(); + Vector aCol; + if (btIndex > aCols.length || (aCol = aCols[btIndex]) == null) { + continue; + } + int j = -1; + for (Vector.Element aEl : aCol.nonZeroes()) { + j = aEl.index(); + + // outKey.setTaskItemOrdinal(j); + // outValue.set(btVec.times(aEl.get())); // assign might work better + // // with memory after all. + // context.write(outKey, outValue); + yiCollector.collect((long) j, btVec.times(aEl.get())); + } + if (lastRowIndex < j) { + lastRowIndex = j; + } + } + aCols = null; + + // output empty rows if we never output partial products for them + // this happens in sparse matrices when last rows are all zeros + // and is subsequently causing shorter Q matrix row count which we + // probably don't want to repair there but rather here. + Vector yDummy = new SequentialAccessSparseVector(kp); + // outValue.set(yDummy); + for (lastRowIndex += 1; lastRowIndex < aRowCount; lastRowIndex++) { + // outKey.setTaskItemOrdinal(lastRowIndex); + // context.write(outKey, outValue); + + yiCollector.collect((long) lastRowIndex, yDummy); + } + + } finally { + IOUtils.close(closeables); + } + } + + @Override + protected void setup(final Context context) throws IOException, + InterruptedException { + + int k = + Integer.parseInt(context.getConfiguration().get(QRFirstStep.PROP_K)); + int p = + Integer.parseInt(context.getConfiguration().get(QRFirstStep.PROP_P)); + kp = k + p; + + outKey = new SplitPartitionedWritable(context); + String propBtPathStr = context.getConfiguration().get(PROP_BT_PATH); + Validate.notNull(propBtPathStr, "Bt input is not set"); + Path btPath = new Path(propBtPathStr); + + boolean distributedBt = + context.getConfiguration().get(PROP_BT_BROADCAST) != null; + + if (distributedBt) { + + Path[] btFiles = HadoopUtil.getCachedFiles(context.getConfiguration()); + + // DEBUG: stdout + //System.out.printf("list of files: " + btFiles); + + StringBuilder btLocalPath = new StringBuilder(); + for (Path btFile : btFiles) { + if (btLocalPath.length() > 0) { + btLocalPath.append(Path.SEPARATOR_CHAR); + } + btLocalPath.append(btFile); + } + + btInput = + new SequenceFileDirIterator<>(new Path(btLocalPath.toString()), + PathType.LIST, + null, + null, + true, + context.getConfiguration()); + + } else { + + btInput = + new SequenceFileDirIterator<>(btPath, PathType.GLOB, null, null, true, context.getConfiguration()); + } + // TODO: how do i release all that stuff?? + closeables.addFirst(btInput); + OutputCollector<LongWritable, SparseRowBlockWritable> yiBlockCollector = + new OutputCollector<LongWritable, SparseRowBlockWritable>() { + + @Override + public void collect(LongWritable blockKey, + SparseRowBlockWritable block) throws IOException { + outKey.setTaskItemOrdinal((int) blockKey.get()); + try { + context.write(outKey, block); + } catch (InterruptedException exc) { + throw new IOException("Interrupted", exc); + } + } + }; + blockHeight = + context.getConfiguration().getInt(BtJob.PROP_OUTER_PROD_BLOCK_HEIGHT, + -1); + yiCollector = + new SparseRowBlockAccumulator(blockHeight, yiBlockCollector); + closeables.addFirst(yiCollector); + } + + } + + /** + * QR first step pushed down to reducer. + * + */ + public static class QRReducer + extends + Reducer<SplitPartitionedWritable, SparseRowBlockWritable, SplitPartitionedWritable, VectorWritable> { + + // hack: partition number formats in hadoop, copied. this may stop working + // if it gets + // out of sync with newer hadoop version. But unfortunately rules of forming + // output file names are not sufficiently exposed so we need to hack it + // if we write the same split output from either mapper or reducer. + // alternatively, we probably can replace it by our own output file namnig + // management + // completely and bypass MultipleOutputs entirely. + + private static final NumberFormat NUMBER_FORMAT = + NumberFormat.getInstance(); + static { + NUMBER_FORMAT.setMinimumIntegerDigits(5); + NUMBER_FORMAT.setGroupingUsed(false); + } + + private final Deque<Closeable> closeables = Lists.newLinkedList(); + protected final SparseRowBlockWritable accum = new SparseRowBlockWritable(); + + protected int blockHeight; + + protected int lastTaskId = -1; + + protected OutputCollector<Writable, DenseBlockWritable> qhatCollector; + protected OutputCollector<Writable, VectorWritable> rhatCollector; + protected QRFirstStep qr; + + @Override + protected void setup(Context context) throws IOException, + InterruptedException { + blockHeight = + context.getConfiguration().getInt(BtJob.PROP_OUTER_PROD_BLOCK_HEIGHT, + -1); + + } + + protected void setupBlock(Context context, SplitPartitionedWritable spw) + throws InterruptedException, IOException { + IOUtils.close(closeables); + qhatCollector = + createOutputCollector(QJob.OUTPUT_QHAT, + spw, + context, + DenseBlockWritable.class); + rhatCollector = + createOutputCollector(QJob.OUTPUT_RHAT, + spw, + context, + VectorWritable.class); + qr = + new QRFirstStep(context.getConfiguration(), + qhatCollector, + rhatCollector); + closeables.addFirst(qr); + lastTaskId = spw.getTaskId(); + + } + + @Override + protected void reduce(SplitPartitionedWritable key, + Iterable<SparseRowBlockWritable> values, + Context context) throws IOException, + InterruptedException { + + accum.clear(); + for (SparseRowBlockWritable bw : values) { + accum.plusBlock(bw); + } + + if (key.getTaskId() != lastTaskId) { + setupBlock(context, key); + } + + long blockBase = key.getTaskItemOrdinal() * blockHeight; + for (int k = 0; k < accum.getNumRows(); k++) { + Vector yiRow = accum.getRows()[k]; + key.setTaskItemOrdinal(blockBase + accum.getRowIndices()[k]); + qr.collect(key, yiRow); + } + + } + + private Path getSplitFilePath(String name, + SplitPartitionedWritable spw, + Context context) throws InterruptedException, + IOException { + String uniqueFileName = FileOutputFormat.getUniqueFile(context, name, ""); + uniqueFileName = uniqueFileName.replaceFirst("-r-", "-m-"); + uniqueFileName = + uniqueFileName.replaceFirst("\\d+$", + Matcher.quoteReplacement(NUMBER_FORMAT.format(spw.getTaskId()))); + return new Path(FileOutputFormat.getWorkOutputPath(context), + uniqueFileName); + } + + /** + * key doesn't matter here, only value does. key always gets substituted by + * SPW. + */ + private <K,V> OutputCollector<K,V> createOutputCollector(String name, + final SplitPartitionedWritable spw, + Context ctx, + Class<V> valueClass) + throws IOException, InterruptedException { + Path outputPath = getSplitFilePath(name, spw, ctx); + final SequenceFile.Writer w = + SequenceFile.createWriter(FileSystem.get(outputPath.toUri(), ctx.getConfiguration()), + ctx.getConfiguration(), + outputPath, + SplitPartitionedWritable.class, + valueClass); + closeables.addFirst(w); + return new OutputCollector<K, V>() { + @Override + public void collect(K key, V val) throws IOException { + w.append(spw, val); + } + }; + } + + @Override + protected void cleanup(Context context) throws IOException, InterruptedException { + IOUtils.close(closeables); + } + + } + + public static void run(Configuration conf, + Path[] inputAPaths, + Path inputBtGlob, + Path outputPath, + int aBlockRows, + int minSplitSize, + int k, + int p, + int outerProdBlockHeight, + int numReduceTasks, + boolean broadcastBInput) + throws ClassNotFoundException, InterruptedException, IOException { + + JobConf oldApiJob = new JobConf(conf); + + // MultipleOutputs + // .addNamedOutput(oldApiJob, + // QJob.OUTPUT_QHAT, + // org.apache.hadoop.mapred.SequenceFileOutputFormat.class, + // SplitPartitionedWritable.class, + // DenseBlockWritable.class); + // + // MultipleOutputs + // .addNamedOutput(oldApiJob, + // QJob.OUTPUT_RHAT, + // org.apache.hadoop.mapred.SequenceFileOutputFormat.class, + // SplitPartitionedWritable.class, + // VectorWritable.class); + + Job job = new Job(oldApiJob); + job.setJobName("ABt-job"); + job.setJarByClass(ABtJob.class); + + job.setInputFormatClass(SequenceFileInputFormat.class); + FileInputFormat.setInputPaths(job, inputAPaths); + if (minSplitSize > 0) { + FileInputFormat.setMinInputSplitSize(job, minSplitSize); + } + + FileOutputFormat.setOutputPath(job, outputPath); + + SequenceFileOutputFormat.setOutputCompressionType(job, + CompressionType.BLOCK); + + job.setMapOutputKeyClass(SplitPartitionedWritable.class); + job.setMapOutputValueClass(SparseRowBlockWritable.class); + + job.setOutputKeyClass(SplitPartitionedWritable.class); + job.setOutputValueClass(VectorWritable.class); + + job.setMapperClass(ABtMapper.class); + job.setCombinerClass(BtJob.OuterProductCombiner.class); + job.setReducerClass(QRReducer.class); + + job.getConfiguration().setInt(QJob.PROP_AROWBLOCK_SIZE, aBlockRows); + job.getConfiguration().setInt(BtJob.PROP_OUTER_PROD_BLOCK_HEIGHT, + outerProdBlockHeight); + job.getConfiguration().setInt(QRFirstStep.PROP_K, k); + job.getConfiguration().setInt(QRFirstStep.PROP_P, p); + job.getConfiguration().set(PROP_BT_PATH, inputBtGlob.toString()); + + // number of reduce tasks doesn't matter. we don't actually + // send anything to reducers. + + job.setNumReduceTasks(numReduceTasks); + + // broadcast Bt files if required. + if (broadcastBInput) { + job.getConfiguration().set(PROP_BT_BROADCAST, "y"); + + FileSystem fs = FileSystem.get(inputBtGlob.toUri(), conf); + FileStatus[] fstats = fs.globStatus(inputBtGlob); + if (fstats != null) { + for (FileStatus fstat : fstats) { + /* + * new api is not enabled yet in our dependencies at this time, still + * using deprecated one + */ + DistributedCache.addCacheFile(fstat.getPath().toUri(), conf); + } + } + } + + job.submit(); + job.waitForCompletion(false); + + if (!job.isSuccessful()) { + throw new IOException("ABt job unsuccessful."); + } + + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/BtJob.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/BtJob.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/BtJob.java new file mode 100644 index 0000000..1277bae --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/BtJob.java @@ -0,0 +1,628 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.hadoop.stochasticsvd; + +import org.apache.commons.lang3.Validate; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.filecache.DistributedCache; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.SequenceFile.CompressionType; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.io.compress.DefaultCodec; +import org.apache.hadoop.mapred.JobConf; +import org.apache.hadoop.mapred.OutputCollector; +import org.apache.hadoop.mapred.lib.MultipleOutputs; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.IOUtils; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterator; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterator; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.NamedVector; +import org.apache.mahout.math.UpperTriangular; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.function.Functions; +import org.apache.mahout.math.function.PlusMult; +import org.apache.mahout.math.hadoop.stochasticsvd.qr.QRLastStep; + +import java.io.Closeable; +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.Deque; + +/** + * Bt job. For details, see working notes in MAHOUT-376. + * <p/> + * <p/> + * Uses hadoop deprecated API wherever new api has not been updated + * (MAHOUT-593), hence @SuppressWarning("deprecation"). + * <p/> + * <p/> + * This job outputs either Bt in its standard output, or upper triangular + * matrices representing BBt partial sums if that's requested . If the latter + * mode is enabled, then we accumulate BBt outer product sums in upper + * triangular accumulator and output it at the end of the job, thus saving space + * and BBt job. + * <p/> + * <p/> + * This job also outputs Q and Bt and optionally BBt. Bt is output to standard + * job output (part-*) and Q and BBt use named multiple outputs. + * <p/> + * <p/> + */ +@SuppressWarnings("deprecation") +public final class BtJob { + + public static final String OUTPUT_Q = "Q"; + public static final String OUTPUT_BT = "part"; + public static final String OUTPUT_BBT = "bbt"; + public static final String OUTPUT_SQ = "sq"; + public static final String OUTPUT_SB = "sb"; + + public static final String PROP_QJOB_PATH = "ssvd.QJob.path"; + public static final String PROP_OUPTUT_BBT_PRODUCTS = + "ssvd.BtJob.outputBBtProducts"; + public static final String PROP_OUTER_PROD_BLOCK_HEIGHT = + "ssvd.outerProdBlockHeight"; + public static final String PROP_RHAT_BROADCAST = "ssvd.rhat.broadcast"; + public static final String PROP_XI_PATH = "ssvdpca.xi.path"; + public static final String PROP_NV = "ssvd.nv"; + + private BtJob() { + } + + public static class BtMapper extends + Mapper<Writable, VectorWritable, LongWritable, SparseRowBlockWritable> { + + private QRLastStep qr; + private final Deque<Closeable> closeables = new ArrayDeque<>(); + + private int blockNum; + private MultipleOutputs outputs; + private final VectorWritable qRowValue = new VectorWritable(); + private Vector btRow; + private SparseRowBlockAccumulator btCollector; + private Context mapContext; + private boolean nv; + + // pca stuff + private Vector sqAccum; + private boolean computeSq; + + /** + * We maintain A and QtHat inputs partitioned the same way, so we + * essentially are performing map-side merge here of A and QtHats except + * QtHat is stored not row-wise but block-wise. + */ + @Override + protected void map(Writable key, VectorWritable value, Context context) + throws IOException, InterruptedException { + + mapContext = context; + // output Bt outer products + Vector aRow = value.get(); + + Vector qRow = qr.next(); + int kp = qRow.size(); + + // make sure Qs are inheriting A row labels. + outputQRow(key, qRow, aRow); + + // MAHOUT-817 + if (computeSq) { + if (sqAccum == null) { + sqAccum = new DenseVector(kp); + } + sqAccum.assign(qRow, Functions.PLUS); + } + + if (btRow == null) { + btRow = new DenseVector(kp); + } + + if (!aRow.isDense()) { + for (Vector.Element el : aRow.nonZeroes()) { + double mul = el.get(); + for (int j = 0; j < kp; j++) { + btRow.setQuick(j, mul * qRow.getQuick(j)); + } + btCollector.collect((long) el.index(), btRow); + } + } else { + int n = aRow.size(); + for (int i = 0; i < n; i++) { + double mul = aRow.getQuick(i); + for (int j = 0; j < kp; j++) { + btRow.setQuick(j, mul * qRow.getQuick(j)); + } + btCollector.collect((long) i, btRow); + } + } + } + + @Override + protected void setup(Context context) throws IOException, + InterruptedException { + super.setup(context); + + Configuration conf = context.getConfiguration(); + + Path qJobPath = new Path(conf.get(PROP_QJOB_PATH)); + + /* + * actually this is kind of dangerous because this routine thinks we need + * to create file name for our current job and this will use -m- so it's + * just serendipity we are calling it from the mapper too as the QJob did. + */ + Path qInputPath = + new Path(qJobPath, FileOutputFormat.getUniqueFile(context, + QJob.OUTPUT_QHAT, + "")); + blockNum = context.getTaskAttemptID().getTaskID().getId(); + + SequenceFileValueIterator<DenseBlockWritable> qhatInput = + new SequenceFileValueIterator<>(qInputPath, + true, + conf); + closeables.addFirst(qhatInput); + + /* + * read all r files _in order of task ids_, i.e. partitions (aka group + * nums). + * + * Note: if broadcast option is used, this comes from distributed cache + * files rather than hdfs path. + */ + + SequenceFileDirValueIterator<VectorWritable> rhatInput; + + boolean distributedRHat = conf.get(PROP_RHAT_BROADCAST) != null; + if (distributedRHat) { + + Path[] rFiles = HadoopUtil.getCachedFiles(conf); + + Validate.notNull(rFiles, + "no RHat files in distributed cache job definition"); + //TODO: this probably can be replaced w/ local fs makeQualified + Configuration lconf = new Configuration(); + lconf.set("fs.default.name", "file:///"); + + rhatInput = + new SequenceFileDirValueIterator<>(rFiles, + SSVDHelper.PARTITION_COMPARATOR, + true, + lconf); + + } else { + Path rPath = new Path(qJobPath, QJob.OUTPUT_RHAT + "-*"); + rhatInput = + new SequenceFileDirValueIterator<>(rPath, + PathType.GLOB, + null, + SSVDHelper.PARTITION_COMPARATOR, + true, + conf); + } + + Validate.isTrue(rhatInput.hasNext(), "Empty R-hat input!"); + + closeables.addFirst(rhatInput); + outputs = new MultipleOutputs(new JobConf(conf)); + closeables.addFirst(new IOUtils.MultipleOutputsCloseableAdapter(outputs)); + + qr = new QRLastStep(qhatInput, rhatInput, blockNum); + closeables.addFirst(qr); + /* + * it's so happens that current QRLastStep's implementation preloads R + * sequence into memory in the constructor so it's ok to close rhat input + * now. + */ + if (!rhatInput.hasNext()) { + closeables.remove(rhatInput); + rhatInput.close(); + } + + OutputCollector<LongWritable, SparseRowBlockWritable> btBlockCollector = + new OutputCollector<LongWritable, SparseRowBlockWritable>() { + + @Override + public void collect(LongWritable blockKey, + SparseRowBlockWritable block) throws IOException { + try { + mapContext.write(blockKey, block); + } catch (InterruptedException exc) { + throw new IOException("Interrupted.", exc); + } + } + }; + + btCollector = + new SparseRowBlockAccumulator(conf.getInt(PROP_OUTER_PROD_BLOCK_HEIGHT, + -1), btBlockCollector); + closeables.addFirst(btCollector); + + // MAHOUT-817 + computeSq = conf.get(PROP_XI_PATH) != null; + + // MAHOUT-1067 + nv = conf.getBoolean(PROP_NV, false); + + } + + @Override + protected void cleanup(Context context) throws IOException, + InterruptedException { + try { + if (sqAccum != null) { + /* + * hack: we will output sq partial sums with index -1 for summation. + */ + SparseRowBlockWritable sbrw = new SparseRowBlockWritable(1); + sbrw.plusRow(0, sqAccum); + LongWritable lw = new LongWritable(-1); + context.write(lw, sbrw); + } + } finally { + IOUtils.close(closeables); + } + } + + @SuppressWarnings("unchecked") + private void outputQRow(Writable key, Vector qRow, Vector aRow) throws IOException { + if (nv && (aRow instanceof NamedVector)) { + qRowValue.set(new NamedVector(qRow, ((NamedVector) aRow).getName())); + } else { + qRowValue.set(qRow); + } + outputs.getCollector(OUTPUT_Q, null).collect(key, qRowValue); + } + } + + public static class OuterProductCombiner + extends + Reducer<Writable, SparseRowBlockWritable, Writable, SparseRowBlockWritable> { + + protected final SparseRowBlockWritable accum = new SparseRowBlockWritable(); + protected final Deque<Closeable> closeables = new ArrayDeque<>(); + protected int blockHeight; + + @Override + protected void setup(Context context) throws IOException, + InterruptedException { + blockHeight = + context.getConfiguration().getInt(PROP_OUTER_PROD_BLOCK_HEIGHT, -1); + } + + @Override + protected void reduce(Writable key, + Iterable<SparseRowBlockWritable> values, + Context context) throws IOException, + InterruptedException { + for (SparseRowBlockWritable bw : values) { + accum.plusBlock(bw); + } + context.write(key, accum); + accum.clear(); + } + + @Override + protected void cleanup(Context context) throws IOException, + InterruptedException { + + IOUtils.close(closeables); + } + } + + public static class OuterProductReducer + extends + Reducer<LongWritable, SparseRowBlockWritable, IntWritable, VectorWritable> { + + protected final SparseRowBlockWritable accum = new SparseRowBlockWritable(); + protected final Deque<Closeable> closeables = new ArrayDeque<>(); + + protected int blockHeight; + private boolean outputBBt; + private UpperTriangular mBBt; + private MultipleOutputs outputs; + private final IntWritable btKey = new IntWritable(); + private final VectorWritable btValue = new VectorWritable(); + + // MAHOUT-817 + private Vector xi; + private final PlusMult pmult = new PlusMult(0); + private Vector sbAccum; + + @Override + protected void setup(Context context) throws IOException, + InterruptedException { + + Configuration conf = context.getConfiguration(); + blockHeight = conf.getInt(PROP_OUTER_PROD_BLOCK_HEIGHT, -1); + + outputBBt = conf.getBoolean(PROP_OUPTUT_BBT_PRODUCTS, false); + + if (outputBBt) { + int k = conf.getInt(QJob.PROP_K, -1); + int p = conf.getInt(QJob.PROP_P, -1); + + Validate.isTrue(k > 0, "invalid k parameter"); + Validate.isTrue(p >= 0, "invalid p parameter"); + mBBt = new UpperTriangular(k + p); + + } + + String xiPathStr = conf.get(PROP_XI_PATH); + if (xiPathStr != null) { + xi = SSVDHelper.loadAndSumUpVectors(new Path(xiPathStr), conf); + if (xi == null) { + throw new IOException(String.format("unable to load mean path xi from %s.", + xiPathStr)); + } + } + + if (outputBBt || xi != null) { + outputs = new MultipleOutputs(new JobConf(conf)); + closeables.addFirst(new IOUtils.MultipleOutputsCloseableAdapter(outputs)); + } + + } + + @Override + protected void reduce(LongWritable key, + Iterable<SparseRowBlockWritable> values, + Context context) throws IOException, + InterruptedException { + + accum.clear(); + for (SparseRowBlockWritable bw : values) { + accum.plusBlock(bw); + } + + // MAHOUT-817: + if (key.get() == -1L) { + + Vector sq = accum.getRows()[0]; + + @SuppressWarnings("unchecked") + OutputCollector<IntWritable, VectorWritable> sqOut = + outputs.getCollector(OUTPUT_SQ, null); + + sqOut.collect(new IntWritable(0), new VectorWritable(sq)); + return; + } + + /* + * at this point, sum of rows should be in accum, so we just generate + * outer self product of it and add to BBt accumulator. + */ + + for (int k = 0; k < accum.getNumRows(); k++) { + Vector btRow = accum.getRows()[k]; + btKey.set((int) (key.get() * blockHeight + accum.getRowIndices()[k])); + btValue.set(btRow); + context.write(btKey, btValue); + + if (outputBBt) { + int kp = mBBt.numRows(); + // accumulate partial BBt sum + for (int i = 0; i < kp; i++) { + double vi = btRow.get(i); + if (vi != 0.0) { + for (int j = i; j < kp; j++) { + double vj = btRow.get(j); + if (vj != 0.0) { + mBBt.setQuick(i, j, mBBt.getQuick(i, j) + vi * vj); + } + } + } + } + } + + // MAHOUT-817 + if (xi != null) { + // code defensively against shortened xi + int btIndex = btKey.get(); + double xii = xi.size() > btIndex ? xi.getQuick(btIndex) : 0.0; + // compute s_b + pmult.setMultiplicator(xii); + if (sbAccum == null) { + sbAccum = new DenseVector(btRow.size()); + } + sbAccum.assign(btRow, pmult); + } + + } + } + + @Override + protected void cleanup(Context context) throws IOException, + InterruptedException { + + // if we output BBt instead of Bt then we need to do it. + try { + if (outputBBt) { + + @SuppressWarnings("unchecked") + OutputCollector<Writable, Writable> collector = + outputs.getCollector(OUTPUT_BBT, null); + + collector.collect(new IntWritable(), + new VectorWritable(new DenseVector(mBBt.getData()))); + } + + // MAHOUT-817 + if (sbAccum != null) { + @SuppressWarnings("unchecked") + OutputCollector<IntWritable, VectorWritable> collector = + outputs.getCollector(OUTPUT_SB, null); + + collector.collect(new IntWritable(), new VectorWritable(sbAccum)); + + } + } finally { + IOUtils.close(closeables); + } + + } + } + + public static void run(Configuration conf, + Path[] inputPathA, + Path inputPathQJob, + Path xiPath, + Path outputPath, + int minSplitSize, + int k, + int p, + int btBlockHeight, + int numReduceTasks, + boolean broadcast, + Class<? extends Writable> labelClass, + boolean outputBBtProducts) + throws ClassNotFoundException, InterruptedException, IOException { + + JobConf oldApiJob = new JobConf(conf); + + MultipleOutputs.addNamedOutput(oldApiJob, + OUTPUT_Q, + org.apache.hadoop.mapred.SequenceFileOutputFormat.class, + labelClass, + VectorWritable.class); + + if (outputBBtProducts) { + MultipleOutputs.addNamedOutput(oldApiJob, + OUTPUT_BBT, + org.apache.hadoop.mapred.SequenceFileOutputFormat.class, + IntWritable.class, + VectorWritable.class); + /* + * MAHOUT-1067: if we are asked to output BBT products then named vector + * names should be propagated to Q too so that UJob could pick them up + * from there. + */ + oldApiJob.setBoolean(PROP_NV, true); + } + if (xiPath != null) { + // compute pca -related stuff as well + MultipleOutputs.addNamedOutput(oldApiJob, + OUTPUT_SQ, + org.apache.hadoop.mapred.SequenceFileOutputFormat.class, + IntWritable.class, + VectorWritable.class); + MultipleOutputs.addNamedOutput(oldApiJob, + OUTPUT_SB, + org.apache.hadoop.mapred.SequenceFileOutputFormat.class, + IntWritable.class, + VectorWritable.class); + } + + /* + * HACK: we use old api multiple outputs since they are not available in the + * new api of either 0.20.2 or 0.20.203 but wrap it into a new api job so we + * can use new api interfaces. + */ + + Job job = new Job(oldApiJob); + job.setJobName("Bt-job"); + job.setJarByClass(BtJob.class); + + job.setInputFormatClass(SequenceFileInputFormat.class); + job.setOutputFormatClass(SequenceFileOutputFormat.class); + FileInputFormat.setInputPaths(job, inputPathA); + if (minSplitSize > 0) { + FileInputFormat.setMinInputSplitSize(job, minSplitSize); + } + FileOutputFormat.setOutputPath(job, outputPath); + + // WARN: tight hadoop integration here: + job.getConfiguration().set("mapreduce.output.basename", OUTPUT_BT); + + FileOutputFormat.setOutputCompressorClass(job, DefaultCodec.class); + SequenceFileOutputFormat.setOutputCompressionType(job, + CompressionType.BLOCK); + + job.setMapOutputKeyClass(LongWritable.class); + job.setMapOutputValueClass(SparseRowBlockWritable.class); + + job.setOutputKeyClass(IntWritable.class); + job.setOutputValueClass(VectorWritable.class); + + job.setMapperClass(BtMapper.class); + job.setCombinerClass(OuterProductCombiner.class); + job.setReducerClass(OuterProductReducer.class); + + job.getConfiguration().setInt(QJob.PROP_K, k); + job.getConfiguration().setInt(QJob.PROP_P, p); + job.getConfiguration().set(PROP_QJOB_PATH, inputPathQJob.toString()); + job.getConfiguration().setBoolean(PROP_OUPTUT_BBT_PRODUCTS, + outputBBtProducts); + job.getConfiguration().setInt(PROP_OUTER_PROD_BLOCK_HEIGHT, btBlockHeight); + + job.setNumReduceTasks(numReduceTasks); + + /* + * PCA-related options, MAHOUT-817 + */ + if (xiPath != null) { + job.getConfiguration().set(PROP_XI_PATH, xiPath.toString()); + } + + /* + * we can broadhast Rhat files since all of them are reuqired by each job, + * but not Q files which correspond to splits of A (so each split of A will + * require only particular Q file, each time different one). + */ + + if (broadcast) { + job.getConfiguration().set(PROP_RHAT_BROADCAST, "y"); + + FileSystem fs = FileSystem.get(inputPathQJob.toUri(), conf); + FileStatus[] fstats = + fs.globStatus(new Path(inputPathQJob, QJob.OUTPUT_RHAT + "-*")); + if (fstats != null) { + for (FileStatus fstat : fstats) { + /* + * new api is not enabled yet in our dependencies at this time, still + * using deprecated one + */ + DistributedCache.addCacheFile(fstat.getPath().toUri(), + job.getConfiguration()); + } + } + } + + job.submit(); + job.waitForCompletion(false); + + if (!job.isSuccessful()) { + throw new IOException("Bt job unsuccessful."); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/DenseBlockWritable.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/DenseBlockWritable.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/DenseBlockWritable.java new file mode 100644 index 0000000..6a9b352 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/DenseBlockWritable.java @@ -0,0 +1,83 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.math.hadoop.stochasticsvd; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.Arrays; + +import org.apache.hadoop.io.Writable; + +/** + * Ad-hoc substitution for {@link org.apache.mahout.math.MatrixWritable}. + * Perhaps more useful for situations with mostly dense data (such as Q-blocks) + * but reduces GC by reusing the same block memory between loads and writes. + * <p> + * + * in case of Q blocks, it doesn't even matter if they this data is dense cause + * we need to unpack it into dense for fast access in computations anyway and + * even if it is not so dense the block compressor in sequence files will take + * care of it for the serialized size. + * <p> + */ +public class DenseBlockWritable implements Writable { + private double[][] block; + + public void setBlock(double[][] block) { + this.block = block; + } + + public double[][] getBlock() { + return block; + } + + @Override + public void readFields(DataInput in) throws IOException { + int m = in.readInt(); + int n = in.readInt(); + if (block == null) { + block = new double[m][0]; + } else if (block.length != m) { + block = Arrays.copyOf(block, m); + } + for (int i = 0; i < m; i++) { + if (block[i] == null || block[i].length != n) { + block[i] = new double[n]; + } + for (int j = 0; j < n; j++) { + block[i][j] = in.readDouble(); + } + + } + } + + @Override + public void write(DataOutput out) throws IOException { + int m = block.length; + int n = block.length == 0 ? 0 : block[0].length; + + out.writeInt(m); + out.writeInt(n); + for (double[] aBlock : block) { + for (int j = 0; j < n; j++) { + out.writeDouble(aBlock[j]); + } + } + } + +}
