http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/term/TermDocumentCountReducer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/term/TermDocumentCountReducer.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/term/TermDocumentCountReducer.java new file mode 100644 index 0000000..c815692 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/term/TermDocumentCountReducer.java @@ -0,0 +1,41 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer.term; + +import java.io.IOException; + +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.mapreduce.Reducer; + +/** + * Can also be used as a local Combiner. This accumulates all the features and the weights and sums them up. + */ +public class TermDocumentCountReducer extends Reducer<IntWritable, LongWritable, IntWritable, LongWritable> { + + @Override + protected void reduce(IntWritable key, Iterable<LongWritable> values, Context context) + throws IOException, InterruptedException { + long sum = 0; + for (LongWritable value : values) { + sum += value.get(); + } + context.write(key, new LongWritable(sum)); + } + +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/tfidf/TFIDFConverter.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/tfidf/TFIDFConverter.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/tfidf/TFIDFConverter.java new file mode 100644 index 0000000..5f9d666 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/tfidf/TFIDFConverter.java @@ -0,0 +1,361 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer.tfidf; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; +import com.google.common.io.Closeables; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.filecache.DistributedCache; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.vectorizer.common.PartialVectorMerger; +import org.apache.mahout.vectorizer.term.TermDocumentCountMapper; +import org.apache.mahout.vectorizer.term.TermDocumentCountReducer; + +import java.io.IOException; +import java.util.List; + +/** + * This class converts a set of input vectors with term frequencies to TfIdf vectors. The Sequence file input + * should have a {@link org.apache.hadoop.io.WritableComparable} key containing and a + * {@link VectorWritable} value containing the + * term frequency vector. This is conversion class uses multiple map/reduces to convert the vectors to TfIdf + * format + * + */ +public final class TFIDFConverter { + + public static final String VECTOR_COUNT = "vector.count"; + public static final String FEATURE_COUNT = "feature.count"; + public static final String MIN_DF = "min.df"; + public static final String MAX_DF = "max.df"; + //public static final String TFIDF_OUTPUT_FOLDER = "tfidf"; + + private static final String DOCUMENT_VECTOR_OUTPUT_FOLDER = "tfidf-vectors"; + public static final String FREQUENCY_FILE = "frequency.file-"; + private static final int MAX_CHUNKSIZE = 10000; + private static final int MIN_CHUNKSIZE = 100; + private static final String OUTPUT_FILES_PATTERN = "part-*"; + private static final int SEQUENCEFILE_BYTE_OVERHEAD = 45; + private static final String VECTOR_OUTPUT_FOLDER = "partial-vectors-"; + public static final String WORDCOUNT_OUTPUT_FOLDER = "df-count"; + + /** + * Cannot be initialized. Use the static functions + */ + private TFIDFConverter() {} + + /** + * Create Term Frequency-Inverse Document Frequency (Tf-Idf) Vectors from the input set of vectors in + * {@link SequenceFile} format. This job uses a fixed limit on the maximum memory used by the feature chunk + * per node thereby splitting the process across multiple map/reduces. + * Before using this method calculateDF should be called + * + * @param input + * input directory of the vectors in {@link SequenceFile} format + * @param output + * output directory where {@link org.apache.mahout.math.RandomAccessSparseVector}'s of the document + * are generated + * @param datasetFeatures + * Document frequencies information calculated by calculateDF + * @param minDf + * The minimum document frequency. Default 1 + * @param maxDF + * The max percentage of vectors for the DF. Can be used to remove really high frequency features. + * Expressed as an integer between 0 and 100. Default 99 + * @param numReducers + * The number of reducers to spawn. This also affects the possible parallelism since each reducer + * will typically produce a single output file containing tf-idf vectors for a subset of the + * documents in the corpus. + */ + public static void processTfIdf(Path input, + Path output, + Configuration baseConf, + Pair<Long[], List<Path>> datasetFeatures, + int minDf, + long maxDF, + float normPower, + boolean logNormalize, + boolean sequentialAccessOutput, + boolean namedVector, + int numReducers) throws IOException, InterruptedException, ClassNotFoundException { + Preconditions.checkArgument(normPower == PartialVectorMerger.NO_NORMALIZING || normPower >= 0, + "If specified normPower must be nonnegative", normPower); + Preconditions.checkArgument(normPower == PartialVectorMerger.NO_NORMALIZING + || (normPower > 1 && !Double.isInfinite(normPower)) + || !logNormalize, + "normPower must be > 1 and not infinite if log normalization is chosen", normPower); + + int partialVectorIndex = 0; + List<Path> partialVectorPaths = Lists.newArrayList(); + List<Path> dictionaryChunks = datasetFeatures.getSecond(); + for (Path dictionaryChunk : dictionaryChunks) { + Path partialVectorOutputPath = new Path(output, VECTOR_OUTPUT_FOLDER + partialVectorIndex++); + partialVectorPaths.add(partialVectorOutputPath); + makePartialVectors(input, + baseConf, + datasetFeatures.getFirst()[0], + datasetFeatures.getFirst()[1], + minDf, + maxDF, + dictionaryChunk, + partialVectorOutputPath, + sequentialAccessOutput, + namedVector); + } + + Configuration conf = new Configuration(baseConf); + + Path outputDir = new Path(output, DOCUMENT_VECTOR_OUTPUT_FOLDER); + + PartialVectorMerger.mergePartialVectors(partialVectorPaths, + outputDir, + baseConf, + normPower, + logNormalize, + datasetFeatures.getFirst()[0].intValue(), + sequentialAccessOutput, + namedVector, + numReducers); + HadoopUtil.delete(conf, partialVectorPaths); + + } + + /** + * Calculates the document frequencies of all terms from the input set of vectors in + * {@link SequenceFile} format. This job uses a fixed limit on the maximum memory used by the feature chunk + * per node thereby splitting the process across multiple map/reduces. + * + * @param input + * input directory of the vectors in {@link SequenceFile} format + * @param output + * output directory where document frequencies will be stored + * @param chunkSizeInMegabytes + * the size in MB of the feature => id chunk to be kept in memory at each node during Map/Reduce + * stage. Its recommended you calculated this based on the number of cores and the free memory + * available to you per node. Say, you have 2 cores and around 1GB extra memory to spare we + * recommend you use a split size of around 400-500MB so that two simultaneous reducers can create + * partial vectors without thrashing the system due to increased swapping + */ + public static Pair<Long[],List<Path>> calculateDF(Path input, + Path output, + Configuration baseConf, + int chunkSizeInMegabytes) + throws IOException, InterruptedException, ClassNotFoundException { + + if (chunkSizeInMegabytes < MIN_CHUNKSIZE) { + chunkSizeInMegabytes = MIN_CHUNKSIZE; + } else if (chunkSizeInMegabytes > MAX_CHUNKSIZE) { // 10GB + chunkSizeInMegabytes = MAX_CHUNKSIZE; + } + + Path wordCountPath = new Path(output, WORDCOUNT_OUTPUT_FOLDER); + + startDFCounting(input, wordCountPath, baseConf); + + return createDictionaryChunks(wordCountPath, output, baseConf, chunkSizeInMegabytes); + } + + /** + * Read the document frequency List which is built at the end of the DF Count Job. This will use constant + * memory and will run at the speed of your disk read + */ + private static Pair<Long[], List<Path>> createDictionaryChunks(Path featureCountPath, + Path dictionaryPathBase, + Configuration baseConf, + int chunkSizeInMegabytes) throws IOException { + List<Path> chunkPaths = Lists.newArrayList(); + Configuration conf = new Configuration(baseConf); + + FileSystem fs = FileSystem.get(featureCountPath.toUri(), conf); + + long chunkSizeLimit = chunkSizeInMegabytes * 1024L * 1024L; + int chunkIndex = 0; + Path chunkPath = new Path(dictionaryPathBase, FREQUENCY_FILE + chunkIndex); + chunkPaths.add(chunkPath); + SequenceFile.Writer freqWriter = + new SequenceFile.Writer(fs, conf, chunkPath, IntWritable.class, LongWritable.class); + + try { + long currentChunkSize = 0; + long featureCount = 0; + long vectorCount = Long.MAX_VALUE; + Path filesPattern = new Path(featureCountPath, OUTPUT_FILES_PATTERN); + for (Pair<IntWritable,LongWritable> record + : new SequenceFileDirIterable<IntWritable,LongWritable>(filesPattern, + PathType.GLOB, + null, + null, + true, + conf)) { + + if (currentChunkSize > chunkSizeLimit) { + Closeables.close(freqWriter, false); + chunkIndex++; + + chunkPath = new Path(dictionaryPathBase, FREQUENCY_FILE + chunkIndex); + chunkPaths.add(chunkPath); + + freqWriter = new SequenceFile.Writer(fs, conf, chunkPath, IntWritable.class, LongWritable.class); + currentChunkSize = 0; + } + + int fieldSize = SEQUENCEFILE_BYTE_OVERHEAD + Integer.SIZE / 8 + Long.SIZE / 8; + currentChunkSize += fieldSize; + IntWritable key = record.getFirst(); + LongWritable value = record.getSecond(); + if (key.get() >= 0) { + freqWriter.append(key, value); + } else if (key.get() == -1) { + vectorCount = value.get(); + } + featureCount = Math.max(key.get(), featureCount); + + } + featureCount++; + Long[] counts = {featureCount, vectorCount}; + return new Pair<>(counts, chunkPaths); + } finally { + Closeables.close(freqWriter, false); + } + } + + /** + * Create a partial tfidf vector using a chunk of features from the input vectors. The input vectors has to + * be in the {@link SequenceFile} format + * + * @param input + * input directory of the vectors in {@link SequenceFile} format + * @param featureCount + * Number of unique features in the dataset + * @param vectorCount + * Number of vectors in the dataset + * @param minDf + * The minimum document frequency. Default 1 + * @param maxDF + * The max percentage of vectors for the DF. Can be used to remove really high frequency features. + * Expressed as an integer between 0 and 100. Default 99 + * @param dictionaryFilePath + * location of the chunk of features and the id's + * @param output + * output directory were the partial vectors have to be created + * @param sequentialAccess + * output vectors should be optimized for sequential access + * @param namedVector + * output vectors should be named, retaining key (doc id) as a label + */ + private static void makePartialVectors(Path input, + Configuration baseConf, + Long featureCount, + Long vectorCount, + int minDf, + long maxDF, + Path dictionaryFilePath, + Path output, + boolean sequentialAccess, + boolean namedVector) + throws IOException, InterruptedException, ClassNotFoundException { + + Configuration conf = new Configuration(baseConf); + // this conf parameter needs to be set enable serialisation of conf values + conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization," + + "org.apache.hadoop.io.serializer.WritableSerialization"); + conf.setLong(FEATURE_COUNT, featureCount); + conf.setLong(VECTOR_COUNT, vectorCount); + conf.setInt(MIN_DF, minDf); + conf.setLong(MAX_DF, maxDF); + conf.setBoolean(PartialVectorMerger.SEQUENTIAL_ACCESS, sequentialAccess); + conf.setBoolean(PartialVectorMerger.NAMED_VECTOR, namedVector); + DistributedCache.addCacheFile(dictionaryFilePath.toUri(), conf); + + Job job = new Job(conf); + job.setJobName(": MakePartialVectors: input-folder: " + input + ", dictionary-file: " + + dictionaryFilePath.toString()); + job.setJarByClass(TFIDFConverter.class); + job.setOutputKeyClass(Text.class); + job.setOutputValueClass(VectorWritable.class); + FileInputFormat.setInputPaths(job, input); + + FileOutputFormat.setOutputPath(job, output); + + job.setMapperClass(Mapper.class); + job.setInputFormatClass(SequenceFileInputFormat.class); + job.setReducerClass(TFIDFPartialVectorReducer.class); + job.setOutputFormatClass(SequenceFileOutputFormat.class); + + HadoopUtil.delete(conf, output); + + boolean succeeded = job.waitForCompletion(true); + if (!succeeded) { + throw new IllegalStateException("Job failed!"); + } + } + + /** + * Count the document frequencies of features in parallel using Map/Reduce. The input documents have to be + * in {@link SequenceFile} format + */ + private static void startDFCounting(Path input, Path output, Configuration baseConf) + throws IOException, InterruptedException, ClassNotFoundException { + + Configuration conf = new Configuration(baseConf); + // this conf parameter needs to be set enable serialisation of conf values + conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization," + + "org.apache.hadoop.io.serializer.WritableSerialization"); + + Job job = new Job(conf); + job.setJobName("VectorTfIdf Document Frequency Count running over input: " + input); + job.setJarByClass(TFIDFConverter.class); + + job.setOutputKeyClass(IntWritable.class); + job.setOutputValueClass(LongWritable.class); + + FileInputFormat.setInputPaths(job, input); + FileOutputFormat.setOutputPath(job, output); + + job.setMapperClass(TermDocumentCountMapper.class); + + job.setInputFormatClass(SequenceFileInputFormat.class); + job.setCombinerClass(TermDocumentCountReducer.class); + job.setReducerClass(TermDocumentCountReducer.class); + job.setOutputFormatClass(SequenceFileOutputFormat.class); + + HadoopUtil.delete(conf, output); + + boolean succeeded = job.waitForCompletion(true); + if (!succeeded) { + throw new IllegalStateException("Job failed!"); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/tfidf/TFIDFPartialVectorReducer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/tfidf/TFIDFPartialVectorReducer.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/tfidf/TFIDFPartialVectorReducer.java new file mode 100644 index 0000000..1e71ed8 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/tfidf/TFIDFPartialVectorReducer.java @@ -0,0 +1,114 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer.tfidf; + +import java.io.IOException; +import java.net.URI; +import java.util.Iterator; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.filecache.DistributedCache; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.WritableComparable; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; +import org.apache.mahout.math.NamedVector; +import org.apache.mahout.math.RandomAccessSparseVector; +import org.apache.mahout.math.SequentialAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.map.OpenIntLongHashMap; +import org.apache.mahout.vectorizer.TFIDF; +import org.apache.mahout.vectorizer.common.PartialVectorMerger; + +/** + * Converts a document into a sparse vector + */ +public class TFIDFPartialVectorReducer extends + Reducer<WritableComparable<?>, VectorWritable, WritableComparable<?>, VectorWritable> { + + private final OpenIntLongHashMap dictionary = new OpenIntLongHashMap(); + private final TFIDF tfidf = new TFIDF(); + + private int minDf = 1; + private long maxDf = -1; + private long vectorCount = 1; + private long featureCount; + private boolean sequentialAccess; + private boolean namedVector; + + @Override + protected void reduce(WritableComparable<?> key, Iterable<VectorWritable> values, Context context) + throws IOException, InterruptedException { + Iterator<VectorWritable> it = values.iterator(); + if (!it.hasNext()) { + return; + } + Vector value = it.next().get(); + Vector vector = new RandomAccessSparseVector((int) featureCount, value.getNumNondefaultElements()); + for (Vector.Element e : value.nonZeroes()) { + if (!dictionary.containsKey(e.index())) { + continue; + } + long df = dictionary.get(e.index()); + if (maxDf > -1 && (100.0 * df) / vectorCount > maxDf) { + continue; + } + if (df < minDf) { + df = minDf; + } + vector.setQuick(e.index(), tfidf.calculate((int) e.get(), (int) df, (int) featureCount, (int) vectorCount)); + } + if (sequentialAccess) { + vector = new SequentialAccessSparseVector(vector); + } + + if (namedVector) { + vector = new NamedVector(vector, key.toString()); + } + + VectorWritable vectorWritable = new VectorWritable(vector); + context.write(key, vectorWritable); + } + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + super.setup(context); + Configuration conf = context.getConfiguration(); + + vectorCount = conf.getLong(TFIDFConverter.VECTOR_COUNT, 1); + featureCount = conf.getLong(TFIDFConverter.FEATURE_COUNT, 1); + minDf = conf.getInt(TFIDFConverter.MIN_DF, 1); + maxDf = conf.getLong(TFIDFConverter.MAX_DF, -1); + sequentialAccess = conf.getBoolean(PartialVectorMerger.SEQUENTIAL_ACCESS, false); + namedVector = conf.getBoolean(PartialVectorMerger.NAMED_VECTOR, false); + + URI[] localFiles = DistributedCache.getCacheFiles(conf); + Path dictionaryFile = HadoopUtil.findInCacheByPartOfFilename(TFIDFConverter.FREQUENCY_FILE, localFiles); + // key is feature, value is the document frequency + for (Pair<IntWritable,LongWritable> record + : new SequenceFileIterable<IntWritable,LongWritable>(dictionaryFile, true, conf)) { + dictionary.put(record.getFirst().get(), record.getSecond().get()); + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/resources/supplemental-models.xml ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/resources/supplemental-models.xml b/community/mahout-mr/src/main/resources/supplemental-models.xml new file mode 100644 index 0000000..971c72b --- /dev/null +++ b/community/mahout-mr/src/main/resources/supplemental-models.xml @@ -0,0 +1,279 @@ +<supplementalDataModels> + <!-- missing: Maven Profile Model --> + <supplement> + <project> + <groupId>org.apache.maven</groupId> + <artifactId>maven-profile</artifactId> + <name>Maven Profile Model</name> + <licenses> + <license> + <name>The Apache Software License, Version 2.0</name> + <url>http://maven.apache.org/ref/2.1.0/maven-profile/license.html</url> + </license> + </licenses> + </project> + </supplement> + <!-- missing: Maven Project Builder --> + <supplement> + <project> + <groupId>org.apache.maven</groupId> + <artifactId>maven-project</artifactId> + <name>Maven Project Builder</name> + <licenses> + <license> + <name>The Apache Software License, Version 2.0</name> + <url>http://maven.apache.org/ref/2.1.0/maven-project/license.html</url> + </license> + </licenses> + </project> + </supplement> + <!-- missing: Maven Local Settings --> + <supplement> + <project> + <groupId>org.apache.maven</groupId> + <artifactId>maven-settings</artifactId> + <name>Maven Local Settings</name> + <licenses> + <license> + <name>The Apache Software License, Version 2.0</name> + <url>http://maven.apache.org/ref/2.1.0/maven-settings/license.html</url> + </license> + </licenses> + </project> + </supplement> + <!-- Maven Repository Metadata Model --> + <supplement> + <project> + <groupId>org.apache.maven</groupId> + <artifactId>maven-repository-metadata</artifactId> + <name>Maven Repository Metadata Model</name> + <licenses> + <license> + <name>The Apache Software License, Version 2.0</name> + <url>http://maven.apache.org/ref/2.1.0/maven-repository-metadata/license.html</url> + </license> + </licenses> + </project> + </supplement> + <!-- Maven Model --> + <supplement> + <project> + <groupId>org.apache.maven</groupId> + <artifactId>maven-model</artifactId> + <name>Maven Model</name> + <licenses> + <license> + <name>The Apache Software License, Version 2.0</name> + <url>http://maven.apache.org/ref/2.0.8/maven-model/license.html</url> + </license> + </licenses> + </project> + </supplement> + <!-- Maven Artifact --> + <supplement> + <project> + <groupId>org.apache.maven</groupId> + <artifactId>maven-artifact</artifactId> + <name>Maven Artifact</name> + <licenses> + <license> + <name>The Apache Software License, Version 2.0</name> + <url>http://www.apache.org/licenses/LICENSE-2.0.txt</url> + </license> + </licenses> + </project> + </supplement> + <!-- Maven Artifact Manager--> + <supplement> + <project> + <groupId>org.apache.maven</groupId> + <artifactId>maven-artifact-manager</artifactId> + <name>Maven Artifact Manager</name> + <licenses> + <license> + <name>The Apache Software License, Version 2.0</name> + <url>http://www.apache.org/licenses/LICENSE-2.0.txt</url> + </license> + </licenses> + </project> + </supplement> + <!-- Maven Artifact Manager--> + <supplement> + <project> + <groupId>org.apache.maven</groupId> + <artifactId>maven-plugin-api</artifactId> + <name>Maven Plugin API</name> + <licenses> + <license> + <name>The Apache Software License, Version 2.0</name> + <url>http://www.apache.org/licenses/LICENSE-2.0.txt</url> + </license> + </licenses> + </project> + </supplement> + <!-- Maven Wagon API--> + <supplement> + <project> + <groupId>org.apache.maven</groupId> + <artifactId>wagon-provider-api</artifactId> + <name>Maven Wagon API</name> + <licenses> + <license> + <name>The Apache Software License, Version 2.0</name> + <url>http://www.apache.org/licenses/LICENSE-2.0.txt</url> + </license> + </licenses> + </project> + </supplement> + <!-- Shade Maven Plugin --> + <supplement> + <project> + <groupId>org.codehouse.mojo</groupId> + <artifactId>shade-maven-plugin</artifactId> + <name>Shade Maven Plugin</name> + <licenses> + <license> + <name>UNKNOWN</name> + <url>UNKNOWN</url> + </license> + </licenses> + </project> + </supplement> + <!-- junit --> + <supplement> + <project> + <groupId>junit</groupId> + <artifactId>junit</artifactId> + <name>Junit Unit testing library</name> + <licenses> + <license> + <name>Common Public License - v 1.0</name> + <url>http://junit.sourceforge.net/cpl-v10.html</url> + </license> + </licenses> + </project> + </supplement> + <!-- jdom --> + <supplement> + <project> + <groupId>jdom</groupId> + <artifactId>jdom</artifactId> + <name>JDom</name> + <licenses> + <license> + <name>UNKOWN</name> + <url>UNKOWN</url> + </license> + </licenses> + </project> + </supplement> + <!-- asm --> + <supplement> + <project> + <groupId>asm</groupId> + <artifactId>asm-all</artifactId> + <name>ASM ALL</name> + <licenses> + <license> + <name>UNKOWN</name> + <url>http://asm.ow2.org/license.html</url> + </license> + </licenses> + </project> + </supplement> + <!-- Default Plexus Container --> + <supplement> + <project> + <groupId>org.codehaus.plexus</groupId> + <artifactId>plexus-container-default</artifactId> + <name>Default Plexus Container</name> + <licenses> + <license> + <name>UNKNOWN</name> + <url>UNKNOWN</url> + </license> + </licenses> + </project> + </supplement> + <!-- Classworlds --> + <supplement> + <project> + <groupId>org.codehouse.classworlds</groupId> + <artifactId>classworlds</artifactId> + <name>Classworlds</name> + <licenses> + <license> + <name></name> + <url>http://classworlds.codehaus.org/license.html</url> + </license> + </licenses> + </project> + </supplement> + <!-- Plexus Common Utilities --> + <supplement> + <project> + <groupId>org.codehouse.plexus</groupId> + <artifactId>plexus-utils</artifactId> + <name>Plexus Common Utilities</name> + <licenses> + <license> + <name>The Apache Software License, Version 2.0</name> + <url>http://plexus.codehaus.org/plexus-utils/license.html</url> + </license> + </licenses> + </project> + </supplement> + <!-- Commons Codec --> + <supplement> + <project> + <groupId>commons-codec</groupId> + <artifactId>commons-codec</artifactId> + <name>Commons Codec</name> + <url>http://commons.apache.org/codec/</url> + <organization> + <name>Apache Software Foundation</name> + <url>http://www.apache.org/</url> + </organization> + <licenses> + <license> + <name>The Apache Software License, Version 2.0</name> + <url>http://www.apache.org/licenses/LICENSE-2.0</url> + </license> + </licenses> + </project> + </supplement> + <!-- Commons CLI --> + <supplement> + <project> + <groupId>org.apache.mahout.commons</groupId> + <artifactId>commons-cli</artifactId> + <name>Commons CLI</name> + <url>http://commons.apache.org/cli/</url> + <organization> + <name>Apache Software Foundation</name> + <url>http://www.apache.org/</url> + </organization> + <licenses> + <license> + <name>The Apache Software License, Version 2.0</name> + <url>http://www.apache.org/licenses/LICENSE-2.0</url> + </license> + </licenses> + </project> + </supplement> + <!-- Xpp3 --> + <supplement> + <project> + <name>Xml Pull Parser 3rd Edition</name> + <groupId>xpp3</groupId> + <artifactId>xpp3_min</artifactId> + <url>http://www.extreme.indiana.edu/xgws/xsoap/xpp/mxp1/</url> + <licenses> + <license> + <name>Public Domain</name> + <url>http://www.xmlpull.org/</url> + </license> + </licenses> + </project> + </supplement> +</supplementalDataModels> http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/resources/version ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/resources/version b/community/mahout-mr/src/main/resources/version new file mode 100644 index 0000000..f2ab45c --- /dev/null +++ b/community/mahout-mr/src/main/resources/version @@ -0,0 +1 @@ +${project.version} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/common/CommonTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/common/CommonTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/common/CommonTest.java new file mode 100644 index 0000000..c37bcd3 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/common/CommonTest.java @@ -0,0 +1,60 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.common; + +import org.apache.mahout.cf.taste.impl.TasteTestCase; +import org.junit.Test; + +import java.io.ByteArrayOutputStream; +import java.io.OutputStreamWriter; +import java.io.PrintStream; +import java.io.PrintWriter; + +/** <p>Tests common classes.</p> */ +public final class CommonTest extends TasteTestCase { + + @Test + public void testTasteException() { + // Just make sure this all doesn't, ah, throw an exception + TasteException te1 = new TasteException(); + TasteException te2 = new TasteException(te1); + TasteException te3 = new TasteException(te2.toString(), te2); + TasteException te4 = new TasteException(te3.toString()); + te4.printStackTrace(new PrintStream(new ByteArrayOutputStream())); + te4.printStackTrace(new PrintWriter(new OutputStreamWriter(new ByteArrayOutputStream()))); + } + + @Test + public void testNSUException() { + // Just make sure this all doesn't, ah, throw an exception + TasteException te1 = new NoSuchUserException(); + TasteException te4 = new NoSuchUserException(te1.toString()); + te4.printStackTrace(new PrintStream(new ByteArrayOutputStream())); + te4.printStackTrace(new PrintWriter(new OutputStreamWriter(new ByteArrayOutputStream()))); + } + + @Test + public void testNSIException() { + // Just make sure this all doesn't, ah, throw an exception + TasteException te1 = new NoSuchItemException(); + TasteException te4 = new NoSuchItemException(te1.toString()); + te4.printStackTrace(new PrintStream(new ByteArrayOutputStream())); + te4.printStackTrace(new PrintWriter(new OutputStreamWriter(new ByteArrayOutputStream()))); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/hadoop/TasteHadoopUtilsTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/hadoop/TasteHadoopUtilsTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/hadoop/TasteHadoopUtilsTest.java new file mode 100644 index 0000000..b299b35 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/hadoop/TasteHadoopUtilsTest.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.hadoop; + +import org.apache.mahout.cf.taste.impl.TasteTestCase; +import org.junit.Test; + +/** <p>Tests {@link TasteHadoopUtils}.</p> */ +public class TasteHadoopUtilsTest extends TasteTestCase { + + @Test + public void testWithinRange() { + assertTrue(TasteHadoopUtils.idToIndex(0) >= 0); + assertTrue(TasteHadoopUtils.idToIndex(0) < Integer.MAX_VALUE); + + assertTrue(TasteHadoopUtils.idToIndex(1) >= 0); + assertTrue(TasteHadoopUtils.idToIndex(1) < Integer.MAX_VALUE); + + assertTrue(TasteHadoopUtils.idToIndex(Long.MAX_VALUE) >= 0); + assertTrue(TasteHadoopUtils.idToIndex(Long.MAX_VALUE) < Integer.MAX_VALUE); + + assertTrue(TasteHadoopUtils.idToIndex(Integer.MAX_VALUE) >= 0); + assertTrue(TasteHadoopUtils.idToIndex(Integer.MAX_VALUE) < Integer.MAX_VALUE); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/hadoop/TopItemsQueueTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/hadoop/TopItemsQueueTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/hadoop/TopItemsQueueTest.java new file mode 100644 index 0000000..9465def --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/hadoop/TopItemsQueueTest.java @@ -0,0 +1,72 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.hadoop; + +import org.apache.mahout.cf.taste.impl.TasteTestCase; +import org.apache.mahout.cf.taste.recommender.RecommendedItem; +import org.apache.mahout.common.MahoutTestCase; +import org.junit.Test; + +import java.util.List; + +public class TopItemsQueueTest extends TasteTestCase { + + @Test + public void topK() { + + float[] ratings = {0.5f, 0.6f, 0.7f, 2.0f, 0.0f}; + + List<RecommendedItem> topItems = findTop(ratings, 2); + + assertEquals(2, topItems.size()); + assertEquals(3L, topItems.get(0).getItemID()); + assertEquals(2.0f, topItems.get(0).getValue(), MahoutTestCase.EPSILON); + assertEquals(2L, topItems.get(1).getItemID()); + assertEquals(0.7f, topItems.get(1).getValue(), MahoutTestCase.EPSILON); + } + + @Test + public void topKInputSmallerThanK() { + + float[] ratings = {0.7f, 2.0f}; + + List<RecommendedItem> topItems = findTop(ratings, 3); + + assertEquals(2, topItems.size()); + assertEquals(1L, topItems.get(0).getItemID()); + assertEquals(2.0f, topItems.get(0).getValue(), MahoutTestCase.EPSILON); + assertEquals(0L, topItems.get(1).getItemID()); + assertEquals(0.7f, topItems.get(1).getValue(), MahoutTestCase.EPSILON); + } + + + private static List<RecommendedItem> findTop(float[] ratings, int k) { + TopItemsQueue queue = new TopItemsQueue(k); + + for (int item = 0; item < ratings.length; item++) { + MutableRecommendedItem top = queue.top(); + if (ratings[item] > top.getValue()) { + top.set(item, ratings[item]); + queue.updateTop(); + } + } + + return queue.getTopItems(); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJobTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJobTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJobTest.java new file mode 100644 index 0000000..9d37da2 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJobTest.java @@ -0,0 +1,379 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.hadoop.als; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils; +import org.apache.mahout.cf.taste.impl.TasteTestCase; +import org.apache.mahout.cf.taste.impl.common.FullRunningAverage; +import org.apache.mahout.cf.taste.impl.common.RunningAverage; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.MatrixSlice; +import org.apache.mahout.math.SparseRowMatrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.Vector.Element; +import org.apache.mahout.math.hadoop.MathHelper; +import org.apache.mahout.math.map.OpenIntLongHashMap; +import org.apache.mahout.math.map.OpenIntObjectHashMap; +import org.junit.Before; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; + +public class ParallelALSFactorizationJobTest extends TasteTestCase { + + private static final Logger log = LoggerFactory.getLogger(ParallelALSFactorizationJobTest.class); + + private File inputFile; + private File intermediateDir; + private File outputDir; + private File tmpDir; + private Configuration conf; + + @Before + @Override + public void setUp() throws Exception { + super.setUp(); + inputFile = getTestTempFile("prefs.txt"); + intermediateDir = getTestTempDir("intermediate"); + intermediateDir.delete(); + outputDir = getTestTempDir("output"); + outputDir.delete(); + tmpDir = getTestTempDir("tmp"); + + conf = getConfiguration(); + // reset as we run all tests in the same JVM + SharingMapper.reset(); + } + + @Test + public void completeJobToyExample() throws Exception { + explicitExample(1); + } + + @Test + public void completeJobToyExampleMultithreaded() throws Exception { + explicitExample(2); + } + + /** + * small integration test that runs the full job + * + * <pre> + * + * user-item-matrix + * + * burger hotdog berries icecream + * dog 5 5 2 - + * rabbit 2 - 3 5 + * cow - 5 - 3 + * donkey 3 - - 5 + * + * </pre> + */ + private void explicitExample(int numThreads) throws Exception { + + Double na = Double.NaN; + Matrix preferences = new SparseRowMatrix(4, 4, new Vector[] { + new DenseVector(new double[] { 5.0, 5.0, 2.0, na }), + new DenseVector(new double[] { 2.0, na, 3.0, 5.0 }), + new DenseVector(new double[] { na, 5.0, na, 3.0 }), + new DenseVector(new double[] { 3.0, na, na, 5.0 }) }); + + writeLines(inputFile, preferencesAsText(preferences)); + + ParallelALSFactorizationJob alsFactorization = new ParallelALSFactorizationJob(); + alsFactorization.setConf(conf); + + int numFeatures = 3; + int numIterations = 5; + double lambda = 0.065; + + alsFactorization.run(new String[] { "--input", inputFile.getAbsolutePath(), "--output", outputDir.getAbsolutePath(), + "--tempDir", tmpDir.getAbsolutePath(), "--lambda", String.valueOf(lambda), + "--numFeatures", String.valueOf(numFeatures), "--numIterations", String.valueOf(numIterations), + "--numThreadsPerSolver", String.valueOf(numThreads) }); + + Matrix u = MathHelper.readMatrix(conf, new Path(outputDir.getAbsolutePath(), "U/part-m-00000"), + preferences.numRows(), numFeatures); + Matrix m = MathHelper.readMatrix(conf, new Path(outputDir.getAbsolutePath(), "M/part-m-00000"), + preferences.numCols(), numFeatures); + + StringBuilder info = new StringBuilder(); + info.append("\nA - users x items\n\n"); + info.append(MathHelper.nice(preferences)); + info.append("\nU - users x features\n\n"); + info.append(MathHelper.nice(u)); + info.append("\nM - items x features\n\n"); + info.append(MathHelper.nice(m)); + Matrix Ak = u.times(m.transpose()); + info.append("\nAk - users x items\n\n"); + info.append(MathHelper.nice(Ak)); + info.append('\n'); + + log.info(info.toString()); + + RunningAverage avg = new FullRunningAverage(); + for (MatrixSlice slice : preferences) { + for (Element e : slice.nonZeroes()) { + if (!Double.isNaN(e.get())) { + double pref = e.get(); + double estimate = u.viewRow(slice.index()).dot(m.viewRow(e.index())); + double err = pref - estimate; + avg.addDatum(err * err); + log.info("Comparing preference of user [{}] towards item [{}], was [{}] estimate is [{}]", + slice.index(), e.index(), pref, estimate); + } + } + } + double rmse = Math.sqrt(avg.getAverage()); + log.info("RMSE: {}", rmse); + + assertTrue(rmse < 0.2); + } + + @Test + public void completeJobImplicitToyExample() throws Exception { + implicitExample(1); + } + + @Test + public void completeJobImplicitToyExampleMultithreaded() throws Exception { + implicitExample(2); + } + + public void implicitExample(int numThreads) throws Exception { + Matrix observations = new SparseRowMatrix(4, 4, new Vector[] { + new DenseVector(new double[] { 5.0, 5.0, 2.0, 0 }), + new DenseVector(new double[] { 2.0, 0, 3.0, 5.0 }), + new DenseVector(new double[] { 0, 5.0, 0, 3.0 }), + new DenseVector(new double[] { 3.0, 0, 0, 5.0 }) }); + + Matrix preferences = new SparseRowMatrix(4, 4, new Vector[] { + new DenseVector(new double[] { 1.0, 1.0, 1.0, 0 }), + new DenseVector(new double[] { 1.0, 0, 1.0, 1.0 }), + new DenseVector(new double[] { 0, 1.0, 0, 1.0 }), + new DenseVector(new double[] { 1.0, 0, 0, 1.0 }) }); + + writeLines(inputFile, preferencesAsText(observations)); + + ParallelALSFactorizationJob alsFactorization = new ParallelALSFactorizationJob(); + alsFactorization.setConf(conf); + + int numFeatures = 3; + int numIterations = 5; + double lambda = 0.065; + double alpha = 20; + + alsFactorization.run(new String[] { "--input", inputFile.getAbsolutePath(), "--output", outputDir.getAbsolutePath(), + "--tempDir", tmpDir.getAbsolutePath(), "--lambda", String.valueOf(lambda), + "--implicitFeedback", String.valueOf(true), "--alpha", String.valueOf(alpha), + "--numFeatures", String.valueOf(numFeatures), "--numIterations", String.valueOf(numIterations), + "--numThreadsPerSolver", String.valueOf(numThreads) }); + + Matrix u = MathHelper.readMatrix(conf, new Path(outputDir.getAbsolutePath(), "U/part-m-00000"), + observations.numRows(), numFeatures); + Matrix m = MathHelper.readMatrix(conf, new Path(outputDir.getAbsolutePath(), "M/part-m-00000"), + observations.numCols(), numFeatures); + + StringBuilder info = new StringBuilder(); + info.append("\nObservations - users x items\n"); + info.append(MathHelper.nice(observations)); + info.append("\nA - users x items\n\n"); + info.append(MathHelper.nice(preferences)); + info.append("\nU - users x features\n\n"); + info.append(MathHelper.nice(u)); + info.append("\nM - items x features\n\n"); + info.append(MathHelper.nice(m)); + Matrix Ak = u.times(m.transpose()); + info.append("\nAk - users x items\n\n"); + info.append(MathHelper.nice(Ak)); + info.append('\n'); + + log.info(info.toString()); + + RunningAverage avg = new FullRunningAverage(); + for (MatrixSlice slice : preferences) { + for (Element e : slice.nonZeroes()) { + if (!Double.isNaN(e.get())) { + double pref = e.get(); + double estimate = u.viewRow(slice.index()).dot(m.viewRow(e.index())); + double confidence = 1 + alpha * observations.getQuick(slice.index(), e.index()); + double err = confidence * (pref - estimate) * (pref - estimate); + avg.addDatum(err); + log.info("Comparing preference of user [{}] towards item [{}], was [{}] with confidence [{}] " + + "estimate is [{}]", slice.index(), e.index(), pref, confidence, estimate); + } + } + } + double rmse = Math.sqrt(avg.getAverage()); + log.info("RMSE: {}", rmse); + + assertTrue(rmse < 0.4); + } + + @Test + public void exampleWithIDMapping() throws Exception { + + String[] preferencesWithLongIDs = { + "5568227754922264005,-4758971626494767444,5.0", + "5568227754922264005,3688396615879561990,5.0", + "5568227754922264005,4594226737871995304,2.0", + "550945997885173934,-4758971626494767444,2.0", + "550945997885173934,4594226737871995304,3.0", + "550945997885173934,706816485922781596,5.0", + "2448095297482319463,3688396615879561990,5.0", + "2448095297482319463,706816485922781596,3.0", + "6839920411763636962,-4758971626494767444,3.0", + "6839920411763636962,706816485922781596,5.0" }; + + writeLines(inputFile, preferencesWithLongIDs); + + ParallelALSFactorizationJob alsFactorization = new ParallelALSFactorizationJob(); + alsFactorization.setConf(conf); + + int numFeatures = 3; + int numIterations = 5; + double lambda = 0.065; + + alsFactorization.run(new String[] { "--input", inputFile.getAbsolutePath(), "--output", outputDir.getAbsolutePath(), + "--tempDir", tmpDir.getAbsolutePath(), "--lambda", String.valueOf(lambda), + "--numFeatures", String.valueOf(numFeatures), "--numIterations", String.valueOf(numIterations), + "--numThreadsPerSolver", String.valueOf(1), "--usesLongIDs", String.valueOf(true) }); + + + OpenIntLongHashMap userIDIndex = + TasteHadoopUtils.readIDIndexMap(outputDir.getAbsolutePath() + "/userIDIndex/part-r-00000", conf); + assertEquals(4, userIDIndex.size()); + + OpenIntLongHashMap itemIDIndex = + TasteHadoopUtils.readIDIndexMap(outputDir.getAbsolutePath() + "/itemIDIndex/part-r-00000", conf); + assertEquals(4, itemIDIndex.size()); + + OpenIntObjectHashMap<Vector> u = + MathHelper.readMatrixRows(conf, new Path(outputDir.getAbsolutePath(), "U/part-m-00000")); + OpenIntObjectHashMap<Vector> m = + MathHelper.readMatrixRows(conf, new Path(outputDir.getAbsolutePath(), "M/part-m-00000")); + + assertEquals(4, u.size()); + assertEquals(4, m.size()); + + RunningAverage avg = new FullRunningAverage(); + for (String line : preferencesWithLongIDs) { + String[] tokens = TasteHadoopUtils.splitPrefTokens(line); + long userID = Long.parseLong(tokens[TasteHadoopUtils.USER_ID_POS]); + long itemID = Long.parseLong(tokens[TasteHadoopUtils.ITEM_ID_POS]); + double rating = Double.parseDouble(tokens[2]); + + Vector userFeatures = u.get(TasteHadoopUtils.idToIndex(userID)); + Vector itemFeatures = m.get(TasteHadoopUtils.idToIndex(itemID)); + + double estimate = userFeatures.dot(itemFeatures); + + double err = rating - estimate; + avg.addDatum(err * err); + } + + double rmse = Math.sqrt(avg.getAverage()); + log.info("RMSE: {}", rmse); + + assertTrue(rmse < 0.2); + } + + protected static String preferencesAsText(Matrix preferences) { + StringBuilder prefsAsText = new StringBuilder(); + String separator = ""; + for (MatrixSlice slice : preferences) { + for (Element e : slice.nonZeroes()) { + if (!Double.isNaN(e.get())) { + prefsAsText.append(separator) + .append(slice.index()).append(',').append(e.index()).append(',').append(e.get()); + separator = "\n"; + } + } + } + System.out.println(prefsAsText.toString()); + return prefsAsText.toString(); + } + + @Test + public void recommenderJobWithIDMapping() throws Exception { + + String[] preferencesWithLongIDs = { + "5568227754922264005,-4758971626494767444,5.0", + "5568227754922264005,3688396615879561990,5.0", + "5568227754922264005,4594226737871995304,2.0", + "550945997885173934,-4758971626494767444,2.0", + "550945997885173934,4594226737871995304,3.0", + "550945997885173934,706816485922781596,5.0", + "2448095297482319463,3688396615879561990,5.0", + "2448095297482319463,706816485922781596,3.0", + "6839920411763636962,-4758971626494767444,3.0", + "6839920411763636962,706816485922781596,5.0" }; + + writeLines(inputFile, preferencesWithLongIDs); + + ParallelALSFactorizationJob alsFactorization = new ParallelALSFactorizationJob(); + alsFactorization.setConf(conf); + + int numFeatures = 3; + int numIterations = 5; + double lambda = 0.065; + + Configuration conf = getConfiguration(); + + int success = ToolRunner.run(alsFactorization, new String[] { + "-Dhadoop.tmp.dir=" + conf.get("hadoop.tmp.dir"), + "--input", inputFile.getAbsolutePath(), + "--output", intermediateDir.getAbsolutePath(), + "--tempDir", tmpDir.getAbsolutePath(), + "--lambda", String.valueOf(lambda), + "--numFeatures", String.valueOf(numFeatures), + "--numIterations", String.valueOf(numIterations), + "--numThreadsPerSolver", String.valueOf(1), + "--usesLongIDs", String.valueOf(true) }); + + assertEquals(0, success); + + // reset as we run in the same JVM + SharingMapper.reset(); + + RecommenderJob recommender = new RecommenderJob(); + + success = ToolRunner.run(recommender, new String[] { + "-Dhadoop.tmp.dir=" + conf.get("hadoop.tmp.dir"), + "--input", intermediateDir.getAbsolutePath() + "/userRatings/", + "--userFeatures", intermediateDir.getAbsolutePath() + "/U/", + "--itemFeatures", intermediateDir.getAbsolutePath() + "/M/", + "--numRecommendations", String.valueOf(2), + "--maxRating", String.valueOf(5.0), + "--numThreads", String.valueOf(2), + "--usesLongIDs", String.valueOf(true), + "--userIDIndex", intermediateDir.getAbsolutePath() + "/userIDIndex/", + "--itemIDIndex", intermediateDir.getAbsolutePath() + "/itemIDIndex/", + "--output", outputDir.getAbsolutePath() }); + + assertEquals(0, success); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/hadoop/item/IDReaderTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/hadoop/item/IDReaderTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/hadoop/item/IDReaderTest.java new file mode 100644 index 0000000..a1cc648 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/cf/taste/hadoop/item/IDReaderTest.java @@ -0,0 +1,66 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.hadoop.item; + +import org.apache.hadoop.conf.Configuration; +import org.apache.mahout.cf.taste.impl.TasteTestCase; +import org.apache.mahout.cf.taste.impl.common.FastIDSet; +import org.junit.Test; + +import java.util.HashMap; +import java.util.Map; + +public class IDReaderTest extends TasteTestCase { + + static final String USER_ITEM_FILTER_FIELD = "userItemFilter"; + + @Test + public void testUserItemFilter() throws Exception { + Configuration conf = getConfiguration(); + IDReader idReader = new IDReader(conf); + Map<Long, FastIDSet> userItemFilter = new HashMap<>(); + + long user1 = 1; + long user2 = 2; + + idReader.addUserAndItemIdToUserItemFilter(userItemFilter, user1, 100L); + idReader.addUserAndItemIdToUserItemFilter(userItemFilter, user1, 200L); + idReader.addUserAndItemIdToUserItemFilter(userItemFilter, user2, 300L); + + FastIDSet userIds = IDReader.extractAllUserIdsFromUserItemFilter(userItemFilter); + + assertEquals(2, userIds.size()); + assertTrue(userIds.contains(user1)); + assertTrue(userIds.contains(user1)); + + setField(idReader, USER_ITEM_FILTER_FIELD, userItemFilter); + + FastIDSet itemsForUser1 = idReader.getItemsToRecommendForUser(user1); + assertEquals(2, itemsForUser1.size()); + assertTrue(itemsForUser1.contains(100L)); + assertTrue(itemsForUser1.contains(200L)); + + FastIDSet itemsForUser2 = idReader.getItemsToRecommendForUser(user2); + assertEquals(1, itemsForUser2.size()); + assertTrue(itemsForUser2.contains(300L)); + + FastIDSet itemsForNonExistingUser = idReader.getItemsToRecommendForUser(3L); + assertTrue(itemsForNonExistingUser.isEmpty()); + } + +}
