http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java b/mr/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java new file mode 100644 index 0000000..d02d974 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java @@ -0,0 +1,296 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.tools; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.filecache.DistributedCache; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.JobContext; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; +import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.mahout.classifier.df.DFUtils; +import org.apache.mahout.classifier.df.data.DataConverter; +import org.apache.mahout.classifier.df.data.Dataset; +import org.apache.mahout.classifier.df.data.Instance; +import org.apache.mahout.classifier.df.mapreduce.Builder; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.net.URI; +import java.util.Arrays; + +/** + * Temporary class used to compute the frequency distribution of the "class attribute".<br> + * This class can be used when the criterion variable is the categorical attribute. + */ +public class FrequenciesJob { + + private static final Logger log = LoggerFactory.getLogger(FrequenciesJob.class); + + /** directory that will hold this job's output */ + private final Path outputPath; + + /** file that contains the serialized dataset */ + private final Path datasetPath; + + /** directory that contains the data used in the first step */ + private final Path dataPath; + + /** + * @param base + * base directory + * @param dataPath + * data used in the first step + */ + public FrequenciesJob(Path base, Path dataPath, Path datasetPath) { + this.outputPath = new Path(base, "frequencies.output"); + this.dataPath = dataPath; + this.datasetPath = datasetPath; + } + + /** + * @return counts[partition][label] = num tuples from 'partition' with class == label + */ + public int[][] run(Configuration conf) throws IOException, ClassNotFoundException, InterruptedException { + + // check the output + FileSystem fs = outputPath.getFileSystem(conf); + if (fs.exists(outputPath)) { + throw new IOException("Output path already exists : " + outputPath); + } + + // put the dataset into the DistributedCache + URI[] files = {datasetPath.toUri()}; + DistributedCache.setCacheFiles(files, conf); + + Job job = new Job(conf); + job.setJarByClass(FrequenciesJob.class); + + FileInputFormat.setInputPaths(job, dataPath); + FileOutputFormat.setOutputPath(job, outputPath); + + job.setMapOutputKeyClass(LongWritable.class); + job.setMapOutputValueClass(IntWritable.class); + job.setOutputKeyClass(LongWritable.class); + job.setOutputValueClass(Frequencies.class); + + job.setMapperClass(FrequenciesMapper.class); + job.setReducerClass(FrequenciesReducer.class); + + job.setInputFormatClass(TextInputFormat.class); + job.setOutputFormatClass(SequenceFileOutputFormat.class); + + // run the job + boolean succeeded = job.waitForCompletion(true); + if (!succeeded) { + throw new IllegalStateException("Job failed!"); + } + + int[][] counts = parseOutput(job); + + HadoopUtil.delete(conf, outputPath); + + return counts; + } + + /** + * Extracts the output and processes it + * + * @return counts[partition][label] = num tuples from 'partition' with class == label + */ + int[][] parseOutput(JobContext job) throws IOException { + Configuration conf = job.getConfiguration(); + + int numMaps = conf.getInt("mapred.map.tasks", -1); + log.info("mapred.map.tasks = {}", numMaps); + + FileSystem fs = outputPath.getFileSystem(conf); + + Path[] outfiles = DFUtils.listOutputFiles(fs, outputPath); + + Frequencies[] values = new Frequencies[numMaps]; + + // read all the outputs + int index = 0; + for (Path path : outfiles) { + for (Frequencies value : new SequenceFileValueIterable<Frequencies>(path, conf)) { + values[index++] = value; + } + } + + if (index < numMaps) { + throw new IllegalStateException("number of output Frequencies (" + index + + ") is lesser than the number of mappers!"); + } + + // sort the frequencies using the firstIds + Arrays.sort(values); + return Frequencies.extractCounts(values); + } + + /** + * Outputs the first key and the label of each tuple + * + */ + private static class FrequenciesMapper extends Mapper<LongWritable,Text,LongWritable,IntWritable> { + + private LongWritable firstId; + + private DataConverter converter; + private Dataset dataset; + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + Configuration conf = context.getConfiguration(); + + dataset = Builder.loadDataset(conf); + setup(dataset); + } + + /** + * Useful when testing + */ + void setup(Dataset dataset) { + converter = new DataConverter(dataset); + } + + @Override + protected void map(LongWritable key, Text value, Context context) throws IOException, + InterruptedException { + if (firstId == null) { + firstId = new LongWritable(key.get()); + } + + Instance instance = converter.convert(value.toString()); + + context.write(firstId, new IntWritable((int) dataset.getLabel(instance))); + } + + } + + private static class FrequenciesReducer extends Reducer<LongWritable,IntWritable,LongWritable,Frequencies> { + + private int nblabels; + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + Configuration conf = context.getConfiguration(); + Dataset dataset = Builder.loadDataset(conf); + setup(dataset.nblabels()); + } + + /** + * Useful when testing + */ + void setup(int nblabels) { + this.nblabels = nblabels; + } + + @Override + protected void reduce(LongWritable key, Iterable<IntWritable> values, Context context) + throws IOException, InterruptedException { + int[] counts = new int[nblabels]; + for (IntWritable value : values) { + counts[value.get()]++; + } + + context.write(key, new Frequencies(key.get(), counts)); + } + } + + /** + * Output of the job + * + */ + private static class Frequencies implements Writable, Comparable<Frequencies>, Cloneable { + + /** first key of the partition used to sort the partitions */ + private long firstId; + + /** counts[c] = num tuples from the partition with label == c */ + private int[] counts; + + Frequencies() { } + + Frequencies(long firstId, int[] counts) { + this.firstId = firstId; + this.counts = Arrays.copyOf(counts, counts.length); + } + + @Override + public void readFields(DataInput in) throws IOException { + firstId = in.readLong(); + counts = DFUtils.readIntArray(in); + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeLong(firstId); + DFUtils.writeArray(out, counts); + } + + @Override + public boolean equals(Object other) { + return other instanceof Frequencies && firstId == ((Frequencies) other).firstId; + } + + @Override + public int hashCode() { + return (int) firstId; + } + + @Override + protected Frequencies clone() { + return new Frequencies(firstId, counts); + } + + @Override + public int compareTo(Frequencies obj) { + if (firstId < obj.firstId) { + return -1; + } else if (firstId > obj.firstId) { + return 1; + } else { + return 0; + } + } + + public static int[][] extractCounts(Frequencies[] partitions) { + int[][] counts = new int[partitions.length][]; + for (int p = 0; p < partitions.length; p++) { + counts[p] = partitions[p].counts; + } + return counts; + } + } +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java b/mr/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java new file mode 100644 index 0000000..d82b383 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java @@ -0,0 +1,263 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.tools; + +import java.lang.reflect.Field; +import java.text.DecimalFormat; +import java.util.HashMap; +import java.util.Map; + +import org.apache.commons.lang3.ArrayUtils; +import org.apache.mahout.classifier.df.data.Data; +import org.apache.mahout.classifier.df.data.Dataset; +import org.apache.mahout.classifier.df.data.Instance; +import org.apache.mahout.classifier.df.node.CategoricalNode; +import org.apache.mahout.classifier.df.node.Leaf; +import org.apache.mahout.classifier.df.node.Node; +import org.apache.mahout.classifier.df.node.NumericalNode; + +/** + * This tool is to visualize the Decision tree + */ +public final class TreeVisualizer { + + private TreeVisualizer() {} + + private static String doubleToString(double value) { + DecimalFormat df = new DecimalFormat("0.##"); + return df.format(value); + } + + private static String toStringNode(Node node, Dataset dataset, + String[] attrNames, Map<String,Field> fields, int layer) { + + StringBuilder buff = new StringBuilder(); + + try { + if (node instanceof CategoricalNode) { + CategoricalNode cnode = (CategoricalNode) node; + int attr = (Integer) fields.get("CategoricalNode.attr").get(cnode); + double[] values = (double[]) fields.get("CategoricalNode.values").get(cnode); + Node[] childs = (Node[]) fields.get("CategoricalNode.childs").get(cnode); + String[][] attrValues = (String[][]) fields.get("Dataset.values").get(dataset); + for (int i = 0; i < attrValues[attr].length; i++) { + int index = ArrayUtils.indexOf(values, i); + if (index < 0) { + continue; + } + buff.append('\n'); + for (int j = 0; j < layer; j++) { + buff.append("| "); + } + buff.append(attrNames == null ? attr : attrNames[attr]).append(" = ") + .append(attrValues[attr][i]); + buff.append(toStringNode(childs[index], dataset, attrNames, fields, layer + 1)); + } + } else if (node instanceof NumericalNode) { + NumericalNode nnode = (NumericalNode) node; + int attr = (Integer) fields.get("NumericalNode.attr").get(nnode); + double split = (Double) fields.get("NumericalNode.split").get(nnode); + Node loChild = (Node) fields.get("NumericalNode.loChild").get(nnode); + Node hiChild = (Node) fields.get("NumericalNode.hiChild").get(nnode); + buff.append('\n'); + for (int j = 0; j < layer; j++) { + buff.append("| "); + } + buff.append(attrNames == null ? attr : attrNames[attr]).append(" < ") + .append(doubleToString(split)); + buff.append(toStringNode(loChild, dataset, attrNames, fields, layer + 1)); + buff.append('\n'); + for (int j = 0; j < layer; j++) { + buff.append("| "); + } + buff.append(attrNames == null ? attr : attrNames[attr]).append(" >= ") + .append(doubleToString(split)); + buff.append(toStringNode(hiChild, dataset, attrNames, fields, layer + 1)); + } else if (node instanceof Leaf) { + Leaf leaf = (Leaf) node; + double label = (Double) fields.get("Leaf.label").get(leaf); + if (dataset.isNumerical(dataset.getLabelId())) { + buff.append(" : ").append(doubleToString(label)); + } else { + buff.append(" : ").append(dataset.getLabelString(label)); + } + } + } catch (IllegalAccessException iae) { + throw new IllegalStateException(iae); + } + + return buff.toString(); + } + + private static Map<String,Field> getReflectMap() { + Map<String,Field> fields = new HashMap<String,Field>(); + + try { + Field m = CategoricalNode.class.getDeclaredField("attr"); + m.setAccessible(true); + fields.put("CategoricalNode.attr", m); + m = CategoricalNode.class.getDeclaredField("values"); + m.setAccessible(true); + fields.put("CategoricalNode.values", m); + m = CategoricalNode.class.getDeclaredField("childs"); + m.setAccessible(true); + fields.put("CategoricalNode.childs", m); + m = NumericalNode.class.getDeclaredField("attr"); + m.setAccessible(true); + fields.put("NumericalNode.attr", m); + m = NumericalNode.class.getDeclaredField("split"); + m.setAccessible(true); + fields.put("NumericalNode.split", m); + m = NumericalNode.class.getDeclaredField("loChild"); + m.setAccessible(true); + fields.put("NumericalNode.loChild", m); + m = NumericalNode.class.getDeclaredField("hiChild"); + m.setAccessible(true); + fields.put("NumericalNode.hiChild", m); + m = Leaf.class.getDeclaredField("label"); + m.setAccessible(true); + fields.put("Leaf.label", m); + m = Dataset.class.getDeclaredField("values"); + m.setAccessible(true); + fields.put("Dataset.values", m); + } catch (NoSuchFieldException nsfe) { + throw new IllegalStateException(nsfe); + } + + return fields; + } + + /** + * Decision tree to String + * + * @param tree + * Node of tree + * @param attrNames + * attribute names + */ + public static String toString(Node tree, Dataset dataset, String[] attrNames) { + return toStringNode(tree, dataset, attrNames, getReflectMap(), 0); + } + + /** + * Print Decision tree + * + * @param tree + * Node of tree + * @param attrNames + * attribute names + */ + public static void print(Node tree, Dataset dataset, String[] attrNames) { + System.out.println(toString(tree, dataset, attrNames)); + } + + private static String toStringPredict(Node node, Instance instance, + Dataset dataset, String[] attrNames, Map<String,Field> fields) { + StringBuilder buff = new StringBuilder(); + + try { + if (node instanceof CategoricalNode) { + CategoricalNode cnode = (CategoricalNode) node; + int attr = (Integer) fields.get("CategoricalNode.attr").get(cnode); + double[] values = (double[]) fields.get("CategoricalNode.values").get( + cnode); + Node[] childs = (Node[]) fields.get("CategoricalNode.childs") + .get(cnode); + String[][] attrValues = (String[][]) fields.get("Dataset.values").get( + dataset); + + int index = ArrayUtils.indexOf(values, instance.get(attr)); + if (index >= 0) { + buff.append(attrNames == null ? attr : attrNames[attr]).append(" = ") + .append(attrValues[attr][(int) instance.get(attr)]); + buff.append(" -> "); + buff.append(toStringPredict(childs[index], instance, dataset, + attrNames, fields)); + } + } else if (node instanceof NumericalNode) { + NumericalNode nnode = (NumericalNode) node; + int attr = (Integer) fields.get("NumericalNode.attr").get(nnode); + double split = (Double) fields.get("NumericalNode.split").get(nnode); + Node loChild = (Node) fields.get("NumericalNode.loChild").get(nnode); + Node hiChild = (Node) fields.get("NumericalNode.hiChild").get(nnode); + + if (instance.get(attr) < split) { + buff.append('(').append(attrNames == null ? attr : attrNames[attr]) + .append(" = ").append(doubleToString(instance.get(attr))) + .append(") < ").append(doubleToString(split)); + buff.append(" -> "); + buff.append(toStringPredict(loChild, instance, dataset, attrNames, + fields)); + } else { + buff.append('(').append(attrNames == null ? attr : attrNames[attr]) + .append(" = ").append(doubleToString(instance.get(attr))) + .append(") >= ").append(doubleToString(split)); + buff.append(" -> "); + buff.append(toStringPredict(hiChild, instance, dataset, attrNames, + fields)); + } + } else if (node instanceof Leaf) { + Leaf leaf = (Leaf) node; + double label = (Double) fields.get("Leaf.label").get(leaf); + if (dataset.isNumerical(dataset.getLabelId())) { + buff.append(doubleToString(label)); + } else { + buff.append(dataset.getLabelString(label)); + } + } + } catch (IllegalAccessException iae) { + throw new IllegalStateException(iae); + } + + return buff.toString(); + } + + /** + * Predict trace to String + * + * @param tree + * Node of tree + * @param attrNames + * attribute names + */ + public static String[] predictTrace(Node tree, Data data, String[] attrNames) { + Map<String,Field> reflectMap = getReflectMap(); + String[] prediction = new String[data.size()]; + for (int i = 0; i < data.size(); i++) { + prediction[i] = toStringPredict(tree, data.get(i), data.getDataset(), + attrNames, reflectMap); + } + return prediction; + } + + /** + * Print predict trace + * + * @param tree + * Node of tree + * @param attrNames + * attribute names + */ + public static void predictTracePrint(Node tree, Data data, String[] attrNames) { + Map<String,Field> reflectMap = getReflectMap(); + for (int i = 0; i < data.size(); i++) { + System.out.println(toStringPredict(tree, data.get(i), data.getDataset(), + attrNames, reflectMap)); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java b/mr/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java new file mode 100644 index 0000000..06876e1 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java @@ -0,0 +1,211 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.tools; + +import java.io.File; +import java.io.IOException; +import java.util.Locale; +import java.util.Random; +import java.util.Scanner; + +import com.google.common.base.Preconditions; +import com.google.common.io.Closeables; +import org.apache.commons.cli2.CommandLine; +import org.apache.commons.cli2.Group; +import org.apache.commons.cli2.Option; +import org.apache.commons.cli2.OptionException; +import org.apache.commons.cli2.builder.ArgumentBuilder; +import org.apache.commons.cli2.builder.DefaultOptionBuilder; +import org.apache.commons.cli2.builder.GroupBuilder; +import org.apache.commons.cli2.commandline.Parser; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FSDataInputStream; +import org.apache.hadoop.fs.FSDataOutputStream; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.FileUtil; +import org.apache.hadoop.fs.Path; +import org.apache.mahout.classifier.df.data.DataConverter; +import org.apache.mahout.classifier.df.data.Dataset; +import org.apache.mahout.classifier.df.data.Instance; +import org.apache.mahout.common.CommandLineUtil; +import org.apache.mahout.common.RandomUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This tool is used to uniformly distribute the class of all the tuples of the dataset over a given number of + * partitions.<br> + * This class can be used when the criterion variable is the categorical attribute. + */ +public final class UDistrib { + + private static final Logger log = LoggerFactory.getLogger(UDistrib.class); + + private UDistrib() {} + + /** + * Launch the uniform distribution tool. Requires the following command line arguments:<br> + * + * data : data path dataset : dataset path numpartitions : num partitions output : output path + * + * @throws java.io.IOException + */ + public static void main(String[] args) throws IOException { + + DefaultOptionBuilder obuilder = new DefaultOptionBuilder(); + ArgumentBuilder abuilder = new ArgumentBuilder(); + GroupBuilder gbuilder = new GroupBuilder(); + + Option dataOpt = obuilder.withLongName("data").withShortName("d").withRequired(true).withArgument( + abuilder.withName("data").withMinimum(1).withMaximum(1).create()).withDescription("Data path").create(); + + Option datasetOpt = obuilder.withLongName("dataset").withShortName("ds").withRequired(true).withArgument( + abuilder.withName("dataset").withMinimum(1).create()).withDescription("Dataset path").create(); + + Option outputOpt = obuilder.withLongName("output").withShortName("o").withRequired(true).withArgument( + abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription( + "Path to generated files").create(); + + Option partitionsOpt = obuilder.withLongName("numpartitions").withShortName("p").withRequired(true) + .withArgument(abuilder.withName("numparts").withMinimum(1).withMinimum(1).create()).withDescription( + "Number of partitions to create").create(); + Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h") + .create(); + + Group group = gbuilder.withName("Options").withOption(dataOpt).withOption(outputOpt).withOption( + datasetOpt).withOption(partitionsOpt).withOption(helpOpt).create(); + + try { + Parser parser = new Parser(); + parser.setGroup(group); + CommandLine cmdLine = parser.parse(args); + + if (cmdLine.hasOption(helpOpt)) { + CommandLineUtil.printHelp(group); + return; + } + + String data = cmdLine.getValue(dataOpt).toString(); + String dataset = cmdLine.getValue(datasetOpt).toString(); + int numPartitions = Integer.parseInt(cmdLine.getValue(partitionsOpt).toString()); + String output = cmdLine.getValue(outputOpt).toString(); + + runTool(data, dataset, output, numPartitions); + } catch (OptionException e) { + log.warn(e.toString(), e); + CommandLineUtil.printHelp(group); + } + + } + + private static void runTool(String dataStr, String datasetStr, String output, int numPartitions) throws IOException { + + Preconditions.checkArgument(numPartitions > 0, "numPartitions <= 0"); + + // make sure the output file does not exist + Path outputPath = new Path(output); + Configuration conf = new Configuration(); + FileSystem fs = outputPath.getFileSystem(conf); + + Preconditions.checkArgument(!fs.exists(outputPath), "Output path already exists"); + + // create a new file corresponding to each partition + // Path workingDir = fs.getWorkingDirectory(); + // FileSystem wfs = workingDir.getFileSystem(conf); + // File parentFile = new File(workingDir.toString()); + // File tempFile = FileUtil.createLocalTempFile(parentFile, "Parts", true); + // File tempFile = File.createTempFile("df.tools.UDistrib",""); + // tempFile.deleteOnExit(); + File tempFile = FileUtil.createLocalTempFile(new File(""), "df.tools.UDistrib", true); + Path partsPath = new Path(tempFile.toString()); + FileSystem pfs = partsPath.getFileSystem(conf); + + Path[] partPaths = new Path[numPartitions]; + FSDataOutputStream[] files = new FSDataOutputStream[numPartitions]; + for (int p = 0; p < numPartitions; p++) { + partPaths[p] = new Path(partsPath, String.format(Locale.ENGLISH, "part.%03d", p)); + files[p] = pfs.create(partPaths[p]); + } + + Path datasetPath = new Path(datasetStr); + Dataset dataset = Dataset.load(conf, datasetPath); + + // currents[label] = next partition file where to place the tuple + int[] currents = new int[dataset.nblabels()]; + + // currents is initialized randomly in the range [0, numpartitions[ + Random random = RandomUtils.getRandom(); + for (int c = 0; c < currents.length; c++) { + currents[c] = random.nextInt(numPartitions); + } + + // foreach tuple of the data + Path dataPath = new Path(dataStr); + FileSystem ifs = dataPath.getFileSystem(conf); + FSDataInputStream input = ifs.open(dataPath); + Scanner scanner = new Scanner(input, "UTF-8"); + DataConverter converter = new DataConverter(dataset); + + int id = 0; + while (scanner.hasNextLine()) { + if (id % 1000 == 0) { + log.info("progress : {}", id); + } + + String line = scanner.nextLine(); + if (line.isEmpty()) { + continue; // skip empty lines + } + + // write the tuple in files[tuple.label] + Instance instance = converter.convert(line); + int label = (int) dataset.getLabel(instance); + files[currents[label]].writeBytes(line); + files[currents[label]].writeChar('\n'); + + // update currents + currents[label]++; + if (currents[label] == numPartitions) { + currents[label] = 0; + } + } + + // close all the files. + scanner.close(); + for (FSDataOutputStream file : files) { + Closeables.close(file, false); + } + + // merge all output files + FileUtil.copyMerge(pfs, partsPath, fs, outputPath, true, conf, null); + /* + * FSDataOutputStream joined = fs.create(new Path(outputPath, "uniform.data")); for (int p = 0; p < + * numPartitions; p++) {log.info("Joining part : {}", p); FSDataInputStream partStream = + * fs.open(partPaths[p]); + * + * IOUtils.copyBytes(partStream, joined, conf, false); + * + * partStream.close(); } + * + * joined.close(); + * + * fs.delete(partsPath, true); + */ + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java b/mr/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java new file mode 100644 index 0000000..049f9bf --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.evaluation; + +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.list.DoubleArrayList; + +import com.google.common.base.Preconditions; + +import java.util.Random; + +/** + * Computes AUC and a few other accuracy statistics without storing huge amounts of data. This is + * done by keeping uniform samples of the positive and negative scores. Then, when AUC is to be + * computed, the remaining scores are sorted and a rank-sum statistic is used to compute the AUC. + * Since AUC is invariant with respect to down-sampling of either positives or negatives, this is + * close to correct and is exactly correct if maxBufferSize or fewer positive and negative scores + * are examined. + */ +public class Auc { + + private int maxBufferSize = 10000; + private final DoubleArrayList[] scores = {new DoubleArrayList(), new DoubleArrayList()}; + private final Random rand; + private int samples; + private final double threshold; + private final Matrix confusion; + private final DenseMatrix entropy; + + private boolean probabilityScore = true; + + private boolean hasScore; + + /** + * Allocates a new data-structure for accumulating information about AUC and a few other accuracy + * measures. + * @param threshold The threshold to use in computing the confusion matrix. + */ + public Auc(double threshold) { + confusion = new DenseMatrix(2, 2); + entropy = new DenseMatrix(2, 2); + this.rand = RandomUtils.getRandom(); + this.threshold = threshold; + } + + public Auc() { + this(0.5); + } + + /** + * Adds a score to the AUC buffers. + * + * @param trueValue Whether this score is for a true-positive or a true-negative example. + * @param score The score for this example. + */ + public void add(int trueValue, double score) { + Preconditions.checkArgument(trueValue == 0 || trueValue == 1, "True value must be 0 or 1"); + hasScore = true; + + int predictedClass = score > threshold ? 1 : 0; + confusion.set(trueValue, predictedClass, confusion.get(trueValue, predictedClass) + 1); + + samples++; + if (isProbabilityScore()) { + double limited = Math.max(1.0e-20, Math.min(score, 1 - 1.0e-20)); + double v0 = entropy.get(trueValue, 0); + entropy.set(trueValue, 0, (Math.log1p(-limited) - v0) / samples + v0); + + double v1 = entropy.get(trueValue, 1); + entropy.set(trueValue, 1, (Math.log(limited) - v1) / samples + v1); + } + + // add to buffers + DoubleArrayList buf = scores[trueValue]; + if (buf.size() >= maxBufferSize) { + // but if too many points are seen, we insert into a random + // place and discard the predecessor. The random place could + // be anywhere, possibly not even in the buffer. + // this is a special case of Knuth's permutation algorithm + // but since we don't ever shuffle the first maxBufferSize + // samples, the result isn't just a fair sample of the prefixes + // of all permutations. The CONTENTs of the result, however, + // will be a fair and uniform sample of maxBufferSize elements + // chosen from all elements without replacement + int index = rand.nextInt(samples); + if (index < buf.size()) { + buf.set(index, score); + } + } else { + // for small buffers, we collect all points without permuting + // since we sort the data later, permuting now would just be + // pedantic + buf.add(score); + } + } + + public void add(int trueValue, int predictedClass) { + hasScore = false; + Preconditions.checkArgument(trueValue == 0 || trueValue == 1, "True value must be 0 or 1"); + confusion.set(trueValue, predictedClass, confusion.get(trueValue, predictedClass) + 1); + } + + /** + * Computes the AUC of points seen so far. This can be moderately expensive since it requires + * that all points that have been retained be sorted. + * + * @return The value of the Area Under the receiver operating Curve. + */ + public double auc() { + Preconditions.checkArgument(hasScore, "Can't compute AUC for classifier without a score"); + scores[0].sort(); + scores[1].sort(); + + double n0 = scores[0].size(); + double n1 = scores[1].size(); + + if (n0 == 0 || n1 == 0) { + return 0.5; + } + + // scan the data + int i0 = 0; + int i1 = 0; + int rank = 1; + double rankSum = 0; + while (i0 < n0 && i1 < n1) { + + double v0 = scores[0].get(i0); + double v1 = scores[1].get(i1); + + if (v0 < v1) { + i0++; + rank++; + } else if (v1 < v0) { + i1++; + rankSum += rank; + rank++; + } else { + // ties have to be handled delicately + double tieScore = v0; + + // how many negatives are tied? + int k0 = 0; + while (i0 < n0 && scores[0].get(i0) == tieScore) { + k0++; + i0++; + } + + // and how many positives + int k1 = 0; + while (i1 < n1 && scores[1].get(i1) == tieScore) { + k1++; + i1++; + } + + // we found k0 + k1 tied values which have + // ranks in the half open interval [rank, rank + k0 + k1) + // the average rank is assigned to all + rankSum += (rank + (k0 + k1 - 1) / 2.0) * k1; + rank += k0 + k1; + } + } + + if (i1 < n1) { + rankSum += (rank + (n1 - i1 - 1) / 2.0) * (n1 - i1); + rank += (int) (n1 - i1); + } + + return (rankSum / n1 - (n1 + 1) / 2) / n0; + } + + /** + * Returns the confusion matrix for the classifier supposing that we were to use a particular + * threshold. + * @return The confusion matrix. + */ + public Matrix confusion() { + return confusion; + } + + /** + * Returns a matrix related to the confusion matrix and to the log-likelihood. For a + * pretty accurate classifier, N + entropy is nearly the same as the confusion matrix + * because log(1-eps) \approx -eps if eps is small. + * + * For lower accuracy classifiers, this measure will give us a better picture of how + * things work our. + * + * Also, by definition, log-likelihood = sum(diag(entropy)) + * @return Returns a cell by cell break-down of the log-likelihood + */ + public Matrix entropy() { + if (!hasScore) { + // find a constant score that would optimize log-likelihood, but use a dash of Bayesian + // conservatism to avoid dividing by zero or taking log(0) + double p = (0.5 + confusion.get(1, 1)) / (1 + confusion.get(0, 0) + confusion.get(1, 1)); + entropy.set(0, 0, confusion.get(0, 0) * Math.log1p(-p)); + entropy.set(0, 1, confusion.get(0, 1) * Math.log(p)); + entropy.set(1, 0, confusion.get(1, 0) * Math.log1p(-p)); + entropy.set(1, 1, confusion.get(1, 1) * Math.log(p)); + } + return entropy; + } + + public void setMaxBufferSize(int maxBufferSize) { + this.maxBufferSize = maxBufferSize; + } + + public boolean isProbabilityScore() { + return probabilityScore; + } + + public void setProbabilityScore(boolean probabilityScore) { + this.probabilityScore = probabilityScore; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/mlp/MultilayerPerceptron.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/mlp/MultilayerPerceptron.java b/mr/src/main/java/org/apache/mahout/classifier/mlp/MultilayerPerceptron.java new file mode 100644 index 0000000..d3e9ff3 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/mlp/MultilayerPerceptron.java @@ -0,0 +1,90 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.classifier.mlp; + +import org.apache.mahout.classifier.OnlineLearner; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; + +import java.io.IOException; + +/** + * A Multilayer Perceptron (MLP) is a kind of feed-forward artificial neural + * network, which is a mathematical model inspired by the biological neural + * network. The Multilayer Perceptron can be used for various machine learning + * tasks such as classification and regression. + * + * A detailed introduction about MLP can be found at + * http://ufldl.stanford.edu/wiki/index.php/Neural_Networks. + * + * For this particular implementation, the users can freely control the topology + * of the MLP, including: 1. The size of the input layer; 2. The number of + * hidden layers; 3. The size of each hidden layer; 4. The size of the output + * layer. 5. The cost function. 6. The squashing function. + * + * The model is trained in an online learning approach, where the weights of + * neurons in the MLP is updated incremented using backPropagation algorithm + * proposed by (Rumelhart, D. E., Hinton, G. E., and Williams, R. J. (1986) + * Learning representations by back-propagating errors. Nature, 323, 533--536.) + */ +public class MultilayerPerceptron extends NeuralNetwork implements OnlineLearner { + + /** + * The default constructor. + */ + public MultilayerPerceptron() { + super(); + } + + /** + * Initialize the MLP by specifying the location of the model. + * + * @param modelPath The path of the model. + */ + public MultilayerPerceptron(String modelPath) throws IOException { + super(modelPath); + } + + @Override + public void train(int actual, Vector instance) { + // construct the training instance, where append the actual to instance + Vector trainingInstance = new DenseVector(instance.size() + 1); + for (int i = 0; i < instance.size(); ++i) { + trainingInstance.setQuick(i, instance.getQuick(i)); + } + trainingInstance.setQuick(instance.size(), actual); + this.trainOnline(trainingInstance); + } + + @Override + public void train(long trackingKey, String groupKey, int actual, + Vector instance) { + throw new UnsupportedOperationException(); + } + + @Override + public void train(long trackingKey, int actual, Vector instance) { + throw new UnsupportedOperationException(); + } + + @Override + public void close() { + // DO NOTHING + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetwork.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetwork.java b/mr/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetwork.java new file mode 100644 index 0000000..cfbe5c4 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetwork.java @@ -0,0 +1,743 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.classifier.mlp; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FSDataInputStream; +import org.apache.hadoop.fs.FSDataOutputStream; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.WritableUtils; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.common.RandomWrapper; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.MatrixWritable; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.function.DoubleDoubleFunction; +import org.apache.mahout.math.function.DoubleFunction; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; +import com.google.common.io.Closeables; + +/** + * AbstractNeuralNetwork defines the general operations for a neural network + * based model. Typically, all derivative models such as Multilayer Perceptron + * and Autoencoder consist of neurons and the weights between neurons. + */ +public abstract class NeuralNetwork { + + private static final Logger log = LoggerFactory.getLogger(NeuralNetwork.class); + + /* The default learning rate */ + public static final double DEFAULT_LEARNING_RATE = 0.5; + /* The default regularization weight */ + public static final double DEFAULT_REGULARIZATION_WEIGHT = 0; + /* The default momentum weight */ + public static final double DEFAULT_MOMENTUM_WEIGHT = 0.1; + + public static enum TrainingMethod { GRADIENT_DESCENT } + + /* The name of the model */ + protected String modelType; + + /* The path to store the model */ + protected String modelPath; + + protected double learningRate; + + /* The weight of regularization */ + protected double regularizationWeight; + + /* The momentum weight */ + protected double momentumWeight; + + /* The cost function of the model */ + protected String costFunctionName; + + /* Record the size of each layer */ + protected List<Integer> layerSizeList; + + /* Training method used for training the model */ + protected TrainingMethod trainingMethod; + + /* Weights between neurons at adjacent layers */ + protected List<Matrix> weightMatrixList; + + /* Previous weight updates between neurons at adjacent layers */ + protected List<Matrix> prevWeightUpdatesList; + + /* Different layers can have different squashing function */ + protected List<String> squashingFunctionList; + + /* The index of the final layer */ + protected int finalLayerIndex; + + /** + * The default constructor that initializes the learning rate, regularization + * weight, and momentum weight by default. + */ + public NeuralNetwork() { + log.info("Initialize model..."); + learningRate = DEFAULT_LEARNING_RATE; + regularizationWeight = DEFAULT_REGULARIZATION_WEIGHT; + momentumWeight = DEFAULT_MOMENTUM_WEIGHT; + trainingMethod = TrainingMethod.GRADIENT_DESCENT; + costFunctionName = "Minus_Squared"; + modelType = getClass().getSimpleName(); + + layerSizeList = Lists.newArrayList(); + layerSizeList = Lists.newArrayList(); + weightMatrixList = Lists.newArrayList(); + prevWeightUpdatesList = Lists.newArrayList(); + squashingFunctionList = Lists.newArrayList(); + } + + /** + * Initialize the NeuralNetwork by specifying learning rate, momentum weight + * and regularization weight. + * + * @param learningRate The learning rate. + * @param momentumWeight The momentum weight. + * @param regularizationWeight The regularization weight. + */ + public NeuralNetwork(double learningRate, double momentumWeight, double regularizationWeight) { + this(); + setLearningRate(learningRate); + setMomentumWeight(momentumWeight); + setRegularizationWeight(regularizationWeight); + } + + /** + * Initialize the NeuralNetwork by specifying the location of the model. + * + * @param modelPath The location that the model is stored. + */ + public NeuralNetwork(String modelPath) throws IOException { + this.modelPath = modelPath; + readFromModel(); + } + + /** + * Get the type of the model. + * + * @return The name of the model. + */ + public String getModelType() { + return this.modelType; + } + + /** + * Set the degree of aggression during model training, a large learning rate + * can increase the training speed, but it also decreases the chance of model + * converge. + * + * @param learningRate Learning rate must be a non-negative value. Recommend in range (0, 0.5). + * @return The model instance. + */ + public final NeuralNetwork setLearningRate(double learningRate) { + Preconditions.checkArgument(learningRate > 0, "Learning rate must be larger than 0."); + this.learningRate = learningRate; + return this; + } + + /** + * Get the value of learning rate. + * + * @return The value of learning rate. + */ + public double getLearningRate() { + return learningRate; + } + + /** + * Set the regularization weight. More complex the model is, less weight the + * regularization is. + * + * @param regularizationWeight regularization must be in the range [0, 0.1). + * @return The model instance. + */ + public final NeuralNetwork setRegularizationWeight(double regularizationWeight) { + Preconditions.checkArgument(regularizationWeight >= 0 + && regularizationWeight < 0.1, "Regularization weight must be in range [0, 0.1)"); + this.regularizationWeight = regularizationWeight; + return this; + } + + /** + * Get the weight of the regularization. + * + * @return The weight of regularization. + */ + public double getRegularizationWeight() { + return regularizationWeight; + } + + /** + * Set the momentum weight for the model. + * + * @param momentumWeight momentumWeight must be in range [0, 0.5]. + * @return The model instance. + */ + public final NeuralNetwork setMomentumWeight(double momentumWeight) { + Preconditions.checkArgument(momentumWeight >= 0 && momentumWeight <= 1.0, + "Momentum weight must be in range [0, 1.0]"); + this.momentumWeight = momentumWeight; + return this; + } + + /** + * Get the momentum weight. + * + * @return The value of momentum. + */ + public double getMomentumWeight() { + return momentumWeight; + } + + /** + * Set the training method. + * + * @param method The training method, currently supports GRADIENT_DESCENT. + * @return The instance of the model. + */ + public NeuralNetwork setTrainingMethod(TrainingMethod method) { + this.trainingMethod = method; + return this; + } + + /** + * Get the training method. + * + * @return The training method enumeration. + */ + public TrainingMethod getTrainingMethod() { + return trainingMethod; + } + + /** + * Set the cost function for the model. + * + * @param costFunction the name of the cost function. Currently supports + * "Minus_Squared", "Cross_Entropy". + */ + public NeuralNetwork setCostFunction(String costFunction) { + this.costFunctionName = costFunction; + return this; + } + + /** + * Add a layer of neurons with specified size. If the added layer is not the + * first layer, it will automatically connect the neurons between with the + * previous layer. + * + * @param size The size of the layer. (bias neuron excluded) + * @param isFinalLayer If false, add a bias neuron. + * @param squashingFunctionName The squashing function for this layer, input + * layer is f(x) = x by default. + * @return The layer index, starts with 0. + */ + public int addLayer(int size, boolean isFinalLayer, String squashingFunctionName) { + Preconditions.checkArgument(size > 0, "Size of layer must be larger than 0."); + log.info("Add layer with size {} and squashing function {}", size, squashingFunctionName); + int actualSize = size; + if (!isFinalLayer) { + actualSize += 1; + } + + layerSizeList.add(actualSize); + int layerIndex = layerSizeList.size() - 1; + if (isFinalLayer) { + finalLayerIndex = layerIndex; + } + + // Add weights between current layer and previous layer, and input layer has no squashing function + if (layerIndex > 0) { + int sizePrevLayer = layerSizeList.get(layerIndex - 1); + // Row count equals to size of current size and column count equal to size of previous layer + int row = isFinalLayer ? actualSize : actualSize - 1; + Matrix weightMatrix = new DenseMatrix(row, sizePrevLayer); + // Initialize weights + final RandomWrapper rnd = RandomUtils.getRandom(); + weightMatrix.assign(new DoubleFunction() { + @Override + public double apply(double value) { + return rnd.nextDouble() - 0.5; + } + }); + weightMatrixList.add(weightMatrix); + prevWeightUpdatesList.add(new DenseMatrix(row, sizePrevLayer)); + squashingFunctionList.add(squashingFunctionName); + } + return layerIndex; + } + + /** + * Get the size of a particular layer. + * + * @param layer The index of the layer, starting from 0. + * @return The size of the corresponding layer. + */ + public int getLayerSize(int layer) { + Preconditions.checkArgument(layer >= 0 && layer < this.layerSizeList.size(), + String.format("Input must be in range [0, %d]\n", this.layerSizeList.size() - 1)); + return layerSizeList.get(layer); + } + + /** + * Get the layer size list. + * + * @return The sizes of the layers. + */ + protected List<Integer> getLayerSizeList() { + return layerSizeList; + } + + /** + * Get the weights between layer layerIndex and layerIndex + 1 + * + * @param layerIndex The index of the layer. + * @return The weights in form of {@link Matrix}. + */ + public Matrix getWeightsByLayer(int layerIndex) { + return weightMatrixList.get(layerIndex); + } + + /** + * Update the weight matrices with given matrices. + * + * @param matrices The weight matrices, must be the same dimension as the + * existing matrices. + */ + public void updateWeightMatrices(Matrix[] matrices) { + for (int i = 0; i < matrices.length; ++i) { + Matrix matrix = weightMatrixList.get(i); + weightMatrixList.set(i, matrix.plus(matrices[i])); + } + } + + /** + * Set the weight matrices. + * + * @param matrices The weight matrices, must be the same dimension of the + * existing matrices. + */ + public void setWeightMatrices(Matrix[] matrices) { + weightMatrixList = Lists.newArrayList(); + Collections.addAll(weightMatrixList, matrices); + } + + /** + * Set the weight matrix for a specified layer. + * + * @param index The index of the matrix, starting from 0 (between layer 0 and 1). + * @param matrix The instance of {@link Matrix}. + */ + public void setWeightMatrix(int index, Matrix matrix) { + Preconditions.checkArgument(0 <= index && index < weightMatrixList.size(), + String.format("index [%s] should be in range [%s, %s).", index, 0, weightMatrixList.size())); + weightMatrixList.set(index, matrix); + } + + /** + * Get all the weight matrices. + * + * @return The weight matrices. + */ + public Matrix[] getWeightMatrices() { + Matrix[] matrices = new Matrix[weightMatrixList.size()]; + weightMatrixList.toArray(matrices); + return matrices; + } + + /** + * Get the output calculated by the model. + * + * @param instance The feature instance in form of {@link Vector}, each dimension contains the value of the corresponding feature. + * @return The output vector. + */ + public Vector getOutput(Vector instance) { + Preconditions.checkArgument(layerSizeList.get(0) == instance.size() + 1, + String.format("The dimension of input instance should be %d, but the input has dimension %d.", + layerSizeList.get(0) - 1, instance.size())); + + // add bias feature + Vector instanceWithBias = new DenseVector(instance.size() + 1); + // set bias to be a little bit less than 1.0 + instanceWithBias.set(0, 0.99999); + for (int i = 1; i < instanceWithBias.size(); ++i) { + instanceWithBias.set(i, instance.get(i - 1)); + } + + List<Vector> outputCache = getOutputInternal(instanceWithBias); + // return the output of the last layer + Vector result = outputCache.get(outputCache.size() - 1); + // remove bias + return result.viewPart(1, result.size() - 1); + } + + /** + * Calculate output internally, the intermediate output of each layer will be + * stored. + * + * @param instance The feature instance in form of {@link Vector}, each dimension contains the value of the corresponding feature. + * @return Cached output of each layer. + */ + protected List<Vector> getOutputInternal(Vector instance) { + List<Vector> outputCache = Lists.newArrayList(); + // fill with instance + Vector intermediateOutput = instance; + outputCache.add(intermediateOutput); + + for (int i = 0; i < layerSizeList.size() - 1; ++i) { + intermediateOutput = forward(i, intermediateOutput); + outputCache.add(intermediateOutput); + } + return outputCache; + } + + /** + * Forward the calculation for one layer. + * + * @param fromLayer The index of the previous layer. + * @param intermediateOutput The intermediate output of previous layer. + * @return The intermediate results of the current layer. + */ + protected Vector forward(int fromLayer, Vector intermediateOutput) { + Matrix weightMatrix = weightMatrixList.get(fromLayer); + + Vector vec = weightMatrix.times(intermediateOutput); + vec = vec.assign(NeuralNetworkFunctions.getDoubleFunction(squashingFunctionList.get(fromLayer))); + + // add bias + Vector vecWithBias = new DenseVector(vec.size() + 1); + vecWithBias.set(0, 1); + for (int i = 0; i < vec.size(); ++i) { + vecWithBias.set(i + 1, vec.get(i)); + } + return vecWithBias; + } + + /** + * Train the neural network incrementally with given training instance. + * + * @param trainingInstance An training instance, including the features and the label(s). Its dimension must equals + * to the size of the input layer (bias neuron excluded) + the size + * of the output layer (a.k.a. the dimension of the labels). + */ + public void trainOnline(Vector trainingInstance) { + Matrix[] matrices = trainByInstance(trainingInstance); + updateWeightMatrices(matrices); + } + + /** + * Get the updated weights using one training instance. + * + * @param trainingInstance An training instance, including the features and the label(s). Its dimension must equals + * to the size of the input layer (bias neuron excluded) + the size + * of the output layer (a.k.a. the dimension of the labels). + * @return The update of each weight, in form of {@link Matrix} list. + */ + public Matrix[] trainByInstance(Vector trainingInstance) { + // validate training instance + int inputDimension = layerSizeList.get(0) - 1; + int outputDimension = layerSizeList.get(this.layerSizeList.size() - 1); + Preconditions.checkArgument(inputDimension + outputDimension == trainingInstance.size(), + String.format("The dimension of training instance is %d, but requires %d.", trainingInstance.size(), + inputDimension + outputDimension)); + + if (trainingMethod.equals(TrainingMethod.GRADIENT_DESCENT)) { + return trainByInstanceGradientDescent(trainingInstance); + } + throw new IllegalArgumentException("Training method is not supported."); + } + + /** + * Train by gradient descent. Get the updated weights using one training + * instance. + * + * @param trainingInstance An training instance, including the features and the label(s). Its dimension must equals + * to the size of the input layer (bias neuron excluded) + the size + * of the output layer (a.k.a. the dimension of the labels). + * @return The weight update matrices. + */ + private Matrix[] trainByInstanceGradientDescent(Vector trainingInstance) { + int inputDimension = layerSizeList.get(0) - 1; + + Vector inputInstance = new DenseVector(layerSizeList.get(0)); + inputInstance.set(0, 1); // add bias + for (int i = 0; i < inputDimension; ++i) { + inputInstance.set(i + 1, trainingInstance.get(i)); + } + + Vector labels = + trainingInstance.viewPart(inputInstance.size() - 1, trainingInstance.size() - inputInstance.size() + 1); + + // initialize weight update matrices + Matrix[] weightUpdateMatrices = new Matrix[weightMatrixList.size()]; + for (int m = 0; m < weightUpdateMatrices.length; ++m) { + weightUpdateMatrices[m] = + new DenseMatrix(weightMatrixList.get(m).rowSize(), weightMatrixList.get(m).columnSize()); + } + + List<Vector> internalResults = getOutputInternal(inputInstance); + + Vector deltaVec = new DenseVector(layerSizeList.get(layerSizeList.size() - 1)); + Vector output = internalResults.get(internalResults.size() - 1); + + final DoubleFunction derivativeSquashingFunction = + NeuralNetworkFunctions.getDerivativeDoubleFunction(squashingFunctionList.get(squashingFunctionList.size() - 1)); + + final DoubleDoubleFunction costFunction = + NeuralNetworkFunctions.getDerivativeDoubleDoubleFunction(costFunctionName); + + Matrix lastWeightMatrix = weightMatrixList.get(weightMatrixList.size() - 1); + + for (int i = 0; i < deltaVec.size(); ++i) { + double costFuncDerivative = costFunction.apply(labels.get(i), output.get(i + 1)); + // Add regularization + costFuncDerivative += regularizationWeight * lastWeightMatrix.viewRow(i).zSum(); + deltaVec.set(i, costFuncDerivative); + deltaVec.set(i, deltaVec.get(i) * derivativeSquashingFunction.apply(output.get(i + 1))); + } + + // Start from previous layer of output layer + for (int layer = layerSizeList.size() - 2; layer >= 0; --layer) { + deltaVec = backPropagate(layer, deltaVec, internalResults, weightUpdateMatrices[layer]); + } + + prevWeightUpdatesList = Arrays.asList(weightUpdateMatrices); + + return weightUpdateMatrices; + } + + /** + * Back-propagate the errors to from next layer to current layer. The weight + * updated information will be stored in the weightUpdateMatrices, and the + * delta of the prevLayer will be returned. + * + * @param currentLayerIndex Index of current layer. + * @param nextLayerDelta Delta of next layer. + * @param outputCache The output cache to store intermediate results. + * @param weightUpdateMatrix The weight update, in form of {@link Matrix}. + * @return The weight updates. + */ + private Vector backPropagate(int currentLayerIndex, Vector nextLayerDelta, + List<Vector> outputCache, Matrix weightUpdateMatrix) { + + // Get layer related information + final DoubleFunction derivativeSquashingFunction = + NeuralNetworkFunctions.getDerivativeDoubleFunction(squashingFunctionList.get(currentLayerIndex)); + Vector curLayerOutput = outputCache.get(currentLayerIndex); + Matrix weightMatrix = weightMatrixList.get(currentLayerIndex); + Matrix prevWeightMatrix = prevWeightUpdatesList.get(currentLayerIndex); + + // Next layer is not output layer, remove the delta of bias neuron + if (currentLayerIndex != layerSizeList.size() - 2) { + nextLayerDelta = nextLayerDelta.viewPart(1, nextLayerDelta.size() - 1); + } + + Vector delta = weightMatrix.transpose().times(nextLayerDelta); + + delta = delta.assign(curLayerOutput, new DoubleDoubleFunction() { + @Override + public double apply(double deltaElem, double curLayerOutputElem) { + return deltaElem * derivativeSquashingFunction.apply(curLayerOutputElem); + } + }); + + // Update weights + for (int i = 0; i < weightUpdateMatrix.rowSize(); ++i) { + for (int j = 0; j < weightUpdateMatrix.columnSize(); ++j) { + weightUpdateMatrix.set(i, j, -learningRate * nextLayerDelta.get(i) * + curLayerOutput.get(j) + this.momentumWeight * prevWeightMatrix.get(i, j)); + } + } + + return delta; + } + + /** + * Read the model meta-data from the specified location. + * + * @throws IOException + */ + protected void readFromModel() throws IOException { + log.info("Load model from {}", modelPath); + Preconditions.checkArgument(modelPath != null, "Model path has not been set."); + FSDataInputStream is = null; + try { + Path path = new Path(modelPath); + FileSystem fs = path.getFileSystem(new Configuration()); + is = new FSDataInputStream(fs.open(path)); + readFields(is); + } finally { + Closeables.close(is, true); + } + } + + /** + * Write the model data to specified location. + * + * @throws IOException + */ + public void writeModelToFile() throws IOException { + log.info("Write model to {}.", modelPath); + Preconditions.checkArgument(modelPath != null, "Model path has not been set."); + FSDataOutputStream stream = null; + try { + Path path = new Path(modelPath); + FileSystem fs = path.getFileSystem(new Configuration()); + stream = fs.create(path, true); + write(stream); + } finally { + Closeables.close(stream, false); + } + } + + /** + * Set the model path. + * + * @param modelPath The path of the model. + */ + public void setModelPath(String modelPath) { + this.modelPath = modelPath; + } + + /** + * Get the model path. + * + * @return The path of the model. + */ + public String getModelPath() { + return modelPath; + } + + /** + * Write the fields of the model to output. + * + * @param output The output instance. + * @throws IOException + */ + public void write(DataOutput output) throws IOException { + // Write model type + WritableUtils.writeString(output, modelType); + // Write learning rate + output.writeDouble(learningRate); + // Write model path + if (modelPath != null) { + WritableUtils.writeString(output, modelPath); + } else { + WritableUtils.writeString(output, "null"); + } + + // Write regularization weight + output.writeDouble(regularizationWeight); + // Write momentum weight + output.writeDouble(momentumWeight); + + // Write cost function + WritableUtils.writeString(output, costFunctionName); + + // Write layer size list + output.writeInt(layerSizeList.size()); + for (Integer aLayerSizeList : layerSizeList) { + output.writeInt(aLayerSizeList); + } + + WritableUtils.writeEnum(output, trainingMethod); + + // Write squashing functions + output.writeInt(squashingFunctionList.size()); + for (String aSquashingFunctionList : squashingFunctionList) { + WritableUtils.writeString(output, aSquashingFunctionList); + } + + // Write weight matrices + output.writeInt(this.weightMatrixList.size()); + for (Matrix aWeightMatrixList : weightMatrixList) { + MatrixWritable.writeMatrix(output, aWeightMatrixList); + } + } + + /** + * Read the fields of the model from input. + * + * @param input The input instance. + * @throws IOException + */ + public void readFields(DataInput input) throws IOException { + // Read model type + modelType = WritableUtils.readString(input); + if (!modelType.equals(this.getClass().getSimpleName())) { + throw new IllegalArgumentException("The specified location does not contains the valid NeuralNetwork model."); + } + // Read learning rate + learningRate = input.readDouble(); + // Read model path + modelPath = WritableUtils.readString(input); + if (modelPath.equals("null")) { + modelPath = null; + } + + // Read regularization weight + regularizationWeight = input.readDouble(); + // Read momentum weight + momentumWeight = input.readDouble(); + + // Read cost function + costFunctionName = WritableUtils.readString(input); + + // Read layer size list + int numLayers = input.readInt(); + layerSizeList = Lists.newArrayList(); + for (int i = 0; i < numLayers; i++) { + layerSizeList.add(input.readInt()); + } + + trainingMethod = WritableUtils.readEnum(input, TrainingMethod.class); + + // Read squash functions + int squashingFunctionSize = input.readInt(); + squashingFunctionList = Lists.newArrayList(); + for (int i = 0; i < squashingFunctionSize; i++) { + squashingFunctionList.add(WritableUtils.readString(input)); + } + + // Read weights and construct matrices of previous updates + int numOfMatrices = input.readInt(); + weightMatrixList = Lists.newArrayList(); + prevWeightUpdatesList = Lists.newArrayList(); + for (int i = 0; i < numOfMatrices; i++) { + Matrix matrix = MatrixWritable.readMatrix(input); + weightMatrixList.add(matrix); + prevWeightUpdatesList.add(new DenseMatrix(matrix.rowSize(), matrix.columnSize())); + } + } + +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetworkFunctions.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetworkFunctions.java b/mr/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetworkFunctions.java new file mode 100644 index 0000000..8fd0176 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetworkFunctions.java @@ -0,0 +1,150 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.classifier.mlp; + +import org.apache.mahout.math.function.DoubleDoubleFunction; +import org.apache.mahout.math.function.DoubleFunction; +import org.apache.mahout.math.function.Functions; + +/** + * The functions that will be used by NeuralNetwork. + */ +public class NeuralNetworkFunctions { + + /** + * The derivation of identity function (f(x) = x). + */ + public static DoubleFunction derivativeIdentityFunction = new DoubleFunction() { + @Override + public double apply(double x) { + return 1; + } + }; + + /** + * The derivation of minus squared function (f(t, o) = (o - t)^2). + */ + public static DoubleDoubleFunction derivativeMinusSquared = new DoubleDoubleFunction() { + @Override + public double apply(double target, double output) { + return 2 * (output - target); + } + }; + + /** + * The cross entropy function (f(t, o) = -t * log(o) - (1 - t) * log(1 - o)). + */ + public static DoubleDoubleFunction crossEntropy = new DoubleDoubleFunction() { + @Override + public double apply(double target, double output) { + return -target * Math.log(output) - (1 - target) * Math.log(1 - output); + } + }; + + /** + * The derivation of cross entropy function (f(t, o) = -t * log(o) - (1 - t) * + * log(1 - o)). + */ + public static DoubleDoubleFunction derivativeCrossEntropy = new DoubleDoubleFunction() { + @Override + public double apply(double target, double output) { + double adjustedTarget = target; + double adjustedActual = output; + if (adjustedActual == 1) { + adjustedActual = 0.999; + } else if (output == 0) { + adjustedActual = 0.001; + } + if (adjustedTarget == 1) { + adjustedTarget = 0.999; + } else if (adjustedTarget == 0) { + adjustedTarget = 0.001; + } + return -adjustedTarget / adjustedActual + (1 - adjustedTarget) / (1 - adjustedActual); + } + }; + + /** + * Get the corresponding function by its name. + * Currently supports: "Identity", "Sigmoid". + * + * @param function The name of the function. + * @return The corresponding double function. + */ + public static DoubleFunction getDoubleFunction(String function) { + if (function.equalsIgnoreCase("Identity")) { + return Functions.IDENTITY; + } else if (function.equalsIgnoreCase("Sigmoid")) { + return Functions.SIGMOID; + } else { + throw new IllegalArgumentException("Function not supported."); + } + } + + /** + * Get the derivation double function by the name. + * Currently supports: "Identity", "Sigmoid". + * + * @param function The name of the function. + * @return The double function. + */ + public static DoubleFunction getDerivativeDoubleFunction(String function) { + if (function.equalsIgnoreCase("Identity")) { + return derivativeIdentityFunction; + } else if (function.equalsIgnoreCase("Sigmoid")) { + return Functions.SIGMOIDGRADIENT; + } else { + throw new IllegalArgumentException("Function not supported."); + } + } + + /** + * Get the corresponding double-double function by the name. + * Currently supports: "Minus_Squared", "Cross_Entropy". + * + * @param function The name of the function. + * @return The double-double function. + */ + public static DoubleDoubleFunction getDoubleDoubleFunction(String function) { + if (function.equalsIgnoreCase("Minus_Squared")) { + return Functions.MINUS_SQUARED; + } else if (function.equalsIgnoreCase("Cross_Entropy")) { + return derivativeCrossEntropy; + } else { + throw new IllegalArgumentException("Function not supported."); + } + } + + /** + * Get the corresponding derivation of double double function by the name. + * Currently supports: "Minus_Squared", "Cross_Entropy". + * + * @param function The name of the function. + * @return The double-double-function. + */ + public static DoubleDoubleFunction getDerivativeDoubleDoubleFunction(String function) { + if (function.equalsIgnoreCase("Minus_Squared")) { + return derivativeMinusSquared; + } else if (function.equalsIgnoreCase("Cross_Entropy")) { + return derivativeCrossEntropy; + } else { + throw new IllegalArgumentException("Function not supported."); + } + } + +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/mlp/RunMultilayerPerceptron.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/mlp/RunMultilayerPerceptron.java b/mr/src/main/java/org/apache/mahout/classifier/mlp/RunMultilayerPerceptron.java new file mode 100644 index 0000000..36d6792 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/mlp/RunMultilayerPerceptron.java @@ -0,0 +1,227 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.mlp; + +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.InputStreamReader; +import java.io.OutputStreamWriter; +import java.util.Arrays; +import java.util.List; + +import org.apache.commons.cli2.CommandLine; +import org.apache.commons.cli2.Group; +import org.apache.commons.cli2.Option; +import org.apache.commons.cli2.builder.ArgumentBuilder; +import org.apache.commons.cli2.builder.DefaultOptionBuilder; +import org.apache.commons.cli2.builder.GroupBuilder; +import org.apache.commons.cli2.commandline.Parser; +import org.apache.commons.csv.CSVUtils; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.collect.Lists; +import com.google.common.io.Closeables; + +/** Run {@link MultilayerPerceptron} classification. */ +public class RunMultilayerPerceptron { + + private static final Logger log = LoggerFactory.getLogger(RunMultilayerPerceptron.class); + + static class Parameters { + String inputFilePathStr; + String inputFileFormat; + String modelFilePathStr; + String outputFilePathStr; + int columnStart; + int columnEnd; + boolean skipHeader; + } + + public static void main(String[] args) throws Exception { + + Parameters parameters = new Parameters(); + + if (parseArgs(args, parameters)) { + log.info("Load model from {}.", parameters.modelFilePathStr); + MultilayerPerceptron mlp = new MultilayerPerceptron(parameters.modelFilePathStr); + + log.info("Topology of MLP: {}.", Arrays.toString(mlp.getLayerSizeList().toArray())); + + // validate the data + log.info("Read the data..."); + Path inputFilePath = new Path(parameters.inputFilePathStr); + FileSystem inputFS = inputFilePath.getFileSystem(new Configuration()); + if (!inputFS.exists(inputFilePath)) { + log.error("Input file '{}' does not exists!", parameters.inputFilePathStr); + mlp.close(); + return; + } + + Path outputFilePath = new Path(parameters.outputFilePathStr); + FileSystem outputFS = inputFilePath.getFileSystem(new Configuration()); + if (outputFS.exists(outputFilePath)) { + log.error("Output file '{}' already exists!", parameters.outputFilePathStr); + mlp.close(); + return; + } + + if (!parameters.inputFileFormat.equals("csv")) { + log.error("Currently only supports for csv format."); + mlp.close(); + return; // current only supports csv format + } + + log.info("Read from column {} to column {}.", parameters.columnStart, parameters.columnEnd); + + BufferedWriter writer = null; + BufferedReader reader = null; + + try { + writer = new BufferedWriter(new OutputStreamWriter(outputFS.create(outputFilePath))); + reader = new BufferedReader(new InputStreamReader(inputFS.open(inputFilePath))); + + String line; + + if (parameters.skipHeader) { + reader.readLine(); + } + + while ((line = reader.readLine()) != null) { + String[] tokens = CSVUtils.parseLine(line); + double[] features = new double[Math.min(parameters.columnEnd, tokens.length) - parameters.columnStart + 1]; + + for (int i = parameters.columnStart, j = 0; i < Math.min(parameters.columnEnd + 1, tokens.length); ++i, ++j) { + features[j] = Double.parseDouble(tokens[i]); + } + Vector featureVec = new DenseVector(features); + Vector res = mlp.getOutput(featureVec); + int mostProbablyLabelIndex = res.maxValueIndex(); + writer.write(String.valueOf(mostProbablyLabelIndex)); + } + mlp.close(); + log.info("Labeling finished."); + } finally { + Closeables.close(reader, true); + Closeables.close(writer, true); + } + } + } + + /** + * Parse the arguments. + * + * @param args The input arguments. + * @param parameters The parameters need to be filled. + * @return true or false + * @throws Exception + */ + private static boolean parseArgs(String[] args, Parameters parameters) throws Exception { + // build the options + log.info("Validate and parse arguments..."); + DefaultOptionBuilder optionBuilder = new DefaultOptionBuilder(); + GroupBuilder groupBuilder = new GroupBuilder(); + ArgumentBuilder argumentBuilder = new ArgumentBuilder(); + + Option inputFileFormatOption = optionBuilder + .withLongName("format") + .withShortName("f") + .withArgument(argumentBuilder.withName("file type").withDefault("csv").withMinimum(1).withMaximum(1).create()) + .withDescription("type of input file, currently support 'csv'") + .create(); + + List<Integer> columnRangeDefault = Lists.newArrayList(); + columnRangeDefault.add(0); + columnRangeDefault.add(Integer.MAX_VALUE); + + Option skipHeaderOption = optionBuilder.withLongName("skipHeader") + .withShortName("sh").withRequired(false) + .withDescription("whether to skip the first row of the input file") + .create(); + + Option inputColumnRangeOption = optionBuilder + .withLongName("columnRange") + .withShortName("cr") + .withDescription("the column range of the input file, start from 0") + .withArgument(argumentBuilder.withName("range").withMinimum(2).withMaximum(2) + .withDefaults(columnRangeDefault).create()).create(); + + Group inputFileTypeGroup = groupBuilder.withOption(skipHeaderOption) + .withOption(inputColumnRangeOption).withOption(inputFileFormatOption) + .create(); + + Option inputOption = optionBuilder + .withLongName("input") + .withShortName("i") + .withRequired(true) + .withArgument(argumentBuilder.withName("file path").withMinimum(1).withMaximum(1).create()) + .withDescription("the file path of unlabelled dataset") + .withChildren(inputFileTypeGroup).create(); + + Option modelOption = optionBuilder + .withLongName("model") + .withShortName("mo") + .withRequired(true) + .withArgument(argumentBuilder.withName("model file").withMinimum(1).withMaximum(1).create()) + .withDescription("the file path of the model").create(); + + Option labelsOption = optionBuilder + .withLongName("labels") + .withShortName("labels") + .withArgument(argumentBuilder.withName("label-name").withMinimum(2).create()) + .withDescription("an ordered list of label names").create(); + + Group labelsGroup = groupBuilder.withOption(labelsOption).create(); + + Option outputOption = optionBuilder + .withLongName("output") + .withShortName("o") + .withRequired(true) + .withArgument(argumentBuilder.withConsumeRemaining("file path").withMinimum(1).withMaximum(1).create()) + .withDescription("the file path of labelled results").withChildren(labelsGroup).create(); + + // parse the input + Parser parser = new Parser(); + Group normalOption = groupBuilder.withOption(inputOption).withOption(modelOption).withOption(outputOption).create(); + parser.setGroup(normalOption); + CommandLine commandLine = parser.parseAndHelp(args); + if (commandLine == null) { + return false; + } + + // obtain the arguments + parameters.inputFilePathStr = TrainMultilayerPerceptron.getString(commandLine, inputOption); + parameters.inputFileFormat = TrainMultilayerPerceptron.getString(commandLine, inputFileFormatOption); + parameters.skipHeader = commandLine.hasOption(skipHeaderOption); + parameters.modelFilePathStr = TrainMultilayerPerceptron.getString(commandLine, modelOption); + parameters.outputFilePathStr = TrainMultilayerPerceptron.getString(commandLine, outputOption); + + List<?> columnRange = commandLine.getValues(inputColumnRangeOption); + parameters.columnStart = Integer.parseInt(columnRange.get(0).toString()); + parameters.columnEnd = Integer.parseInt(columnRange.get(1).toString()); + + return true; + } + +}
