http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java
----------------------------------------------------------------------
diff --git 
a/mr/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java 
b/mr/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java
new file mode 100644
index 0000000..d02d974
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java
@@ -0,0 +1,296 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.tools;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.JobContext;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.classifier.df.DFUtils;
+import org.apache.mahout.classifier.df.data.DataConverter;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Instance;
+import org.apache.mahout.classifier.df.mapreduce.Builder;
+import org.apache.mahout.common.HadoopUtil;
+import 
org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.net.URI;
+import java.util.Arrays;
+
+/**
+ * Temporary class used to compute the frequency distribution of the "class 
attribute".<br>
+ * This class can be used when the criterion variable is the categorical 
attribute.
+ */
+public class FrequenciesJob {
+  
+  private static final Logger log = 
LoggerFactory.getLogger(FrequenciesJob.class);
+  
+  /** directory that will hold this job's output */
+  private final Path outputPath;
+  
+  /** file that contains the serialized dataset */
+  private final Path datasetPath;
+  
+  /** directory that contains the data used in the first step */
+  private final Path dataPath;
+  
+  /**
+   * @param base
+   *          base directory
+   * @param dataPath
+   *          data used in the first step
+   */
+  public FrequenciesJob(Path base, Path dataPath, Path datasetPath) {
+    this.outputPath = new Path(base, "frequencies.output");
+    this.dataPath = dataPath;
+    this.datasetPath = datasetPath;
+  }
+  
+  /**
+   * @return counts[partition][label] = num tuples from 'partition' with class 
== label
+   */
+  public int[][] run(Configuration conf) throws IOException, 
ClassNotFoundException, InterruptedException {
+    
+    // check the output
+    FileSystem fs = outputPath.getFileSystem(conf);
+    if (fs.exists(outputPath)) {
+      throw new IOException("Output path already exists : " + outputPath);
+    }
+    
+    // put the dataset into the DistributedCache
+    URI[] files = {datasetPath.toUri()};
+    DistributedCache.setCacheFiles(files, conf);
+    
+    Job job = new Job(conf);
+    job.setJarByClass(FrequenciesJob.class);
+    
+    FileInputFormat.setInputPaths(job, dataPath);
+    FileOutputFormat.setOutputPath(job, outputPath);
+    
+    job.setMapOutputKeyClass(LongWritable.class);
+    job.setMapOutputValueClass(IntWritable.class);
+    job.setOutputKeyClass(LongWritable.class);
+    job.setOutputValueClass(Frequencies.class);
+    
+    job.setMapperClass(FrequenciesMapper.class);
+    job.setReducerClass(FrequenciesReducer.class);
+    
+    job.setInputFormatClass(TextInputFormat.class);
+    job.setOutputFormatClass(SequenceFileOutputFormat.class);
+    
+    // run the job
+    boolean succeeded = job.waitForCompletion(true);
+    if (!succeeded) {
+      throw new IllegalStateException("Job failed!");
+    }
+    
+    int[][] counts = parseOutput(job);
+
+    HadoopUtil.delete(conf, outputPath);
+    
+    return counts;
+  }
+  
+  /**
+   * Extracts the output and processes it
+   * 
+   * @return counts[partition][label] = num tuples from 'partition' with class 
== label
+   */
+  int[][] parseOutput(JobContext job) throws IOException {
+    Configuration conf = job.getConfiguration();
+    
+    int numMaps = conf.getInt("mapred.map.tasks", -1);
+    log.info("mapred.map.tasks = {}", numMaps);
+    
+    FileSystem fs = outputPath.getFileSystem(conf);
+    
+    Path[] outfiles = DFUtils.listOutputFiles(fs, outputPath);
+    
+    Frequencies[] values = new Frequencies[numMaps];
+    
+    // read all the outputs
+    int index = 0;
+    for (Path path : outfiles) {
+      for (Frequencies value : new 
SequenceFileValueIterable<Frequencies>(path, conf)) {
+        values[index++] = value;
+      }
+    }
+    
+    if (index < numMaps) {
+      throw new IllegalStateException("number of output Frequencies (" + index
+          + ") is lesser than the number of mappers!");
+    }
+    
+    // sort the frequencies using the firstIds
+    Arrays.sort(values);
+    return Frequencies.extractCounts(values);
+  }
+  
+  /**
+   * Outputs the first key and the label of each tuple
+   * 
+   */
+  private static class FrequenciesMapper extends 
Mapper<LongWritable,Text,LongWritable,IntWritable> {
+    
+    private LongWritable firstId;
+    
+    private DataConverter converter;
+    private Dataset dataset;
+    
+    @Override
+    protected void setup(Context context) throws IOException, 
InterruptedException {
+      Configuration conf = context.getConfiguration();
+      
+      dataset = Builder.loadDataset(conf);
+      setup(dataset);
+    }
+    
+    /**
+     * Useful when testing
+     */
+    void setup(Dataset dataset) {
+      converter = new DataConverter(dataset);
+    }
+    
+    @Override
+    protected void map(LongWritable key, Text value, Context context) throws 
IOException,
+                                                                     
InterruptedException {
+      if (firstId == null) {
+        firstId = new LongWritable(key.get());
+      }
+      
+      Instance instance = converter.convert(value.toString());
+      
+      context.write(firstId, new IntWritable((int) 
dataset.getLabel(instance)));
+    }
+    
+  }
+  
+  private static class FrequenciesReducer extends 
Reducer<LongWritable,IntWritable,LongWritable,Frequencies> {
+    
+    private int nblabels;
+    
+    @Override
+    protected void setup(Context context) throws IOException, 
InterruptedException {
+      Configuration conf = context.getConfiguration();
+      Dataset dataset = Builder.loadDataset(conf);
+      setup(dataset.nblabels());
+    }
+    
+    /**
+     * Useful when testing
+     */
+    void setup(int nblabels) {
+      this.nblabels = nblabels;
+    }
+    
+    @Override
+    protected void reduce(LongWritable key, Iterable<IntWritable> values, 
Context context)
+      throws IOException, InterruptedException {
+      int[] counts = new int[nblabels];
+      for (IntWritable value : values) {
+        counts[value.get()]++;
+      }
+      
+      context.write(key, new Frequencies(key.get(), counts));
+    }
+  }
+  
+  /**
+   * Output of the job
+   * 
+   */
+  private static class Frequencies implements Writable, 
Comparable<Frequencies>, Cloneable {
+    
+    /** first key of the partition used to sort the partitions */
+    private long firstId;
+    
+    /** counts[c] = num tuples from the partition with label == c */
+    private int[] counts;
+    
+    Frequencies() { }
+    
+    Frequencies(long firstId, int[] counts) {
+      this.firstId = firstId;
+      this.counts = Arrays.copyOf(counts, counts.length);
+    }
+    
+    @Override
+    public void readFields(DataInput in) throws IOException {
+      firstId = in.readLong();
+      counts = DFUtils.readIntArray(in);
+    }
+    
+    @Override
+    public void write(DataOutput out) throws IOException {
+      out.writeLong(firstId);
+      DFUtils.writeArray(out, counts);
+    }
+    
+    @Override
+    public boolean equals(Object other) {
+      return other instanceof Frequencies && firstId == ((Frequencies) 
other).firstId;
+    }
+    
+    @Override
+    public int hashCode() {
+      return (int) firstId;
+    }
+    
+    @Override
+    protected Frequencies clone() {
+      return new Frequencies(firstId, counts);
+    }
+    
+    @Override
+    public int compareTo(Frequencies obj) {
+      if (firstId < obj.firstId) {
+        return -1;
+      } else if (firstId > obj.firstId) {
+        return 1;
+      } else {
+        return 0;
+      }
+    }
+    
+    public static int[][] extractCounts(Frequencies[] partitions) {
+      int[][] counts = new int[partitions.length][];
+      for (int p = 0; p < partitions.length; p++) {
+        counts[p] = partitions[p].counts;
+      }
+      return counts;
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java
----------------------------------------------------------------------
diff --git 
a/mr/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java 
b/mr/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java
new file mode 100644
index 0000000..d82b383
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java
@@ -0,0 +1,263 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.tools;
+
+import java.lang.reflect.Field;
+import java.text.DecimalFormat;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Instance;
+import org.apache.mahout.classifier.df.node.CategoricalNode;
+import org.apache.mahout.classifier.df.node.Leaf;
+import org.apache.mahout.classifier.df.node.Node;
+import org.apache.mahout.classifier.df.node.NumericalNode;
+
+/**
+ * This tool is to visualize the Decision tree
+ */
+public final class TreeVisualizer {
+  
+  private TreeVisualizer() {}
+  
+  private static String doubleToString(double value) {
+    DecimalFormat df = new DecimalFormat("0.##");
+    return df.format(value);
+  }
+  
+  private static String toStringNode(Node node, Dataset dataset,
+      String[] attrNames, Map<String,Field> fields, int layer) {
+    
+    StringBuilder buff = new StringBuilder();
+    
+    try {
+      if (node instanceof CategoricalNode) {
+        CategoricalNode cnode = (CategoricalNode) node;
+        int attr = (Integer) fields.get("CategoricalNode.attr").get(cnode);
+        double[] values = (double[]) 
fields.get("CategoricalNode.values").get(cnode);
+        Node[] childs = (Node[]) 
fields.get("CategoricalNode.childs").get(cnode);
+        String[][] attrValues = (String[][]) 
fields.get("Dataset.values").get(dataset);
+        for (int i = 0; i < attrValues[attr].length; i++) {
+          int index = ArrayUtils.indexOf(values, i);
+          if (index < 0) {
+            continue;
+          }
+          buff.append('\n');
+          for (int j = 0; j < layer; j++) {
+            buff.append("|   ");
+          }
+          buff.append(attrNames == null ? attr : attrNames[attr]).append(" = ")
+              .append(attrValues[attr][i]);
+          buff.append(toStringNode(childs[index], dataset, attrNames, fields, 
layer + 1));
+        }
+      } else if (node instanceof NumericalNode) {
+        NumericalNode nnode = (NumericalNode) node;
+        int attr = (Integer) fields.get("NumericalNode.attr").get(nnode);
+        double split = (Double) fields.get("NumericalNode.split").get(nnode);
+        Node loChild = (Node) fields.get("NumericalNode.loChild").get(nnode);
+        Node hiChild = (Node) fields.get("NumericalNode.hiChild").get(nnode);
+        buff.append('\n');
+        for (int j = 0; j < layer; j++) {
+          buff.append("|   ");
+        }
+        buff.append(attrNames == null ? attr : attrNames[attr]).append(" < ")
+            .append(doubleToString(split));
+        buff.append(toStringNode(loChild, dataset, attrNames, fields, layer + 
1));
+        buff.append('\n');
+        for (int j = 0; j < layer; j++) {
+          buff.append("|   ");
+        }
+        buff.append(attrNames == null ? attr : attrNames[attr]).append(" >= ")
+            .append(doubleToString(split));
+        buff.append(toStringNode(hiChild, dataset, attrNames, fields, layer + 
1));
+      } else if (node instanceof Leaf) {
+        Leaf leaf = (Leaf) node;
+        double label = (Double) fields.get("Leaf.label").get(leaf);
+        if (dataset.isNumerical(dataset.getLabelId())) {
+          buff.append(" : ").append(doubleToString(label));
+        } else {
+          buff.append(" : ").append(dataset.getLabelString(label));
+        }
+      }
+    } catch (IllegalAccessException iae) {
+      throw new IllegalStateException(iae);
+    }
+    
+    return buff.toString();
+  }
+  
+  private static Map<String,Field> getReflectMap() {
+    Map<String,Field> fields = new HashMap<String,Field>();
+    
+    try {
+      Field m = CategoricalNode.class.getDeclaredField("attr");
+      m.setAccessible(true);
+      fields.put("CategoricalNode.attr", m);
+      m = CategoricalNode.class.getDeclaredField("values");
+      m.setAccessible(true);
+      fields.put("CategoricalNode.values", m);
+      m = CategoricalNode.class.getDeclaredField("childs");
+      m.setAccessible(true);
+      fields.put("CategoricalNode.childs", m);
+      m = NumericalNode.class.getDeclaredField("attr");
+      m.setAccessible(true);
+      fields.put("NumericalNode.attr", m);
+      m = NumericalNode.class.getDeclaredField("split");
+      m.setAccessible(true);
+      fields.put("NumericalNode.split", m);
+      m = NumericalNode.class.getDeclaredField("loChild");
+      m.setAccessible(true);
+      fields.put("NumericalNode.loChild", m);
+      m = NumericalNode.class.getDeclaredField("hiChild");
+      m.setAccessible(true);
+      fields.put("NumericalNode.hiChild", m);
+      m = Leaf.class.getDeclaredField("label");
+      m.setAccessible(true);
+      fields.put("Leaf.label", m);
+      m = Dataset.class.getDeclaredField("values");
+      m.setAccessible(true);
+      fields.put("Dataset.values", m);
+    } catch (NoSuchFieldException nsfe) {
+      throw new IllegalStateException(nsfe);
+    }
+    
+    return fields;
+  }
+  
+  /**
+   * Decision tree to String
+   * 
+   * @param tree
+   *          Node of tree
+   * @param attrNames
+   *          attribute names
+   */
+  public static String toString(Node tree, Dataset dataset, String[] 
attrNames) {
+    return toStringNode(tree, dataset, attrNames, getReflectMap(), 0);
+  }
+  
+  /**
+   * Print Decision tree
+   * 
+   * @param tree
+   *          Node of tree
+   * @param attrNames
+   *          attribute names
+   */
+  public static void print(Node tree, Dataset dataset, String[] attrNames) {
+    System.out.println(toString(tree, dataset, attrNames));
+  }
+  
+  private static String toStringPredict(Node node, Instance instance,
+      Dataset dataset, String[] attrNames, Map<String,Field> fields) {
+    StringBuilder buff = new StringBuilder();
+    
+    try {
+      if (node instanceof CategoricalNode) {
+        CategoricalNode cnode = (CategoricalNode) node;
+        int attr = (Integer) fields.get("CategoricalNode.attr").get(cnode);
+        double[] values = (double[]) fields.get("CategoricalNode.values").get(
+            cnode);
+        Node[] childs = (Node[]) fields.get("CategoricalNode.childs")
+            .get(cnode);
+        String[][] attrValues = (String[][]) fields.get("Dataset.values").get(
+            dataset);
+        
+        int index = ArrayUtils.indexOf(values, instance.get(attr));
+        if (index >= 0) {
+          buff.append(attrNames == null ? attr : attrNames[attr]).append(" = ")
+              .append(attrValues[attr][(int) instance.get(attr)]);
+          buff.append(" -> ");
+          buff.append(toStringPredict(childs[index], instance, dataset,
+              attrNames, fields));
+        }
+      } else if (node instanceof NumericalNode) {
+        NumericalNode nnode = (NumericalNode) node;
+        int attr = (Integer) fields.get("NumericalNode.attr").get(nnode);
+        double split = (Double) fields.get("NumericalNode.split").get(nnode);
+        Node loChild = (Node) fields.get("NumericalNode.loChild").get(nnode);
+        Node hiChild = (Node) fields.get("NumericalNode.hiChild").get(nnode);
+        
+        if (instance.get(attr) < split) {
+          buff.append('(').append(attrNames == null ? attr : attrNames[attr])
+              .append(" = ").append(doubleToString(instance.get(attr)))
+              .append(") < ").append(doubleToString(split));
+          buff.append(" -> ");
+          buff.append(toStringPredict(loChild, instance, dataset, attrNames,
+              fields));
+        } else {
+          buff.append('(').append(attrNames == null ? attr : attrNames[attr])
+              .append(" = ").append(doubleToString(instance.get(attr)))
+              .append(") >= ").append(doubleToString(split));
+          buff.append(" -> ");
+          buff.append(toStringPredict(hiChild, instance, dataset, attrNames,
+              fields));
+        }
+      } else if (node instanceof Leaf) {
+        Leaf leaf = (Leaf) node;
+        double label = (Double) fields.get("Leaf.label").get(leaf);
+        if (dataset.isNumerical(dataset.getLabelId())) {
+          buff.append(doubleToString(label));
+        } else {
+          buff.append(dataset.getLabelString(label));
+        }
+      }
+    } catch (IllegalAccessException iae) {
+      throw new IllegalStateException(iae);
+    }
+    
+    return buff.toString();
+  }
+  
+  /**
+   * Predict trace to String
+   * 
+   * @param tree
+   *          Node of tree
+   * @param attrNames
+   *          attribute names
+   */
+  public static String[] predictTrace(Node tree, Data data, String[] 
attrNames) {
+    Map<String,Field> reflectMap = getReflectMap();
+    String[] prediction = new String[data.size()];
+    for (int i = 0; i < data.size(); i++) {
+      prediction[i] = toStringPredict(tree, data.get(i), data.getDataset(),
+          attrNames, reflectMap);
+    }
+    return prediction;
+  }
+  
+  /**
+   * Print predict trace
+   * 
+   * @param tree
+   *          Node of tree
+   * @param attrNames
+   *          attribute names
+   */
+  public static void predictTracePrint(Node tree, Data data, String[] 
attrNames) {
+    Map<String,Field> reflectMap = getReflectMap();
+    for (int i = 0; i < data.size(); i++) {
+      System.out.println(toStringPredict(tree, data.get(i), data.getDataset(),
+          attrNames, reflectMap));
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java
----------------------------------------------------------------------
diff --git 
a/mr/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java 
b/mr/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java
new file mode 100644
index 0000000..06876e1
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java
@@ -0,0 +1,211 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.tools;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.Locale;
+import java.util.Random;
+import java.util.Scanner;
+
+import com.google.common.base.Preconditions;
+import com.google.common.io.Closeables;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.FileUtil;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.classifier.df.data.DataConverter;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Instance;
+import org.apache.mahout.common.CommandLineUtil;
+import org.apache.mahout.common.RandomUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * This tool is used to uniformly distribute the class of all the tuples of 
the dataset over a given number of
+ * partitions.<br>
+ * This class can be used when the criterion variable is the categorical 
attribute.
+ */
+public final class UDistrib {
+  
+  private static final Logger log = LoggerFactory.getLogger(UDistrib.class);
+  
+  private UDistrib() {}
+  
+  /**
+   * Launch the uniform distribution tool. Requires the following command line 
arguments:<br>
+   * 
+   * data : data path dataset : dataset path numpartitions : num partitions 
output : output path
+   *
+   * @throws java.io.IOException
+   */
+  public static void main(String[] args) throws IOException {
+    
+    DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+    ArgumentBuilder abuilder = new ArgumentBuilder();
+    GroupBuilder gbuilder = new GroupBuilder();
+    
+    Option dataOpt = 
obuilder.withLongName("data").withShortName("d").withRequired(true).withArgument(
+      
abuilder.withName("data").withMinimum(1).withMaximum(1).create()).withDescription("Data
 path").create();
+    
+    Option datasetOpt = 
obuilder.withLongName("dataset").withShortName("ds").withRequired(true).withArgument(
+      
abuilder.withName("dataset").withMinimum(1).create()).withDescription("Dataset 
path").create();
+    
+    Option outputOpt = 
obuilder.withLongName("output").withShortName("o").withRequired(true).withArgument(
+      
abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription(
+      "Path to generated files").create();
+    
+    Option partitionsOpt = 
obuilder.withLongName("numpartitions").withShortName("p").withRequired(true)
+        
.withArgument(abuilder.withName("numparts").withMinimum(1).withMinimum(1).create()).withDescription(
+          "Number of partitions to create").create();
+    Option helpOpt = obuilder.withLongName("help").withDescription("Print out 
help").withShortName("h")
+        .create();
+    
+    Group group = 
gbuilder.withName("Options").withOption(dataOpt).withOption(outputOpt).withOption(
+      datasetOpt).withOption(partitionsOpt).withOption(helpOpt).create();
+    
+    try {
+      Parser parser = new Parser();
+      parser.setGroup(group);
+      CommandLine cmdLine = parser.parse(args);
+      
+      if (cmdLine.hasOption(helpOpt)) {
+        CommandLineUtil.printHelp(group);
+        return;
+      }
+      
+      String data = cmdLine.getValue(dataOpt).toString();
+      String dataset = cmdLine.getValue(datasetOpt).toString();
+      int numPartitions = 
Integer.parseInt(cmdLine.getValue(partitionsOpt).toString());
+      String output = cmdLine.getValue(outputOpt).toString();
+      
+      runTool(data, dataset, output, numPartitions);
+    } catch (OptionException e) {
+      log.warn(e.toString(), e);
+      CommandLineUtil.printHelp(group);
+    }
+    
+  }
+  
+  private static void runTool(String dataStr, String datasetStr, String 
output, int numPartitions) throws IOException {
+
+    Preconditions.checkArgument(numPartitions > 0, "numPartitions <= 0");
+    
+    // make sure the output file does not exist
+    Path outputPath = new Path(output);
+    Configuration conf = new Configuration();
+    FileSystem fs = outputPath.getFileSystem(conf);
+
+    Preconditions.checkArgument(!fs.exists(outputPath), "Output path already 
exists");
+    
+    // create a new file corresponding to each partition
+    // Path workingDir = fs.getWorkingDirectory();
+    // FileSystem wfs = workingDir.getFileSystem(conf);
+    // File parentFile = new File(workingDir.toString());
+    // File tempFile = FileUtil.createLocalTempFile(parentFile, "Parts", true);
+    // File tempFile = File.createTempFile("df.tools.UDistrib","");
+    // tempFile.deleteOnExit();
+    File tempFile = FileUtil.createLocalTempFile(new File(""), 
"df.tools.UDistrib", true);
+    Path partsPath = new Path(tempFile.toString());
+    FileSystem pfs = partsPath.getFileSystem(conf);
+    
+    Path[] partPaths = new Path[numPartitions];
+    FSDataOutputStream[] files = new FSDataOutputStream[numPartitions];
+    for (int p = 0; p < numPartitions; p++) {
+      partPaths[p] = new Path(partsPath, String.format(Locale.ENGLISH, 
"part.%03d", p));
+      files[p] = pfs.create(partPaths[p]);
+    }
+    
+    Path datasetPath = new Path(datasetStr);
+    Dataset dataset = Dataset.load(conf, datasetPath);
+    
+    // currents[label] = next partition file where to place the tuple
+    int[] currents = new int[dataset.nblabels()];
+    
+    // currents is initialized randomly in the range [0, numpartitions[
+    Random random = RandomUtils.getRandom();
+    for (int c = 0; c < currents.length; c++) {
+      currents[c] = random.nextInt(numPartitions);
+    }
+    
+    // foreach tuple of the data
+    Path dataPath = new Path(dataStr);
+    FileSystem ifs = dataPath.getFileSystem(conf);
+    FSDataInputStream input = ifs.open(dataPath);
+    Scanner scanner = new Scanner(input, "UTF-8");
+    DataConverter converter = new DataConverter(dataset);
+    
+    int id = 0;
+    while (scanner.hasNextLine()) {
+      if (id % 1000 == 0) {
+        log.info("progress : {}", id);
+      }
+      
+      String line = scanner.nextLine();
+      if (line.isEmpty()) {
+        continue; // skip empty lines
+      }
+      
+      // write the tuple in files[tuple.label]
+      Instance instance = converter.convert(line);
+      int label = (int) dataset.getLabel(instance);
+      files[currents[label]].writeBytes(line);
+      files[currents[label]].writeChar('\n');
+      
+      // update currents
+      currents[label]++;
+      if (currents[label] == numPartitions) {
+        currents[label] = 0;
+      }
+    }
+    
+    // close all the files.
+    scanner.close();
+    for (FSDataOutputStream file : files) {
+      Closeables.close(file, false);
+    }
+    
+    // merge all output files
+    FileUtil.copyMerge(pfs, partsPath, fs, outputPath, true, conf, null);
+    /*
+     * FSDataOutputStream joined = fs.create(new Path(outputPath, 
"uniform.data")); for (int p = 0; p <
+     * numPartitions; p++) {log.info("Joining part : {}", p); 
FSDataInputStream partStream =
+     * fs.open(partPaths[p]);
+     * 
+     * IOUtils.copyBytes(partStream, joined, conf, false);
+     * 
+     * partStream.close(); }
+     * 
+     * joined.close();
+     * 
+     * fs.delete(partsPath, true);
+     */
+  }
+  
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java 
b/mr/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java
new file mode 100644
index 0000000..049f9bf
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.evaluation;
+
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.list.DoubleArrayList;
+
+import com.google.common.base.Preconditions;
+
+import java.util.Random;
+
+/**
+ * Computes AUC and a few other accuracy statistics without storing huge 
amounts of data.  This is
+ * done by keeping uniform samples of the positive and negative scores.  Then, 
when AUC is to be
+ * computed, the remaining scores are sorted and a rank-sum statistic is used 
to compute the AUC.
+ * Since AUC is invariant with respect to down-sampling of either positives or 
negatives, this is
+ * close to correct and is exactly correct if maxBufferSize or fewer positive 
and negative scores
+ * are examined.
+ */
+public class Auc {
+
+  private int maxBufferSize = 10000;
+  private final DoubleArrayList[] scores = {new DoubleArrayList(), new 
DoubleArrayList()};
+  private final Random rand;
+  private int samples;
+  private final double threshold;
+  private final Matrix confusion;
+  private final DenseMatrix entropy;
+
+  private boolean probabilityScore = true;
+
+  private boolean hasScore;
+
+  /**
+   * Allocates a new data-structure for accumulating information about AUC and 
a few other accuracy
+   * measures.
+   * @param threshold The threshold to use in computing the confusion matrix.
+   */
+  public Auc(double threshold) {
+    confusion = new DenseMatrix(2, 2);
+    entropy = new DenseMatrix(2, 2);
+    this.rand = RandomUtils.getRandom();
+    this.threshold = threshold;
+  }
+
+  public Auc() {
+    this(0.5);
+  }
+
+  /**
+   * Adds a score to the AUC buffers.
+   *
+   * @param trueValue Whether this score is for a true-positive or a 
true-negative example.
+   * @param score     The score for this example.
+   */
+  public void add(int trueValue, double score) {
+    Preconditions.checkArgument(trueValue == 0 || trueValue == 1, "True value 
must be 0 or 1");
+    hasScore = true;
+
+    int predictedClass = score > threshold ? 1 : 0;
+    confusion.set(trueValue, predictedClass, confusion.get(trueValue, 
predictedClass) + 1);
+
+    samples++;
+    if (isProbabilityScore()) {
+      double limited = Math.max(1.0e-20, Math.min(score, 1 - 1.0e-20));
+      double v0 = entropy.get(trueValue, 0);
+      entropy.set(trueValue, 0, (Math.log1p(-limited) - v0) / samples + v0);
+
+      double v1 = entropy.get(trueValue, 1);
+      entropy.set(trueValue, 1, (Math.log(limited) - v1) / samples + v1);
+    }
+
+    // add to buffers
+    DoubleArrayList buf = scores[trueValue];
+    if (buf.size() >= maxBufferSize) {
+      // but if too many points are seen, we insert into a random
+      // place and discard the predecessor.  The random place could
+      // be anywhere, possibly not even in the buffer.
+      // this is a special case of Knuth's permutation algorithm
+      // but since we don't ever shuffle the first maxBufferSize
+      // samples, the result isn't just a fair sample of the prefixes
+      // of all permutations.  The CONTENTs of the result, however,
+      // will be a fair and uniform sample of maxBufferSize elements
+      // chosen from all elements without replacement
+      int index = rand.nextInt(samples);
+      if (index < buf.size()) {
+        buf.set(index, score);
+      }
+    } else {
+      // for small buffers, we collect all points without permuting
+      // since we sort the data later, permuting now would just be
+      // pedantic
+      buf.add(score);
+    }
+  }
+
+  public void add(int trueValue, int predictedClass) {
+    hasScore = false;
+    Preconditions.checkArgument(trueValue == 0 || trueValue == 1, "True value 
must be 0 or 1");
+    confusion.set(trueValue, predictedClass, confusion.get(trueValue, 
predictedClass) + 1);
+  }
+
+  /**
+   * Computes the AUC of points seen so far.  This can be moderately expensive 
since it requires
+   * that all points that have been retained be sorted.
+   *
+   * @return The value of the Area Under the receiver operating Curve.
+   */
+  public double auc() {
+    Preconditions.checkArgument(hasScore, "Can't compute AUC for classifier 
without a score");
+    scores[0].sort();
+    scores[1].sort();
+
+    double n0 = scores[0].size();
+    double n1 = scores[1].size();
+
+    if (n0 == 0 || n1 == 0) {
+      return 0.5;
+    }
+
+    // scan the data
+    int i0 = 0;
+    int i1 = 0;
+    int rank = 1;
+    double rankSum = 0;
+    while (i0 < n0 && i1 < n1) {
+
+      double v0 = scores[0].get(i0);
+      double v1 = scores[1].get(i1);
+
+      if (v0 < v1) {
+        i0++;
+        rank++;
+      } else if (v1 < v0) {
+        i1++;
+        rankSum += rank;
+        rank++;
+      } else {
+        // ties have to be handled delicately
+        double tieScore = v0;
+
+        // how many negatives are tied?
+        int k0 = 0;
+        while (i0 < n0 && scores[0].get(i0) == tieScore) {
+          k0++;
+          i0++;
+        }
+
+        // and how many positives
+        int k1 = 0;
+        while (i1 < n1 && scores[1].get(i1) == tieScore) {
+          k1++;
+          i1++;
+        }
+
+        // we found k0 + k1 tied values which have
+        // ranks in the half open interval [rank, rank + k0 + k1)
+        // the average rank is assigned to all
+        rankSum += (rank + (k0 + k1 - 1) / 2.0) * k1;
+        rank += k0 + k1;
+      }
+    }
+
+    if (i1 < n1) {
+      rankSum += (rank + (n1 - i1 - 1) / 2.0) * (n1 - i1);
+      rank += (int) (n1 - i1);
+    }
+
+    return (rankSum / n1 - (n1 + 1) / 2) / n0;
+  }
+
+  /**
+   * Returns the confusion matrix for the classifier supposing that we were to 
use a particular
+   * threshold.
+   * @return The confusion matrix.
+   */
+  public Matrix confusion() {
+    return confusion;
+  }
+
+  /**
+   * Returns a matrix related to the confusion matrix and to the 
log-likelihood.  For a
+   * pretty accurate classifier, N + entropy is nearly the same as the 
confusion matrix
+   * because log(1-eps) \approx -eps if eps is small.
+   *
+   * For lower accuracy classifiers, this measure will give us a better 
picture of how
+   * things work our.
+   *
+   * Also, by definition, log-likelihood = sum(diag(entropy))
+   * @return Returns a cell by cell break-down of the log-likelihood
+   */
+  public Matrix entropy() {
+    if (!hasScore) {
+      // find a constant score that would optimize log-likelihood, but use a 
dash of Bayesian
+      // conservatism to avoid dividing by zero or taking log(0)
+      double p = (0.5 + confusion.get(1, 1)) / (1 + confusion.get(0, 0) + 
confusion.get(1, 1));
+      entropy.set(0, 0, confusion.get(0, 0) * Math.log1p(-p));
+      entropy.set(0, 1, confusion.get(0, 1) * Math.log(p));
+      entropy.set(1, 0, confusion.get(1, 0) * Math.log1p(-p));
+      entropy.set(1, 1, confusion.get(1, 1) * Math.log(p));
+    }
+    return entropy;
+  }
+
+  public void setMaxBufferSize(int maxBufferSize) {
+    this.maxBufferSize = maxBufferSize;
+  }
+
+  public boolean isProbabilityScore() {
+    return probabilityScore;
+  }
+
+  public void setProbabilityScore(boolean probabilityScore) {
+    this.probabilityScore = probabilityScore;
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/mlp/MultilayerPerceptron.java
----------------------------------------------------------------------
diff --git 
a/mr/src/main/java/org/apache/mahout/classifier/mlp/MultilayerPerceptron.java 
b/mr/src/main/java/org/apache/mahout/classifier/mlp/MultilayerPerceptron.java
new file mode 100644
index 0000000..d3e9ff3
--- /dev/null
+++ 
b/mr/src/main/java/org/apache/mahout/classifier/mlp/MultilayerPerceptron.java
@@ -0,0 +1,90 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.mlp;
+
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+
+import java.io.IOException;
+
+/**
+ * A Multilayer Perceptron (MLP) is a kind of feed-forward artificial neural
+ * network, which is a mathematical model inspired by the biological neural
+ * network. The Multilayer Perceptron can be used for various machine learning
+ * tasks such as classification and regression.
+ * 
+ * A detailed introduction about MLP can be found at
+ * http://ufldl.stanford.edu/wiki/index.php/Neural_Networks.
+ * 
+ * For this particular implementation, the users can freely control the 
topology
+ * of the MLP, including: 1. The size of the input layer; 2. The number of
+ * hidden layers; 3. The size of each hidden layer; 4. The size of the output
+ * layer. 5. The cost function. 6. The squashing function.
+ * 
+ * The model is trained in an online learning approach, where the weights of
+ * neurons in the MLP is updated incremented using backPropagation algorithm
+ * proposed by (Rumelhart, D. E., Hinton, G. E., and Williams, R. J. (1986)
+ * Learning representations by back-propagating errors. Nature, 323, 533--536.)
+ */
+public class MultilayerPerceptron extends NeuralNetwork implements 
OnlineLearner {
+
+  /**
+   * The default constructor.
+   */
+  public MultilayerPerceptron() {
+    super();
+  }
+
+  /**
+   * Initialize the MLP by specifying the location of the model.
+   * 
+   * @param modelPath The path of the model.
+   */
+  public MultilayerPerceptron(String modelPath) throws IOException {
+    super(modelPath);
+  }
+
+  @Override
+  public void train(int actual, Vector instance) {
+    // construct the training instance, where append the actual to instance
+    Vector trainingInstance = new DenseVector(instance.size() + 1);
+    for (int i = 0; i < instance.size(); ++i) {
+      trainingInstance.setQuick(i, instance.getQuick(i));
+    }
+    trainingInstance.setQuick(instance.size(), actual);
+    this.trainOnline(trainingInstance);
+  }
+
+  @Override
+  public void train(long trackingKey, String groupKey, int actual,
+      Vector instance) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public void train(long trackingKey, int actual, Vector instance) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public void close() {
+    // DO NOTHING
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetwork.java
----------------------------------------------------------------------
diff --git 
a/mr/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetwork.java 
b/mr/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetwork.java
new file mode 100644
index 0000000..cfbe5c4
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetwork.java
@@ -0,0 +1,743 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.mlp;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.WritableUtils;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.RandomWrapper;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixWritable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.DoubleDoubleFunction;
+import org.apache.mahout.math.function.DoubleFunction;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+
+/**
+ * AbstractNeuralNetwork defines the general operations for a neural network
+ * based model. Typically, all derivative models such as Multilayer Perceptron
+ * and Autoencoder consist of neurons and the weights between neurons.
+ */
+public abstract class NeuralNetwork {
+  
+  private static final Logger log = 
LoggerFactory.getLogger(NeuralNetwork.class);
+
+  /* The default learning rate */
+  public static final double DEFAULT_LEARNING_RATE = 0.5;
+  /* The default regularization weight */
+  public static final double DEFAULT_REGULARIZATION_WEIGHT = 0;
+  /* The default momentum weight */
+  public static final double DEFAULT_MOMENTUM_WEIGHT = 0.1;
+
+  public static enum TrainingMethod { GRADIENT_DESCENT }
+
+  /* The name of the model */
+  protected String modelType;
+
+  /* The path to store the model */
+  protected String modelPath;
+
+  protected double learningRate;
+
+  /* The weight of regularization */
+  protected double regularizationWeight;
+
+  /* The momentum weight */
+  protected double momentumWeight;
+
+  /* The cost function of the model */
+  protected String costFunctionName;
+
+  /* Record the size of each layer */
+  protected List<Integer> layerSizeList;
+
+  /* Training method used for training the model */
+  protected TrainingMethod trainingMethod;
+
+  /* Weights between neurons at adjacent layers */
+  protected List<Matrix> weightMatrixList;
+
+  /* Previous weight updates between neurons at adjacent layers */
+  protected List<Matrix> prevWeightUpdatesList;
+
+  /* Different layers can have different squashing function */
+  protected List<String> squashingFunctionList;
+
+  /* The index of the final layer */
+  protected int finalLayerIndex;
+
+  /**
+   * The default constructor that initializes the learning rate, regularization
+   * weight, and momentum weight by default.
+   */
+  public NeuralNetwork() {
+    log.info("Initialize model...");
+    learningRate = DEFAULT_LEARNING_RATE;
+    regularizationWeight = DEFAULT_REGULARIZATION_WEIGHT;
+    momentumWeight = DEFAULT_MOMENTUM_WEIGHT;
+    trainingMethod = TrainingMethod.GRADIENT_DESCENT;
+    costFunctionName = "Minus_Squared";
+    modelType = getClass().getSimpleName();
+
+    layerSizeList = Lists.newArrayList();
+    layerSizeList = Lists.newArrayList();
+    weightMatrixList = Lists.newArrayList();
+    prevWeightUpdatesList = Lists.newArrayList();
+    squashingFunctionList = Lists.newArrayList();
+  }
+
+  /**
+   * Initialize the NeuralNetwork by specifying learning rate, momentum weight
+   * and regularization weight.
+   * 
+   * @param learningRate The learning rate.
+   * @param momentumWeight The momentum weight.
+   * @param regularizationWeight The regularization weight.
+   */
+  public NeuralNetwork(double learningRate, double momentumWeight, double 
regularizationWeight) {
+    this();
+    setLearningRate(learningRate);
+    setMomentumWeight(momentumWeight);
+    setRegularizationWeight(regularizationWeight);
+  }
+
+  /**
+   * Initialize the NeuralNetwork by specifying the location of the model.
+   * 
+   * @param modelPath The location that the model is stored.
+   */
+  public NeuralNetwork(String modelPath) throws IOException {
+    this.modelPath = modelPath;
+    readFromModel();
+  }
+
+  /**
+   * Get the type of the model.
+   * 
+   * @return The name of the model.
+   */
+  public String getModelType() {
+    return this.modelType;
+  }
+
+  /**
+   * Set the degree of aggression during model training, a large learning rate
+   * can increase the training speed, but it also decreases the chance of model
+   * converge.
+   * 
+   * @param learningRate Learning rate must be a non-negative value. Recommend 
in range (0, 0.5).
+   * @return The model instance.
+   */
+  public final NeuralNetwork setLearningRate(double learningRate) {
+    Preconditions.checkArgument(learningRate > 0, "Learning rate must be 
larger than 0.");
+    this.learningRate = learningRate;
+    return this;
+  }
+
+  /**
+   * Get the value of learning rate.
+   * 
+   * @return The value of learning rate.
+   */
+  public double getLearningRate() {
+    return learningRate;
+  }
+
+  /**
+   * Set the regularization weight. More complex the model is, less weight the
+   * regularization is.
+   * 
+   * @param regularizationWeight regularization must be in the range [0, 0.1).
+   * @return The model instance.
+   */
+  public final NeuralNetwork setRegularizationWeight(double 
regularizationWeight) {
+    Preconditions.checkArgument(regularizationWeight >= 0
+        && regularizationWeight < 0.1, "Regularization weight must be in range 
[0, 0.1)");
+    this.regularizationWeight = regularizationWeight;
+    return this;
+  }
+
+  /**
+   * Get the weight of the regularization.
+   * 
+   * @return The weight of regularization.
+   */
+  public double getRegularizationWeight() {
+    return regularizationWeight;
+  }
+
+  /**
+   * Set the momentum weight for the model.
+   * 
+   * @param momentumWeight momentumWeight must be in range [0, 0.5].
+   * @return The model instance.
+   */
+  public final NeuralNetwork setMomentumWeight(double momentumWeight) {
+    Preconditions.checkArgument(momentumWeight >= 0 && momentumWeight <= 1.0,
+        "Momentum weight must be in range [0, 1.0]");
+    this.momentumWeight = momentumWeight;
+    return this;
+  }
+
+  /**
+   * Get the momentum weight.
+   * 
+   * @return The value of momentum.
+   */
+  public double getMomentumWeight() {
+    return momentumWeight;
+  }
+
+  /**
+   * Set the training method.
+   * 
+   * @param method The training method, currently supports GRADIENT_DESCENT.
+   * @return The instance of the model.
+   */
+  public NeuralNetwork setTrainingMethod(TrainingMethod method) {
+    this.trainingMethod = method;
+    return this;
+  }
+
+  /**
+   * Get the training method.
+   * 
+   * @return The training method enumeration.
+   */
+  public TrainingMethod getTrainingMethod() {
+    return trainingMethod;
+  }
+
+  /**
+   * Set the cost function for the model.
+   * 
+   * @param costFunction the name of the cost function. Currently supports
+   *          "Minus_Squared", "Cross_Entropy".
+   */
+  public NeuralNetwork setCostFunction(String costFunction) {
+    this.costFunctionName = costFunction;
+    return this;
+  }
+
+  /**
+   * Add a layer of neurons with specified size. If the added layer is not the
+   * first layer, it will automatically connect the neurons between with the
+   * previous layer.
+   * 
+   * @param size The size of the layer. (bias neuron excluded)
+   * @param isFinalLayer If false, add a bias neuron.
+   * @param squashingFunctionName The squashing function for this layer, input
+   *          layer is f(x) = x by default.
+   * @return The layer index, starts with 0.
+   */
+  public int addLayer(int size, boolean isFinalLayer, String 
squashingFunctionName) {
+    Preconditions.checkArgument(size > 0, "Size of layer must be larger than 
0.");
+    log.info("Add layer with size {} and squashing function {}", size, 
squashingFunctionName);
+    int actualSize = size;
+    if (!isFinalLayer) {
+      actualSize += 1;
+    }
+
+    layerSizeList.add(actualSize);
+    int layerIndex = layerSizeList.size() - 1;
+    if (isFinalLayer) {
+      finalLayerIndex = layerIndex;
+    }
+
+    // Add weights between current layer and previous layer, and input layer 
has no squashing function
+    if (layerIndex > 0) {
+      int sizePrevLayer = layerSizeList.get(layerIndex - 1);
+      // Row count equals to size of current size and column count equal to 
size of previous layer
+      int row = isFinalLayer ? actualSize : actualSize - 1;
+      Matrix weightMatrix = new DenseMatrix(row, sizePrevLayer);
+      // Initialize weights
+      final RandomWrapper rnd = RandomUtils.getRandom();
+      weightMatrix.assign(new DoubleFunction() {
+        @Override
+        public double apply(double value) {
+          return rnd.nextDouble() - 0.5;
+        }
+      });
+      weightMatrixList.add(weightMatrix);
+      prevWeightUpdatesList.add(new DenseMatrix(row, sizePrevLayer));
+      squashingFunctionList.add(squashingFunctionName);
+    }
+    return layerIndex;
+  }
+
+  /**
+   * Get the size of a particular layer.
+   * 
+   * @param layer The index of the layer, starting from 0.
+   * @return The size of the corresponding layer.
+   */
+  public int getLayerSize(int layer) {
+    Preconditions.checkArgument(layer >= 0 && layer < 
this.layerSizeList.size(),
+        String.format("Input must be in range [0, %d]\n", 
this.layerSizeList.size() - 1));
+    return layerSizeList.get(layer);
+  }
+
+  /**
+   * Get the layer size list.
+   * 
+   * @return The sizes of the layers.
+   */
+  protected List<Integer> getLayerSizeList() {
+    return layerSizeList;
+  }
+
+  /**
+   * Get the weights between layer layerIndex and layerIndex + 1
+   * 
+   * @param layerIndex The index of the layer.
+   * @return The weights in form of {@link Matrix}.
+   */
+  public Matrix getWeightsByLayer(int layerIndex) {
+    return weightMatrixList.get(layerIndex);
+  }
+
+  /**
+   * Update the weight matrices with given matrices.
+   * 
+   * @param matrices The weight matrices, must be the same dimension as the
+   *          existing matrices.
+   */
+  public void updateWeightMatrices(Matrix[] matrices) {
+    for (int i = 0; i < matrices.length; ++i) {
+      Matrix matrix = weightMatrixList.get(i);
+      weightMatrixList.set(i, matrix.plus(matrices[i]));
+    }
+  }
+
+  /**
+   * Set the weight matrices.
+   * 
+   * @param matrices The weight matrices, must be the same dimension of the
+   *          existing matrices.
+   */
+  public void setWeightMatrices(Matrix[] matrices) {
+    weightMatrixList = Lists.newArrayList();
+    Collections.addAll(weightMatrixList, matrices);
+  }
+
+  /**
+   * Set the weight matrix for a specified layer.
+   * 
+   * @param index The index of the matrix, starting from 0 (between layer 0 
and 1).
+   * @param matrix The instance of {@link Matrix}.
+   */
+  public void setWeightMatrix(int index, Matrix matrix) {
+    Preconditions.checkArgument(0 <= index && index < weightMatrixList.size(),
+        String.format("index [%s] should be in range [%s, %s).", index, 0, 
weightMatrixList.size()));
+    weightMatrixList.set(index, matrix);
+  }
+
+  /**
+   * Get all the weight matrices.
+   * 
+   * @return The weight matrices.
+   */
+  public Matrix[] getWeightMatrices() {
+    Matrix[] matrices = new Matrix[weightMatrixList.size()];
+    weightMatrixList.toArray(matrices);
+    return matrices;
+  }
+
+  /**
+   * Get the output calculated by the model.
+   * 
+   * @param instance The feature instance in form of {@link Vector}, each 
dimension contains the value of the corresponding feature.
+   * @return The output vector.
+   */
+  public Vector getOutput(Vector instance) {
+    Preconditions.checkArgument(layerSizeList.get(0) == instance.size() + 1,
+        String.format("The dimension of input instance should be %d, but the 
input has dimension %d.",
+            layerSizeList.get(0) - 1, instance.size()));
+
+    // add bias feature
+    Vector instanceWithBias = new DenseVector(instance.size() + 1);
+    // set bias to be a little bit less than 1.0
+    instanceWithBias.set(0, 0.99999);
+    for (int i = 1; i < instanceWithBias.size(); ++i) {
+      instanceWithBias.set(i, instance.get(i - 1));
+    }
+
+    List<Vector> outputCache = getOutputInternal(instanceWithBias);
+    // return the output of the last layer
+    Vector result = outputCache.get(outputCache.size() - 1);
+    // remove bias
+    return result.viewPart(1, result.size() - 1);
+  }
+
+  /**
+   * Calculate output internally, the intermediate output of each layer will be
+   * stored.
+   * 
+   * @param instance The feature instance in form of {@link Vector}, each 
dimension contains the value of the corresponding feature.
+   * @return Cached output of each layer.
+   */
+  protected List<Vector> getOutputInternal(Vector instance) {
+    List<Vector> outputCache = Lists.newArrayList();
+    // fill with instance
+    Vector intermediateOutput = instance;
+    outputCache.add(intermediateOutput);
+
+    for (int i = 0; i < layerSizeList.size() - 1; ++i) {
+      intermediateOutput = forward(i, intermediateOutput);
+      outputCache.add(intermediateOutput);
+    }
+    return outputCache;
+  }
+
+  /**
+   * Forward the calculation for one layer.
+   * 
+   * @param fromLayer The index of the previous layer.
+   * @param intermediateOutput The intermediate output of previous layer.
+   * @return The intermediate results of the current layer.
+   */
+  protected Vector forward(int fromLayer, Vector intermediateOutput) {
+    Matrix weightMatrix = weightMatrixList.get(fromLayer);
+
+    Vector vec = weightMatrix.times(intermediateOutput);
+    vec = 
vec.assign(NeuralNetworkFunctions.getDoubleFunction(squashingFunctionList.get(fromLayer)));
+
+    // add bias
+    Vector vecWithBias = new DenseVector(vec.size() + 1);
+    vecWithBias.set(0, 1);
+    for (int i = 0; i < vec.size(); ++i) {
+      vecWithBias.set(i + 1, vec.get(i));
+    }
+    return vecWithBias;
+  }
+
+  /**
+   * Train the neural network incrementally with given training instance.
+   * 
+   * @param trainingInstance An training instance, including the features and 
the label(s). Its dimension must equals
+   *          to the size of the input layer (bias neuron excluded) + the size
+   *          of the output layer (a.k.a. the dimension of the labels).
+   */
+  public void trainOnline(Vector trainingInstance) {
+    Matrix[] matrices = trainByInstance(trainingInstance);
+    updateWeightMatrices(matrices);
+  }
+
+  /**
+   * Get the updated weights using one training instance.
+   * 
+   * @param trainingInstance An training instance, including the features and 
the label(s). Its dimension must equals
+   *          to the size of the input layer (bias neuron excluded) + the size
+   *          of the output layer (a.k.a. the dimension of the labels).
+   * @return The update of each weight, in form of {@link Matrix} list.
+   */
+  public Matrix[] trainByInstance(Vector trainingInstance) {
+    // validate training instance
+    int inputDimension = layerSizeList.get(0) - 1;
+    int outputDimension = layerSizeList.get(this.layerSizeList.size() - 1);
+    Preconditions.checkArgument(inputDimension + outputDimension == 
trainingInstance.size(),
+        String.format("The dimension of training instance is %d, but requires 
%d.", trainingInstance.size(),
+            inputDimension + outputDimension));
+
+    if (trainingMethod.equals(TrainingMethod.GRADIENT_DESCENT)) {
+      return trainByInstanceGradientDescent(trainingInstance);
+    }
+    throw new IllegalArgumentException("Training method is not supported.");
+  }
+
+  /**
+   * Train by gradient descent. Get the updated weights using one training
+   * instance.
+   * 
+   * @param trainingInstance An training instance, including the features and 
the label(s). Its dimension must equals
+   *          to the size of the input layer (bias neuron excluded) + the size
+   *          of the output layer (a.k.a. the dimension of the labels).
+   * @return The weight update matrices.
+   */
+  private Matrix[] trainByInstanceGradientDescent(Vector trainingInstance) {
+    int inputDimension = layerSizeList.get(0) - 1;
+
+    Vector inputInstance = new DenseVector(layerSizeList.get(0));
+    inputInstance.set(0, 1); // add bias
+    for (int i = 0; i < inputDimension; ++i) {
+      inputInstance.set(i + 1, trainingInstance.get(i));
+    }
+
+    Vector labels =
+        trainingInstance.viewPart(inputInstance.size() - 1, 
trainingInstance.size() - inputInstance.size() + 1);
+
+    // initialize weight update matrices
+    Matrix[] weightUpdateMatrices = new Matrix[weightMatrixList.size()];
+    for (int m = 0; m < weightUpdateMatrices.length; ++m) {
+      weightUpdateMatrices[m] =
+          new DenseMatrix(weightMatrixList.get(m).rowSize(), 
weightMatrixList.get(m).columnSize());
+    }
+
+    List<Vector> internalResults = getOutputInternal(inputInstance);
+
+    Vector deltaVec = new DenseVector(layerSizeList.get(layerSizeList.size() - 
1));
+    Vector output = internalResults.get(internalResults.size() - 1);
+
+    final DoubleFunction derivativeSquashingFunction =
+        
NeuralNetworkFunctions.getDerivativeDoubleFunction(squashingFunctionList.get(squashingFunctionList.size()
 - 1));
+
+    final DoubleDoubleFunction costFunction =
+        
NeuralNetworkFunctions.getDerivativeDoubleDoubleFunction(costFunctionName);
+
+    Matrix lastWeightMatrix = weightMatrixList.get(weightMatrixList.size() - 
1);
+
+    for (int i = 0; i < deltaVec.size(); ++i) {
+      double costFuncDerivative = costFunction.apply(labels.get(i), 
output.get(i + 1));
+      // Add regularization
+      costFuncDerivative += regularizationWeight * 
lastWeightMatrix.viewRow(i).zSum();
+      deltaVec.set(i, costFuncDerivative);
+      deltaVec.set(i, deltaVec.get(i) * 
derivativeSquashingFunction.apply(output.get(i + 1)));
+    }
+
+    // Start from previous layer of output layer
+    for (int layer = layerSizeList.size() - 2; layer >= 0; --layer) {
+      deltaVec = backPropagate(layer, deltaVec, internalResults, 
weightUpdateMatrices[layer]);
+    }
+
+    prevWeightUpdatesList = Arrays.asList(weightUpdateMatrices);
+
+    return weightUpdateMatrices;
+  }
+
+  /**
+   * Back-propagate the errors to from next layer to current layer. The weight
+   * updated information will be stored in the weightUpdateMatrices, and the
+   * delta of the prevLayer will be returned.
+   * 
+   * @param currentLayerIndex Index of current layer.
+   * @param nextLayerDelta Delta of next layer.
+   * @param outputCache The output cache to store intermediate results.
+   * @param weightUpdateMatrix  The weight update, in form of {@link Matrix}.
+   * @return The weight updates.
+   */
+  private Vector backPropagate(int currentLayerIndex, Vector nextLayerDelta,
+                               List<Vector> outputCache, Matrix 
weightUpdateMatrix) {
+
+    // Get layer related information
+    final DoubleFunction derivativeSquashingFunction =
+        
NeuralNetworkFunctions.getDerivativeDoubleFunction(squashingFunctionList.get(currentLayerIndex));
+    Vector curLayerOutput = outputCache.get(currentLayerIndex);
+    Matrix weightMatrix = weightMatrixList.get(currentLayerIndex);
+    Matrix prevWeightMatrix = prevWeightUpdatesList.get(currentLayerIndex);
+
+    // Next layer is not output layer, remove the delta of bias neuron
+    if (currentLayerIndex != layerSizeList.size() - 2) {
+      nextLayerDelta = nextLayerDelta.viewPart(1, nextLayerDelta.size() - 1);
+    }
+
+    Vector delta = weightMatrix.transpose().times(nextLayerDelta);
+
+    delta = delta.assign(curLayerOutput, new DoubleDoubleFunction() {
+      @Override
+      public double apply(double deltaElem, double curLayerOutputElem) {
+        return deltaElem * 
derivativeSquashingFunction.apply(curLayerOutputElem);
+      }
+    });
+
+    // Update weights
+    for (int i = 0; i < weightUpdateMatrix.rowSize(); ++i) {
+      for (int j = 0; j < weightUpdateMatrix.columnSize(); ++j) {
+        weightUpdateMatrix.set(i, j, -learningRate * nextLayerDelta.get(i) *
+            curLayerOutput.get(j) + this.momentumWeight * 
prevWeightMatrix.get(i, j));
+      }
+    }
+
+    return delta;
+  }
+
+  /**
+   * Read the model meta-data from the specified location.
+   * 
+   * @throws IOException
+   */
+  protected void readFromModel() throws IOException {
+    log.info("Load model from {}", modelPath);
+    Preconditions.checkArgument(modelPath != null, "Model path has not been 
set.");
+    FSDataInputStream is = null;
+    try {
+      Path path = new Path(modelPath);
+      FileSystem fs = path.getFileSystem(new Configuration());
+      is = new FSDataInputStream(fs.open(path));
+      readFields(is);
+    } finally {
+      Closeables.close(is, true);
+    }
+  }
+
+  /**
+   * Write the model data to specified location.
+   * 
+   * @throws IOException
+   */
+  public void writeModelToFile() throws IOException {
+    log.info("Write model to {}.", modelPath);
+    Preconditions.checkArgument(modelPath != null, "Model path has not been 
set.");
+    FSDataOutputStream stream = null;
+    try {
+      Path path = new Path(modelPath);
+      FileSystem fs = path.getFileSystem(new Configuration());
+      stream = fs.create(path, true);
+      write(stream);
+    } finally {
+      Closeables.close(stream, false);
+    }
+  }
+
+  /**
+   * Set the model path.
+   * 
+   * @param modelPath The path of the model.
+   */
+  public void setModelPath(String modelPath) {
+    this.modelPath = modelPath;
+  }
+
+  /**
+   * Get the model path.
+   * 
+   * @return The path of the model.
+   */
+  public String getModelPath() {
+    return modelPath;
+  }
+
+  /**
+   * Write the fields of the model to output.
+   * 
+   * @param output The output instance.
+   * @throws IOException
+   */
+  public void write(DataOutput output) throws IOException {
+    // Write model type
+    WritableUtils.writeString(output, modelType);
+    // Write learning rate
+    output.writeDouble(learningRate);
+    // Write model path
+    if (modelPath != null) {
+      WritableUtils.writeString(output, modelPath);
+    } else {
+      WritableUtils.writeString(output, "null");
+    }
+
+    // Write regularization weight
+    output.writeDouble(regularizationWeight);
+    // Write momentum weight
+    output.writeDouble(momentumWeight);
+
+    // Write cost function
+    WritableUtils.writeString(output, costFunctionName);
+
+    // Write layer size list
+    output.writeInt(layerSizeList.size());
+    for (Integer aLayerSizeList : layerSizeList) {
+      output.writeInt(aLayerSizeList);
+    }
+
+    WritableUtils.writeEnum(output, trainingMethod);
+
+    // Write squashing functions
+    output.writeInt(squashingFunctionList.size());
+    for (String aSquashingFunctionList : squashingFunctionList) {
+      WritableUtils.writeString(output, aSquashingFunctionList);
+    }
+
+    // Write weight matrices
+    output.writeInt(this.weightMatrixList.size());
+    for (Matrix aWeightMatrixList : weightMatrixList) {
+      MatrixWritable.writeMatrix(output, aWeightMatrixList);
+    }
+  }
+
+  /**
+   * Read the fields of the model from input.
+   * 
+   * @param input The input instance.
+   * @throws IOException
+   */
+  public void readFields(DataInput input) throws IOException {
+    // Read model type
+    modelType = WritableUtils.readString(input);
+    if (!modelType.equals(this.getClass().getSimpleName())) {
+      throw new IllegalArgumentException("The specified location does not 
contains the valid NeuralNetwork model.");
+    }
+    // Read learning rate
+    learningRate = input.readDouble();
+    // Read model path
+    modelPath = WritableUtils.readString(input);
+    if (modelPath.equals("null")) {
+      modelPath = null;
+    }
+
+    // Read regularization weight
+    regularizationWeight = input.readDouble();
+    // Read momentum weight
+    momentumWeight = input.readDouble();
+
+    // Read cost function
+    costFunctionName = WritableUtils.readString(input);
+
+    // Read layer size list
+    int numLayers = input.readInt();
+    layerSizeList = Lists.newArrayList();
+    for (int i = 0; i < numLayers; i++) {
+      layerSizeList.add(input.readInt());
+    }
+
+    trainingMethod = WritableUtils.readEnum(input, TrainingMethod.class);
+
+    // Read squash functions
+    int squashingFunctionSize = input.readInt();
+    squashingFunctionList = Lists.newArrayList();
+    for (int i = 0; i < squashingFunctionSize; i++) {
+      squashingFunctionList.add(WritableUtils.readString(input));
+    }
+
+    // Read weights and construct matrices of previous updates
+    int numOfMatrices = input.readInt();
+    weightMatrixList = Lists.newArrayList();
+    prevWeightUpdatesList = Lists.newArrayList();
+    for (int i = 0; i < numOfMatrices; i++) {
+      Matrix matrix = MatrixWritable.readMatrix(input);
+      weightMatrixList.add(matrix);
+      prevWeightUpdatesList.add(new DenseMatrix(matrix.rowSize(), 
matrix.columnSize()));
+    }
+  }
+
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetworkFunctions.java
----------------------------------------------------------------------
diff --git 
a/mr/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetworkFunctions.java 
b/mr/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetworkFunctions.java
new file mode 100644
index 0000000..8fd0176
--- /dev/null
+++ 
b/mr/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetworkFunctions.java
@@ -0,0 +1,150 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.mlp;
+
+import org.apache.mahout.math.function.DoubleDoubleFunction;
+import org.apache.mahout.math.function.DoubleFunction;
+import org.apache.mahout.math.function.Functions;
+
+/**
+ * The functions that will be used by NeuralNetwork.
+ */
+public class NeuralNetworkFunctions {
+
+  /**
+   * The derivation of identity function (f(x) = x).
+   */
+  public static DoubleFunction derivativeIdentityFunction = new 
DoubleFunction() {
+    @Override
+    public double apply(double x) {
+      return 1;
+    }
+  };
+
+  /**
+   * The derivation of minus squared function (f(t, o) = (o - t)^2).
+   */
+  public static DoubleDoubleFunction derivativeMinusSquared = new 
DoubleDoubleFunction() {
+    @Override
+    public double apply(double target, double output) {
+      return 2 * (output - target);
+    }
+  };
+
+  /**
+   * The cross entropy function (f(t, o) = -t * log(o) - (1 - t) * log(1 - o)).
+   */
+  public static DoubleDoubleFunction crossEntropy = new DoubleDoubleFunction() 
{
+    @Override
+    public double apply(double target, double output) {
+      return -target * Math.log(output) - (1 - target) * Math.log(1 - output);
+    }
+  };
+
+  /**
+   * The derivation of cross entropy function (f(t, o) = -t * log(o) - (1 - t) 
*
+   * log(1 - o)).
+   */
+  public static DoubleDoubleFunction derivativeCrossEntropy = new 
DoubleDoubleFunction() {
+    @Override
+    public double apply(double target, double output) {
+      double adjustedTarget = target;
+      double adjustedActual = output;
+      if (adjustedActual == 1) {
+        adjustedActual = 0.999;
+      } else if (output == 0) {
+        adjustedActual = 0.001;
+      }
+      if (adjustedTarget == 1) {
+        adjustedTarget = 0.999;
+      } else if (adjustedTarget == 0) {
+        adjustedTarget = 0.001;
+      }
+      return -adjustedTarget / adjustedActual + (1 - adjustedTarget) / (1 - 
adjustedActual);
+    }
+  };
+
+  /**
+   * Get the corresponding function by its name.
+   * Currently supports: "Identity", "Sigmoid".
+   * 
+   * @param function The name of the function.
+   * @return The corresponding double function.
+   */
+  public static DoubleFunction getDoubleFunction(String function) {
+    if (function.equalsIgnoreCase("Identity")) {
+      return Functions.IDENTITY;
+    } else if (function.equalsIgnoreCase("Sigmoid")) {
+      return Functions.SIGMOID;
+    } else {
+      throw new IllegalArgumentException("Function not supported.");
+    }
+  }
+
+  /**
+   * Get the derivation double function by the name.
+   * Currently supports: "Identity", "Sigmoid".
+   * 
+   * @param function The name of the function.
+   * @return The double function.
+   */
+  public static DoubleFunction getDerivativeDoubleFunction(String function) {
+    if (function.equalsIgnoreCase("Identity")) {
+      return derivativeIdentityFunction;
+    } else if (function.equalsIgnoreCase("Sigmoid")) {
+      return Functions.SIGMOIDGRADIENT;
+    } else {
+      throw new IllegalArgumentException("Function not supported.");
+    }
+  }
+
+  /**
+   * Get the corresponding double-double function by the name.
+   * Currently supports: "Minus_Squared", "Cross_Entropy".
+   * 
+   * @param function The name of the function.
+   * @return The double-double function.
+   */
+  public static DoubleDoubleFunction getDoubleDoubleFunction(String function) {
+    if (function.equalsIgnoreCase("Minus_Squared")) {
+      return Functions.MINUS_SQUARED;
+    } else if (function.equalsIgnoreCase("Cross_Entropy")) {
+      return derivativeCrossEntropy;
+    } else {
+      throw new IllegalArgumentException("Function not supported.");
+    }
+  }
+
+  /**
+   * Get the corresponding derivation of double double function by the name.
+   * Currently supports: "Minus_Squared", "Cross_Entropy".
+   * 
+   * @param function The name of the function.
+   * @return The double-double-function.
+   */
+  public static DoubleDoubleFunction getDerivativeDoubleDoubleFunction(String 
function) {
+    if (function.equalsIgnoreCase("Minus_Squared")) {
+      return derivativeMinusSquared;
+    } else if (function.equalsIgnoreCase("Cross_Entropy")) {
+      return derivativeCrossEntropy;
+    } else {
+      throw new IllegalArgumentException("Function not supported.");
+    }
+  }
+
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/mlp/RunMultilayerPerceptron.java
----------------------------------------------------------------------
diff --git 
a/mr/src/main/java/org/apache/mahout/classifier/mlp/RunMultilayerPerceptron.java
 
b/mr/src/main/java/org/apache/mahout/classifier/mlp/RunMultilayerPerceptron.java
new file mode 100644
index 0000000..36d6792
--- /dev/null
+++ 
b/mr/src/main/java/org/apache/mahout/classifier/mlp/RunMultilayerPerceptron.java
@@ -0,0 +1,227 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.mlp;
+
+import java.io.BufferedReader;
+import java.io.BufferedWriter;
+import java.io.InputStreamReader;
+import java.io.OutputStreamWriter;
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.csv.CSVUtils;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+
+/** Run {@link MultilayerPerceptron} classification. */
+public class RunMultilayerPerceptron {
+
+  private static final Logger log = 
LoggerFactory.getLogger(RunMultilayerPerceptron.class);
+
+  static class Parameters {
+    String inputFilePathStr;
+    String inputFileFormat;
+    String modelFilePathStr;
+    String outputFilePathStr;
+    int columnStart;
+    int columnEnd;
+    boolean skipHeader;
+  }
+  
+  public static void main(String[] args) throws Exception {
+    
+    Parameters parameters = new Parameters();
+    
+    if (parseArgs(args, parameters)) {
+      log.info("Load model from {}.", parameters.modelFilePathStr);
+      MultilayerPerceptron mlp = new 
MultilayerPerceptron(parameters.modelFilePathStr);
+
+      log.info("Topology of MLP: {}.", 
Arrays.toString(mlp.getLayerSizeList().toArray()));
+
+      // validate the data
+      log.info("Read the data...");
+      Path inputFilePath = new Path(parameters.inputFilePathStr);
+      FileSystem inputFS = inputFilePath.getFileSystem(new Configuration());
+      if (!inputFS.exists(inputFilePath)) {
+        log.error("Input file '{}' does not exists!", 
parameters.inputFilePathStr);
+        mlp.close();
+        return;
+      }
+
+      Path outputFilePath = new Path(parameters.outputFilePathStr);
+      FileSystem outputFS = inputFilePath.getFileSystem(new Configuration());
+      if (outputFS.exists(outputFilePath)) {
+        log.error("Output file '{}' already exists!", 
parameters.outputFilePathStr);
+        mlp.close();
+        return;
+      }
+
+      if (!parameters.inputFileFormat.equals("csv")) {
+        log.error("Currently only supports for csv format.");
+        mlp.close();
+        return; // current only supports csv format
+      }
+
+      log.info("Read from column {} to column {}.", parameters.columnStart, 
parameters.columnEnd);
+
+      BufferedWriter writer = null;
+      BufferedReader reader = null;
+
+      try {
+        writer = new BufferedWriter(new 
OutputStreamWriter(outputFS.create(outputFilePath)));
+        reader = new BufferedReader(new 
InputStreamReader(inputFS.open(inputFilePath)));
+        
+        String line;
+
+        if (parameters.skipHeader) {
+          reader.readLine();
+        }
+
+        while ((line = reader.readLine()) != null) {
+          String[] tokens = CSVUtils.parseLine(line);
+          double[] features = new double[Math.min(parameters.columnEnd, 
tokens.length) - parameters.columnStart + 1];
+
+          for (int i = parameters.columnStart, j = 0; i < 
Math.min(parameters.columnEnd + 1, tokens.length); ++i, ++j) {
+            features[j] = Double.parseDouble(tokens[i]);
+          }
+          Vector featureVec = new DenseVector(features);
+          Vector res = mlp.getOutput(featureVec);
+          int mostProbablyLabelIndex = res.maxValueIndex();
+          writer.write(String.valueOf(mostProbablyLabelIndex));
+        }
+        mlp.close();
+        log.info("Labeling finished.");
+      } finally {
+        Closeables.close(reader, true);
+        Closeables.close(writer, true);
+      }
+    }
+  }
+
+  /**
+   * Parse the arguments.
+   *
+   * @param args The input arguments.
+   * @param parameters  The parameters need to be filled.
+   * @return true or false
+   * @throws Exception
+   */
+  private static boolean parseArgs(String[] args, Parameters parameters) 
throws Exception {
+    // build the options
+    log.info("Validate and parse arguments...");
+    DefaultOptionBuilder optionBuilder = new DefaultOptionBuilder();
+    GroupBuilder groupBuilder = new GroupBuilder();
+    ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+
+    Option inputFileFormatOption = optionBuilder
+        .withLongName("format")
+        .withShortName("f")
+        .withArgument(argumentBuilder.withName("file 
type").withDefault("csv").withMinimum(1).withMaximum(1).create())
+        .withDescription("type of input file, currently support 'csv'")
+        .create();
+
+    List<Integer> columnRangeDefault = Lists.newArrayList();
+    columnRangeDefault.add(0);
+    columnRangeDefault.add(Integer.MAX_VALUE);
+
+    Option skipHeaderOption = optionBuilder.withLongName("skipHeader")
+        .withShortName("sh").withRequired(false)
+        .withDescription("whether to skip the first row of the input file")
+        .create();
+
+    Option inputColumnRangeOption = optionBuilder
+        .withLongName("columnRange")
+        .withShortName("cr")
+        .withDescription("the column range of the input file, start from 0")
+        
.withArgument(argumentBuilder.withName("range").withMinimum(2).withMaximum(2)
+            .withDefaults(columnRangeDefault).create()).create();
+
+    Group inputFileTypeGroup = groupBuilder.withOption(skipHeaderOption)
+        .withOption(inputColumnRangeOption).withOption(inputFileFormatOption)
+        .create();
+
+    Option inputOption = optionBuilder
+        .withLongName("input")
+        .withShortName("i")
+        .withRequired(true)
+        .withArgument(argumentBuilder.withName("file 
path").withMinimum(1).withMaximum(1).create())
+        .withDescription("the file path of unlabelled dataset")
+        .withChildren(inputFileTypeGroup).create();
+
+    Option modelOption = optionBuilder
+        .withLongName("model")
+        .withShortName("mo")
+        .withRequired(true)
+        .withArgument(argumentBuilder.withName("model 
file").withMinimum(1).withMaximum(1).create())
+        .withDescription("the file path of the model").create();
+
+    Option labelsOption = optionBuilder
+        .withLongName("labels")
+        .withShortName("labels")
+        
.withArgument(argumentBuilder.withName("label-name").withMinimum(2).create())
+        .withDescription("an ordered list of label names").create();
+
+    Group labelsGroup = groupBuilder.withOption(labelsOption).create();
+
+    Option outputOption = optionBuilder
+        .withLongName("output")
+        .withShortName("o")
+        .withRequired(true)
+        .withArgument(argumentBuilder.withConsumeRemaining("file 
path").withMinimum(1).withMaximum(1).create())
+        .withDescription("the file path of labelled 
results").withChildren(labelsGroup).create();
+
+    // parse the input
+    Parser parser = new Parser();
+    Group normalOption = 
groupBuilder.withOption(inputOption).withOption(modelOption).withOption(outputOption).create();
+    parser.setGroup(normalOption);
+    CommandLine commandLine = parser.parseAndHelp(args);
+    if (commandLine == null) {
+      return false;
+    }
+
+    // obtain the arguments
+    parameters.inputFilePathStr = 
TrainMultilayerPerceptron.getString(commandLine, inputOption);
+    parameters.inputFileFormat = 
TrainMultilayerPerceptron.getString(commandLine, inputFileFormatOption);
+    parameters.skipHeader = commandLine.hasOption(skipHeaderOption);
+    parameters.modelFilePathStr = 
TrainMultilayerPerceptron.getString(commandLine, modelOption);
+    parameters.outputFilePathStr = 
TrainMultilayerPerceptron.getString(commandLine, outputOption);
+
+    List<?> columnRange = commandLine.getValues(inputColumnRangeOption);
+    parameters.columnStart = Integer.parseInt(columnRange.get(0).toString());
+    parameters.columnEnd = Integer.parseInt(columnRange.get(1).toString());
+
+    return true;
+  }
+
+}

Reply via email to