http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/DatasetSplitter.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/DatasetSplitter.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/DatasetSplitter.java new file mode 100644 index 0000000..b061a63 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/DatasetSplitter.java @@ -0,0 +1,158 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.hadoop.als; + +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.NullWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.common.AbstractJob; +import org.apache.mahout.common.RandomUtils; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Random; + +/** + * <p>Split a recommendation dataset into a training and a test set</p> + * + * <p>Command line arguments specific to this class are:</p> + * + * <ol> + * <li>--input (path): Directory containing one or more text files with the dataset</li> + * <li>--output (path): path where output should go</li> + * <li>--trainingPercentage (double): percentage of the data to use as training set (optional, default 0.9)</li> + * <li>--probePercentage (double): percentage of the data to use as probe set (optional, default 0.1)</li> + * </ol> + */ +public class DatasetSplitter extends AbstractJob { + + private static final String TRAINING_PERCENTAGE = DatasetSplitter.class.getName() + ".trainingPercentage"; + private static final String PROBE_PERCENTAGE = DatasetSplitter.class.getName() + ".probePercentage"; + private static final String PART_TO_USE = DatasetSplitter.class.getName() + ".partToUse"; + + private static final Text INTO_TRAINING_SET = new Text("T"); + private static final Text INTO_PROBE_SET = new Text("P"); + + private static final double DEFAULT_TRAINING_PERCENTAGE = 0.9; + private static final double DEFAULT_PROBE_PERCENTAGE = 0.1; + + public static void main(String[] args) throws Exception { + ToolRunner.run(new DatasetSplitter(), args); + } + + @Override + public int run(String[] args) throws Exception { + + addInputOption(); + addOutputOption(); + addOption("trainingPercentage", "t", "percentage of the data to use as training set (default: " + + DEFAULT_TRAINING_PERCENTAGE + ')', String.valueOf(DEFAULT_TRAINING_PERCENTAGE)); + addOption("probePercentage", "p", "percentage of the data to use as probe set (default: " + + DEFAULT_PROBE_PERCENTAGE + ')', String.valueOf(DEFAULT_PROBE_PERCENTAGE)); + + Map<String,List<String>> parsedArgs = parseArguments(args); + if (parsedArgs == null) { + return -1; + } + + double trainingPercentage = Double.parseDouble(getOption("trainingPercentage")); + double probePercentage = Double.parseDouble(getOption("probePercentage")); + String tempDir = getOption("tempDir"); + + Path markedPrefs = new Path(tempDir, "markedPreferences"); + Path trainingSetPath = new Path(getOutputPath(), "trainingSet"); + Path probeSetPath = new Path(getOutputPath(), "probeSet"); + + Job markPreferences = prepareJob(getInputPath(), markedPrefs, TextInputFormat.class, MarkPreferencesMapper.class, + Text.class, Text.class, SequenceFileOutputFormat.class); + markPreferences.getConfiguration().set(TRAINING_PERCENTAGE, String.valueOf(trainingPercentage)); + markPreferences.getConfiguration().set(PROBE_PERCENTAGE, String.valueOf(probePercentage)); + boolean succeeded = markPreferences.waitForCompletion(true); + if (!succeeded) { + return -1; + } + + Job createTrainingSet = prepareJob(markedPrefs, trainingSetPath, SequenceFileInputFormat.class, + WritePrefsMapper.class, NullWritable.class, Text.class, TextOutputFormat.class); + createTrainingSet.getConfiguration().set(PART_TO_USE, INTO_TRAINING_SET.toString()); + succeeded = createTrainingSet.waitForCompletion(true); + if (!succeeded) { + return -1; + } + + Job createProbeSet = prepareJob(markedPrefs, probeSetPath, SequenceFileInputFormat.class, + WritePrefsMapper.class, NullWritable.class, Text.class, TextOutputFormat.class); + createProbeSet.getConfiguration().set(PART_TO_USE, INTO_PROBE_SET.toString()); + succeeded = createProbeSet.waitForCompletion(true); + if (!succeeded) { + return -1; + } + + return 0; + } + + static class MarkPreferencesMapper extends Mapper<LongWritable,Text,Text,Text> { + + private Random random; + private double trainingBound; + private double probeBound; + + @Override + protected void setup(Context ctx) throws IOException, InterruptedException { + random = RandomUtils.getRandom(); + trainingBound = Double.parseDouble(ctx.getConfiguration().get(TRAINING_PERCENTAGE)); + probeBound = trainingBound + Double.parseDouble(ctx.getConfiguration().get(PROBE_PERCENTAGE)); + } + + @Override + protected void map(LongWritable key, Text text, Context ctx) throws IOException, InterruptedException { + double randomValue = random.nextDouble(); + if (randomValue <= trainingBound) { + ctx.write(INTO_TRAINING_SET, text); + } else if (randomValue <= probeBound) { + ctx.write(INTO_PROBE_SET, text); + } + } + } + + static class WritePrefsMapper extends Mapper<Text,Text,NullWritable,Text> { + + private String partToUse; + + @Override + protected void setup(Context ctx) throws IOException, InterruptedException { + partToUse = ctx.getConfiguration().get(PART_TO_USE); + } + + @Override + protected void map(Text key, Text text, Context ctx) throws IOException, InterruptedException { + if (partToUse.equals(key.toString())) { + ctx.write(NullWritable.get(), text); + } + } + } +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/FactorizationEvaluator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/FactorizationEvaluator.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/FactorizationEvaluator.java new file mode 100644 index 0000000..4e6aaf5 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/FactorizationEvaluator.java @@ -0,0 +1,166 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.hadoop.als; + +import java.io.BufferedWriter; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.util.List; +import java.util.Map; + +import org.apache.commons.io.Charsets; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FSDataOutputStream; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.DoubleWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.NullWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils; +import org.apache.mahout.cf.taste.impl.common.FullRunningAverage; +import org.apache.mahout.cf.taste.impl.common.RunningAverage; +import org.apache.mahout.common.AbstractJob; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.map.OpenIntObjectHashMap; + +/** + * <p>Measures the root-mean-squared error of a rating matrix factorization against a test set.</p> + * + * <p>Command line arguments specific to this class are:</p> + * + * <ol> + * <li>--output (path): path where output should go</li> + * <li>--pairs (path): path containing the test ratings, each line must be userID,itemID,rating</li> + * <li>--userFeatures (path): path to the user feature matrix</li> + * <li>--itemFeatures (path): path to the item feature matrix</li> + * </ol> + */ +public class FactorizationEvaluator extends AbstractJob { + + private static final String USER_FEATURES_PATH = RecommenderJob.class.getName() + ".userFeatures"; + private static final String ITEM_FEATURES_PATH = RecommenderJob.class.getName() + ".itemFeatures"; + + public static void main(String[] args) throws Exception { + ToolRunner.run(new FactorizationEvaluator(), args); + } + + @Override + public int run(String[] args) throws Exception { + + addInputOption(); + addOption("userFeatures", null, "path to the user feature matrix", true); + addOption("itemFeatures", null, "path to the item feature matrix", true); + addOption("usesLongIDs", null, "input contains long IDs that need to be translated"); + addOutputOption(); + + Map<String,List<String>> parsedArgs = parseArguments(args); + if (parsedArgs == null) { + return -1; + } + + Path errors = getTempPath("errors"); + + Job predictRatings = prepareJob(getInputPath(), errors, TextInputFormat.class, PredictRatingsMapper.class, + DoubleWritable.class, NullWritable.class, SequenceFileOutputFormat.class); + + Configuration conf = predictRatings.getConfiguration(); + conf.set(USER_FEATURES_PATH, getOption("userFeatures")); + conf.set(ITEM_FEATURES_PATH, getOption("itemFeatures")); + + boolean usesLongIDs = Boolean.parseBoolean(getOption("usesLongIDs")); + if (usesLongIDs) { + conf.set(ParallelALSFactorizationJob.USES_LONG_IDS, String.valueOf(true)); + } + + + boolean succeeded = predictRatings.waitForCompletion(true); + if (!succeeded) { + return -1; + } + + FileSystem fs = FileSystem.get(getOutputPath().toUri(), getConf()); + FSDataOutputStream outputStream = fs.create(getOutputPath("rmse.txt")); + try (BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(outputStream, Charsets.UTF_8))){ + double rmse = computeRmse(errors); + writer.write(String.valueOf(rmse)); + } + return 0; + } + + private double computeRmse(Path errors) { + RunningAverage average = new FullRunningAverage(); + for (Pair<DoubleWritable,NullWritable> entry + : new SequenceFileDirIterable<DoubleWritable, NullWritable>(errors, PathType.LIST, PathFilters.logsCRCFilter(), + getConf())) { + DoubleWritable error = entry.getFirst(); + average.addDatum(error.get() * error.get()); + } + + return Math.sqrt(average.getAverage()); + } + + public static class PredictRatingsMapper extends Mapper<LongWritable,Text,DoubleWritable,NullWritable> { + + private OpenIntObjectHashMap<Vector> U; + private OpenIntObjectHashMap<Vector> M; + + private boolean usesLongIDs; + + private final DoubleWritable error = new DoubleWritable(); + + @Override + protected void setup(Context ctx) throws IOException, InterruptedException { + Configuration conf = ctx.getConfiguration(); + + Path pathToU = new Path(conf.get(USER_FEATURES_PATH)); + Path pathToM = new Path(conf.get(ITEM_FEATURES_PATH)); + + U = ALS.readMatrixByRows(pathToU, conf); + M = ALS.readMatrixByRows(pathToM, conf); + + usesLongIDs = conf.getBoolean(ParallelALSFactorizationJob.USES_LONG_IDS, false); + } + + @Override + protected void map(LongWritable key, Text value, Context ctx) throws IOException, InterruptedException { + + String[] tokens = TasteHadoopUtils.splitPrefTokens(value.toString()); + + int userID = TasteHadoopUtils.readID(tokens[TasteHadoopUtils.USER_ID_POS], usesLongIDs); + int itemID = TasteHadoopUtils.readID(tokens[TasteHadoopUtils.ITEM_ID_POS], usesLongIDs); + double rating = Double.parseDouble(tokens[2]); + + if (U.containsKey(userID) && M.containsKey(itemID)) { + double estimate = U.get(userID).dot(M.get(itemID)); + error.set(rating - estimate); + ctx.write(error, NullWritable.get()); + } + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/MultithreadedSharingMapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/MultithreadedSharingMapper.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/MultithreadedSharingMapper.java new file mode 100644 index 0000000..d93e3a4 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/MultithreadedSharingMapper.java @@ -0,0 +1,62 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.hadoop.als; + +import com.google.common.base.Preconditions; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.mapreduce.JobContext; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.hadoop.mapreduce.lib.map.MultithreadedMapper; +import org.apache.hadoop.util.ReflectionUtils; + +import java.io.IOException; + +/** + * Multithreaded Mapper for {@link SharingMapper}s. Will call setupSharedInstance() once in the controlling thread + * before executing the mappers using a thread pool. + * + * @param <K1> + * @param <V1> + * @param <K2> + * @param <V2> + */ +public class MultithreadedSharingMapper<K1, V1, K2, V2> extends MultithreadedMapper<K1, V1, K2, V2> { + + @Override + public void run(Context ctx) throws IOException, InterruptedException { + Class<Mapper<K1, V1, K2, V2>> mapperClass = + MultithreadedSharingMapper.getMapperClass((JobContext) ctx); + Preconditions.checkNotNull(mapperClass, "Could not find Multithreaded Mapper class."); + + Configuration conf = ctx.getConfiguration(); + // instantiate the mapper + Mapper<K1, V1, K2, V2> mapper1 = ReflectionUtils.newInstance(mapperClass, conf); + SharingMapper<K1, V1, K2, V2, ?> mapper = null; + if (mapper1 instanceof SharingMapper) { + mapper = (SharingMapper<K1, V1, K2, V2, ?>) mapper1; + } + Preconditions.checkNotNull(mapper, "Could not instantiate SharingMapper. Class was: %s", + mapper1.getClass().getName()); + + // single threaded call to setup the sharing mapper + mapper.setupSharedInstance(ctx); + + // multithreaded execution + super.run(ctx); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob.java new file mode 100644 index 0000000..2ce9b61 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob.java @@ -0,0 +1,414 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.hadoop.als; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Random; + +import com.google.common.base.Preconditions; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.filecache.DistributedCache; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.io.WritableComparable; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; +import org.apache.hadoop.mapreduce.lib.map.MultithreadedMapper; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils; +import org.apache.mahout.cf.taste.impl.common.FullRunningAverage; +import org.apache.mahout.cf.taste.impl.common.RunningAverage; +import org.apache.mahout.common.AbstractJob; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; +import org.apache.mahout.common.mapreduce.MergeVectorsCombiner; +import org.apache.mahout.common.mapreduce.MergeVectorsReducer; +import org.apache.mahout.common.mapreduce.TransposeMapper; +import org.apache.mahout.common.mapreduce.VectorSumCombiner; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.RandomAccessSparseVector; +import org.apache.mahout.math.SequentialAccessSparseVector; +import org.apache.mahout.math.VarIntWritable; +import org.apache.mahout.math.VarLongWritable; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.hadoop.similarity.cooccurrence.Vectors; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * <p>MapReduce implementation of the two factorization algorithms described in + * + * <p>"Large-scale Parallel Collaborative Filtering for the Netï¬ix Prize" available at + * http://www.hpl.hp.com/personal/Robert_Schreiber/papers/2008%20AAIM%20Netflix/netflix_aaim08(submitted).pdf.</p> + * + * "<p>Collaborative Filtering for Implicit Feedback Datasets" available at + * http://research.yahoo.com/pub/2433</p> + * + * </p> + * <p>Command line arguments specific to this class are:</p> + * + * <ol> + * <li>--input (path): Directory containing one or more text files with the dataset</li> + * <li>--output (path): path where output should go</li> + * <li>--lambda (double): regularization parameter to avoid overfitting</li> + * <li>--userFeatures (path): path to the user feature matrix</li> + * <li>--itemFeatures (path): path to the item feature matrix</li> + * <li>--numThreadsPerSolver (int): threads to use per solver mapper, (default: 1)</li> + * </ol> + */ +public class ParallelALSFactorizationJob extends AbstractJob { + + private static final Logger log = LoggerFactory.getLogger(ParallelALSFactorizationJob.class); + + static final String NUM_FEATURES = ParallelALSFactorizationJob.class.getName() + ".numFeatures"; + static final String LAMBDA = ParallelALSFactorizationJob.class.getName() + ".lambda"; + static final String ALPHA = ParallelALSFactorizationJob.class.getName() + ".alpha"; + static final String NUM_ENTITIES = ParallelALSFactorizationJob.class.getName() + ".numEntities"; + + static final String USES_LONG_IDS = ParallelALSFactorizationJob.class.getName() + ".usesLongIDs"; + static final String TOKEN_POS = ParallelALSFactorizationJob.class.getName() + ".tokenPos"; + + private boolean implicitFeedback; + private int numIterations; + private int numFeatures; + private double lambda; + private double alpha; + private int numThreadsPerSolver; + + enum Stats { NUM_USERS } + + public static void main(String[] args) throws Exception { + ToolRunner.run(new ParallelALSFactorizationJob(), args); + } + + @Override + public int run(String[] args) throws Exception { + + addInputOption(); + addOutputOption(); + addOption("lambda", null, "regularization parameter", true); + addOption("implicitFeedback", null, "data consists of implicit feedback?", String.valueOf(false)); + addOption("alpha", null, "confidence parameter (only used on implicit feedback)", String.valueOf(40)); + addOption("numFeatures", null, "dimension of the feature space", true); + addOption("numIterations", null, "number of iterations", true); + addOption("numThreadsPerSolver", null, "threads per solver mapper", String.valueOf(1)); + addOption("usesLongIDs", null, "input contains long IDs that need to be translated"); + + Map<String,List<String>> parsedArgs = parseArguments(args); + if (parsedArgs == null) { + return -1; + } + + numFeatures = Integer.parseInt(getOption("numFeatures")); + numIterations = Integer.parseInt(getOption("numIterations")); + lambda = Double.parseDouble(getOption("lambda")); + alpha = Double.parseDouble(getOption("alpha")); + implicitFeedback = Boolean.parseBoolean(getOption("implicitFeedback")); + + numThreadsPerSolver = Integer.parseInt(getOption("numThreadsPerSolver")); + boolean usesLongIDs = Boolean.parseBoolean(getOption("usesLongIDs", String.valueOf(false))); + + /* + * compute the factorization A = U M' + * + * where A (users x items) is the matrix of known ratings + * U (users x features) is the representation of users in the feature space + * M (items x features) is the representation of items in the feature space + */ + + if (usesLongIDs) { + Job mapUsers = prepareJob(getInputPath(), getOutputPath("userIDIndex"), TextInputFormat.class, + MapLongIDsMapper.class, VarIntWritable.class, VarLongWritable.class, IDMapReducer.class, + VarIntWritable.class, VarLongWritable.class, SequenceFileOutputFormat.class); + mapUsers.getConfiguration().set(TOKEN_POS, String.valueOf(TasteHadoopUtils.USER_ID_POS)); + mapUsers.waitForCompletion(true); + + Job mapItems = prepareJob(getInputPath(), getOutputPath("itemIDIndex"), TextInputFormat.class, + MapLongIDsMapper.class, VarIntWritable.class, VarLongWritable.class, IDMapReducer.class, + VarIntWritable.class, VarLongWritable.class, SequenceFileOutputFormat.class); + mapItems.getConfiguration().set(TOKEN_POS, String.valueOf(TasteHadoopUtils.ITEM_ID_POS)); + mapItems.waitForCompletion(true); + } + + /* create A' */ + Job itemRatings = prepareJob(getInputPath(), pathToItemRatings(), + TextInputFormat.class, ItemRatingVectorsMapper.class, IntWritable.class, + VectorWritable.class, VectorSumReducer.class, IntWritable.class, + VectorWritable.class, SequenceFileOutputFormat.class); + itemRatings.setCombinerClass(VectorSumCombiner.class); + itemRatings.getConfiguration().set(USES_LONG_IDS, String.valueOf(usesLongIDs)); + boolean succeeded = itemRatings.waitForCompletion(true); + if (!succeeded) { + return -1; + } + + /* create A */ + Job userRatings = prepareJob(pathToItemRatings(), pathToUserRatings(), + TransposeMapper.class, IntWritable.class, VectorWritable.class, MergeUserVectorsReducer.class, + IntWritable.class, VectorWritable.class); + userRatings.setCombinerClass(MergeVectorsCombiner.class); + succeeded = userRatings.waitForCompletion(true); + if (!succeeded) { + return -1; + } + + //TODO this could be fiddled into one of the upper jobs + Job averageItemRatings = prepareJob(pathToItemRatings(), getTempPath("averageRatings"), + AverageRatingMapper.class, IntWritable.class, VectorWritable.class, MergeVectorsReducer.class, + IntWritable.class, VectorWritable.class); + averageItemRatings.setCombinerClass(MergeVectorsCombiner.class); + succeeded = averageItemRatings.waitForCompletion(true); + if (!succeeded) { + return -1; + } + + Vector averageRatings = ALS.readFirstRow(getTempPath("averageRatings"), getConf()); + + int numItems = averageRatings.getNumNondefaultElements(); + int numUsers = (int) userRatings.getCounters().findCounter(Stats.NUM_USERS).getValue(); + + log.info("Found {} users and {} items", numUsers, numItems); + + /* create an initial M */ + initializeM(averageRatings); + + for (int currentIteration = 0; currentIteration < numIterations; currentIteration++) { + /* broadcast M, read A row-wise, recompute U row-wise */ + log.info("Recomputing U (iteration {}/{})", currentIteration, numIterations); + runSolver(pathToUserRatings(), pathToU(currentIteration), pathToM(currentIteration - 1), currentIteration, "U", + numItems); + /* broadcast U, read A' row-wise, recompute M row-wise */ + log.info("Recomputing M (iteration {}/{})", currentIteration, numIterations); + runSolver(pathToItemRatings(), pathToM(currentIteration), pathToU(currentIteration), currentIteration, "M", + numUsers); + } + + return 0; + } + + private void initializeM(Vector averageRatings) throws IOException { + Random random = RandomUtils.getRandom(); + + FileSystem fs = FileSystem.get(pathToM(-1).toUri(), getConf()); + try (SequenceFile.Writer writer = + new SequenceFile.Writer(fs, getConf(), new Path(pathToM(-1), "part-m-00000"), + IntWritable.class, VectorWritable.class)) { + IntWritable index = new IntWritable(); + VectorWritable featureVector = new VectorWritable(); + + for (Vector.Element e : averageRatings.nonZeroes()) { + Vector row = new DenseVector(numFeatures); + row.setQuick(0, e.get()); + for (int m = 1; m < numFeatures; m++) { + row.setQuick(m, random.nextDouble()); + } + index.set(e.index()); + featureVector.set(row); + writer.append(index, featureVector); + } + } + } + + static class VectorSumReducer + extends Reducer<WritableComparable<?>, VectorWritable, WritableComparable<?>, VectorWritable> { + + private final VectorWritable result = new VectorWritable(); + + @Override + protected void reduce(WritableComparable<?> key, Iterable<VectorWritable> values, Context ctx) + throws IOException, InterruptedException { + Vector sum = Vectors.sum(values.iterator()); + result.set(new SequentialAccessSparseVector(sum)); + ctx.write(key, result); + } + } + + static class MergeUserVectorsReducer extends + Reducer<WritableComparable<?>,VectorWritable,WritableComparable<?>,VectorWritable> { + + private final VectorWritable result = new VectorWritable(); + + @Override + public void reduce(WritableComparable<?> key, Iterable<VectorWritable> vectors, Context ctx) + throws IOException, InterruptedException { + Vector merged = VectorWritable.merge(vectors.iterator()).get(); + result.set(new SequentialAccessSparseVector(merged)); + ctx.write(key, result); + ctx.getCounter(Stats.NUM_USERS).increment(1); + } + } + + static class ItemRatingVectorsMapper extends Mapper<LongWritable,Text,IntWritable,VectorWritable> { + + private final IntWritable itemIDWritable = new IntWritable(); + private final VectorWritable ratingsWritable = new VectorWritable(true); + private final Vector ratings = new RandomAccessSparseVector(Integer.MAX_VALUE, 1); + + private boolean usesLongIDs; + + @Override + protected void setup(Context ctx) throws IOException, InterruptedException { + usesLongIDs = ctx.getConfiguration().getBoolean(USES_LONG_IDS, false); + } + + @Override + protected void map(LongWritable offset, Text line, Context ctx) throws IOException, InterruptedException { + String[] tokens = TasteHadoopUtils.splitPrefTokens(line.toString()); + int userID = TasteHadoopUtils.readID(tokens[TasteHadoopUtils.USER_ID_POS], usesLongIDs); + int itemID = TasteHadoopUtils.readID(tokens[TasteHadoopUtils.ITEM_ID_POS], usesLongIDs); + float rating = Float.parseFloat(tokens[2]); + + ratings.setQuick(userID, rating); + + itemIDWritable.set(itemID); + ratingsWritable.set(ratings); + + ctx.write(itemIDWritable, ratingsWritable); + + // prepare instance for reuse + ratings.setQuick(userID, 0.0d); + } + } + + private void runSolver(Path ratings, Path output, Path pathToUorM, int currentIteration, String matrixName, + int numEntities) throws ClassNotFoundException, IOException, InterruptedException { + + // necessary for local execution in the same JVM only + SharingMapper.reset(); + + Class<? extends Mapper<IntWritable,VectorWritable,IntWritable,VectorWritable>> solverMapperClassInternal; + String name; + + if (implicitFeedback) { + solverMapperClassInternal = SolveImplicitFeedbackMapper.class; + name = "Recompute " + matrixName + ", iteration (" + currentIteration + '/' + numIterations + "), " + + '(' + numThreadsPerSolver + " threads, " + numFeatures + " features, implicit feedback)"; + } else { + solverMapperClassInternal = SolveExplicitFeedbackMapper.class; + name = "Recompute " + matrixName + ", iteration (" + currentIteration + '/' + numIterations + "), " + + '(' + numThreadsPerSolver + " threads, " + numFeatures + " features, explicit feedback)"; + } + + Job solverForUorI = prepareJob(ratings, output, SequenceFileInputFormat.class, MultithreadedSharingMapper.class, + IntWritable.class, VectorWritable.class, SequenceFileOutputFormat.class, name); + Configuration solverConf = solverForUorI.getConfiguration(); + solverConf.set(LAMBDA, String.valueOf(lambda)); + solverConf.set(ALPHA, String.valueOf(alpha)); + solverConf.setInt(NUM_FEATURES, numFeatures); + solverConf.set(NUM_ENTITIES, String.valueOf(numEntities)); + + FileSystem fs = FileSystem.get(pathToUorM.toUri(), solverConf); + FileStatus[] parts = fs.listStatus(pathToUorM, PathFilters.partFilter()); + for (FileStatus part : parts) { + if (log.isDebugEnabled()) { + log.debug("Adding {} to distributed cache", part.getPath().toString()); + } + DistributedCache.addCacheFile(part.getPath().toUri(), solverConf); + } + + MultithreadedMapper.setMapperClass(solverForUorI, solverMapperClassInternal); + MultithreadedMapper.setNumberOfThreads(solverForUorI, numThreadsPerSolver); + + boolean succeeded = solverForUorI.waitForCompletion(true); + if (!succeeded) { + throw new IllegalStateException("Job failed!"); + } + } + + static class AverageRatingMapper extends Mapper<IntWritable,VectorWritable,IntWritable,VectorWritable> { + + private final IntWritable firstIndex = new IntWritable(0); + private final Vector featureVector = new RandomAccessSparseVector(Integer.MAX_VALUE, 1); + private final VectorWritable featureVectorWritable = new VectorWritable(); + + @Override + protected void map(IntWritable r, VectorWritable v, Context ctx) throws IOException, InterruptedException { + RunningAverage avg = new FullRunningAverage(); + for (Vector.Element e : v.get().nonZeroes()) { + avg.addDatum(e.get()); + } + + featureVector.setQuick(r.get(), avg.getAverage()); + featureVectorWritable.set(featureVector); + ctx.write(firstIndex, featureVectorWritable); + + // prepare instance for reuse + featureVector.setQuick(r.get(), 0.0d); + } + } + + static class MapLongIDsMapper extends Mapper<LongWritable,Text,VarIntWritable,VarLongWritable> { + + private int tokenPos; + private final VarIntWritable index = new VarIntWritable(); + private final VarLongWritable idWritable = new VarLongWritable(); + + @Override + protected void setup(Context ctx) throws IOException, InterruptedException { + tokenPos = ctx.getConfiguration().getInt(TOKEN_POS, -1); + Preconditions.checkState(tokenPos >= 0); + } + + @Override + protected void map(LongWritable key, Text line, Context ctx) throws IOException, InterruptedException { + String[] tokens = TasteHadoopUtils.splitPrefTokens(line.toString()); + + long id = Long.parseLong(tokens[tokenPos]); + + index.set(TasteHadoopUtils.idToIndex(id)); + idWritable.set(id); + ctx.write(index, idWritable); + } + } + + static class IDMapReducer extends Reducer<VarIntWritable,VarLongWritable,VarIntWritable,VarLongWritable> { + @Override + protected void reduce(VarIntWritable index, Iterable<VarLongWritable> ids, Context ctx) + throws IOException, InterruptedException { + ctx.write(index, ids.iterator().next()); + } + } + + private Path pathToM(int iteration) { + return iteration == numIterations - 1 ? getOutputPath("M") : getTempPath("M-" + iteration); + } + + private Path pathToU(int iteration) { + return iteration == numIterations - 1 ? getOutputPath("U") : getTempPath("U-" + iteration); + } + + private Path pathToItemRatings() { + return getTempPath("itemRatings"); + } + + private Path pathToUserRatings() { + return getOutputPath("userRatings"); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/PredictionMapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/PredictionMapper.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/PredictionMapper.java new file mode 100644 index 0000000..6e7ea81 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/PredictionMapper.java @@ -0,0 +1,145 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.hadoop.als; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.mahout.cf.taste.hadoop.MutableRecommendedItem; +import org.apache.mahout.cf.taste.hadoop.RecommendedItemsWritable; +import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils; +import org.apache.mahout.cf.taste.hadoop.TopItemsQueue; +import org.apache.mahout.cf.taste.recommender.RecommendedItem; +import org.apache.mahout.common.Pair; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.function.IntObjectProcedure; +import org.apache.mahout.math.map.OpenIntLongHashMap; +import org.apache.mahout.math.map.OpenIntObjectHashMap; +import org.apache.mahout.math.set.OpenIntHashSet; + +import java.io.IOException; +import java.util.List; + +/** + * a multithreaded mapper that loads the feature matrices U and M into memory. Afterwards it computes recommendations + * from these. Can be executed by a {@link MultithreadedSharingMapper}. + */ +public class PredictionMapper extends SharingMapper<IntWritable,VectorWritable,LongWritable,RecommendedItemsWritable, + Pair<OpenIntObjectHashMap<Vector>,OpenIntObjectHashMap<Vector>>> { + + private int recommendationsPerUser; + private float maxRating; + + private boolean usesLongIDs; + private OpenIntLongHashMap userIDIndex; + private OpenIntLongHashMap itemIDIndex; + + private final LongWritable userIDWritable = new LongWritable(); + private final RecommendedItemsWritable recommendations = new RecommendedItemsWritable(); + + @Override + Pair<OpenIntObjectHashMap<Vector>, OpenIntObjectHashMap<Vector>> createSharedInstance(Context ctx) { + Configuration conf = ctx.getConfiguration(); + Path pathToU = new Path(conf.get(RecommenderJob.USER_FEATURES_PATH)); + Path pathToM = new Path(conf.get(RecommenderJob.ITEM_FEATURES_PATH)); + + OpenIntObjectHashMap<Vector> U = ALS.readMatrixByRows(pathToU, conf); + OpenIntObjectHashMap<Vector> M = ALS.readMatrixByRows(pathToM, conf); + + return new Pair<>(U, M); + } + + @Override + protected void setup(Context ctx) throws IOException, InterruptedException { + Configuration conf = ctx.getConfiguration(); + recommendationsPerUser = conf.getInt(RecommenderJob.NUM_RECOMMENDATIONS, + RecommenderJob.DEFAULT_NUM_RECOMMENDATIONS); + maxRating = Float.parseFloat(conf.get(RecommenderJob.MAX_RATING)); + + usesLongIDs = conf.getBoolean(ParallelALSFactorizationJob.USES_LONG_IDS, false); + if (usesLongIDs) { + userIDIndex = TasteHadoopUtils.readIDIndexMap(conf.get(RecommenderJob.USER_INDEX_PATH), conf); + itemIDIndex = TasteHadoopUtils.readIDIndexMap(conf.get(RecommenderJob.ITEM_INDEX_PATH), conf); + } + } + + @Override + protected void map(IntWritable userIndexWritable, VectorWritable ratingsWritable, Context ctx) + throws IOException, InterruptedException { + + Pair<OpenIntObjectHashMap<Vector>, OpenIntObjectHashMap<Vector>> uAndM = getSharedInstance(); + OpenIntObjectHashMap<Vector> U = uAndM.getFirst(); + OpenIntObjectHashMap<Vector> M = uAndM.getSecond(); + + Vector ratings = ratingsWritable.get(); + int userIndex = userIndexWritable.get(); + final OpenIntHashSet alreadyRatedItems = new OpenIntHashSet(ratings.getNumNondefaultElements()); + + for (Vector.Element e : ratings.nonZeroes()) { + alreadyRatedItems.add(e.index()); + } + + final TopItemsQueue topItemsQueue = new TopItemsQueue(recommendationsPerUser); + final Vector userFeatures = U.get(userIndex); + + M.forEachPair(new IntObjectProcedure<Vector>() { + @Override + public boolean apply(int itemID, Vector itemFeatures) { + if (!alreadyRatedItems.contains(itemID)) { + double predictedRating = userFeatures.dot(itemFeatures); + + MutableRecommendedItem top = topItemsQueue.top(); + if (predictedRating > top.getValue()) { + top.set(itemID, (float) predictedRating); + topItemsQueue.updateTop(); + } + } + return true; + } + }); + + List<RecommendedItem> recommendedItems = topItemsQueue.getTopItems(); + + if (!recommendedItems.isEmpty()) { + + // cap predictions to maxRating + for (RecommendedItem topItem : recommendedItems) { + ((MutableRecommendedItem) topItem).capToMaxValue(maxRating); + } + + if (usesLongIDs) { + long userID = userIDIndex.get(userIndex); + userIDWritable.set(userID); + + for (RecommendedItem topItem : recommendedItems) { + // remap item IDs + long itemID = itemIDIndex.get((int) topItem.getItemID()); + ((MutableRecommendedItem) topItem).setItemID(itemID); + } + + } else { + userIDWritable.set(userIndex); + } + + recommendations.set(recommendedItems); + ctx.write(userIDWritable, recommendations); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/RecommenderJob.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/RecommenderJob.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/RecommenderJob.java new file mode 100644 index 0000000..679d227 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/RecommenderJob.java @@ -0,0 +1,110 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.hadoop.als; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.apache.hadoop.mapreduce.lib.map.MultithreadedMapper; +import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.cf.taste.hadoop.RecommendedItemsWritable; +import org.apache.mahout.common.AbstractJob; + +import java.util.List; +import java.util.Map; + +/** + * <p>Computes the top-N recommendations per user from a decomposition of the rating matrix</p> + * + * <p>Command line arguments specific to this class are:</p> + * + * <ol> + * <li>--input (path): Directory containing the vectorized user ratings</li> + * <li>--output (path): path where output should go</li> + * <li>--numRecommendations (int): maximum number of recommendations per user (default: 10)</li> + * <li>--maxRating (double): maximum rating of an item</li> + * <li>--numThreads (int): threads to use per mapper, (default: 1)</li> + * </ol> + */ +public class RecommenderJob extends AbstractJob { + + static final String NUM_RECOMMENDATIONS = RecommenderJob.class.getName() + ".numRecommendations"; + static final String USER_FEATURES_PATH = RecommenderJob.class.getName() + ".userFeatures"; + static final String ITEM_FEATURES_PATH = RecommenderJob.class.getName() + ".itemFeatures"; + static final String MAX_RATING = RecommenderJob.class.getName() + ".maxRating"; + static final String USER_INDEX_PATH = RecommenderJob.class.getName() + ".userIndex"; + static final String ITEM_INDEX_PATH = RecommenderJob.class.getName() + ".itemIndex"; + + static final int DEFAULT_NUM_RECOMMENDATIONS = 10; + + public static void main(String[] args) throws Exception { + ToolRunner.run(new RecommenderJob(), args); + } + + @Override + public int run(String[] args) throws Exception { + + addInputOption(); + addOption("userFeatures", null, "path to the user feature matrix", true); + addOption("itemFeatures", null, "path to the item feature matrix", true); + addOption("numRecommendations", null, "number of recommendations per user", + String.valueOf(DEFAULT_NUM_RECOMMENDATIONS)); + addOption("maxRating", null, "maximum rating available", true); + addOption("numThreads", null, "threads per mapper", String.valueOf(1)); + addOption("usesLongIDs", null, "input contains long IDs that need to be translated"); + addOption("userIDIndex", null, "index for user long IDs (necessary if usesLongIDs is true)"); + addOption("itemIDIndex", null, "index for user long IDs (necessary if usesLongIDs is true)"); + addOutputOption(); + + Map<String,List<String>> parsedArgs = parseArguments(args); + if (parsedArgs == null) { + return -1; + } + + Job prediction = prepareJob(getInputPath(), getOutputPath(), SequenceFileInputFormat.class, + MultithreadedSharingMapper.class, IntWritable.class, RecommendedItemsWritable.class, TextOutputFormat.class); + Configuration conf = prediction.getConfiguration(); + + int numThreads = Integer.parseInt(getOption("numThreads")); + + conf.setInt(NUM_RECOMMENDATIONS, Integer.parseInt(getOption("numRecommendations"))); + conf.set(USER_FEATURES_PATH, getOption("userFeatures")); + conf.set(ITEM_FEATURES_PATH, getOption("itemFeatures")); + conf.set(MAX_RATING, getOption("maxRating")); + + boolean usesLongIDs = Boolean.parseBoolean(getOption("usesLongIDs")); + if (usesLongIDs) { + conf.set(ParallelALSFactorizationJob.USES_LONG_IDS, String.valueOf(true)); + conf.set(USER_INDEX_PATH, getOption("userIDIndex")); + conf.set(ITEM_INDEX_PATH, getOption("itemIDIndex")); + } + + MultithreadedMapper.setMapperClass(prediction, PredictionMapper.class); + MultithreadedMapper.setNumberOfThreads(prediction, numThreads); + + boolean succeeded = prediction.waitForCompletion(true); + if (!succeeded) { + return -1; + } + + return 0; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SharingMapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SharingMapper.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SharingMapper.java new file mode 100644 index 0000000..9925807 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SharingMapper.java @@ -0,0 +1,59 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.hadoop.als; + +import org.apache.hadoop.mapreduce.Mapper; + +import java.io.IOException; + +/** + * Mapper class to be used by {@link MultithreadedSharingMapper}. Offers "global" before() and after() methods + * that will typically be used to set up static variables. + * + * Suitable for mappers that need large, read-only in-memory data to operate. + * + * @param <K1> + * @param <V1> + * @param <K2> + * @param <V2> + */ +public abstract class SharingMapper<K1,V1,K2,V2,S> extends Mapper<K1,V1,K2,V2> { + + private static Object SHARED_INSTANCE; + + /** + * Called before the multithreaded execution + * + * @param context mapper's context + */ + abstract S createSharedInstance(Context context) throws IOException; + + final void setupSharedInstance(Context context) throws IOException { + if (SHARED_INSTANCE == null) { + SHARED_INSTANCE = createSharedInstance(context); + } + } + + final S getSharedInstance() { + return (S) SHARED_INSTANCE; + } + + static void reset() { + SHARED_INSTANCE = null; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveExplicitFeedbackMapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveExplicitFeedbackMapper.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveExplicitFeedbackMapper.java new file mode 100644 index 0000000..2569918 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveExplicitFeedbackMapper.java @@ -0,0 +1,61 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.hadoop.als; + +import com.google.common.base.Preconditions; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; + +import org.apache.mahout.math.map.OpenIntObjectHashMap; + +import java.io.IOException; + +/** Solving mapper that can be safely executed using multiple threads */ +public class SolveExplicitFeedbackMapper + extends SharingMapper<IntWritable,VectorWritable,IntWritable,VectorWritable,OpenIntObjectHashMap<Vector>> { + + private double lambda; + private int numFeatures; + private final VectorWritable uiOrmj = new VectorWritable(); + + @Override + OpenIntObjectHashMap<Vector> createSharedInstance(Context ctx) throws IOException { + Configuration conf = ctx.getConfiguration(); + int numEntities = Integer.parseInt(conf.get(ParallelALSFactorizationJob.NUM_ENTITIES)); + return ALS.readMatrixByRowsFromDistributedCache(numEntities, conf); + } + + @Override + protected void setup(Mapper.Context ctx) throws IOException, InterruptedException { + lambda = Double.parseDouble(ctx.getConfiguration().get(ParallelALSFactorizationJob.LAMBDA)); + numFeatures = ctx.getConfiguration().getInt(ParallelALSFactorizationJob.NUM_FEATURES, -1); + Preconditions.checkArgument(numFeatures > 0, "numFeatures must be greater then 0!"); + } + + @Override + protected void map(IntWritable userOrItemID, VectorWritable ratingsWritable, Context ctx) + throws IOException, InterruptedException { + OpenIntObjectHashMap<Vector> uOrM = getSharedInstance(); + uiOrmj.set(ALS.solveExplicit(ratingsWritable, uOrM, lambda, numFeatures)); + ctx.write(userOrItemID, uiOrmj); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveImplicitFeedbackMapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveImplicitFeedbackMapper.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveImplicitFeedbackMapper.java new file mode 100644 index 0000000..fd6657f --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveImplicitFeedbackMapper.java @@ -0,0 +1,58 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.hadoop.als; + +import com.google.common.base.Preconditions; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.IntWritable; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.als.ImplicitFeedbackAlternatingLeastSquaresSolver; + +import java.io.IOException; + +/** Solving mapper that can be safely executed using multiple threads */ +public class SolveImplicitFeedbackMapper + extends SharingMapper<IntWritable,VectorWritable,IntWritable,VectorWritable, + ImplicitFeedbackAlternatingLeastSquaresSolver> { + + private final VectorWritable uiOrmj = new VectorWritable(); + + @Override + ImplicitFeedbackAlternatingLeastSquaresSolver createSharedInstance(Context ctx) throws IOException { + Configuration conf = ctx.getConfiguration(); + + double lambda = Double.parseDouble(conf.get(ParallelALSFactorizationJob.LAMBDA)); + double alpha = Double.parseDouble(conf.get(ParallelALSFactorizationJob.ALPHA)); + int numFeatures = conf.getInt(ParallelALSFactorizationJob.NUM_FEATURES, -1); + int numEntities = Integer.parseInt(conf.get(ParallelALSFactorizationJob.NUM_ENTITIES)); + + Preconditions.checkArgument(numFeatures > 0, "numFeatures must be greater then 0!"); + + return new ImplicitFeedbackAlternatingLeastSquaresSolver(numFeatures, lambda, alpha, + ALS.readMatrixByRowsFromDistributedCache(numEntities, conf), 1); + } + + @Override + protected void map(IntWritable userOrItemID, VectorWritable ratingsWritable, Context ctx) + throws IOException, InterruptedException { + ImplicitFeedbackAlternatingLeastSquaresSolver solver = getSharedInstance(); + uiOrmj.set(solver.solve(ratingsWritable.get())); + ctx.write(userOrItemID, uiOrmj); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/AggregateAndRecommendReducer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/AggregateAndRecommendReducer.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/AggregateAndRecommendReducer.java new file mode 100644 index 0000000..b44fd5b --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/AggregateAndRecommendReducer.java @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.hadoop.item; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.mahout.cf.taste.hadoop.MutableRecommendedItem; +import org.apache.mahout.cf.taste.hadoop.RecommendedItemsWritable; +import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils; +import org.apache.mahout.cf.taste.hadoop.TopItemsQueue; +import org.apache.mahout.cf.taste.impl.common.FastIDSet; +import org.apache.mahout.cf.taste.recommender.RecommendedItem; +import org.apache.mahout.math.RandomAccessSparseVector; +import org.apache.mahout.math.VarLongWritable; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.Vector.Element; +import org.apache.mahout.math.function.Functions; +import org.apache.mahout.math.map.OpenIntLongHashMap; + +import java.io.IOException; +import java.util.Iterator; +import java.util.List; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * <p>computes prediction values for each user</p> + * + * <pre> + * u = a user + * i = an item not yet rated by u + * N = all items similar to i (where similarity is usually computed by pairwisely comparing the item-vectors + * of the user-item matrix) + * + * Prediction(u,i) = sum(all n from N: similarity(i,n) * rating(u,n)) / sum(all n from N: abs(similarity(i,n))) + * </pre> + */ +public final class AggregateAndRecommendReducer extends + Reducer<VarLongWritable,PrefAndSimilarityColumnWritable,VarLongWritable,RecommendedItemsWritable> { + + private static final Logger log = LoggerFactory.getLogger(AggregateAndRecommendReducer.class); + + static final String ITEMID_INDEX_PATH = "itemIDIndexPath"; + static final String NUM_RECOMMENDATIONS = "numRecommendations"; + static final int DEFAULT_NUM_RECOMMENDATIONS = 10; + static final String ITEMS_FILE = "itemsFile"; + + private boolean booleanData; + private int recommendationsPerUser; + private IDReader idReader; + private FastIDSet itemsToRecommendFor; + private OpenIntLongHashMap indexItemIDMap; + + private final RecommendedItemsWritable recommendedItems = new RecommendedItemsWritable(); + + private static final float BOOLEAN_PREF_VALUE = 1.0f; + + @Override + protected void setup(Context context) throws IOException { + Configuration conf = context.getConfiguration(); + recommendationsPerUser = conf.getInt(NUM_RECOMMENDATIONS, DEFAULT_NUM_RECOMMENDATIONS); + booleanData = conf.getBoolean(RecommenderJob.BOOLEAN_DATA, false); + indexItemIDMap = TasteHadoopUtils.readIDIndexMap(conf.get(ITEMID_INDEX_PATH), conf); + + idReader = new IDReader(conf); + idReader.readIDs(); + itemsToRecommendFor = idReader.getItemIds(); + } + + @Override + protected void reduce(VarLongWritable userID, + Iterable<PrefAndSimilarityColumnWritable> values, + Context context) throws IOException, InterruptedException { + if (booleanData) { + reduceBooleanData(userID, values, context); + } else { + reduceNonBooleanData(userID, values, context); + } + } + + private void reduceBooleanData(VarLongWritable userID, + Iterable<PrefAndSimilarityColumnWritable> values, + Context context) throws IOException, InterruptedException { + /* having boolean data, each estimated preference can only be 1, + * however we can't use this to rank the recommended items, + * so we use the sum of similarities for that. */ + Iterator<PrefAndSimilarityColumnWritable> columns = values.iterator(); + Vector predictions = columns.next().getSimilarityColumn(); + while (columns.hasNext()) { + predictions.assign(columns.next().getSimilarityColumn(), Functions.PLUS); + } + writeRecommendedItems(userID, predictions, context); + } + + private void reduceNonBooleanData(VarLongWritable userID, + Iterable<PrefAndSimilarityColumnWritable> values, + Context context) throws IOException, InterruptedException { + /* each entry here is the sum in the numerator of the prediction formula */ + Vector numerators = null; + /* each entry here is the sum in the denominator of the prediction formula */ + Vector denominators = null; + /* each entry here is the number of similar items used in the prediction formula */ + Vector numberOfSimilarItemsUsed = new RandomAccessSparseVector(Integer.MAX_VALUE, 100); + + for (PrefAndSimilarityColumnWritable prefAndSimilarityColumn : values) { + Vector simColumn = prefAndSimilarityColumn.getSimilarityColumn(); + float prefValue = prefAndSimilarityColumn.getPrefValue(); + /* count the number of items used for each prediction */ + for (Element e : simColumn.nonZeroes()) { + int itemIDIndex = e.index(); + numberOfSimilarItemsUsed.setQuick(itemIDIndex, numberOfSimilarItemsUsed.getQuick(itemIDIndex) + 1); + } + + if (denominators == null) { + denominators = simColumn.clone(); + } else { + denominators.assign(simColumn, Functions.PLUS_ABS); + } + + if (numerators == null) { + numerators = simColumn.clone(); + if (prefValue != BOOLEAN_PREF_VALUE) { + numerators.assign(Functions.MULT, prefValue); + } + } else { + if (prefValue != BOOLEAN_PREF_VALUE) { + simColumn.assign(Functions.MULT, prefValue); + } + numerators.assign(simColumn, Functions.PLUS); + } + + } + + if (numerators == null) { + return; + } + + Vector recommendationVector = new RandomAccessSparseVector(Integer.MAX_VALUE, 100); + for (Element element : numerators.nonZeroes()) { + int itemIDIndex = element.index(); + /* preference estimations must be based on at least 2 datapoints */ + if (numberOfSimilarItemsUsed.getQuick(itemIDIndex) > 1) { + /* compute normalized prediction */ + double prediction = element.get() / denominators.getQuick(itemIDIndex); + recommendationVector.setQuick(itemIDIndex, prediction); + } + } + writeRecommendedItems(userID, recommendationVector, context); + } + + /** + * find the top entries in recommendationVector, map them to the real itemIDs and write back the result + */ + private void writeRecommendedItems(VarLongWritable userID, Vector recommendationVector, Context context) + throws IOException, InterruptedException { + TopItemsQueue topKItems = new TopItemsQueue(recommendationsPerUser); + FastIDSet itemsForUser = null; + + if (idReader != null && idReader.isUserItemFilterSpecified()) { + itemsForUser = idReader.getItemsToRecommendForUser(userID.get()); + } + + for (Element element : recommendationVector.nonZeroes()) { + int index = element.index(); + long itemID; + if (indexItemIDMap != null && !indexItemIDMap.isEmpty()) { + itemID = indexItemIDMap.get(index); + } else { // we don't have any mappings, so just use the original + itemID = index; + } + + if (shouldIncludeItemIntoRecommendations(itemID, itemsToRecommendFor, itemsForUser)) { + + float value = (float) element.get(); + if (!Float.isNaN(value)) { + + MutableRecommendedItem topItem = topKItems.top(); + if (value > topItem.getValue()) { + topItem.set(itemID, value); + topKItems.updateTop(); + } + } + } + } + + List<RecommendedItem> topItems = topKItems.getTopItems(); + if (!topItems.isEmpty()) { + recommendedItems.set(topItems); + context.write(userID, recommendedItems); + } + } + + private boolean shouldIncludeItemIntoRecommendations(long itemID, FastIDSet allItemsToRecommendFor, + FastIDSet itemsForUser) { + if (allItemsToRecommendFor == null && itemsForUser == null) { + return true; + } else if (itemsForUser != null) { + return itemsForUser.contains(itemID); + } else { + return allItemsToRecommendFor.contains(itemID); + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/IDReader.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/IDReader.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/IDReader.java new file mode 100644 index 0000000..7797fe9 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/IDReader.java @@ -0,0 +1,244 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.hadoop.item; + +import java.io.IOException; +import java.io.InputStream; +import java.util.HashMap; +import java.util.Map; +import java.util.regex.Pattern; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.mahout.cf.taste.impl.common.FastIDSet; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.iterator.FileLineIterable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Reads user ids and item ids from files specified in usersFile, itemsFile or userItemFile options in item-based + * recommender. Composes a list of users and a list of items which can be used by + * {@link org.apache.mahout.cf.taste.hadoop.item.UserVectorSplitterMapper} and + * {@link org.apache.mahout.cf.taste.hadoop.item.AggregateAndRecommendReducer}. + */ +public class IDReader { + + static final String USER_ITEM_FILE = "userItemFile"; + + private static final Logger log = LoggerFactory.getLogger(IDReader.class); + private static final Pattern SEPARATOR = Pattern.compile("[\t,]"); + + private Configuration conf; + + private String usersFile; + private String itemsFile; + private String userItemFile; + + private FastIDSet userIds; + private FastIDSet itemIds; + + private FastIDSet emptySet; + + /* Key - user id, value - a set of item ids to include into recommendations for this user */ + private Map<Long, FastIDSet> userItemFilter; + + /** + * Creates a new IDReader + * + * @param conf Job configuration + */ + public IDReader(Configuration conf) { + this.conf = conf; + emptySet = new FastIDSet(); + + usersFile = conf.get(UserVectorSplitterMapper.USERS_FILE); + itemsFile = conf.get(AggregateAndRecommendReducer.ITEMS_FILE); + userItemFile = conf.get(USER_ITEM_FILE); + } + + /** + * Reads user ids and item ids from files specified in a job configuration + * + * @throws IOException if an error occurs during file read operation + * + * @throws IllegalStateException if userItemFile option is specified together with usersFile or itemsFile + */ + public void readIDs() throws IOException, IllegalStateException { + if (isUserItemFileSpecified()) { + readUserItemFilterIfNeeded(); + } + + if (isUsersFileSpecified() || isUserItemFilterSpecified()) { + readUserIds(); + } + + if (isItemsFileSpecified() || isUserItemFilterSpecified()) { + readItemIds(); + } + } + + /** + * Gets a collection of items which should be recommended for a user + * + * @param userId ID of a user we are interested in + * @return if a userItemFile option is specified, and that file contains at least one item ID for the user, + * then this method returns a {@link FastIDSet} object populated with item IDs. Otherwise, this + * method returns an empty set. + */ + public FastIDSet getItemsToRecommendForUser(Long userId) { + if (isUserItemFilterSpecified() && userItemFilter.containsKey(userId)) { + return userItemFilter.get(userId); + } else { + return emptySet; + } + } + + private void readUserIds() throws IOException, IllegalStateException { + if (isUsersFileSpecified() && !isUserItemFileSpecified()) { + userIds = readIDList(usersFile); + } else if (isUserItemFileSpecified() && !isUsersFileSpecified()) { + readUserItemFilterIfNeeded(); + userIds = extractAllUserIdsFromUserItemFilter(userItemFilter); + } else if (!isUsersFileSpecified()) { + throw new IllegalStateException("Neither usersFile nor userItemFile options are specified"); + } else { + throw new IllegalStateException("usersFile and userItemFile options cannot be used simultaneously"); + } + } + + private void readItemIds() throws IOException, IllegalStateException { + if (isItemsFileSpecified() && !isUserItemFileSpecified()) { + itemIds = readIDList(itemsFile); + } else if (isUserItemFileSpecified() && !isItemsFileSpecified()) { + readUserItemFilterIfNeeded(); + itemIds = extractAllItemIdsFromUserItemFilter(userItemFilter); + } else if (!isItemsFileSpecified()) { + throw new IllegalStateException("Neither itemsFile nor userItemFile options are specified"); + } else { + throw new IllegalStateException("itemsFile and userItemFile options cannot be specified simultaneously"); + } + } + + private void readUserItemFilterIfNeeded() throws IOException { + if (!isUserItemFilterSpecified() && isUserItemFileSpecified()) { + userItemFilter = readUserItemFilter(userItemFile); + } + } + + private Map<Long, FastIDSet> readUserItemFilter(String pathString) throws IOException { + Map<Long, FastIDSet> result = new HashMap<>(); + + try (InputStream in = openFile(pathString)) { + for (String line : new FileLineIterable(in)) { + try { + String[] tokens = SEPARATOR.split(line); + Long userId = Long.parseLong(tokens[0]); + Long itemId = Long.parseLong(tokens[1]); + + addUserAndItemIdToUserItemFilter(result, userId, itemId); + } catch (NumberFormatException nfe) { + log.warn("userItemFile line ignored: {}", line); + } + } + } + + return result; + } + + void addUserAndItemIdToUserItemFilter(Map<Long, FastIDSet> filter, Long userId, Long itemId) { + FastIDSet itemIds; + + if (filter.containsKey(userId)) { + itemIds = filter.get(userId); + } else { + itemIds = new FastIDSet(); + filter.put(userId, itemIds); + } + + itemIds.add(itemId); + } + + static FastIDSet extractAllUserIdsFromUserItemFilter(Map<Long, FastIDSet> filter) { + FastIDSet result = new FastIDSet(); + + for (Long userId : filter.keySet()) { + result.add(userId); + } + + return result; + } + + private FastIDSet extractAllItemIdsFromUserItemFilter(Map<Long, FastIDSet> filter) { + FastIDSet result = new FastIDSet(); + + for (FastIDSet itemIds : filter.values()) { + result.addAll(itemIds); + } + + return result; + } + + private FastIDSet readIDList(String pathString) throws IOException { + FastIDSet result = null; + + if (pathString != null) { + result = new FastIDSet(); + + try (InputStream in = openFile(pathString)){ + for (String line : new FileLineIterable(in)) { + try { + result.add(Long.parseLong(line)); + } catch (NumberFormatException nfe) { + log.warn("line ignored: {}", line); + } + } + } + } + + return result; + } + + private InputStream openFile(String pathString) throws IOException { + return HadoopUtil.openStream(new Path(pathString), conf); + } + + public boolean isUsersFileSpecified () { + return usersFile != null; + } + + public boolean isItemsFileSpecified () { + return itemsFile != null; + } + + public boolean isUserItemFileSpecified () { + return userItemFile != null; + } + + public boolean isUserItemFilterSpecified() { + return userItemFilter != null; + } + + public FastIDSet getUserIds() { + return userIds; + } + + public FastIDSet getItemIds() { + return itemIds; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ItemFilterAsVectorAndPrefsReducer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ItemFilterAsVectorAndPrefsReducer.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ItemFilterAsVectorAndPrefsReducer.java new file mode 100644 index 0000000..4415a55 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ItemFilterAsVectorAndPrefsReducer.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.hadoop.item; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils; +import org.apache.mahout.math.RandomAccessSparseVector; +import org.apache.mahout.math.VarIntWritable; +import org.apache.mahout.math.VarLongWritable; +import org.apache.mahout.math.Vector; + +/** + * we use a neat little trick to explicitly filter items for some users: we inject a NaN summand into the preference + * estimation for those items, which makes {@link org.apache.mahout.cf.taste.hadoop.item.AggregateAndRecommendReducer} + * automatically exclude them + */ +public class ItemFilterAsVectorAndPrefsReducer + extends Reducer<VarLongWritable,VarLongWritable,VarIntWritable,VectorAndPrefsWritable> { + + private final VarIntWritable itemIDIndexWritable = new VarIntWritable(); + private final VectorAndPrefsWritable vectorAndPrefs = new VectorAndPrefsWritable(); + + @Override + protected void reduce(VarLongWritable itemID, Iterable<VarLongWritable> values, Context ctx) + throws IOException, InterruptedException { + + int itemIDIndex = TasteHadoopUtils.idToIndex(itemID.get()); + Vector vector = new RandomAccessSparseVector(Integer.MAX_VALUE, 1); + /* artificial NaN summand to exclude this item from the recommendations for all users specified in userIDs */ + vector.set(itemIDIndex, Double.NaN); + + List<Long> userIDs = new ArrayList<>(); + List<Float> prefValues = new ArrayList<>(); + for (VarLongWritable userID : values) { + userIDs.add(userID.get()); + prefValues.add(1.0f); + } + + itemIDIndexWritable.set(itemIDIndex); + vectorAndPrefs.set(vector, userIDs, prefValues); + ctx.write(itemIDIndexWritable, vectorAndPrefs); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ItemFilterMapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ItemFilterMapper.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ItemFilterMapper.java new file mode 100644 index 0000000..cdc1ddf --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ItemFilterMapper.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.hadoop.item; + +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.math.VarLongWritable; + +import java.io.IOException; +import java.util.regex.Pattern; + +/** + * map out all user/item pairs to filter, keyed by the itemID + */ +public class ItemFilterMapper extends Mapper<LongWritable,Text,VarLongWritable,VarLongWritable> { + + private static final Pattern SEPARATOR = Pattern.compile("[\t,]"); + + private final VarLongWritable itemIDWritable = new VarLongWritable(); + private final VarLongWritable userIDWritable = new VarLongWritable(); + + @Override + protected void map(LongWritable key, Text line, Context ctx) throws IOException, InterruptedException { + String[] tokens = SEPARATOR.split(line.toString()); + long userID = Long.parseLong(tokens[0]); + long itemID = Long.parseLong(tokens[1]); + itemIDWritable.set(itemID); + userIDWritable.set(userID); + ctx.write(itemIDWritable, userIDWritable); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ItemIDIndexMapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ItemIDIndexMapper.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ItemIDIndexMapper.java new file mode 100644 index 0000000..ac8597e --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ItemIDIndexMapper.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.hadoop.item; + +import java.io.IOException; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils; +import org.apache.mahout.cf.taste.hadoop.ToEntityPrefsMapper; +import org.apache.mahout.math.VarIntWritable; +import org.apache.mahout.math.VarLongWritable; + +public final class ItemIDIndexMapper extends + Mapper<LongWritable,Text, VarIntWritable, VarLongWritable> { + + private boolean transpose; + + private final VarIntWritable indexWritable = new VarIntWritable(); + private final VarLongWritable itemIDWritable = new VarLongWritable(); + + @Override + protected void setup(Context context) { + Configuration jobConf = context.getConfiguration(); + transpose = jobConf.getBoolean(ToEntityPrefsMapper.TRANSPOSE_USER_ITEM, false); + } + + @Override + protected void map(LongWritable key, + Text value, + Context context) throws IOException, InterruptedException { + String[] tokens = TasteHadoopUtils.splitPrefTokens(value.toString()); + long itemID = Long.parseLong(tokens[transpose ? 0 : 1]); + int index = TasteHadoopUtils.idToIndex(itemID); + indexWritable.set(index); + itemIDWritable.set(itemID); + context.write(indexWritable, itemIDWritable); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ItemIDIndexReducer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ItemIDIndexReducer.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ItemIDIndexReducer.java new file mode 100644 index 0000000..d9ecf5e --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ItemIDIndexReducer.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.hadoop.item; + +import java.io.IOException; + +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.mahout.math.VarIntWritable; +import org.apache.mahout.math.VarLongWritable; + +public final class ItemIDIndexReducer extends + Reducer<VarIntWritable, VarLongWritable, VarIntWritable,VarLongWritable> { + + private final VarLongWritable minimumItemIDWritable = new VarLongWritable(); + + @Override + protected void reduce(VarIntWritable index, + Iterable<VarLongWritable> possibleItemIDs, + Context context) throws IOException, InterruptedException { + long minimumItemID = Long.MAX_VALUE; + for (VarLongWritable varLongWritable : possibleItemIDs) { + long itemID = varLongWritable.get(); + if (itemID < minimumItemID) { + minimumItemID = itemID; + } + } + if (minimumItemID != Long.MAX_VALUE) { + minimumItemIDWritable.set(minimumItemID); + context.write(index, minimumItemIDWritable); + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/PartialMultiplyMapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/PartialMultiplyMapper.java b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/PartialMultiplyMapper.java new file mode 100644 index 0000000..0e818f3 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/PartialMultiplyMapper.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.hadoop.item; + +import java.io.IOException; +import java.util.List; + +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.math.VarIntWritable; +import org.apache.mahout.math.VarLongWritable; +import org.apache.mahout.math.Vector; + +/** + * maps similar items and their preference values per user + */ +public final class PartialMultiplyMapper extends + Mapper<VarIntWritable,VectorAndPrefsWritable,VarLongWritable,PrefAndSimilarityColumnWritable> { + + private final VarLongWritable userIDWritable = new VarLongWritable(); + private final PrefAndSimilarityColumnWritable prefAndSimilarityColumn = new PrefAndSimilarityColumnWritable(); + + @Override + protected void map(VarIntWritable key, + VectorAndPrefsWritable vectorAndPrefsWritable, + Context context) throws IOException, InterruptedException { + + Vector similarityMatrixColumn = vectorAndPrefsWritable.getVector(); + List<Long> userIDs = vectorAndPrefsWritable.getUserIDs(); + List<Float> prefValues = vectorAndPrefsWritable.getValues(); + + for (int i = 0; i < userIDs.size(); i++) { + long userID = userIDs.get(i); + float prefValue = prefValues.get(i); + if (!Float.isNaN(prefValue)) { + prefAndSimilarityColumn.set(prefValue, similarityMatrixColumn); + userIDWritable.set(userID); + context.write(userIDWritable, prefAndSimilarityColumn); + } + } + } + +}
