http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/EuclideanDistanceSimilarity.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/EuclideanDistanceSimilarity.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/EuclideanDistanceSimilarity.java new file mode 100644 index 0000000..e61c3eb --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/EuclideanDistanceSimilarity.java @@ -0,0 +1,57 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.hadoop.similarity.cooccurrence.measures; + +import org.apache.mahout.math.Vector; + +public class EuclideanDistanceSimilarity implements VectorSimilarityMeasure { + + @Override + public Vector normalize(Vector vector) { + return vector; + } + + @Override + public double norm(Vector vector) { + double norm = 0; + for (Vector.Element e : vector.nonZeroes()) { + double value = e.get(); + norm += value * value; + } + return norm; + } + + @Override + public double aggregate(double valueA, double nonZeroValueB) { + return valueA * nonZeroValueB; + } + + @Override + public double similarity(double dots, double normA, double normB, int numberOfColumns) { + // Arg can't be negative in theory, but can in practice due to rounding, so cap it. + // Also note that normA / normB are actually the squares of the norms. + double euclideanDistance = Math.sqrt(Math.max(0.0, normA - 2 * dots + normB)); + return 1.0 / (1.0 + euclideanDistance); + } + + @Override + public boolean consider(int numNonZeroEntriesA, int numNonZeroEntriesB, double maxValueA, double maxValueB, + double threshold) { + return true; + } +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/LoglikelihoodSimilarity.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/LoglikelihoodSimilarity.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/LoglikelihoodSimilarity.java new file mode 100644 index 0000000..7544b5d --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/LoglikelihoodSimilarity.java @@ -0,0 +1,34 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.hadoop.similarity.cooccurrence.measures; + +import org.apache.mahout.math.stats.LogLikelihood; + +public class LoglikelihoodSimilarity extends CountbasedMeasure { + + @Override + public double similarity(double summedAggregations, double normA, double normB, int numberOfColumns) { + double logLikelihood = + LogLikelihood.logLikelihoodRatio((long) summedAggregations, + (long) (normB - summedAggregations), + (long) (normA - summedAggregations), + (long) (numberOfColumns - normA - normB + summedAggregations)); + return 1.0 - 1.0 / (1.0 + logLikelihood); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/PearsonCorrelationSimilarity.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/PearsonCorrelationSimilarity.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/PearsonCorrelationSimilarity.java new file mode 100644 index 0000000..c650d8f --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/PearsonCorrelationSimilarity.java @@ -0,0 +1,37 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.hadoop.similarity.cooccurrence.measures; + +import org.apache.mahout.math.Vector; + +public class PearsonCorrelationSimilarity extends CosineSimilarity { + + @Override + public Vector normalize(Vector vector) { + if (vector.getNumNondefaultElements() == 0) { + return vector; + } + + // center non-zero elements + double average = vector.norm(1) / vector.getNumNonZeroElements(); + for (Vector.Element e : vector.nonZeroes()) { + e.set(e.get() - average); + } + return super.normalize(vector); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/TanimotoCoefficientSimilarity.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/TanimotoCoefficientSimilarity.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/TanimotoCoefficientSimilarity.java new file mode 100644 index 0000000..e000579 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/TanimotoCoefficientSimilarity.java @@ -0,0 +1,34 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.hadoop.similarity.cooccurrence.measures; + +public class TanimotoCoefficientSimilarity extends CountbasedMeasure { + + @Override + public double similarity(double dots, double normA, double normB, int numberOfColumns) { + // Return 0 even when dots == 0 since this will cause it to be ignored -- not NaN + return dots / (normA + normB - dots); + } + + @Override + public boolean consider(int numNonZeroEntriesA, int numNonZeroEntriesB, double maxValueA, double maxValueB, + double threshold) { + return numNonZeroEntriesA >= numNonZeroEntriesB * threshold + && numNonZeroEntriesB >= numNonZeroEntriesA * threshold; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/VectorSimilarityMeasure.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/VectorSimilarityMeasure.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/VectorSimilarityMeasure.java new file mode 100644 index 0000000..77125c2 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/VectorSimilarityMeasure.java @@ -0,0 +1,32 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.hadoop.similarity.cooccurrence.measures; + +import org.apache.mahout.math.Vector; + +public interface VectorSimilarityMeasure { + + double NO_NORM = 0.0; + + Vector normalize(Vector vector); + double norm(Vector vector); + double aggregate(double nonZeroValueA, double nonZeroValueB); + double similarity(double summedAggregations, double normA, double normB, int numberOfColumns); + boolean consider(int numNonZeroEntriesA, int numNonZeroEntriesB, double maxValueA, double maxValueB, + double threshold); +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/VectorSimilarityMeasures.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/VectorSimilarityMeasures.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/VectorSimilarityMeasures.java new file mode 100644 index 0000000..9d1160e --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/VectorSimilarityMeasures.java @@ -0,0 +1,46 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.hadoop.similarity.cooccurrence.measures; + +import java.util.Arrays; + +public enum VectorSimilarityMeasures { + + SIMILARITY_COOCCURRENCE(CooccurrenceCountSimilarity.class), + SIMILARITY_LOGLIKELIHOOD(LoglikelihoodSimilarity.class), + SIMILARITY_TANIMOTO_COEFFICIENT(TanimotoCoefficientSimilarity.class), + SIMILARITY_CITY_BLOCK(CityBlockSimilarity.class), + SIMILARITY_COSINE(CosineSimilarity.class), + SIMILARITY_PEARSON_CORRELATION(PearsonCorrelationSimilarity.class), + SIMILARITY_EUCLIDEAN_DISTANCE(EuclideanDistanceSimilarity.class); + + private final Class<? extends VectorSimilarityMeasure> implementingClass; + + VectorSimilarityMeasures(Class<? extends VectorSimilarityMeasure> impl) { + this.implementingClass = impl; + } + + public String getClassname() { + return implementingClass.getName(); + } + + public static String list() { + return Arrays.toString(values()); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/solver/DistributedConjugateGradientSolver.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/solver/DistributedConjugateGradientSolver.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/solver/DistributedConjugateGradientSolver.java new file mode 100644 index 0000000..dd38971 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/solver/DistributedConjugateGradientSolver.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.hadoop.solver; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.util.Tool; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.common.AbstractJob; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.hadoop.DistributedRowMatrix; +import org.apache.mahout.math.solver.ConjugateGradientSolver; +import org.apache.mahout.math.solver.Preconditioner; + +/** + * Distributed implementation of the conjugate gradient solver. More or less, this is just the standard solver + * but wrapped with some methods that make it easy to run it on a DistributedRowMatrix. + */ +public class DistributedConjugateGradientSolver extends ConjugateGradientSolver implements Tool { + + private Configuration conf; + private Map<String, List<String>> parsedArgs; + + /** + * + * Runs the distributed conjugate gradient solver programmatically to solve the system (A + lambda*I)x = b. + * + * @param inputPath Path to the matrix A + * @param tempPath Path to scratch output path, deleted after the solver completes + * @param numRows Number of rows in A + * @param numCols Number of columns in A + * @param b Vector b + * @param preconditioner Optional preconditioner for the system + * @param maxIterations Maximum number of iterations to run, defaults to numCols + * @param maxError Maximum error tolerated in the result. If the norm of the residual falls below this, + * then the algorithm stops and returns. + * @return The vector that solves the system. + */ + public Vector runJob(Path inputPath, + Path tempPath, + int numRows, + int numCols, + Vector b, + Preconditioner preconditioner, + int maxIterations, + double maxError) { + DistributedRowMatrix matrix = new DistributedRowMatrix(inputPath, tempPath, numRows, numCols); + matrix.setConf(conf); + + return solve(matrix, b, preconditioner, maxIterations, maxError); + } + + @Override + public Configuration getConf() { + return conf; + } + + @Override + public void setConf(Configuration conf) { + this.conf = conf; + } + + @Override + public int run(String[] strings) throws Exception { + Path inputPath = new Path(AbstractJob.getOption(parsedArgs, "--input")); + Path outputPath = new Path(AbstractJob.getOption(parsedArgs, "--output")); + Path tempPath = new Path(AbstractJob.getOption(parsedArgs, "--tempDir")); + Path vectorPath = new Path(AbstractJob.getOption(parsedArgs, "--vector")); + int numRows = Integer.parseInt(AbstractJob.getOption(parsedArgs, "--numRows")); + int numCols = Integer.parseInt(AbstractJob.getOption(parsedArgs, "--numCols")); + int maxIterations = parsedArgs.containsKey("--maxIter") + ? Integer.parseInt(AbstractJob.getOption(parsedArgs, "--maxIter")) + : numCols + 2; + double maxError = parsedArgs.containsKey("--maxError") + ? Double.parseDouble(AbstractJob.getOption(parsedArgs, "--maxError")) + : ConjugateGradientSolver.DEFAULT_MAX_ERROR; + + Vector b = loadInputVector(vectorPath); + Vector x = runJob(inputPath, tempPath, numRows, numCols, b, null, maxIterations, maxError); + saveOutputVector(outputPath, x); + tempPath.getFileSystem(conf).delete(tempPath, true); + + return 0; + } + + public DistributedConjugateGradientSolverJob job() { + return new DistributedConjugateGradientSolverJob(); + } + + private Vector loadInputVector(Path path) throws IOException { + FileSystem fs = path.getFileSystem(conf); + try (SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, conf)) { + VectorWritable value = new VectorWritable(); + if (!reader.next(new IntWritable(), value)) { + throw new IOException("Input vector file is empty."); + } + return value.get(); + } + } + + private void saveOutputVector(Path path, Vector v) throws IOException { + FileSystem fs = path.getFileSystem(conf); + try (SequenceFile.Writer writer = + new SequenceFile.Writer(fs, conf, path, IntWritable.class, VectorWritable.class)) { + writer.append(new IntWritable(0), new VectorWritable(v)); + } + } + + public class DistributedConjugateGradientSolverJob extends AbstractJob { + @Override + public void setConf(Configuration conf) { + DistributedConjugateGradientSolver.this.setConf(conf); + } + + @Override + public Configuration getConf() { + return DistributedConjugateGradientSolver.this.getConf(); + } + + @Override + public int run(String[] args) throws Exception { + addInputOption(); + addOutputOption(); + addOption("numRows", "nr", "Number of rows in the input matrix", true); + addOption("numCols", "nc", "Number of columns in the input matrix", true); + addOption("vector", "b", "Vector to solve against", true); + addOption("lambda", "l", "Scalar in A + lambda * I [default = 0]", "0.0"); + addOption("symmetric", "sym", "Is the input matrix square and symmetric?", "true"); + addOption("maxIter", "x", "Maximum number of iterations to run"); + addOption("maxError", "err", "Maximum residual error to allow before stopping"); + + DistributedConjugateGradientSolver.this.parsedArgs = parseArguments(args); + if (DistributedConjugateGradientSolver.this.parsedArgs == null) { + return -1; + } else { + Configuration conf = getConf(); + if (conf == null) { + conf = new Configuration(); + } + DistributedConjugateGradientSolver.this.setConf(conf); + return DistributedConjugateGradientSolver.this.run(args); + } + } + } + + public static void main(String[] args) throws Exception { + ToolRunner.run(new DistributedConjugateGradientSolver().job(), args); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/stats/BasicStats.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/stats/BasicStats.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/stats/BasicStats.java new file mode 100644 index 0000000..ad0baf3 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/stats/BasicStats.java @@ -0,0 +1,148 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.hadoop.stats; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.DoubleWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable; + +import java.io.IOException; + +/** + * Methods for calculating basic stats (mean, variance, stdDev, etc.) in map/reduce + */ +public final class BasicStats { + + private BasicStats() { + } + + /** + * Calculate the variance of values stored as + * + * @param input The input file containing the key and the count + * @param output The output to store the intermediate values + * @param baseConf + * @return The variance (based on sample estimation) + */ + public static double variance(Path input, Path output, + Configuration baseConf) + throws IOException, InterruptedException, ClassNotFoundException { + VarianceTotals varianceTotals = computeVarianceTotals(input, output, baseConf); + return varianceTotals.computeVariance(); + } + + /** + * Calculate the variance by a predefined mean of values stored as + * + * @param input The input file containing the key and the count + * @param output The output to store the intermediate values + * @param mean The mean based on which to compute the variance + * @param baseConf + * @return The variance (based on sample estimation) + */ + public static double varianceForGivenMean(Path input, Path output, double mean, + Configuration baseConf) + throws IOException, InterruptedException, ClassNotFoundException { + VarianceTotals varianceTotals = computeVarianceTotals(input, output, baseConf); + return varianceTotals.computeVarianceForGivenMean(mean); + } + + private static VarianceTotals computeVarianceTotals(Path input, Path output, + Configuration baseConf) throws IOException, InterruptedException, + ClassNotFoundException { + Configuration conf = new Configuration(baseConf); + conf.set("io.serializations", + "org.apache.hadoop.io.serializer.JavaSerialization," + + "org.apache.hadoop.io.serializer.WritableSerialization"); + Job job = HadoopUtil.prepareJob(input, output, SequenceFileInputFormat.class, + StandardDeviationCalculatorMapper.class, IntWritable.class, DoubleWritable.class, + StandardDeviationCalculatorReducer.class, IntWritable.class, DoubleWritable.class, + SequenceFileOutputFormat.class, conf); + HadoopUtil.delete(conf, output); + job.setCombinerClass(StandardDeviationCalculatorReducer.class); + boolean succeeded = job.waitForCompletion(true); + if (!succeeded) { + throw new IllegalStateException("Job failed!"); + } + + // Now extract the computed sum + Path filesPattern = new Path(output, "part-*"); + double sumOfSquares = 0; + double sum = 0; + double totalCount = 0; + for (Pair<Writable, Writable> record : new SequenceFileDirIterable<>( + filesPattern, PathType.GLOB, null, null, true, conf)) { + + int key = ((IntWritable) record.getFirst()).get(); + if (key == StandardDeviationCalculatorMapper.SUM_OF_SQUARES.get()) { + sumOfSquares += ((DoubleWritable) record.getSecond()).get(); + } else if (key == StandardDeviationCalculatorMapper.TOTAL_COUNT + .get()) { + totalCount += ((DoubleWritable) record.getSecond()).get(); + } else if (key == StandardDeviationCalculatorMapper.SUM + .get()) { + sum += ((DoubleWritable) record.getSecond()).get(); + } + } + + VarianceTotals varianceTotals = new VarianceTotals(); + varianceTotals.setSum(sum); + varianceTotals.setSumOfSquares(sumOfSquares); + varianceTotals.setTotalCount(totalCount); + + return varianceTotals; + } + + /** + * Calculate the standard deviation + * + * @param input The input file containing the key and the count + * @param output The output file to write the counting results to + * @param baseConf The base configuration + * @return The standard deviation + */ + public static double stdDev(Path input, Path output, + Configuration baseConf) throws IOException, InterruptedException, + ClassNotFoundException { + return Math.sqrt(variance(input, output, baseConf)); + } + + /** + * Calculate the standard deviation given a predefined mean + * + * @param input The input file containing the key and the count + * @param output The output file to write the counting results to + * @param mean The mean based on which to compute the standard deviation + * @param baseConf The base configuration + * @return The standard deviation + */ + public static double stdDevForGivenMean(Path input, Path output, double mean, + Configuration baseConf) throws IOException, InterruptedException, + ClassNotFoundException { + return Math.sqrt(varianceForGivenMean(input, output, mean, baseConf)); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/stats/StandardDeviationCalculatorMapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/stats/StandardDeviationCalculatorMapper.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/stats/StandardDeviationCalculatorMapper.java new file mode 100644 index 0000000..03271da --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/stats/StandardDeviationCalculatorMapper.java @@ -0,0 +1,55 @@ +package org.apache.mahout.math.hadoop.stats; +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.hadoop.io.DoubleWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.mapreduce.Mapper; + +import java.io.IOException; + +public class StandardDeviationCalculatorMapper extends + Mapper<IntWritable, Writable, IntWritable, DoubleWritable> { + + public static final IntWritable SUM_OF_SQUARES = new IntWritable(1); + public static final IntWritable SUM = new IntWritable(2); + public static final IntWritable TOTAL_COUNT = new IntWritable(-1); + + @Override + protected void map(IntWritable key, Writable value, Context context) + throws IOException, InterruptedException { + if (key.get() == -1) { + return; + } + //Kind of ugly, but such is life + double df = Double.NaN; + if (value instanceof LongWritable) { + df = ((LongWritable)value).get(); + } else if (value instanceof DoubleWritable) { + df = ((DoubleWritable)value).get(); + } + if (!Double.isNaN(df)) { + // For calculating the sum of squares + context.write(SUM_OF_SQUARES, new DoubleWritable(df * df)); + context.write(SUM, new DoubleWritable(df)); + // For calculating the total number of entries + context.write(TOTAL_COUNT, new DoubleWritable(1)); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/stats/StandardDeviationCalculatorReducer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/stats/StandardDeviationCalculatorReducer.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/stats/StandardDeviationCalculatorReducer.java new file mode 100644 index 0000000..0a27eec --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/stats/StandardDeviationCalculatorReducer.java @@ -0,0 +1,37 @@ +package org.apache.mahout.math.hadoop.stats; +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.hadoop.io.DoubleWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.mapreduce.Reducer; + +import java.io.IOException; + +public class StandardDeviationCalculatorReducer extends + Reducer<IntWritable, DoubleWritable, IntWritable, DoubleWritable> { + + @Override + protected void reduce(IntWritable key, Iterable<DoubleWritable> values, + Context context) throws IOException, InterruptedException { + double sum = 0.0; + for (DoubleWritable value : values) { + sum += value.get(); + } + context.write(key, new DoubleWritable(sum)); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/stats/VarianceTotals.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/stats/VarianceTotals.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/stats/VarianceTotals.java new file mode 100644 index 0000000..87448bc --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/stats/VarianceTotals.java @@ -0,0 +1,68 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.hadoop.stats; + +/** + * Holds the total values needed to compute mean and standard deviation + * Provides methods for their computation + */ +public final class VarianceTotals { + + private double sumOfSquares; + private double sum; + private double totalCount; + + public double getSumOfSquares() { + return sumOfSquares; + } + + public void setSumOfSquares(double sumOfSquares) { + this.sumOfSquares = sumOfSquares; + } + + public double getSum() { + return sum; + } + + public void setSum(double sum) { + this.sum = sum; + } + + public double getTotalCount() { + return totalCount; + } + + public void setTotalCount(double totalCount) { + this.totalCount = totalCount; + } + + public double computeMean() { + return sum / totalCount; + } + + public double computeVariance() { + return ((totalCount * sumOfSquares) - (sum * sum)) + / (totalCount * (totalCount - 1.0)); + } + + public double computeVarianceForGivenMean(double mean) { + return (sumOfSquares - totalCount * mean * mean) + / (totalCount - 1.0); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/ABtDenseOutJob.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/ABtDenseOutJob.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/ABtDenseOutJob.java new file mode 100644 index 0000000..359b281 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/ABtDenseOutJob.java @@ -0,0 +1,585 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.hadoop.stochasticsvd; + +import java.io.Closeable; +import java.io.IOException; +import java.text.NumberFormat; +import java.util.ArrayDeque; +import java.util.Arrays; +import java.util.Deque; +import java.util.Iterator; +import java.util.regex.Matcher; + +import com.google.common.collect.Lists; +import org.apache.commons.lang3.Validate; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.filecache.DistributedCache; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.SequenceFile.CompressionType; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.mapred.JobConf; +import org.apache.hadoop.mapred.OutputCollector; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.IOUtils; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterator; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.SequentialAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.function.Functions; +import org.apache.mahout.math.hadoop.stochasticsvd.qr.QRFirstStep; + +/** + * Computes ABt products, then first step of QR which is pushed down to the + * reducer. + */ +@SuppressWarnings("deprecation") +public final class ABtDenseOutJob { + + public static final String PROP_BT_PATH = "ssvd.Bt.path"; + public static final String PROP_BT_BROADCAST = "ssvd.Bt.broadcast"; + public static final String PROP_SB_PATH = "ssvdpca.sb.path"; + public static final String PROP_SQ_PATH = "ssvdpca.sq.path"; + public static final String PROP_XI_PATH = "ssvdpca.xi.path"; + + private ABtDenseOutJob() { + } + + /** + * So, here, i preload A block into memory. + * <P> + * + * A sparse matrix seems to be ideal for that but there are two reasons why i + * am not using it: + * <UL> + * <LI>1) I don't know the full block height. so i may need to reallocate it + * from time to time. Although this probably not a showstopper. + * <LI>2) I found that RandomAccessSparseVectors seem to take much more memory + * than the SequentialAccessSparseVectors. + * </UL> + * <P> + * + */ + public static class ABtMapper + extends + Mapper<Writable, VectorWritable, SplitPartitionedWritable, DenseBlockWritable> { + + private SplitPartitionedWritable outKey; + private final Deque<Closeable> closeables = new ArrayDeque<>(); + private SequenceFileDirIterator<IntWritable, VectorWritable> btInput; + private Vector[] aCols; + private double[][] yiCols; + private int aRowCount; + private int kp; + private int blockHeight; + private boolean distributedBt; + private Path[] btLocalPath; + private Configuration localFsConfig; + /* + * xi and s_q are PCA-related corrections, per MAHOUT-817 + */ + protected Vector xi; + protected Vector sq; + + @Override + protected void map(Writable key, VectorWritable value, Context context) + throws IOException, InterruptedException { + + Vector vec = value.get(); + + int vecSize = vec.size(); + if (aCols == null) { + aCols = new Vector[vecSize]; + } else if (aCols.length < vecSize) { + aCols = Arrays.copyOf(aCols, vecSize); + } + + if (vec.isDense()) { + for (int i = 0; i < vecSize; i++) { + extendAColIfNeeded(i, aRowCount + 1); + aCols[i].setQuick(aRowCount, vec.getQuick(i)); + } + } else if (vec.size() > 0) { + for (Vector.Element vecEl : vec.nonZeroes()) { + int i = vecEl.index(); + extendAColIfNeeded(i, aRowCount + 1); + aCols[i].setQuick(aRowCount, vecEl.get()); + } + } + aRowCount++; + } + + private void extendAColIfNeeded(int col, int rowCount) { + if (aCols[col] == null) { + aCols[col] = + new SequentialAccessSparseVector(rowCount < blockHeight ? blockHeight + : rowCount, 1); + } else if (aCols[col].size() < rowCount) { + Vector newVec = + new SequentialAccessSparseVector(rowCount + blockHeight, + aCols[col].getNumNondefaultElements() << 1); + newVec.viewPart(0, aCols[col].size()).assign(aCols[col]); + aCols[col] = newVec; + } + } + + @Override + protected void cleanup(Context context) throws IOException, + InterruptedException { + try { + + yiCols = new double[kp][]; + + for (int i = 0; i < kp; i++) { + yiCols[i] = new double[Math.min(aRowCount, blockHeight)]; + } + + int numPasses = (aRowCount - 1) / blockHeight + 1; + + String propBtPathStr = context.getConfiguration().get(PROP_BT_PATH); + Validate.notNull(propBtPathStr, "Bt input is not set"); + Path btPath = new Path(propBtPathStr); + DenseBlockWritable dbw = new DenseBlockWritable(); + + /* + * so it turns out that it may be much more efficient to do a few + * independent passes over Bt accumulating the entire block in memory + * than pass huge amount of blocks out to combiner. so we aim of course + * to fit entire s x (k+p) dense block in memory where s is the number + * of A rows in this split. If A is much sparser than (k+p) avg # of + * elements per row then the block may exceed the split size. if this + * happens, and if the given blockHeight is not high enough to + * accomodate this (because of memory constraints), then we start + * splitting s into several passes. since computation is cpu-bound + * anyway, it should be o.k. for supersparse inputs. (as ok it can be + * that projection is thicker than the original anyway, why would one + * use that many k+p then). + */ + int lastRowIndex = -1; + for (int pass = 0; pass < numPasses; pass++) { + + if (distributedBt) { + + btInput = + new SequenceFileDirIterator<>(btLocalPath, true, localFsConfig); + + } else { + + btInput = + new SequenceFileDirIterator<>(btPath, PathType.GLOB, null, null, true, context.getConfiguration()); + } + closeables.addFirst(btInput); + Validate.isTrue(btInput.hasNext(), "Empty B' input!"); + + int aRowBegin = pass * blockHeight; + int bh = Math.min(blockHeight, aRowCount - aRowBegin); + + /* + * check if we need to trim block allocation + */ + if (pass > 0) { + if (bh == blockHeight) { + for (int i = 0; i < kp; i++) { + Arrays.fill(yiCols[i], 0.0); + } + } else { + + for (int i = 0; i < kp; i++) { + yiCols[i] = null; + } + for (int i = 0; i < kp; i++) { + yiCols[i] = new double[bh]; + } + } + } + + while (btInput.hasNext()) { + Pair<IntWritable, VectorWritable> btRec = btInput.next(); + int btIndex = btRec.getFirst().get(); + Vector btVec = btRec.getSecond().get(); + Vector aCol; + if (btIndex > aCols.length || (aCol = aCols[btIndex]) == null + || aCol.size() == 0) { + + /* 100% zero A column in the block, skip it as sparse */ + continue; + } + int j = -1; + for (Vector.Element aEl : aCol.nonZeroes()) { + j = aEl.index(); + + /* + * now we compute only swathes between aRowBegin..aRowBegin+bh + * exclusive. it seems like a deficiency but in fact i think it + * will balance itself out: either A is dense and then we + * shouldn't have more than one pass and therefore filter + * conditions will never kick in. Or, the only situation where we + * can't fit Y_i block in memory is when A input is much sparser + * than k+p per row. But if this is the case, then we'd be looking + * at very few elements without engaging them in any operations so + * even then it should be ok. + */ + if (j < aRowBegin) { + continue; + } + if (j >= aRowBegin + bh) { + break; + } + + /* + * assume btVec is dense + */ + if (xi != null) { + /* + * MAHOUT-817: PCA correction for B'. I rewrite the whole + * computation loop so i don't have to check if PCA correction + * is needed at individual element level. It looks bulkier this + * way but perhaps less wasteful on cpu. + */ + for (int s = 0; s < kp; s++) { + // code defensively against shortened xi + double xii = xi.size() > btIndex ? xi.get(btIndex) : 0.0; + yiCols[s][j - aRowBegin] += + aEl.get() * (btVec.getQuick(s) - xii * sq.get(s)); + } + } else { + /* + * no PCA correction + */ + for (int s = 0; s < kp; s++) { + yiCols[s][j - aRowBegin] += aEl.get() * btVec.getQuick(s); + } + } + + } + if (lastRowIndex < j) { + lastRowIndex = j; + } + } + + /* + * so now we have stuff in yi + */ + dbw.setBlock(yiCols); + outKey.setTaskItemOrdinal(pass); + context.write(outKey, dbw); + + closeables.remove(btInput); + btInput.close(); + } + + } finally { + IOUtils.close(closeables); + } + } + + @Override + protected void setup(Context context) throws IOException, + InterruptedException { + + Configuration conf = context.getConfiguration(); + int k = Integer.parseInt(conf.get(QRFirstStep.PROP_K)); + int p = Integer.parseInt(conf.get(QRFirstStep.PROP_P)); + kp = k + p; + + outKey = new SplitPartitionedWritable(context); + + blockHeight = conf.getInt(BtJob.PROP_OUTER_PROD_BLOCK_HEIGHT, -1); + distributedBt = conf.get(PROP_BT_BROADCAST) != null; + if (distributedBt) { + btLocalPath = HadoopUtil.getCachedFiles(conf); + localFsConfig = new Configuration(); + localFsConfig.set("fs.default.name", "file:///"); + } + + /* + * PCA -related corrections (MAHOUT-817) + */ + String xiPathStr = conf.get(PROP_XI_PATH); + if (xiPathStr != null) { + xi = SSVDHelper.loadAndSumUpVectors(new Path(xiPathStr), conf); + sq = + SSVDHelper.loadAndSumUpVectors(new Path(conf.get(PROP_SQ_PATH)), conf); + } + + } + } + + /** + * QR first step pushed down to reducer. + * + */ + public static class QRReducer + extends Reducer<SplitPartitionedWritable, DenseBlockWritable, SplitPartitionedWritable, VectorWritable> { + + /* + * HACK: partition number formats in hadoop, copied. this may stop working + * if it gets out of sync with newer hadoop version. But unfortunately rules + * of forming output file names are not sufficiently exposed so we need to + * hack it if we write the same split output from either mapper or reducer. + * alternatively, we probably can replace it by our own output file naming + * management completely and bypass MultipleOutputs entirely. + */ + + private static final NumberFormat NUMBER_FORMAT = + NumberFormat.getInstance(); + static { + NUMBER_FORMAT.setMinimumIntegerDigits(5); + NUMBER_FORMAT.setGroupingUsed(false); + } + + private final Deque<Closeable> closeables = Lists.newLinkedList(); + + protected int blockHeight; + + protected int lastTaskId = -1; + + protected OutputCollector<Writable, DenseBlockWritable> qhatCollector; + protected OutputCollector<Writable, VectorWritable> rhatCollector; + protected QRFirstStep qr; + protected Vector yiRow; + protected Vector sb; + + @Override + protected void setup(Context context) throws IOException, + InterruptedException { + Configuration conf = context.getConfiguration(); + blockHeight = conf.getInt(BtJob.PROP_OUTER_PROD_BLOCK_HEIGHT, -1); + String sbPathStr = conf.get(PROP_SB_PATH); + + /* + * PCA -related corrections (MAHOUT-817) + */ + if (sbPathStr != null) { + sb = SSVDHelper.loadAndSumUpVectors(new Path(sbPathStr), conf); + } + } + + protected void setupBlock(Context context, SplitPartitionedWritable spw) + throws InterruptedException, IOException { + IOUtils.close(closeables); + qhatCollector = + createOutputCollector(QJob.OUTPUT_QHAT, + spw, + context, + DenseBlockWritable.class); + rhatCollector = + createOutputCollector(QJob.OUTPUT_RHAT, + spw, + context, + VectorWritable.class); + qr = + new QRFirstStep(context.getConfiguration(), + qhatCollector, + rhatCollector); + closeables.addFirst(qr); + lastTaskId = spw.getTaskId(); + + } + + @Override + protected void reduce(SplitPartitionedWritable key, + Iterable<DenseBlockWritable> values, + Context context) throws IOException, + InterruptedException { + + if (key.getTaskId() != lastTaskId) { + setupBlock(context, key); + } + + Iterator<DenseBlockWritable> iter = values.iterator(); + DenseBlockWritable dbw = iter.next(); + double[][] yiCols = dbw.getBlock(); + if (iter.hasNext()) { + throw new IOException("Unexpected extra Y_i block in reducer input."); + } + + long blockBase = key.getTaskItemOrdinal() * blockHeight; + int bh = yiCols[0].length; + if (yiRow == null) { + yiRow = new DenseVector(yiCols.length); + } + + for (int k = 0; k < bh; k++) { + for (int j = 0; j < yiCols.length; j++) { + yiRow.setQuick(j, yiCols[j][k]); + } + + key.setTaskItemOrdinal(blockBase + k); + + // pca offset correction if any + if (sb != null) { + yiRow.assign(sb, Functions.MINUS); + } + + qr.collect(key, yiRow); + } + + } + + private Path getSplitFilePath(String name, + SplitPartitionedWritable spw, + Context context) throws InterruptedException, + IOException { + String uniqueFileName = FileOutputFormat.getUniqueFile(context, name, ""); + uniqueFileName = uniqueFileName.replaceFirst("-r-", "-m-"); + uniqueFileName = + uniqueFileName.replaceFirst("\\d+$", + Matcher.quoteReplacement(NUMBER_FORMAT.format(spw.getTaskId()))); + return new Path(FileOutputFormat.getWorkOutputPath(context), + uniqueFileName); + } + + /** + * key doesn't matter here, only value does. key always gets substituted by + * SPW. + * + * @param <K> + * bogus + */ + private <K, V> OutputCollector<K, V> createOutputCollector(String name, + final SplitPartitionedWritable spw, + Context ctx, + Class<V> valueClass) throws IOException, InterruptedException { + Path outputPath = getSplitFilePath(name, spw, ctx); + final SequenceFile.Writer w = + SequenceFile.createWriter(FileSystem.get(outputPath.toUri(), ctx.getConfiguration()), + ctx.getConfiguration(), + outputPath, + SplitPartitionedWritable.class, + valueClass); + closeables.addFirst(w); + return new OutputCollector<K, V>() { + @Override + public void collect(K key, V val) throws IOException { + w.append(spw, val); + } + }; + } + + @Override + protected void cleanup(Context context) throws IOException, + InterruptedException { + + IOUtils.close(closeables); + } + + } + + public static void run(Configuration conf, + Path[] inputAPaths, + Path inputBtGlob, + Path xiPath, + Path sqPath, + Path sbPath, + Path outputPath, + int aBlockRows, + int minSplitSize, + int k, + int p, + int outerProdBlockHeight, + int numReduceTasks, + boolean broadcastBInput) + throws ClassNotFoundException, InterruptedException, IOException { + + JobConf oldApiJob = new JobConf(conf); + + Job job = new Job(oldApiJob); + job.setJobName("ABt-job"); + job.setJarByClass(ABtDenseOutJob.class); + + job.setInputFormatClass(SequenceFileInputFormat.class); + FileInputFormat.setInputPaths(job, inputAPaths); + if (minSplitSize > 0) { + FileInputFormat.setMinInputSplitSize(job, minSplitSize); + } + + FileOutputFormat.setOutputPath(job, outputPath); + + SequenceFileOutputFormat.setOutputCompressionType(job, + CompressionType.BLOCK); + + job.setMapOutputKeyClass(SplitPartitionedWritable.class); + job.setMapOutputValueClass(DenseBlockWritable.class); + + job.setOutputKeyClass(SplitPartitionedWritable.class); + job.setOutputValueClass(VectorWritable.class); + + job.setMapperClass(ABtMapper.class); + job.setReducerClass(QRReducer.class); + + job.getConfiguration().setInt(QJob.PROP_AROWBLOCK_SIZE, aBlockRows); + job.getConfiguration().setInt(BtJob.PROP_OUTER_PROD_BLOCK_HEIGHT, + outerProdBlockHeight); + job.getConfiguration().setInt(QRFirstStep.PROP_K, k); + job.getConfiguration().setInt(QRFirstStep.PROP_P, p); + job.getConfiguration().set(PROP_BT_PATH, inputBtGlob.toString()); + + /* + * PCA-related options, MAHOUT-817 + */ + if (xiPath != null) { + job.getConfiguration().set(PROP_XI_PATH, xiPath.toString()); + job.getConfiguration().set(PROP_SB_PATH, sbPath.toString()); + job.getConfiguration().set(PROP_SQ_PATH, sqPath.toString()); + } + + job.setNumReduceTasks(numReduceTasks); + + // broadcast Bt files if required. + if (broadcastBInput) { + job.getConfiguration().set(PROP_BT_BROADCAST, "y"); + + FileSystem fs = FileSystem.get(inputBtGlob.toUri(), conf); + FileStatus[] fstats = fs.globStatus(inputBtGlob); + if (fstats != null) { + for (FileStatus fstat : fstats) { + /* + * new api is not enabled yet in our dependencies at this time, still + * using deprecated one + */ + DistributedCache.addCacheFile(fstat.getPath().toUri(), + job.getConfiguration()); + } + } + } + + job.submit(); + job.waitForCompletion(false); + + if (!job.isSuccessful()) { + throw new IOException("ABt job unsuccessful."); + } + + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/ABtJob.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/ABtJob.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/ABtJob.java new file mode 100644 index 0000000..afa1463 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/ABtJob.java @@ -0,0 +1,494 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.hadoop.stochasticsvd; + +import java.io.Closeable; +import java.io.IOException; +import java.text.NumberFormat; +import java.util.ArrayDeque; +import java.util.Arrays; +import java.util.Deque; +import java.util.regex.Matcher; + +import com.google.common.collect.Lists; +import org.apache.commons.lang3.Validate; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.filecache.DistributedCache; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.SequenceFile.CompressionType; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.mapred.JobConf; +import org.apache.hadoop.mapred.OutputCollector; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.IOUtils; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterator; +import org.apache.mahout.math.SequentialAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.hadoop.stochasticsvd.qr.QRFirstStep; + +/** + * Computes ABt products, then first step of QR which is pushed down to the + * reducer. + * + */ +@SuppressWarnings("deprecation") +public final class ABtJob { + + public static final String PROP_BT_PATH = "ssvd.Bt.path"; + public static final String PROP_BT_BROADCAST = "ssvd.Bt.broadcast"; + + private ABtJob() { + } + + /** + * So, here, i preload A block into memory. + * <P> + * + * A sparse matrix seems to be ideal for that but there are two reasons why i + * am not using it: + * <UL> + * <LI>1) I don't know the full block height. so i may need to reallocate it + * from time to time. Although this probably not a showstopper. + * <LI>2) I found that RandomAccessSparseVectors seem to take much more memory + * than the SequentialAccessSparseVectors. + * </UL> + * <P> + * + */ + public static class ABtMapper + extends + Mapper<Writable, VectorWritable, SplitPartitionedWritable, SparseRowBlockWritable> { + + private SplitPartitionedWritable outKey; + private final Deque<Closeable> closeables = new ArrayDeque<>(); + private SequenceFileDirIterator<IntWritable, VectorWritable> btInput; + private Vector[] aCols; + // private Vector[] yiRows; + // private VectorWritable outValue = new VectorWritable(); + private int aRowCount; + private int kp; + private int blockHeight; + private SparseRowBlockAccumulator yiCollector; + + @Override + protected void map(Writable key, VectorWritable value, Context context) + throws IOException, InterruptedException { + + Vector vec = value.get(); + + int vecSize = vec.size(); + if (aCols == null) { + aCols = new Vector[vecSize]; + } else if (aCols.length < vecSize) { + aCols = Arrays.copyOf(aCols, vecSize); + } + + if (vec.isDense()) { + for (int i = 0; i < vecSize; i++) { + extendAColIfNeeded(i, aRowCount + 1); + aCols[i].setQuick(aRowCount, vec.getQuick(i)); + } + } else { + for (Vector.Element vecEl : vec.nonZeroes()) { + int i = vecEl.index(); + extendAColIfNeeded(i, aRowCount + 1); + aCols[i].setQuick(aRowCount, vecEl.get()); + } + } + aRowCount++; + } + + private void extendAColIfNeeded(int col, int rowCount) { + if (aCols[col] == null) { + aCols[col] = + new SequentialAccessSparseVector(rowCount < 10000 ? 10000 : rowCount, + 1); + } else if (aCols[col].size() < rowCount) { + Vector newVec = + new SequentialAccessSparseVector(rowCount << 1, + aCols[col].getNumNondefaultElements() << 1); + newVec.viewPart(0, aCols[col].size()).assign(aCols[col]); + aCols[col] = newVec; + } + } + + @Override + protected void cleanup(Context context) throws IOException, + InterruptedException { + try { + // yiRows= new Vector[aRowCount]; + + int lastRowIndex = -1; + + while (btInput.hasNext()) { + Pair<IntWritable, VectorWritable> btRec = btInput.next(); + int btIndex = btRec.getFirst().get(); + Vector btVec = btRec.getSecond().get(); + Vector aCol; + if (btIndex > aCols.length || (aCol = aCols[btIndex]) == null) { + continue; + } + int j = -1; + for (Vector.Element aEl : aCol.nonZeroes()) { + j = aEl.index(); + + // outKey.setTaskItemOrdinal(j); + // outValue.set(btVec.times(aEl.get())); // assign might work better + // // with memory after all. + // context.write(outKey, outValue); + yiCollector.collect((long) j, btVec.times(aEl.get())); + } + if (lastRowIndex < j) { + lastRowIndex = j; + } + } + aCols = null; + + // output empty rows if we never output partial products for them + // this happens in sparse matrices when last rows are all zeros + // and is subsequently causing shorter Q matrix row count which we + // probably don't want to repair there but rather here. + Vector yDummy = new SequentialAccessSparseVector(kp); + // outValue.set(yDummy); + for (lastRowIndex += 1; lastRowIndex < aRowCount; lastRowIndex++) { + // outKey.setTaskItemOrdinal(lastRowIndex); + // context.write(outKey, outValue); + + yiCollector.collect((long) lastRowIndex, yDummy); + } + + } finally { + IOUtils.close(closeables); + } + } + + @Override + protected void setup(final Context context) throws IOException, + InterruptedException { + + int k = + Integer.parseInt(context.getConfiguration().get(QRFirstStep.PROP_K)); + int p = + Integer.parseInt(context.getConfiguration().get(QRFirstStep.PROP_P)); + kp = k + p; + + outKey = new SplitPartitionedWritable(context); + String propBtPathStr = context.getConfiguration().get(PROP_BT_PATH); + Validate.notNull(propBtPathStr, "Bt input is not set"); + Path btPath = new Path(propBtPathStr); + + boolean distributedBt = + context.getConfiguration().get(PROP_BT_BROADCAST) != null; + + if (distributedBt) { + + Path[] btFiles = HadoopUtil.getCachedFiles(context.getConfiguration()); + + // DEBUG: stdout + //System.out.printf("list of files: " + btFiles); + + StringBuilder btLocalPath = new StringBuilder(); + for (Path btFile : btFiles) { + if (btLocalPath.length() > 0) { + btLocalPath.append(Path.SEPARATOR_CHAR); + } + btLocalPath.append(btFile); + } + + btInput = + new SequenceFileDirIterator<>(new Path(btLocalPath.toString()), + PathType.LIST, + null, + null, + true, + context.getConfiguration()); + + } else { + + btInput = + new SequenceFileDirIterator<>(btPath, PathType.GLOB, null, null, true, context.getConfiguration()); + } + // TODO: how do i release all that stuff?? + closeables.addFirst(btInput); + OutputCollector<LongWritable, SparseRowBlockWritable> yiBlockCollector = + new OutputCollector<LongWritable, SparseRowBlockWritable>() { + + @Override + public void collect(LongWritable blockKey, + SparseRowBlockWritable block) throws IOException { + outKey.setTaskItemOrdinal((int) blockKey.get()); + try { + context.write(outKey, block); + } catch (InterruptedException exc) { + throw new IOException("Interrupted", exc); + } + } + }; + blockHeight = + context.getConfiguration().getInt(BtJob.PROP_OUTER_PROD_BLOCK_HEIGHT, + -1); + yiCollector = + new SparseRowBlockAccumulator(blockHeight, yiBlockCollector); + closeables.addFirst(yiCollector); + } + + } + + /** + * QR first step pushed down to reducer. + * + */ + public static class QRReducer + extends + Reducer<SplitPartitionedWritable, SparseRowBlockWritable, SplitPartitionedWritable, VectorWritable> { + + // hack: partition number formats in hadoop, copied. this may stop working + // if it gets + // out of sync with newer hadoop version. But unfortunately rules of forming + // output file names are not sufficiently exposed so we need to hack it + // if we write the same split output from either mapper or reducer. + // alternatively, we probably can replace it by our own output file namnig + // management + // completely and bypass MultipleOutputs entirely. + + private static final NumberFormat NUMBER_FORMAT = + NumberFormat.getInstance(); + static { + NUMBER_FORMAT.setMinimumIntegerDigits(5); + NUMBER_FORMAT.setGroupingUsed(false); + } + + private final Deque<Closeable> closeables = Lists.newLinkedList(); + protected final SparseRowBlockWritable accum = new SparseRowBlockWritable(); + + protected int blockHeight; + + protected int lastTaskId = -1; + + protected OutputCollector<Writable, DenseBlockWritable> qhatCollector; + protected OutputCollector<Writable, VectorWritable> rhatCollector; + protected QRFirstStep qr; + + @Override + protected void setup(Context context) throws IOException, + InterruptedException { + blockHeight = + context.getConfiguration().getInt(BtJob.PROP_OUTER_PROD_BLOCK_HEIGHT, + -1); + + } + + protected void setupBlock(Context context, SplitPartitionedWritable spw) + throws InterruptedException, IOException { + IOUtils.close(closeables); + qhatCollector = + createOutputCollector(QJob.OUTPUT_QHAT, + spw, + context, + DenseBlockWritable.class); + rhatCollector = + createOutputCollector(QJob.OUTPUT_RHAT, + spw, + context, + VectorWritable.class); + qr = + new QRFirstStep(context.getConfiguration(), + qhatCollector, + rhatCollector); + closeables.addFirst(qr); + lastTaskId = spw.getTaskId(); + + } + + @Override + protected void reduce(SplitPartitionedWritable key, + Iterable<SparseRowBlockWritable> values, + Context context) throws IOException, + InterruptedException { + + accum.clear(); + for (SparseRowBlockWritable bw : values) { + accum.plusBlock(bw); + } + + if (key.getTaskId() != lastTaskId) { + setupBlock(context, key); + } + + long blockBase = key.getTaskItemOrdinal() * blockHeight; + for (int k = 0; k < accum.getNumRows(); k++) { + Vector yiRow = accum.getRows()[k]; + key.setTaskItemOrdinal(blockBase + accum.getRowIndices()[k]); + qr.collect(key, yiRow); + } + + } + + private Path getSplitFilePath(String name, + SplitPartitionedWritable spw, + Context context) throws InterruptedException, + IOException { + String uniqueFileName = FileOutputFormat.getUniqueFile(context, name, ""); + uniqueFileName = uniqueFileName.replaceFirst("-r-", "-m-"); + uniqueFileName = + uniqueFileName.replaceFirst("\\d+$", + Matcher.quoteReplacement(NUMBER_FORMAT.format(spw.getTaskId()))); + return new Path(FileOutputFormat.getWorkOutputPath(context), + uniqueFileName); + } + + /** + * key doesn't matter here, only value does. key always gets substituted by + * SPW. + */ + private <K,V> OutputCollector<K,V> createOutputCollector(String name, + final SplitPartitionedWritable spw, + Context ctx, + Class<V> valueClass) + throws IOException, InterruptedException { + Path outputPath = getSplitFilePath(name, spw, ctx); + final SequenceFile.Writer w = + SequenceFile.createWriter(FileSystem.get(outputPath.toUri(), ctx.getConfiguration()), + ctx.getConfiguration(), + outputPath, + SplitPartitionedWritable.class, + valueClass); + closeables.addFirst(w); + return new OutputCollector<K, V>() { + @Override + public void collect(K key, V val) throws IOException { + w.append(spw, val); + } + }; + } + + @Override + protected void cleanup(Context context) throws IOException, InterruptedException { + IOUtils.close(closeables); + } + + } + + public static void run(Configuration conf, + Path[] inputAPaths, + Path inputBtGlob, + Path outputPath, + int aBlockRows, + int minSplitSize, + int k, + int p, + int outerProdBlockHeight, + int numReduceTasks, + boolean broadcastBInput) + throws ClassNotFoundException, InterruptedException, IOException { + + JobConf oldApiJob = new JobConf(conf); + + // MultipleOutputs + // .addNamedOutput(oldApiJob, + // QJob.OUTPUT_QHAT, + // org.apache.hadoop.mapred.SequenceFileOutputFormat.class, + // SplitPartitionedWritable.class, + // DenseBlockWritable.class); + // + // MultipleOutputs + // .addNamedOutput(oldApiJob, + // QJob.OUTPUT_RHAT, + // org.apache.hadoop.mapred.SequenceFileOutputFormat.class, + // SplitPartitionedWritable.class, + // VectorWritable.class); + + Job job = new Job(oldApiJob); + job.setJobName("ABt-job"); + job.setJarByClass(ABtJob.class); + + job.setInputFormatClass(SequenceFileInputFormat.class); + FileInputFormat.setInputPaths(job, inputAPaths); + if (minSplitSize > 0) { + FileInputFormat.setMinInputSplitSize(job, minSplitSize); + } + + FileOutputFormat.setOutputPath(job, outputPath); + + SequenceFileOutputFormat.setOutputCompressionType(job, + CompressionType.BLOCK); + + job.setMapOutputKeyClass(SplitPartitionedWritable.class); + job.setMapOutputValueClass(SparseRowBlockWritable.class); + + job.setOutputKeyClass(SplitPartitionedWritable.class); + job.setOutputValueClass(VectorWritable.class); + + job.setMapperClass(ABtMapper.class); + job.setCombinerClass(BtJob.OuterProductCombiner.class); + job.setReducerClass(QRReducer.class); + + job.getConfiguration().setInt(QJob.PROP_AROWBLOCK_SIZE, aBlockRows); + job.getConfiguration().setInt(BtJob.PROP_OUTER_PROD_BLOCK_HEIGHT, + outerProdBlockHeight); + job.getConfiguration().setInt(QRFirstStep.PROP_K, k); + job.getConfiguration().setInt(QRFirstStep.PROP_P, p); + job.getConfiguration().set(PROP_BT_PATH, inputBtGlob.toString()); + + // number of reduce tasks doesn't matter. we don't actually + // send anything to reducers. + + job.setNumReduceTasks(numReduceTasks); + + // broadcast Bt files if required. + if (broadcastBInput) { + job.getConfiguration().set(PROP_BT_BROADCAST, "y"); + + FileSystem fs = FileSystem.get(inputBtGlob.toUri(), conf); + FileStatus[] fstats = fs.globStatus(inputBtGlob); + if (fstats != null) { + for (FileStatus fstat : fstats) { + /* + * new api is not enabled yet in our dependencies at this time, still + * using deprecated one + */ + DistributedCache.addCacheFile(fstat.getPath().toUri(), conf); + } + } + } + + job.submit(); + job.waitForCompletion(false); + + if (!job.isSuccessful()) { + throw new IOException("ABt job unsuccessful."); + } + + } + +}
