Author: ssc
Date: Tue Apr 12 09:40:58 2011
New Revision: 1091345

URL: http://svn.apache.org/viewvc?rev=1091345&view=rev
Log:
MAHOUT-542 missing evaluation classes

Added:
    mahout/trunk/utils/src/main/java/org/apache/mahout/utils/eval/
    
mahout/trunk/utils/src/main/java/org/apache/mahout/utils/eval/DatasetSplitter.java
    
mahout/trunk/utils/src/main/java/org/apache/mahout/utils/eval/InMemoryFactorizationEvaluator.java
    
mahout/trunk/utils/src/main/java/org/apache/mahout/utils/eval/ParallelFactorizationEvaluator.java
    mahout/trunk/utils/src/test/java/org/apache/mahout/utils/eval/
    
mahout/trunk/utils/src/test/java/org/apache/mahout/utils/eval/ParallelFactorizationEvaluatorTest.java

Added: 
mahout/trunk/utils/src/main/java/org/apache/mahout/utils/eval/DatasetSplitter.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/utils/src/main/java/org/apache/mahout/utils/eval/DatasetSplitter.java?rev=1091345&view=auto
==============================================================================
--- 
mahout/trunk/utils/src/main/java/org/apache/mahout/utils/eval/DatasetSplitter.java
 (added)
+++ 
mahout/trunk/utils/src/main/java/org/apache/mahout/utils/eval/DatasetSplitter.java
 Tue Apr 12 09:40:58 2011
