http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/Frequencies.java
----------------------------------------------------------------------
diff --git 
a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/Frequencies.java
 
b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/Frequencies.java
new file mode 100644
index 0000000..c37af4e
--- /dev/null
+++ 
b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/Frequencies.java
@@ -0,0 +1,122 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.tools;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.conf.Configured;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.util.Tool;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.CommandLineUtil;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.Arrays;
+
+/**
+ * Compute the frequency distribution of the "class label"<br>
+ * This class can be used when the criterion variable is the categorical 
attribute.
+ */
+@Deprecated
+public final class Frequencies extends Configured implements Tool {
+  
+  private static final Logger log = LoggerFactory.getLogger(Frequencies.class);
+  
+  private Frequencies() { }
+  
+  @Override
+  public int run(String[] args) throws IOException, ClassNotFoundException, 
InterruptedException {
+    
+    DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+    ArgumentBuilder abuilder = new ArgumentBuilder();
+    GroupBuilder gbuilder = new GroupBuilder();
+    
+    Option dataOpt = 
obuilder.withLongName("data").withShortName("d").withRequired(true).withArgument(
+      
abuilder.withName("path").withMinimum(1).withMaximum(1).create()).withDescription("Data
 path").create();
+    
+    Option datasetOpt = 
obuilder.withLongName("dataset").withShortName("ds").withRequired(true).withArgument(
+      
abuilder.withName("path").withMinimum(1).create()).withDescription("dataset 
path").create();
+    
+    Option helpOpt = obuilder.withLongName("help").withDescription("Print out 
help").withShortName("h")
+        .create();
+    
+    Group group = 
gbuilder.withName("Options").withOption(dataOpt).withOption(datasetOpt).withOption(helpOpt)
+        .create();
+    
+    try {
+      Parser parser = new Parser();
+      parser.setGroup(group);
+      CommandLine cmdLine = parser.parse(args);
+      
+      if (cmdLine.hasOption(helpOpt)) {
+        CommandLineUtil.printHelp(group);
+        return 0;
+      }
+      
+      String dataPath = cmdLine.getValue(dataOpt).toString();
+      String datasetPath = cmdLine.getValue(datasetOpt).toString();
+      
+      log.debug("Data path : {}", dataPath);
+      log.debug("Dataset path : {}", datasetPath);
+      
+      runTool(dataPath, datasetPath);
+    } catch (OptionException e) {
+      log.warn(e.toString(), e);
+      CommandLineUtil.printHelp(group);
+    }
+    
+    return 0;
+  }
+  
+  private void runTool(String data, String dataset) throws IOException,
+                                                   ClassNotFoundException,
+                                                   InterruptedException {
+    
+    FileSystem fs = FileSystem.get(getConf());
+    Path workingDir = fs.getWorkingDirectory();
+    
+    Path dataPath = new Path(data);
+    Path datasetPath = new Path(dataset);
+    
+    log.info("Computing the frequencies...");
+    FrequenciesJob job = new FrequenciesJob(new Path(workingDir, "output"), 
dataPath, datasetPath);
+    
+    int[][] counts = job.run(getConf());
+    
+    // outputing the frequencies
+    log.info("counts[partition][class]");
+    for (int[] count : counts) {
+      log.info(Arrays.toString(count));
+    }
+  }
+  
+  public static void main(String[] args) throws Exception {
+    ToolRunner.run(new Configuration(), new Frequencies(), args);
+  }
+  
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java
----------------------------------------------------------------------
diff --git 
a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java
 
b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java
new file mode 100644
index 0000000..9d7e2ff
--- /dev/null
+++ 
b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java
@@ -0,0 +1,297 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.tools;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.JobContext;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.classifier.df.DFUtils;
+import org.apache.mahout.classifier.df.data.DataConverter;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Instance;
+import org.apache.mahout.classifier.df.mapreduce.Builder;
+import org.apache.mahout.common.HadoopUtil;
+import 
org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.net.URI;
+import java.util.Arrays;
+
+/**
+ * Temporary class used to compute the frequency distribution of the "class 
attribute".<br>
+ * This class can be used when the criterion variable is the categorical 
attribute.
+ */
+@Deprecated
+public class FrequenciesJob {
+  
+  private static final Logger log = 
LoggerFactory.getLogger(FrequenciesJob.class);
+  
+  /** directory that will hold this job's output */
+  private final Path outputPath;
+  
+  /** file that contains the serialized dataset */
+  private final Path datasetPath;
+  
+  /** directory that contains the data used in the first step */
+  private final Path dataPath;
+  
+  /**
+   * @param base
+   *          base directory
+   * @param dataPath
+   *          data used in the first step
+   */
+  public FrequenciesJob(Path base, Path dataPath, Path datasetPath) {
+    this.outputPath = new Path(base, "frequencies.output");
+    this.dataPath = dataPath;
+    this.datasetPath = datasetPath;
+  }
+  
+  /**
+   * @return counts[partition][label] = num tuples from 'partition' with class 
== label
+   */
+  public int[][] run(Configuration conf) throws IOException, 
ClassNotFoundException, InterruptedException {
+    
+    // check the output
+    FileSystem fs = outputPath.getFileSystem(conf);
+    if (fs.exists(outputPath)) {
+      throw new IOException("Output path already exists : " + outputPath);
+    }
+    
+    // put the dataset into the DistributedCache
+    URI[] files = {datasetPath.toUri()};
+    DistributedCache.setCacheFiles(files, conf);
+    
+    Job job = new Job(conf);
+    job.setJarByClass(FrequenciesJob.class);
+    
+    FileInputFormat.setInputPaths(job, dataPath);
+    FileOutputFormat.setOutputPath(job, outputPath);
+    
+    job.setMapOutputKeyClass(LongWritable.class);
+    job.setMapOutputValueClass(IntWritable.class);
+    job.setOutputKeyClass(LongWritable.class);
+    job.setOutputValueClass(Frequencies.class);
+    
+    job.setMapperClass(FrequenciesMapper.class);
+    job.setReducerClass(FrequenciesReducer.class);
+    
+    job.setInputFormatClass(TextInputFormat.class);
+    job.setOutputFormatClass(SequenceFileOutputFormat.class);
+    
+    // run the job
+    boolean succeeded = job.waitForCompletion(true);
+    if (!succeeded) {
+      throw new IllegalStateException("Job failed!");
+    }
+    
+    int[][] counts = parseOutput(job);
+
+    HadoopUtil.delete(conf, outputPath);
+    
+    return counts;
+  }
+  
+  /**
+   * Extracts the output and processes it
+   * 
+   * @return counts[partition][label] = num tuples from 'partition' with class 
== label
+   */
+  int[][] parseOutput(JobContext job) throws IOException {
+    Configuration conf = job.getConfiguration();
+    
+    int numMaps = conf.getInt("mapred.map.tasks", -1);
+    log.info("mapred.map.tasks = {}", numMaps);
+    
+    FileSystem fs = outputPath.getFileSystem(conf);
+    
+    Path[] outfiles = DFUtils.listOutputFiles(fs, outputPath);
+    
+    Frequencies[] values = new Frequencies[numMaps];
+    
+    // read all the outputs
+    int index = 0;
+    for (Path path : outfiles) {
+      for (Frequencies value : new 
SequenceFileValueIterable<Frequencies>(path, conf)) {
+        values[index++] = value;
+      }
+    }
+    
+    if (index < numMaps) {
+      throw new IllegalStateException("number of output Frequencies (" + index
+          + ") is lesser than the number of mappers!");
+    }
+    
+    // sort the frequencies using the firstIds
+    Arrays.sort(values);
+    return Frequencies.extractCounts(values);
+  }
+  
+  /**
+   * Outputs the first key and the label of each tuple
+   * 
+   */
+  private static class FrequenciesMapper extends 
Mapper<LongWritable,Text,LongWritable,IntWritable> {
+    
+    private LongWritable firstId;
+    
+    private DataConverter converter;
+    private Dataset dataset;
+    
+    @Override
+    protected void setup(Context context) throws IOException, 
InterruptedException {
+      Configuration conf = context.getConfiguration();
+      
+      dataset = Builder.loadDataset(conf);
+      setup(dataset);
+    }
+    
+    /**
+     * Useful when testing
+     */
+    void setup(Dataset dataset) {
+      converter = new DataConverter(dataset);
+    }
+    
+    @Override
+    protected void map(LongWritable key, Text value, Context context) throws 
IOException,
+                                                                     
InterruptedException {
+      if (firstId == null) {
+        firstId = new LongWritable(key.get());
+      }
+      
+      Instance instance = converter.convert(value.toString());
+      
+      context.write(firstId, new IntWritable((int) 
dataset.getLabel(instance)));
+    }
+    
+  }
+  
+  private static class FrequenciesReducer extends 
Reducer<LongWritable,IntWritable,LongWritable,Frequencies> {
+    
+    private int nblabels;
+    
+    @Override
+    protected void setup(Context context) throws IOException, 
InterruptedException {
+      Configuration conf = context.getConfiguration();
+      Dataset dataset = Builder.loadDataset(conf);
+      setup(dataset.nblabels());
+    }
+    
+    /**
+     * Useful when testing
+     */
+    void setup(int nblabels) {
+      this.nblabels = nblabels;
+    }
+    
+    @Override
+    protected void reduce(LongWritable key, Iterable<IntWritable> values, 
Context context)
+      throws IOException, InterruptedException {
+      int[] counts = new int[nblabels];
+      for (IntWritable value : values) {
+        counts[value.get()]++;
+      }
+      
+      context.write(key, new Frequencies(key.get(), counts));
+    }
+  }
+  
+  /**
+   * Output of the job
+   * 
+   */
+  private static class Frequencies implements Writable, 
Comparable<Frequencies>, Cloneable {
+    
+    /** first key of the partition used to sort the partitions */
+    private long firstId;
+    
+    /** counts[c] = num tuples from the partition with label == c */
+    private int[] counts;
+    
+    Frequencies() { }
+    
+    Frequencies(long firstId, int[] counts) {
+      this.firstId = firstId;
+      this.counts = Arrays.copyOf(counts, counts.length);
+    }
+    
+    @Override
+    public void readFields(DataInput in) throws IOException {
+      firstId = in.readLong();
+      counts = DFUtils.readIntArray(in);
+    }
+    
+    @Override
+    public void write(DataOutput out) throws IOException {
+      out.writeLong(firstId);
+      DFUtils.writeArray(out, counts);
+    }
+    
+    @Override
+    public boolean equals(Object other) {
+      return other instanceof Frequencies && firstId == ((Frequencies) 
other).firstId;
+    }
+    
+    @Override
+    public int hashCode() {
+      return (int) firstId;
+    }
+    
+    @Override
+    protected Frequencies clone() {
+      return new Frequencies(firstId, counts);
+    }
+    
+    @Override
+    public int compareTo(Frequencies obj) {
+      if (firstId < obj.firstId) {
+        return -1;
+      } else if (firstId > obj.firstId) {
+        return 1;
+      } else {
+        return 0;
+      }
+    }
+    
+    public static int[][] extractCounts(Frequencies[] partitions) {
+      int[][] counts = new int[partitions.length][];
+      for (int p = 0; p < partitions.length; p++) {
+        counts[p] = partitions[p].counts;
+      }
+      return counts;
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java
----------------------------------------------------------------------
diff --git 
a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java
 
b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java
new file mode 100644
index 0000000..a2a3458
--- /dev/null
+++ 
b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java
@@ -0,0 +1,264 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.tools;
+
+import java.lang.reflect.Field;
+import java.text.DecimalFormat;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Instance;
+import org.apache.mahout.classifier.df.node.CategoricalNode;
+import org.apache.mahout.classifier.df.node.Leaf;
+import org.apache.mahout.classifier.df.node.Node;
+import org.apache.mahout.classifier.df.node.NumericalNode;
+
+/**
+ * This tool is to visualize the Decision tree
+ */
+@Deprecated
+public final class TreeVisualizer {
+  
+  private TreeVisualizer() {}
+  
+  private static String doubleToString(double value) {
+    DecimalFormat df = new DecimalFormat("0.##");
+    return df.format(value);
+  }
+  
+  private static String toStringNode(Node node, Dataset dataset,
+      String[] attrNames, Map<String,Field> fields, int layer) {
+    
+    StringBuilder buff = new StringBuilder();
+    
+    try {
+      if (node instanceof CategoricalNode) {
+        CategoricalNode cnode = (CategoricalNode) node;
+        int attr = (Integer) fields.get("CategoricalNode.attr").get(cnode);
+        double[] values = (double[]) 
fields.get("CategoricalNode.values").get(cnode);
+        Node[] childs = (Node[]) 
fields.get("CategoricalNode.childs").get(cnode);
+        String[][] attrValues = (String[][]) 
fields.get("Dataset.values").get(dataset);
+        for (int i = 0; i < attrValues[attr].length; i++) {
+          int index = ArrayUtils.indexOf(values, i);
+          if (index < 0) {
+            continue;
+          }
+          buff.append('\n');
+          for (int j = 0; j < layer; j++) {
+            buff.append("|   ");
+          }
+          buff.append(attrNames == null ? attr : attrNames[attr]).append(" = ")
+              .append(attrValues[attr][i]);
+          buff.append(toStringNode(childs[index], dataset, attrNames, fields, 
layer + 1));
+        }
+      } else if (node instanceof NumericalNode) {
+        NumericalNode nnode = (NumericalNode) node;
+        int attr = (Integer) fields.get("NumericalNode.attr").get(nnode);
+        double split = (Double) fields.get("NumericalNode.split").get(nnode);
+        Node loChild = (Node) fields.get("NumericalNode.loChild").get(nnode);
+        Node hiChild = (Node) fields.get("NumericalNode.hiChild").get(nnode);
+        buff.append('\n');
+        for (int j = 0; j < layer; j++) {
+          buff.append("|   ");
+        }
+        buff.append(attrNames == null ? attr : attrNames[attr]).append(" < ")
+            .append(doubleToString(split));
+        buff.append(toStringNode(loChild, dataset, attrNames, fields, layer + 
1));
+        buff.append('\n');
+        for (int j = 0; j < layer; j++) {
+          buff.append("|   ");
+        }
+        buff.append(attrNames == null ? attr : attrNames[attr]).append(" >= ")
+            .append(doubleToString(split));
+        buff.append(toStringNode(hiChild, dataset, attrNames, fields, layer + 
1));
+      } else if (node instanceof Leaf) {
+        Leaf leaf = (Leaf) node;
+        double label = (Double) fields.get("Leaf.label").get(leaf);
+        if (dataset.isNumerical(dataset.getLabelId())) {
+          buff.append(" : ").append(doubleToString(label));
+        } else {
+          buff.append(" : ").append(dataset.getLabelString(label));
+        }
+      }
+    } catch (IllegalAccessException iae) {
+      throw new IllegalStateException(iae);
+    }
+    
+    return buff.toString();
+  }
+  
+  private static Map<String,Field> getReflectMap() {
+    Map<String,Field> fields = new HashMap<>();
+    
+    try {
+      Field m = CategoricalNode.class.getDeclaredField("attr");
+      m.setAccessible(true);
+      fields.put("CategoricalNode.attr", m);
+      m = CategoricalNode.class.getDeclaredField("values");
+      m.setAccessible(true);
+      fields.put("CategoricalNode.values", m);
+      m = CategoricalNode.class.getDeclaredField("childs");
+      m.setAccessible(true);
+      fields.put("CategoricalNode.childs", m);
+      m = NumericalNode.class.getDeclaredField("attr");
+      m.setAccessible(true);
+      fields.put("NumericalNode.attr", m);
+      m = NumericalNode.class.getDeclaredField("split");
+      m.setAccessible(true);
+      fields.put("NumericalNode.split", m);
+      m = NumericalNode.class.getDeclaredField("loChild");
+      m.setAccessible(true);
+      fields.put("NumericalNode.loChild", m);
+      m = NumericalNode.class.getDeclaredField("hiChild");
+      m.setAccessible(true);
+      fields.put("NumericalNode.hiChild", m);
+      m = Leaf.class.getDeclaredField("label");
+      m.setAccessible(true);
+      fields.put("Leaf.label", m);
+      m = Dataset.class.getDeclaredField("values");
+      m.setAccessible(true);
+      fields.put("Dataset.values", m);
+    } catch (NoSuchFieldException nsfe) {
+      throw new IllegalStateException(nsfe);
+    }
+    
+    return fields;
+  }
+  
+  /**
+   * Decision tree to String
+   * 
+   * @param tree
+   *          Node of tree
+   * @param attrNames
+   *          attribute names
+   */
+  public static String toString(Node tree, Dataset dataset, String[] 
attrNames) {
+    return toStringNode(tree, dataset, attrNames, getReflectMap(), 0);
+  }
+  
+  /**
+   * Print Decision tree
+   * 
+   * @param tree
+   *          Node of tree
+   * @param attrNames
+   *          attribute names
+   */
+  public static void print(Node tree, Dataset dataset, String[] attrNames) {
+    System.out.println(toString(tree, dataset, attrNames));
+  }
+  
+  private static String toStringPredict(Node node, Instance instance,
+      Dataset dataset, String[] attrNames, Map<String,Field> fields) {
+    StringBuilder buff = new StringBuilder();
+    
+    try {
+      if (node instanceof CategoricalNode) {
+        CategoricalNode cnode = (CategoricalNode) node;
+        int attr = (Integer) fields.get("CategoricalNode.attr").get(cnode);
+        double[] values = (double[]) fields.get("CategoricalNode.values").get(
+            cnode);
+        Node[] childs = (Node[]) fields.get("CategoricalNode.childs")
+            .get(cnode);
+        String[][] attrValues = (String[][]) fields.get("Dataset.values").get(
+            dataset);
+        
+        int index = ArrayUtils.indexOf(values, instance.get(attr));
+        if (index >= 0) {
+          buff.append(attrNames == null ? attr : attrNames[attr]).append(" = ")
+              .append(attrValues[attr][(int) instance.get(attr)]);
+          buff.append(" -> ");
+          buff.append(toStringPredict(childs[index], instance, dataset,
+              attrNames, fields));
+        }
+      } else if (node instanceof NumericalNode) {
+        NumericalNode nnode = (NumericalNode) node;
+        int attr = (Integer) fields.get("NumericalNode.attr").get(nnode);
+        double split = (Double) fields.get("NumericalNode.split").get(nnode);
+        Node loChild = (Node) fields.get("NumericalNode.loChild").get(nnode);
+        Node hiChild = (Node) fields.get("NumericalNode.hiChild").get(nnode);
+        
+        if (instance.get(attr) < split) {
+          buff.append('(').append(attrNames == null ? attr : attrNames[attr])
+              .append(" = ").append(doubleToString(instance.get(attr)))
+              .append(") < ").append(doubleToString(split));
+          buff.append(" -> ");
+          buff.append(toStringPredict(loChild, instance, dataset, attrNames,
+              fields));
+        } else {
+          buff.append('(').append(attrNames == null ? attr : attrNames[attr])
+              .append(" = ").append(doubleToString(instance.get(attr)))
+              .append(") >= ").append(doubleToString(split));
+          buff.append(" -> ");
+          buff.append(toStringPredict(hiChild, instance, dataset, attrNames,
+              fields));
+        }
+      } else if (node instanceof Leaf) {
+        Leaf leaf = (Leaf) node;
+        double label = (Double) fields.get("Leaf.label").get(leaf);
+        if (dataset.isNumerical(dataset.getLabelId())) {
+          buff.append(doubleToString(label));
+        } else {
+          buff.append(dataset.getLabelString(label));
+        }
+      }
+    } catch (IllegalAccessException iae) {
+      throw new IllegalStateException(iae);
+    }
+    
+    return buff.toString();
+  }
+  
+  /**
+   * Predict trace to String
+   * 
+   * @param tree
+   *          Node of tree
+   * @param attrNames
+   *          attribute names
+   */
+  public static String[] predictTrace(Node tree, Data data, String[] 
attrNames) {
+    Map<String,Field> reflectMap = getReflectMap();
+    String[] prediction = new String[data.size()];
+    for (int i = 0; i < data.size(); i++) {
+      prediction[i] = toStringPredict(tree, data.get(i), data.getDataset(),
+          attrNames, reflectMap);
+    }
+    return prediction;
+  }
+  
+  /**
+   * Print predict trace
+   * 
+   * @param tree
+   *          Node of tree
+   * @param attrNames
+   *          attribute names
+   */
+  public static void predictTracePrint(Node tree, Data data, String[] 
attrNames) {
+    Map<String,Field> reflectMap = getReflectMap();
+    for (int i = 0; i < data.size(); i++) {
+      System.out.println(toStringPredict(tree, data.get(i), data.getDataset(),
+          attrNames, reflectMap));
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java
----------------------------------------------------------------------
diff --git 
a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java
 
b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java
new file mode 100644
index 0000000..e1b55ab
--- /dev/null
+++ 
b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java
@@ -0,0 +1,212 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.tools;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.Locale;
+import java.util.Random;
+import java.util.Scanner;
+
+import com.google.common.base.Preconditions;
+import com.google.common.io.Closeables;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.FileUtil;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.classifier.df.data.DataConverter;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Instance;
+import org.apache.mahout.common.CommandLineUtil;
+import org.apache.mahout.common.RandomUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * This tool is used to uniformly distribute the class of all the tuples of 
the dataset over a given number of
+ * partitions.<br>
+ * This class can be used when the criterion variable is the categorical 
attribute.
+ */
+@Deprecated
+public final class UDistrib {
+  
+  private static final Logger log = LoggerFactory.getLogger(UDistrib.class);
+  
+  private UDistrib() {}
+  
+  /**
+   * Launch the uniform distribution tool. Requires the following command line 
arguments:<br>
+   * 
+   * data : data path dataset : dataset path numpartitions : num partitions 
output : output path
+   *
+   * @throws java.io.IOException
+   */
+  public static void main(String[] args) throws IOException {
+    
+    DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+    ArgumentBuilder abuilder = new ArgumentBuilder();
+    GroupBuilder gbuilder = new GroupBuilder();
+    
+    Option dataOpt = 
obuilder.withLongName("data").withShortName("d").withRequired(true).withArgument(
+      
abuilder.withName("data").withMinimum(1).withMaximum(1).create()).withDescription("Data
 path").create();
+    
+    Option datasetOpt = 
obuilder.withLongName("dataset").withShortName("ds").withRequired(true).withArgument(
+      
abuilder.withName("dataset").withMinimum(1).create()).withDescription("Dataset 
path").create();
+    
+    Option outputOpt = 
obuilder.withLongName("output").withShortName("o").withRequired(true).withArgument(
+      
abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription(
+      "Path to generated files").create();
+    
+    Option partitionsOpt = 
obuilder.withLongName("numpartitions").withShortName("p").withRequired(true)
+        
.withArgument(abuilder.withName("numparts").withMinimum(1).withMinimum(1).create()).withDescription(
+          "Number of partitions to create").create();
+    Option helpOpt = obuilder.withLongName("help").withDescription("Print out 
help").withShortName("h")
+        .create();
+    
+    Group group = 
gbuilder.withName("Options").withOption(dataOpt).withOption(outputOpt).withOption(
+      datasetOpt).withOption(partitionsOpt).withOption(helpOpt).create();
+    
+    try {
+      Parser parser = new Parser();
+      parser.setGroup(group);
+      CommandLine cmdLine = parser.parse(args);
+      
+      if (cmdLine.hasOption(helpOpt)) {
+        CommandLineUtil.printHelp(group);
+        return;
+      }
+      
+      String data = cmdLine.getValue(dataOpt).toString();
+      String dataset = cmdLine.getValue(datasetOpt).toString();
+      int numPartitions = 
Integer.parseInt(cmdLine.getValue(partitionsOpt).toString());
+      String output = cmdLine.getValue(outputOpt).toString();
+      
+      runTool(data, dataset, output, numPartitions);
+    } catch (OptionException e) {
+      log.warn(e.toString(), e);
+      CommandLineUtil.printHelp(group);
+    }
+    
+  }
+  
+  private static void runTool(String dataStr, String datasetStr, String 
output, int numPartitions) throws IOException {
+
+    Preconditions.checkArgument(numPartitions > 0, "numPartitions <= 0");
+    
+    // make sure the output file does not exist
+    Path outputPath = new Path(output);
+    Configuration conf = new Configuration();
+    FileSystem fs = outputPath.getFileSystem(conf);
+
+    Preconditions.checkArgument(!fs.exists(outputPath), "Output path already 
exists");
+    
+    // create a new file corresponding to each partition
+    // Path workingDir = fs.getWorkingDirectory();
+    // FileSystem wfs = workingDir.getFileSystem(conf);
+    // File parentFile = new File(workingDir.toString());
+    // File tempFile = FileUtil.createLocalTempFile(parentFile, "Parts", true);
+    // File tempFile = File.createTempFile("df.tools.UDistrib","");
+    // tempFile.deleteOnExit();
+    File tempFile = FileUtil.createLocalTempFile(new File(""), 
"df.tools.UDistrib", true);
+    Path partsPath = new Path(tempFile.toString());
+    FileSystem pfs = partsPath.getFileSystem(conf);
+    
+    Path[] partPaths = new Path[numPartitions];
+    FSDataOutputStream[] files = new FSDataOutputStream[numPartitions];
+    for (int p = 0; p < numPartitions; p++) {
+      partPaths[p] = new Path(partsPath, String.format(Locale.ENGLISH, 
"part.%03d", p));
+      files[p] = pfs.create(partPaths[p]);
+    }
+    
+    Path datasetPath = new Path(datasetStr);
+    Dataset dataset = Dataset.load(conf, datasetPath);
+    
+    // currents[label] = next partition file where to place the tuple
+    int[] currents = new int[dataset.nblabels()];
+    
+    // currents is initialized randomly in the range [0, numpartitions[
+    Random random = RandomUtils.getRandom();
+    for (int c = 0; c < currents.length; c++) {
+      currents[c] = random.nextInt(numPartitions);
+    }
+    
+    // foreach tuple of the data
+    Path dataPath = new Path(dataStr);
+    FileSystem ifs = dataPath.getFileSystem(conf);
+    FSDataInputStream input = ifs.open(dataPath);
+    Scanner scanner = new Scanner(input, "UTF-8");
+    DataConverter converter = new DataConverter(dataset);
+    
+    int id = 0;
+    while (scanner.hasNextLine()) {
+      if (id % 1000 == 0) {
+        log.info("progress : {}", id);
+      }
+      
+      String line = scanner.nextLine();
+      if (line.isEmpty()) {
+        continue; // skip empty lines
+      }
+      
+      // write the tuple in files[tuple.label]
+      Instance instance = converter.convert(line);
+      int label = (int) dataset.getLabel(instance);
+      files[currents[label]].writeBytes(line);
+      files[currents[label]].writeChar('\n');
+      
+      // update currents
+      currents[label]++;
+      if (currents[label] == numPartitions) {
+        currents[label] = 0;
+      }
+    }
+    
+    // close all the files.
+    scanner.close();
+    for (FSDataOutputStream file : files) {
+      Closeables.close(file, false);
+    }
+    
+    // merge all output files
+    FileUtil.copyMerge(pfs, partsPath, fs, outputPath, true, conf, null);
+    /*
+     * FSDataOutputStream joined = fs.create(new Path(outputPath, 
"uniform.data")); for (int p = 0; p <
+     * numPartitions; p++) {log.info("Joining part : {}", p); 
FSDataInputStream partStream =
+     * fs.open(partPaths[p]);
+     * 
+     * IOUtils.copyBytes(partStream, joined, conf, false);
+     * 
+     * partStream.close(); }
+     * 
+     * joined.close();
+     * 
+     * fs.delete(partsPath, true);
+     */
+  }
+  
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java
----------------------------------------------------------------------
diff --git 
a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java
 
b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java
new file mode 100644
index 0000000..049f9bf
--- /dev/null
+++ 
b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.evaluation;
+
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.list.DoubleArrayList;
+
+import com.google.common.base.Preconditions;
+
+import java.util.Random;
+
+/**
+ * Computes AUC and a few other accuracy statistics without storing huge 
amounts of data.  This is
+ * done by keeping uniform samples of the positive and negative scores.  Then, 
when AUC is to be
+ * computed, the remaining scores are sorted and a rank-sum statistic is used 
to compute the AUC.
+ * Since AUC is invariant with respect to down-sampling of either positives or 
negatives, this is
+ * close to correct and is exactly correct if maxBufferSize or fewer positive 
and negative scores
+ * are examined.
+ */
+public class Auc {
+
+  private int maxBufferSize = 10000;
+  private final DoubleArrayList[] scores = {new DoubleArrayList(), new 
DoubleArrayList()};
+  private final Random rand;
+  private int samples;
+  private final double threshold;
+  private final Matrix confusion;
+  private final DenseMatrix entropy;
+
+  private boolean probabilityScore = true;
+
+  private boolean hasScore;
+
+  /**
+   * Allocates a new data-structure for accumulating information about AUC and 
a few other accuracy
+   * measures.
+   * @param threshold The threshold to use in computing the confusion matrix.
+   */
+  public Auc(double threshold) {
+    confusion = new DenseMatrix(2, 2);
+    entropy = new DenseMatrix(2, 2);
+    this.rand = RandomUtils.getRandom();
+    this.threshold = threshold;
+  }
+
+  public Auc() {
+    this(0.5);
+  }
+
+  /**
+   * Adds a score to the AUC buffers.
+   *
+   * @param trueValue Whether this score is for a true-positive or a 
true-negative example.
+   * @param score     The score for this example.
+   */
+  public void add(int trueValue, double score) {
+    Preconditions.checkArgument(trueValue == 0 || trueValue == 1, "True value 
must be 0 or 1");
+    hasScore = true;
+
+    int predictedClass = score > threshold ? 1 : 0;
+    confusion.set(trueValue, predictedClass, confusion.get(trueValue, 
predictedClass) + 1);
+
+    samples++;
+    if (isProbabilityScore()) {
+      double limited = Math.max(1.0e-20, Math.min(score, 1 - 1.0e-20));
+      double v0 = entropy.get(trueValue, 0);
+      entropy.set(trueValue, 0, (Math.log1p(-limited) - v0) / samples + v0);
+
+      double v1 = entropy.get(trueValue, 1);
+      entropy.set(trueValue, 1, (Math.log(limited) - v1) / samples + v1);
+    }
+
+    // add to buffers
+    DoubleArrayList buf = scores[trueValue];
+    if (buf.size() >= maxBufferSize) {
+      // but if too many points are seen, we insert into a random
+      // place and discard the predecessor.  The random place could
+      // be anywhere, possibly not even in the buffer.
+      // this is a special case of Knuth's permutation algorithm
+      // but since we don't ever shuffle the first maxBufferSize
+      // samples, the result isn't just a fair sample of the prefixes
+      // of all permutations.  The CONTENTs of the result, however,
+      // will be a fair and uniform sample of maxBufferSize elements
+      // chosen from all elements without replacement
+      int index = rand.nextInt(samples);
+      if (index < buf.size()) {
+        buf.set(index, score);
+      }
+    } else {
+      // for small buffers, we collect all points without permuting
+      // since we sort the data later, permuting now would just be
+      // pedantic
+      buf.add(score);
+    }
+  }
+
+  public void add(int trueValue, int predictedClass) {
+    hasScore = false;
+    Preconditions.checkArgument(trueValue == 0 || trueValue == 1, "True value 
must be 0 or 1");
+    confusion.set(trueValue, predictedClass, confusion.get(trueValue, 
predictedClass) + 1);
+  }
+
+  /**
+   * Computes the AUC of points seen so far.  This can be moderately expensive 
since it requires
+   * that all points that have been retained be sorted.
+   *
+   * @return The value of the Area Under the receiver operating Curve.
+   */
+  public double auc() {
+    Preconditions.checkArgument(hasScore, "Can't compute AUC for classifier 
without a score");
+    scores[0].sort();
+    scores[1].sort();
+
+    double n0 = scores[0].size();
+    double n1 = scores[1].size();
+
+    if (n0 == 0 || n1 == 0) {
+      return 0.5;
+    }
+
+    // scan the data
+    int i0 = 0;
+    int i1 = 0;
+    int rank = 1;
+    double rankSum = 0;
+    while (i0 < n0 && i1 < n1) {
+
+      double v0 = scores[0].get(i0);
+      double v1 = scores[1].get(i1);
+
+      if (v0 < v1) {
+        i0++;
+        rank++;
+      } else if (v1 < v0) {
+        i1++;
+        rankSum += rank;
+        rank++;
+      } else {
+        // ties have to be handled delicately
+        double tieScore = v0;
+
+        // how many negatives are tied?
+        int k0 = 0;
+        while (i0 < n0 && scores[0].get(i0) == tieScore) {
+          k0++;
+          i0++;
+        }
+
+        // and how many positives
+        int k1 = 0;
+        while (i1 < n1 && scores[1].get(i1) == tieScore) {
+          k1++;
+          i1++;
+        }
+
+        // we found k0 + k1 tied values which have
+        // ranks in the half open interval [rank, rank + k0 + k1)
+        // the average rank is assigned to all
+        rankSum += (rank + (k0 + k1 - 1) / 2.0) * k1;
+        rank += k0 + k1;
+      }
+    }
+
+    if (i1 < n1) {
+      rankSum += (rank + (n1 - i1 - 1) / 2.0) * (n1 - i1);
+      rank += (int) (n1 - i1);
+    }
+
+    return (rankSum / n1 - (n1 + 1) / 2) / n0;
+  }
+
+  /**
+   * Returns the confusion matrix for the classifier supposing that we were to 
use a particular
+   * threshold.
+   * @return The confusion matrix.
+   */
+  public Matrix confusion() {
+    return confusion;
+  }
+
+  /**
+   * Returns a matrix related to the confusion matrix and to the 
log-likelihood.  For a
+   * pretty accurate classifier, N + entropy is nearly the same as the 
confusion matrix
+   * because log(1-eps) \approx -eps if eps is small.
+   *
+   * For lower accuracy classifiers, this measure will give us a better 
picture of how
+   * things work our.
+   *
+   * Also, by definition, log-likelihood = sum(diag(entropy))
+   * @return Returns a cell by cell break-down of the log-likelihood
+   */
+  public Matrix entropy() {
+    if (!hasScore) {
+      // find a constant score that would optimize log-likelihood, but use a 
dash of Bayesian
+      // conservatism to avoid dividing by zero or taking log(0)
+      double p = (0.5 + confusion.get(1, 1)) / (1 + confusion.get(0, 0) + 
confusion.get(1, 1));
+      entropy.set(0, 0, confusion.get(0, 0) * Math.log1p(-p));
+      entropy.set(0, 1, confusion.get(0, 1) * Math.log(p));
+      entropy.set(1, 0, confusion.get(1, 0) * Math.log1p(-p));
+      entropy.set(1, 1, confusion.get(1, 1) * Math.log(p));
+    }
+    return entropy;
+  }
+
+  public void setMaxBufferSize(int maxBufferSize) {
+    this.maxBufferSize = maxBufferSize;
+  }
+
+  public boolean isProbabilityScore() {
+    return probabilityScore;
+  }
+
+  public void setProbabilityScore(boolean probabilityScore) {
+    this.probabilityScore = probabilityScore;
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java
----------------------------------------------------------------------
diff --git 
a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java
 
b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java
new file mode 100644
index 0000000..f0794b3
--- /dev/null
+++ 
b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java
@@ -0,0 +1,82 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+
+/**
+ * Class implementing the Naive Bayes Classifier Algorithm. Note that this 
class
+ * supports {@link #classifyFull}, but not {@code classify} or
+ * {@code classifyScalar}. The reason that these two methods are not
+ * supported is because the scores computed by a NaiveBayesClassifier do not
+ * represent probabilities.
+ */
+public abstract class AbstractNaiveBayesClassifier extends 
AbstractVectorClassifier {
+
+  private final NaiveBayesModel model;
+  
+  protected AbstractNaiveBayesClassifier(NaiveBayesModel model) {
+    this.model = model;
+  }
+
+  protected NaiveBayesModel getModel() {
+    return model;
+  }
+  
+  protected abstract double getScoreForLabelFeature(int label, int feature);
+
+  protected double getScoreForLabelInstance(int label, Vector instance) {
+    double result = 0.0;
+    for (Element e : instance.nonZeroes()) {
+      result += e.get() * getScoreForLabelFeature(label, e.index());
+    }
+    return result;
+  }
+  
+  @Override
+  public int numCategories() {
+    return model.numLabels();
+  }
+
+  @Override
+  public Vector classifyFull(Vector instance) {
+    return classifyFull(model.createScoringVector(), instance);
+  }
+  
+  @Override
+  public Vector classifyFull(Vector r, Vector instance) {
+    for (int label = 0; label < model.numLabels(); label++) {
+      r.setQuick(label, getScoreForLabelInstance(label, instance));
+    }
+    return r;
+  }
+
+  /** Unsupported method. This implementation simply throws an {@link 
UnsupportedOperationException}. */
+  @Override
+  public double classifyScalar(Vector instance) {
+    throw new UnsupportedOperationException("Not supported in Naive Bayes");
+  }
+  
+  /** Unsupported method. This implementation simply throws an {@link 
UnsupportedOperationException}. */
+  @Override
+  public Vector classify(Vector instance) {
+    throw new UnsupportedOperationException("probabilites not supported in 
Naive Bayes");
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java
----------------------------------------------------------------------
diff --git 
a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java
 
b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java
new file mode 100644
index 0000000..4db8b17
--- /dev/null
+++ 
b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java
@@ -0,0 +1,161 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+import com.google.common.base.Preconditions;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.regex.Pattern;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.naivebayes.training.ThetaMapper;
+import org.apache.mahout.classifier.naivebayes.training.TrainNaiveBayesJob;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.SparseMatrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+
+public final class BayesUtils {
+
+  private static final Pattern SLASH = Pattern.compile("/");
+
+  private BayesUtils() {}
+
+  public static NaiveBayesModel readModelFromDir(Path base, Configuration 
conf) {
+
+    float alphaI = conf.getFloat(ThetaMapper.ALPHA_I, 1.0f);
+    boolean isComplementary = 
conf.getBoolean(NaiveBayesModel.COMPLEMENTARY_MODEL, true);
+
+    // read feature sums and label sums
+    Vector scoresPerLabel = null;
+    Vector scoresPerFeature = null;
+    for (Pair<Text,VectorWritable> record : new SequenceFileDirIterable<Text, 
VectorWritable>(
+        new Path(base, TrainNaiveBayesJob.WEIGHTS), PathType.LIST, 
PathFilters.partFilter(), conf)) {
+      String key = record.getFirst().toString();
+      VectorWritable value = record.getSecond();
+      if (key.equals(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE)) {
+        scoresPerFeature = value.get();
+      } else if (key.equals(TrainNaiveBayesJob.WEIGHTS_PER_LABEL)) {
+        scoresPerLabel = value.get();
+      }
+    }
+
+    Preconditions.checkNotNull(scoresPerFeature);
+    Preconditions.checkNotNull(scoresPerLabel);
+
+    Matrix scoresPerLabelAndFeature = new SparseMatrix(scoresPerLabel.size(), 
scoresPerFeature.size());
+    for (Pair<IntWritable,VectorWritable> entry : new 
SequenceFileDirIterable<IntWritable,VectorWritable>(
+        new Path(base, TrainNaiveBayesJob.SUMMED_OBSERVATIONS), PathType.LIST, 
PathFilters.partFilter(), conf)) {
+      scoresPerLabelAndFeature.assignRow(entry.getFirst().get(), 
entry.getSecond().get());
+    }
+    
+    // perLabelThetaNormalizer is only used by the complementary model, we do 
not instantiate it for the standard model
+    Vector perLabelThetaNormalizer = null;
+    if (isComplementary) {
+      perLabelThetaNormalizer=scoresPerLabel.like();    
+      for (Pair<Text,VectorWritable> entry : new 
SequenceFileDirIterable<Text,VectorWritable>(
+          new Path(base, TrainNaiveBayesJob.THETAS), PathType.LIST, 
PathFilters.partFilter(), conf)) {
+        if 
(entry.getFirst().toString().equals(TrainNaiveBayesJob.LABEL_THETA_NORMALIZER)) 
{
+          perLabelThetaNormalizer = entry.getSecond().get();
+        }
+      }
+      Preconditions.checkNotNull(perLabelThetaNormalizer);
+    }
+     
+    return new NaiveBayesModel(scoresPerLabelAndFeature, scoresPerFeature, 
scoresPerLabel, perLabelThetaNormalizer,
+        alphaI, isComplementary);
+  }
+
+  /** Write the list of labels into a map file */
+  public static int writeLabelIndex(Configuration conf, Iterable<String> 
labels, Path indexPath)
+    throws IOException {
+    FileSystem fs = FileSystem.get(indexPath.toUri(), conf);
+    int i = 0;
+    try (SequenceFile.Writer writer =
+           SequenceFile.createWriter(fs.getConf(), 
SequenceFile.Writer.file(indexPath),
+             SequenceFile.Writer.keyClass(Text.class), 
SequenceFile.Writer.valueClass(IntWritable.class))) {
+      for (String label : labels) {
+        writer.append(new Text(label), new IntWritable(i++));
+      }
+    }
+    return i;
+  }
+
+  public static int writeLabelIndex(Configuration conf, Path indexPath,
+                                    Iterable<Pair<Text,IntWritable>> labels) 
throws IOException {
+    FileSystem fs = FileSystem.get(indexPath.toUri(), conf);
+    Collection<String> seen = new HashSet<>();
+    int i = 0;
+    try (SequenceFile.Writer writer =
+           SequenceFile.createWriter(fs.getConf(), 
SequenceFile.Writer.file(indexPath),
+             SequenceFile.Writer.keyClass(Text.class), 
SequenceFile.Writer.valueClass(IntWritable.class))){
+      for (Object label : labels) {
+        String theLabel = SLASH.split(((Pair<?, ?>) 
label).getFirst().toString())[1];
+        if (!seen.contains(theLabel)) {
+          writer.append(new Text(theLabel), new IntWritable(i++));
+          seen.add(theLabel);
+        }
+      }
+    }
+    return i;
+  }
+
+  public static Map<Integer, String> readLabelIndex(Configuration conf, Path 
indexPath) {
+    Map<Integer, String> labelMap = new HashMap<>();
+    for (Pair<Text, IntWritable> pair : new SequenceFileIterable<Text, 
IntWritable>(indexPath, true, conf)) {
+      labelMap.put(pair.getSecond().get(), pair.getFirst().toString());
+    }
+    return labelMap;
+  }
+
+  public static OpenObjectIntHashMap<String> readIndexFromCache(Configuration 
conf) throws IOException {
+    OpenObjectIntHashMap<String> index = new OpenObjectIntHashMap<>();
+    for (Pair<Writable,IntWritable> entry
+        : new 
SequenceFileIterable<Writable,IntWritable>(HadoopUtil.getSingleCachedFile(conf),
 conf)) {
+      index.put(entry.getFirst().toString(), entry.getSecond().get());
+    }
+    return index;
+  }
+
+  public static Map<String,Vector> readScoresFromCache(Configuration conf) 
throws IOException {
+    Map<String,Vector> sumVectors = new HashMap<>();
+    for (Pair<Text,VectorWritable> entry
+        : new 
SequenceFileDirIterable<Text,VectorWritable>(HadoopUtil.getSingleCachedFile(conf),
+          PathType.LIST, PathFilters.partFilter(), conf)) {
+      sumVectors.put(entry.getFirst().toString(), entry.getSecond().get());
+    }
+    return sumVectors;
+  }
+
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
----------------------------------------------------------------------
diff --git 
a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
 
b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
new file mode 100644
index 0000000..18bd3d6
--- /dev/null
+++ 
b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
@@ -0,0 +1,43 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+
+/** Implementation of the Naive Bayes Classifier Algorithm */
+public class ComplementaryNaiveBayesClassifier extends 
AbstractNaiveBayesClassifier {
+  public ComplementaryNaiveBayesClassifier(NaiveBayesModel model) {
+    super(model);
+  }
+
+  @Override
+  public double getScoreForLabelFeature(int label, int feature) {
+    NaiveBayesModel model = getModel();
+    double weight = computeWeight(model.featureWeight(feature), 
model.weight(label, feature),
+        model.totalWeightSum(), model.labelWeight(label), model.alphaI(), 
model.numFeatures());
+    // see http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - Section 
3.2, Weight Magnitude Errors
+    return weight / model.thetaNormalizer(label);
+  }
+
+  // see http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - Section 
3.1, Skewed Data bias
+  public static double computeWeight(double featureWeight, double 
featureLabelWeight,
+      double totalWeight, double labelWeight, double alphaI, double 
numFeatures) {
+    double numerator = featureWeight - featureLabelWeight + alphaI;
+    double denominator = totalWeight - labelWeight + alphaI * numFeatures;
+    return -Math.log(numerator / denominator);
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java
----------------------------------------------------------------------
diff --git 
a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java
 
b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java
new file mode 100644
index 0000000..9f85aab
--- /dev/null
+++ 
b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java
@@ -0,0 +1,170 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+import java.io.IOException;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.SparseRowMatrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import com.google.common.base.Preconditions;
+
+/** NaiveBayesModel holds the weight matrix, the feature and label sums and 
the weight normalizer vectors.*/
+public class NaiveBayesModel {
+
+  private final Vector weightsPerLabel;
+  private final Vector perlabelThetaNormalizer;
+  private final Vector weightsPerFeature;
+  private final Matrix weightsPerLabelAndFeature;
+  private final float alphaI;
+  private final double numFeatures;
+  private final double totalWeightSum;
+  private final boolean isComplementary;  
+   
+  public final static String COMPLEMENTARY_MODEL = "COMPLEMENTARY_MODEL";
+
+  public NaiveBayesModel(Matrix weightMatrix, Vector weightsPerFeature, Vector 
weightsPerLabel, Vector thetaNormalizer,
+                         float alphaI, boolean isComplementary) {
+    this.weightsPerLabelAndFeature = weightMatrix;
+    this.weightsPerFeature = weightsPerFeature;
+    this.weightsPerLabel = weightsPerLabel;
+    this.perlabelThetaNormalizer = thetaNormalizer;
+    this.numFeatures = weightsPerFeature.getNumNondefaultElements();
+    this.totalWeightSum = weightsPerLabel.zSum();
+    this.alphaI = alphaI;
+    this.isComplementary=isComplementary;
+  }
+
+  public double labelWeight(int label) {
+    return weightsPerLabel.getQuick(label);
+  }
+
+  public double thetaNormalizer(int label) {
+    return perlabelThetaNormalizer.get(label); 
+  }
+
+  public double featureWeight(int feature) {
+    return weightsPerFeature.getQuick(feature);
+  }
+
+  public double weight(int label, int feature) {
+    return weightsPerLabelAndFeature.getQuick(label, feature);
+  }
+
+  public float alphaI() {
+    return alphaI;
+  }
+
+  public double numFeatures() {
+    return numFeatures;
+  }
+
+  public double totalWeightSum() {
+    return totalWeightSum;
+  }
+  
+  public int numLabels() {
+    return weightsPerLabel.size();
+  }
+
+  public Vector createScoringVector() {
+    return weightsPerLabel.like();
+  }
+  
+  public boolean isComplemtary(){
+      return isComplementary;
+  }
+  
+  public static NaiveBayesModel materialize(Path output, Configuration conf) 
throws IOException {
+    FileSystem fs = output.getFileSystem(conf);
+
+    Vector weightsPerLabel;
+    Vector perLabelThetaNormalizer = null;
+    Vector weightsPerFeature;
+    Matrix weightsPerLabelAndFeature;
+    float alphaI;
+    boolean isComplementary;
+
+    try (FSDataInputStream in = fs.open(new Path(output, 
"naiveBayesModel.bin"))) {
+      alphaI = in.readFloat();
+      isComplementary = in.readBoolean();
+      weightsPerFeature = VectorWritable.readVector(in);
+      weightsPerLabel = new DenseVector(VectorWritable.readVector(in));
+      if (isComplementary){
+        perLabelThetaNormalizer = new 
DenseVector(VectorWritable.readVector(in));
+      }
+      weightsPerLabelAndFeature = new SparseRowMatrix(weightsPerLabel.size(), 
weightsPerFeature.size());
+      for (int label = 0; label < weightsPerLabelAndFeature.numRows(); 
label++) {
+        weightsPerLabelAndFeature.assignRow(label, 
VectorWritable.readVector(in));
+      }
+    }
+
+    NaiveBayesModel model = new NaiveBayesModel(weightsPerLabelAndFeature, 
weightsPerFeature, weightsPerLabel,
+        perLabelThetaNormalizer, alphaI, isComplementary);
+    model.validate();
+    return model;
+  }
+
+  public void serialize(Path output, Configuration conf) throws IOException {
+    FileSystem fs = output.getFileSystem(conf);
+    try (FSDataOutputStream out = fs.create(new Path(output, 
"naiveBayesModel.bin"))) {
+      out.writeFloat(alphaI);
+      out.writeBoolean(isComplementary);
+      VectorWritable.writeVector(out, weightsPerFeature);
+      VectorWritable.writeVector(out, weightsPerLabel); 
+      if (isComplementary){
+        VectorWritable.writeVector(out, perlabelThetaNormalizer);
+      }
+      for (int row = 0; row < weightsPerLabelAndFeature.numRows(); row++) {
+        VectorWritable.writeVector(out, 
weightsPerLabelAndFeature.viewRow(row));
+      }
+    }
+  }
+  
+  public void validate() {
+    Preconditions.checkState(alphaI > 0, "alphaI has to be greater than 0!");
+    Preconditions.checkArgument(numFeatures > 0, "the vocab count has to be 
greater than 0!");
+    Preconditions.checkArgument(totalWeightSum > 0, "the totalWeightSum has to 
be greater than 0!");
+    Preconditions.checkNotNull(weightsPerLabel, "the number of labels has to 
be defined!");
+    Preconditions.checkArgument(weightsPerLabel.getNumNondefaultElements() > 0,
+        "the number of labels has to be greater than 0!");
+    Preconditions.checkNotNull(weightsPerFeature, "the feature sums have to be 
defined");
+    Preconditions.checkArgument(weightsPerFeature.getNumNondefaultElements() > 
0,
+        "the feature sums have to be greater than 0!");
+    if (isComplementary){
+        Preconditions.checkArgument(perlabelThetaNormalizer != null, "the 
theta normalizers have to be defined");
+        
Preconditions.checkArgument(perlabelThetaNormalizer.getNumNondefaultElements() 
> 0,
+            "the number of theta normalizers has to be greater than 0!");    
+        
Preconditions.checkArgument(Math.signum(perlabelThetaNormalizer.minValue()) 
+                == Math.signum(perlabelThetaNormalizer.maxValue()), 
+           "Theta normalizers do not all have the same sign");            
+        
Preconditions.checkArgument(perlabelThetaNormalizer.getNumNonZeroElements() 
+                == perlabelThetaNormalizer.size(), 
+           "Theta normalizers can not have zero value.");
+    }
+    
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
----------------------------------------------------------------------
diff --git 
a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
 
b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
new file mode 100644
index 0000000..e4ce8aa
--- /dev/null
+++ 
b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
@@ -0,0 +1,40 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+
+/** Implementation of the Naive Bayes Classifier Algorithm */
+public class StandardNaiveBayesClassifier extends AbstractNaiveBayesClassifier 
{ 
+ 
+  public StandardNaiveBayesClassifier(NaiveBayesModel model) {
+    super(model);
+  }
+
+  @Override
+  public double getScoreForLabelFeature(int label, int feature) {
+    NaiveBayesModel model = getModel();
+    // Standard Naive Bayes does not use weight normalization
+    return computeWeight(model.weight(label, feature), 
model.labelWeight(label), model.alphaI(), model.numFeatures());
+  }
+
+  public static double computeWeight(double featureLabelWeight, double 
labelWeight, double alphaI, double numFeatures) {
+    double numerator = featureLabelWeight + alphaI;
+    double denominator = labelWeight + alphaI * numFeatures;
+    return Math.log(numerator / denominator);
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java
----------------------------------------------------------------------
diff --git 
a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java
 
b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java
new file mode 100644
index 0000000..37a3b71
--- /dev/null
+++ 
b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java
@@ -0,0 +1,76 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.test;
+
+import com.google.common.base.Preconditions;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.classifier.naivebayes.AbstractNaiveBayesClassifier;
+import 
org.apache.mahout.classifier.naivebayes.ComplementaryNaiveBayesClassifier;
+import org.apache.mahout.classifier.naivebayes.NaiveBayesModel;
+import org.apache.mahout.classifier.naivebayes.StandardNaiveBayesClassifier;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.IOException;
+import java.util.regex.Pattern;
+
+/**
+ * Run the input through the model and see if it matches.
+ * <p/>
+ * The output value is the generated label, the Pair is the expected label and 
true if they match:
+ */
+public class BayesTestMapper extends Mapper<Text, VectorWritable, Text, 
VectorWritable> {
+
+  private static final Pattern SLASH = Pattern.compile("/");
+
+  private AbstractNaiveBayesClassifier classifier;
+
+  @Override
+  protected void setup(Context context) throws IOException, 
InterruptedException {
+    super.setup(context);
+    Configuration conf = context.getConfiguration();
+    Path modelPath = HadoopUtil.getSingleCachedFile(conf);
+    NaiveBayesModel model = NaiveBayesModel.materialize(modelPath, conf);
+    boolean isComplementary = 
Boolean.parseBoolean(conf.get(TestNaiveBayesDriver.COMPLEMENTARY));
+    
+    // ensure that if we are testing in complementary mode, the model has been
+    // trained complementary. a complementarty model will work for standard 
classification
+    // a standard model will not work for complementary classification
+    if (isComplementary) {
+      Preconditions.checkArgument((model.isComplemtary()),
+          "Complementary mode in model is different than test mode");
+    }
+    
+    if (isComplementary) {
+      classifier = new ComplementaryNaiveBayesClassifier(model);
+    } else {
+      classifier = new StandardNaiveBayesClassifier(model);
+    }
+  }
+
+  @Override
+  protected void map(Text key, VectorWritable value, Context context) throws 
IOException, InterruptedException {
+    Vector result = classifier.classifyFull(value.get());
+    //the key is the expected value
+    context.write(new Text(SLASH.split(key.toString())[1]), new 
VectorWritable(result));
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java
----------------------------------------------------------------------
diff --git 
a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java
 
b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java
new file mode 100644
index 0000000..d9eedcf
--- /dev/null
+++ 
b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java
@@ -0,0 +1,176 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.test;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.regex.Pattern;
+
+import com.google.common.base.Preconditions;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.classifier.ClassifierResult;
+import org.apache.mahout.classifier.ResultAnalyzer;
+import org.apache.mahout.classifier.naivebayes.AbstractNaiveBayesClassifier;
+import org.apache.mahout.classifier.naivebayes.BayesUtils;
+import 
org.apache.mahout.classifier.naivebayes.ComplementaryNaiveBayesClassifier;
+import org.apache.mahout.classifier.naivebayes.NaiveBayesModel;
+import org.apache.mahout.classifier.naivebayes.StandardNaiveBayesClassifier;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Test the (Complementary) Naive Bayes model that was built during training
+ * by running the iterating the test set and comparing it to the model
+ */
+public class TestNaiveBayesDriver extends AbstractJob {
+
+  private static final Logger log = 
LoggerFactory.getLogger(TestNaiveBayesDriver.class);
+
+  public static final String COMPLEMENTARY = "class"; //b for bayes, c for 
complementary
+  private static final Pattern SLASH = Pattern.compile("/");
+
+  public static void main(String[] args) throws Exception {
+    ToolRunner.run(new Configuration(), new TestNaiveBayesDriver(), args);
+  }
+
+  @Override
+  public int run(String[] args) throws Exception {
+    addInputOption();
+    addOutputOption();
+    addOption(addOption(DefaultOptionCreator.overwriteOption().create()));
+    addOption("model", "m", "The path to the model built during training", 
true);
+    addOption(buildOption("testComplementary", "c", "test complementary?", 
false, false, String.valueOf(false)));
+    addOption(buildOption("runSequential", "seq", "run sequential?", false, 
false, String.valueOf(false)));
+    addOption("labelIndex", "l", "The path to the location of the label 
index", true);
+    Map<String, List<String>> parsedArgs = parseArguments(args);
+    if (parsedArgs == null) {
+      return -1;
+    }
+    if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
+      HadoopUtil.delete(getConf(), getOutputPath());
+    }
+
+    boolean sequential = hasOption("runSequential");
+    boolean succeeded;
+    if (sequential) {
+       runSequential();
+    } else {
+      succeeded = runMapReduce();
+      if (!succeeded) {
+        return -1;
+      }
+    }
+
+    //load the labels
+    Map<Integer, String> labelMap = BayesUtils.readLabelIndex(getConf(), new 
Path(getOption("labelIndex")));
+
+    //loop over the results and create the confusion matrix
+    SequenceFileDirIterable<Text, VectorWritable> dirIterable =
+        new SequenceFileDirIterable<>(getOutputPath(), PathType.LIST, 
PathFilters.partFilter(), getConf());
+    ResultAnalyzer analyzer = new ResultAnalyzer(labelMap.values(), "DEFAULT");
+    analyzeResults(labelMap, dirIterable, analyzer);
+
+    log.info("{} Results: {}", hasOption("testComplementary") ? 
"Complementary" : "Standard NB", analyzer);
+    return 0;
+  }
+
+  private void runSequential() throws IOException {
+    boolean complementary = hasOption("testComplementary");
+    FileSystem fs = FileSystem.get(getConf());
+    NaiveBayesModel model = NaiveBayesModel.materialize(new 
Path(getOption("model")), getConf());
+    
+    // Ensure that if we are testing in complementary mode, the model has been
+    // trained complementary. a complementarty model will work for standard 
classification
+    // a standard model will not work for complementary classification
+    if (complementary){
+        Preconditions.checkArgument((model.isComplemtary()),
+            "Complementary mode in model is different from test mode");
+    }
+    
+    AbstractNaiveBayesClassifier classifier;
+    if (complementary) {
+      classifier = new ComplementaryNaiveBayesClassifier(model);
+    } else {
+      classifier = new StandardNaiveBayesClassifier(model);
+    }
+
+    try (SequenceFile.Writer writer =
+             SequenceFile.createWriter(fs, getConf(), new 
Path(getOutputPath(), "part-r-00000"),
+                 Text.class, VectorWritable.class)) {
+      SequenceFileDirIterable<Text, VectorWritable> dirIterable =
+          new SequenceFileDirIterable<>(getInputPath(), PathType.LIST, 
PathFilters.partFilter(), getConf());
+      // loop through the part-r-* files in getInputPath() and get 
classification scores for all entries
+      for (Pair<Text, VectorWritable> pair : dirIterable) {
+        writer.append(new Text(SLASH.split(pair.getFirst().toString())[1]),
+            new 
VectorWritable(classifier.classifyFull(pair.getSecond().get())));
+      }
+    }
+  }
+
+  private boolean runMapReduce() throws IOException,
+      InterruptedException, ClassNotFoundException {
+    Path model = new Path(getOption("model"));
+    HadoopUtil.cacheFiles(model, getConf());
+    //the output key is the expected value, the output value are the scores 
for all the labels
+    Job testJob = prepareJob(getInputPath(), getOutputPath(), 
SequenceFileInputFormat.class, BayesTestMapper.class,
+        Text.class, VectorWritable.class, SequenceFileOutputFormat.class);
+    //testJob.getConfiguration().set(LABEL_KEY, getOption("--labels"));
+
+
+    boolean complementary = hasOption("testComplementary");
+    testJob.getConfiguration().set(COMPLEMENTARY, 
String.valueOf(complementary));
+    return testJob.waitForCompletion(true);
+  }
+
+  private static void analyzeResults(Map<Integer, String> labelMap,
+                                     SequenceFileDirIterable<Text, 
VectorWritable> dirIterable,
+                                     ResultAnalyzer analyzer) {
+    for (Pair<Text, VectorWritable> pair : dirIterable) {
+      int bestIdx = Integer.MIN_VALUE;
+      double bestScore = Long.MIN_VALUE;
+      for (Vector.Element element : pair.getSecond().get().all()) {
+        if (element.get() > bestScore) {
+          bestScore = element.get();
+          bestIdx = element.index();
+        }
+      }
+      if (bestIdx != Integer.MIN_VALUE) {
+        ClassifierResult classifierResult = new 
ClassifierResult(labelMap.get(bestIdx), bestScore);
+        analyzer.addInstance(pair.getFirst().toString(), classifierResult);
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java
----------------------------------------------------------------------
diff --git 
a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java
 
b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java
new file mode 100644
index 0000000..2b8ee1e
--- /dev/null
+++ 
b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java
@@ -0,0 +1,83 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.training;
+
+import com.google.common.base.Preconditions;
+import 
org.apache.mahout.classifier.naivebayes.ComplementaryNaiveBayesClassifier;
+import org.apache.mahout.math.Vector;
+
+public class ComplementaryThetaTrainer {
+
+  private final Vector weightsPerFeature;
+  private final Vector weightsPerLabel;
+  private final Vector perLabelThetaNormalizer;
+  private final double alphaI;
+  private final double totalWeightSum;
+  private final double numFeatures;
+
+  public ComplementaryThetaTrainer(Vector weightsPerFeature, Vector 
weightsPerLabel, double alphaI) {
+    Preconditions.checkNotNull(weightsPerFeature);
+    Preconditions.checkNotNull(weightsPerLabel);
+    this.weightsPerFeature = weightsPerFeature;
+    this.weightsPerLabel = weightsPerLabel;
+    this.alphaI = alphaI;
+    perLabelThetaNormalizer = weightsPerLabel.like();
+    totalWeightSum = weightsPerLabel.zSum();
+    numFeatures = weightsPerFeature.getNumNondefaultElements();
+  }
+
+  public void train(int label, Vector perLabelWeight) {
+    double labelWeight = labelWeight(label);
+    // sum weights for each label including those with zero word counts
+    for(int i = 0; i < perLabelWeight.size(); i++){
+      Vector.Element perLabelWeightElement = perLabelWeight.getElement(i);
+      updatePerLabelThetaNormalizer(label,
+          
ComplementaryNaiveBayesClassifier.computeWeight(featureWeight(perLabelWeightElement.index()),
+              perLabelWeightElement.get(), totalWeightSum(), labelWeight, 
alphaI(), numFeatures()));
+    }
+  }
+
+  protected double alphaI() {
+    return alphaI;
+  }
+
+  protected double numFeatures() {
+    return numFeatures;
+  }
+
+  protected double labelWeight(int label) {
+    return weightsPerLabel.get(label);
+  }
+
+  protected double totalWeightSum() {
+    return totalWeightSum;
+  }
+
+  protected double featureWeight(int feature) {
+    return weightsPerFeature.get(feature);
+  }
+
+  // http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - Section 3.2, 
Weight Magnitude Errors
+  protected void updatePerLabelThetaNormalizer(int label, double weight) {
+    perLabelThetaNormalizer.set(label, perLabelThetaNormalizer.get(label) + 
Math.abs(weight));
+  }
+
+  public Vector retrievePerLabelThetaNormalizer() {
+    return perLabelThetaNormalizer.clone();
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java
----------------------------------------------------------------------
diff --git 
a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java
 
b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java
new file mode 100644
index 0000000..4df869e
--- /dev/null
+++ 
b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java
@@ -0,0 +1,53 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.training;
+
+import java.io.IOException;
+import java.util.regex.Pattern;
+
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.classifier.naivebayes.BayesUtils;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+
+public class IndexInstancesMapper extends Mapper<Text, VectorWritable, 
IntWritable, VectorWritable> {
+
+  private static final Pattern SLASH = Pattern.compile("/");
+
+  enum Counter { SKIPPED_INSTANCES }
+
+  private OpenObjectIntHashMap<String> labelIndex;
+
+  @Override
+  protected void setup(Context ctx) throws IOException, InterruptedException {
+    super.setup(ctx);
+    labelIndex = BayesUtils.readIndexFromCache(ctx.getConfiguration());
+  }
+
+  @Override
+  protected void map(Text labelText, VectorWritable instance, Context ctx) 
throws IOException, InterruptedException {
+    String label = SLASH.split(labelText.toString())[1];
+    if (labelIndex.containsKey(label)) {
+      ctx.write(new IntWritable(labelIndex.get(label)), instance);
+    } else {
+      ctx.getCounter(Counter.SKIPPED_INSTANCES).increment(1);
+    }
+  }
+}

Reply via email to