@@ -0,0 +1,149 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.utils.eval;
+
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.BooleanWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.NullWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.RandomUtils;
+
+import java.io.IOException;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * <p>Split a recommendation dataset into a training and a test set</p>
+ *
+  * <p>Command line arguments specific to this class are:</p>
+ *
+ * <ol>
+ * <li>--input (path): Directory containing one or more text files with the 
dataset</li>
+ * <li>--output (path): path where output should go</li>
+ * <li>--trainingPercentage (double): percentage of the data to use as 
training set (optional, default 0.9)</li>
+ * <li>--probePercentage (double): percentage of the data to use as probe set 
(optional, default 0.1)</li>
+ * </ol>
+ */
+public class DatasetSplitter extends AbstractJob {
+
+  private static final String TRAINING_PERCENTAGE = 
DatasetSplitter.class.getName() + ".trainingPercentage";
+  private static final String PROBE_PERCENTAGE = 
DatasetSplitter.class.getName() + ".probePercentage";
+  private static final String PART_TO_USE = DatasetSplitter.class.getName() + 
".partToUse";
+
+  private static final Text INTO_TRAINING_SET = new Text("T");
+  private static final Text INTO_PROBE_SET = new Text("P");
+
+  private static final double DEFAULT_TRAINING_PERCENTAGE = 0.9;
+  private static final double DEFAULT_PROBE_PERCENTAGE = 0.1;
+
+  public static void main(String[] args) throws Exception {
+    ToolRunner.run(new DatasetSplitter(), args);
+  }
+
+  @Override
+  public int run(String[] args) throws Exception {
+
+    addInputOption();
+    addOutputOption();
+    addOption("trainingPercentage", "t", "percentage of the data to use as 
training set (default: " +
+        DEFAULT_TRAINING_PERCENTAGE + ")", 
String.valueOf(DEFAULT_TRAINING_PERCENTAGE));
+    addOption("probePercentage", "p", "percentage of the data to use as probe 
set (default: " +
+        DEFAULT_PROBE_PERCENTAGE +")", 
String.valueOf(DEFAULT_PROBE_PERCENTAGE));
+
+    Map<String, String> parsedArgs = parseArguments(args);
+    double trainingPercentage = 
Double.parseDouble(parsedArgs.get("--trainingPercentage"));
+    double probePercentage = 
Double.parseDouble(parsedArgs.get("--probePercentage"));
+    String tempDir = parsedArgs.get("--tempDir");
+
+    Path markedPrefs = new Path(tempDir, "markedPreferences");
+    Path trainingSetPath = new Path(getOutputPath(), "trainingSet");
+    Path probeSetPath = new Path(getOutputPath(), "probeSet");
+
+    Job markPreferences = prepareJob(getInputPath(), markedPrefs, 
TextInputFormat.class, MarkPreferencesMapper.class,
+        Text.class, Text.class, Reducer.class, Text.class, Text.class,
+        SequenceFileOutputFormat.class);
+    markPreferences.getConfiguration().set(TRAINING_PERCENTAGE, 
String.valueOf(trainingPercentage));
+    markPreferences.getConfiguration().set(PROBE_PERCENTAGE, 
String.valueOf(probePercentage));
+    markPreferences.waitForCompletion(true);
+
+    Job createTrainingSet = prepareJob(markedPrefs, trainingSetPath, 
SequenceFileInputFormat.class,
+        WritePrefsMapper.class, NullWritable.class, Text.class, Reducer.class, 
NullWritable.class, Text.class,
+        TextOutputFormat.class);
+    createTrainingSet.getConfiguration().set(PART_TO_USE, 
INTO_TRAINING_SET.toString());
+    createTrainingSet.waitForCompletion(true);
+
+    Job createProbeSet = prepareJob(markedPrefs, probeSetPath, 
SequenceFileInputFormat.class,
+        WritePrefsMapper.class, NullWritable.class, Text.class, Reducer.class, 
NullWritable.class, Text.class,
+        TextOutputFormat.class);
+    createProbeSet.getConfiguration().set(PART_TO_USE, 
INTO_PROBE_SET.toString());
+    createProbeSet.waitForCompletion(true);
+
+    return 0;
+  }
+
+  static class MarkPreferencesMapper extends 
Mapper<LongWritable,Text,Text,Text> {
+
+    private Random random;
+    private double trainingBound;
+    private double probeBound;
+
+    @Override
+    protected void setup(Context ctx) throws IOException, InterruptedException 
{
+      random = RandomUtils.getRandom();
+      trainingBound = 
Double.parseDouble(ctx.getConfiguration().get(TRAINING_PERCENTAGE));
+      probeBound = trainingBound + 
Double.parseDouble(ctx.getConfiguration().get(PROBE_PERCENTAGE));
+    }
+
+    @Override
+    protected void map(LongWritable key, Text text, Context ctx) throws 
IOException, InterruptedException {
+      double randomValue = random.nextDouble();
+      if (randomValue <= trainingBound) {
+        ctx.write(INTO_TRAINING_SET, text);
+      } else if (randomValue <= probeBound) {
+        ctx.write(INTO_PROBE_SET, text);
+      }
+    }
+  }
+
+  static class WritePrefsMapper extends Mapper<Text,Text,NullWritable,Text> {
+
+    private String partToUse;
+
+    @Override
+    protected void setup(Context ctx) throws IOException, InterruptedException 
{
+      partToUse = ctx.getConfiguration().get(PART_TO_USE);
+    }
+
+    @Override
+    protected void map(Text key, Text text, Context ctx) throws IOException, 
InterruptedException {
+      if (partToUse.equals(key.toString())) {
+        ctx.write(NullWritable.get(), text);
+      }
+    }
+  }
+}
\ No newline at end of file

Added: 
mahout/trunk/utils/src/main/java/org/apache/mahout/utils/eval/InMemoryFactorizationEvaluator.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/utils/src/main/java/org/apache/mahout/utils/eval/InMemoryFactorizationEvaluator.java?rev=1091345&view=auto
==============================================================================
--- 
mahout/trunk/utils/src/main/java/org/apache/mahout/utils/eval/InMemoryFactorizationEvaluator.java
 (added)
+++ 
mahout/trunk/utils/src/main/java/org/apache/mahout/utils/eval/InMemoryFactorizationEvaluator.java
 Tue Apr 12 09:40:58 2011
@@ -0,0 +1,168 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.utils.eval;
+
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.model.GenericPreference;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.IOUtils;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.SparseMatrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.io.OutputStreamWriter;
+import java.io.Writer;
+import java.nio.charset.Charset;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * <p>Measures the root-mean-squared error of a ratring matrix factorization 
against a test set.</p>
+ *
+ * <p>the factorization matrices are read into memory, which makes this job 
pretty fast, if you get OutOfMemoryErrors,
+ * use {@link ParallelFactorizationEvaluator} instead</p>
+ *
+  * <p>Command line arguments specific to this class are:</p>
+ *
+ * <ol>
+ * <li>--output (path): path where output should go</li>
+ * <li>--pairs (path): path containing the test ratings, each line must be 
userID,itemID,rating</li>
+ * <li>--userFeatures (path): path to the user feature matrix</li>
+ * <li>--itemFeatures (path): path to the item feature matrix</li>
+ * </ol>
+ */
+public class InMemoryFactorizationEvaluator extends AbstractJob {
+
+  public static void main(String[] args) throws Exception {
+    ToolRunner.run(new InMemoryFactorizationEvaluator(), args);
+  }
+
+  @Override
+  public int run(String[] args) throws Exception {
+
+    addOption("pairs", "p", "path containing the test ratings, each line must 
be userID,itemID,rating", true);
+    addOption("userFeatures", "u", "path to the user feature matrix", true);
+    addOption("itemFeatures", "i", "path to the item feature matrix", true);
+    addOutputOption();
+
+    Map<String,String> parsedArgs = parseArguments(args);
+    if (parsedArgs == null) {
+      return -1;
+    }
+
+    Path pairs = new Path(parsedArgs.get("--pairs"));
+    Path userFeatures = new Path(parsedArgs.get("--userFeatures"));
+    Path itemFeatures = new Path(parsedArgs.get("--itemFeatures"));
+
+    Matrix u = readMatrix(userFeatures);
+    Matrix m = readMatrix(itemFeatures);
+
+    FullRunningAverage rmseAvg = new FullRunningAverage();
+    FullRunningAverage maeAvg = new FullRunningAverage();
+    int pairsUsed = 1;
+    Writer writer = new OutputStreamWriter(System.out);
+    try {
+      for (Preference pref : readProbePreferences(pairs)) {
+        int userID = (int) pref.getUserID();
+        int itemID = (int) pref.getItemID();
+
+        double rating = pref.getValue();
+        double estimate = u.getRow(userID).dot(m.getRow(itemID));
+        double err = rating - estimate;
+        rmseAvg.addDatum(err * err);
+        maeAvg.addDatum(Math.abs(err));
+        writer.write("Probe [" + pairsUsed + "], rating of user [" + userID + 
"] towards item [" + itemID + "], " +
+            "[" + rating + "] estimated [" + estimate + "]\n");
+        pairsUsed++;
+      }
+      double rmse = Math.sqrt(rmseAvg.getAverage());
+      double mae = maeAvg.getAverage();
+      writer.write("RMSE: " + rmse + ", MAE: " + mae + "\n");
+    } finally {
+      IOUtils.quietClose(writer);
+    }
+    return 0;
+  }
+
+  private Matrix readMatrix(Path dir) throws IOException {
+
+    Matrix matrix = new SparseMatrix(new int[] { Integer.MAX_VALUE, 
Integer.MAX_VALUE });
+
+    FileSystem fs = dir.getFileSystem(getConf());
+    for (FileStatus seqFile : fs.globStatus(new Path(dir, "part-*"))) {
+      Path path = seqFile.getPath();
+      SequenceFile.Reader reader = null;
+      try {
+        reader = new SequenceFile.Reader(fs, path, getConf());
+        IntWritable key = new IntWritable();
+        VectorWritable value = new VectorWritable();
+        while (reader.next(key, value)) {
+          int row = key.get();
+          Iterator<Vector.Element> elementsIterator = 
value.get().iterateNonZero();
+          while (elementsIterator.hasNext()) {
+            Vector.Element element = elementsIterator.next();
+            matrix.set(row, element.index(), element.get());
+          }
+        }
+      } finally {
+        IOUtils.quietClose(reader);
+      }
+    }
+    return matrix;
+  }
+
+  private List<Preference> readProbePreferences(Path dir) throws IOException {
+
+    List<Preference> preferences = new LinkedList<Preference>();
+    FileSystem fs = dir.getFileSystem(getConf());
+    for (FileStatus seqFile : fs.globStatus(new Path(dir, "part-*"))) {
+      Path path = seqFile.getPath();
+      InputStream in = null;
+      try  {
+        in = fs.open(path);
+        BufferedReader reader = new BufferedReader(new InputStreamReader(in, 
Charset.forName("UTF-8")));
+        String line;
+        while ((line = reader.readLine()) != null) {
+          String[] tokens = TasteHadoopUtils.splitPrefTokens(line);
+          long userID = Long.parseLong(tokens[0]);
+          long itemID = Long.parseLong(tokens[1]);
+          float value = Float.parseFloat(tokens[2]);
+          preferences.add(new GenericPreference(userID, itemID, value));
+        }
+      } finally {
+        IOUtils.quietClose(in);
+      }
+    }
+    return preferences;
+  }
+}
\ No newline at end of file

Added: 
mahout/trunk/utils/src/main/java/org/apache/mahout/utils/eval/ParallelFactorizationEvaluator.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/utils/src/main/java/org/apache/mahout/utils/eval/ParallelFactorizationEvaluator.java?rev=1091345&view=auto
==============================================================================
--- 
mahout/trunk/utils/src/main/java/org/apache/mahout/utils/eval/ParallelFactorizationEvaluator.java
 (added)
+++ 
mahout/trunk/utils/src/main/java/org/apache/mahout/utils/eval/ParallelFactorizationEvaluator.java
 Tue Apr 12 09:40:58 2011
@@ -0,0 +1,154 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.utils.eval;
+
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.NullWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils;
+import org.apache.mahout.cf.taste.hadoop.als.PredictionJob;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.IOUtils;
+import org.apache.mahout.common.IntPairWritable;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
+
+import java.io.BufferedWriter;
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.util.Map;
+
+/**
+ * <p>Measures the root-mean-squared error of a ratring matrix factorization 
against a test set.</p>
+ *
+ * <p>Command line arguments specific to this class are:</p>
+ *
+ * <ol>
+ * <li>--output (path): path where output should go</li>
+ * <li>--pairs (path): path containing the test ratings, each line must be 
userID,itemID,rating</li>
+ * <li>--userFeatures (path): path to the user feature matrix</li>
+ * <li>--itemFeatures (path): path to the item feature matrix</li>
+ * </ol>
+ */
+public class ParallelFactorizationEvaluator extends AbstractJob {
+
+  public static void main(String[] args) throws Exception {
+    ToolRunner.run(new ParallelFactorizationEvaluator(), args);
+  }
+
+  @Override
+  public int run(String[] args) throws Exception {
+
+    addOption("pairs", "p", "path containing the test ratings, each line must 
be userID,itemID,rating", true);
+    addOption("userFeatures", "u", "path to the user feature matrix", true);
+    addOption("itemFeatures", "i", "path to the item feature matrix", true);
+    addOutputOption();
+
+    Map<String,String> parsedArgs = parseArguments(args);
+    if (parsedArgs == null) {
+      return -1;
+    }
+
+    Path tempDir = new Path(parsedArgs.get("--tempDir"));
+    Path predictions = new Path(tempDir, "predictions");
+    Path errors = new Path(tempDir, "errors");
+
+    ToolRunner.run(getConf(), new PredictionJob(), new String[] { "--output", 
predictions.toString(),
+        "--pairs", parsedArgs.get("--pairs"), "--userFeatures", 
parsedArgs.get("--userFeatures"),
+        "--itemFeatures", parsedArgs.get("--itemFeatures"),
+        "--tempDir", tempDir.toString() });
+
+    Job estimationErrors = prepareJob(new Path(parsedArgs.get("--pairs") + "," 
+ predictions.toString()), errors,
+        TextInputFormat.class, PairsWithRatingMapper.class, 
IntPairWritable.class, DoubleWritable.class,
+        ErrorReducer.class, DoubleWritable.class, NullWritable.class, 
SequenceFileOutputFormat.class);
+    estimationErrors.waitForCompletion(true);
+
+    BufferedWriter writer  = null;
+    try {
+      FileSystem fs = FileSystem.get(getOutputPath().toUri(), getConf());
+      FSDataOutputStream outputStream = fs.create(new Path(getOutputPath(), 
"rmse.txt"));
+      double rmse = computeRmse(errors);
+      writer = new BufferedWriter(new OutputStreamWriter(outputStream));
+      writer.write(String.valueOf(rmse));
+    } finally {
+      IOUtils.quietClose(writer);
+    }
+
+    return 0;
+  }
+
+  protected double computeRmse(Path errors) {
+    RunningAverage average = new FullRunningAverage();
+    for (Pair<DoubleWritable,NullWritable> entry :
+        new SequenceFileDirIterable<DoubleWritable, NullWritable>(errors, 
PathType.LIST, getConf())) {
+      DoubleWritable error = entry.getFirst();
+      average.addDatum(error.get() * error.get());
+    }
+
+    return Math.sqrt(average.getAverage());
+  }
+
+  public static class PairsWithRatingMapper extends 
Mapper<LongWritable,Text,IntPairWritable,DoubleWritable> {
+    @Override
+    protected void map(LongWritable key, Text value, Context ctx) throws 
IOException, InterruptedException {
+      String[] tokens = TasteHadoopUtils.splitPrefTokens(value.toString());
+      int userIDIndex = TasteHadoopUtils.idToIndex(Long.parseLong(tokens[0]));
+      int itemIDIndex = TasteHadoopUtils.idToIndex(Long.parseLong(tokens[1]));
+      double rating = Double.parseDouble(tokens[2]);
+      ctx.write(new IntPairWritable(userIDIndex, itemIDIndex), new 
DoubleWritable(rating));
+    }
+  }
+
+  public static class ErrorReducer extends 
Reducer<IntPairWritable,DoubleWritable,DoubleWritable,NullWritable> {
+    @Override
+    protected void reduce(IntPairWritable key, Iterable<DoubleWritable> 
ratingAndEstimate, Context ctx)
+        throws IOException, InterruptedException {
+
+      double error = Double.NaN;
+      boolean bothFound = false;
+      for (DoubleWritable ratingOrEstimate : ratingAndEstimate) {
+        if (Double.isNaN(error)) {
+          error = ratingOrEstimate.get();
+        } else {
+          error -= ratingOrEstimate.get();
+          bothFound = true;
+          break;
+        }
+      }
+
+      if (bothFound) {
+        ctx.write(new DoubleWritable(error), NullWritable.get());
+      }
+    }
+  }
+}
\ No newline at end of file

Added: 
mahout/trunk/utils/src/test/java/org/apache/mahout/utils/eval/ParallelFactorizationEvaluatorTest.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/utils/src/test/java/org/apache/mahout/utils/eval/ParallelFactorizationEvaluatorTest.java?rev=1091345&view=auto
==============================================================================
--- 
mahout/trunk/utils/src/test/java/org/apache/mahout/utils/eval/ParallelFactorizationEvaluatorTest.java
 (added)
+++ 
mahout/trunk/utils/src/test/java/org/apache/mahout/utils/eval/ParallelFactorizationEvaluatorTest.java
 Tue Apr 12 09:40:58 2011
@@ -0,0 +1,76 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.utils.eval;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.common.IOUtils;
+import org.apache.mahout.math.hadoop.MathHelper;
+import org.junit.Test;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.FileReader;
+
+public class ParallelFactorizationEvaluatorTest extends TasteTestCase {
+
+  @Test
+  public void smallIntegration() throws Exception {
+
+    File pairs = getTestTempFile("pairs.txt");
+    File userFeatures = getTestTempFile("userFeatures.seq");
+    File itemFeatures = getTestTempFile("itemFeatures.seq");
+    File tempDir = getTestTempDir("temp");
+    File outputDir = getTestTempDir("out");
+    outputDir.delete();
+
+    Configuration conf = new Configuration();
+    Path inputPath = new Path(pairs.getAbsolutePath());
+    FileSystem fs = FileSystem.get(inputPath.toUri(), conf);
+
+    MathHelper.writeEntries(new double[][] {
+        new double[] {  1.5, -2,   0.3 },
+        new double[] { -0.7,  2,   0.6 },
+        new double[] { -1,    2.5, 3   } }, fs, conf, new 
Path(userFeatures.getAbsolutePath()));
+
+    MathHelper.writeEntries(new double [][] {
+        new double[] {  2.3,  0.5, 0   },
+        new double[] {  4.7, -1,   0.2 },
+        new double[] {  0.6,  2,   1.3 } }, fs, conf, new 
Path(itemFeatures.getAbsolutePath()));
+
+    writeLines(pairs, "0,0,3", "2,1,-7", "1,0,-2");
+
+    ParallelFactorizationEvaluator evaluator = new 
ParallelFactorizationEvaluator();
+    evaluator.setConf(conf);
+    evaluator.run(new String[] { "--output", outputDir.getAbsolutePath(), 
"--pairs", pairs.getAbsolutePath(),
+        "--userFeatures", userFeatures.getAbsolutePath(), "--itemFeatures", 
itemFeatures.getAbsolutePath(),
+        "--tempDir", tempDir.getAbsolutePath() });
+
+    BufferedReader reader = null;
+    try {
+      reader = new BufferedReader(new FileReader(new File(outputDir, 
"rmse.txt")));
+      double rmse = Double.parseDouble(reader.readLine());
+      assertEquals(0.89342, rmse, EPSILON);
+    } finally {
+      IOUtils.quietClose(reader);
+    }
+
+  }
+}
\ No newline at end of file


Reply via email to