http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/Builder.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/Builder.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/Builder.java new file mode 100644 index 0000000..32d7b5c --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/Builder.java @@ -0,0 +1,333 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.mapreduce; + +import com.google.common.base.Preconditions; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.filecache.DistributedCache; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.hadoop.mapreduce.Job; +import org.apache.mahout.classifier.df.DecisionForest; +import org.apache.mahout.classifier.df.builder.TreeBuilder; +import org.apache.mahout.classifier.df.data.Dataset; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Comparator; + +/** + * Base class for Mapred DecisionForest builders. Takes care of storing the parameters common to the mapred + * implementations.<br> + * The child classes must implement at least : + * <ul> + * <li>void configureJob(Job) : to further configure the job before its launch; and</li> + * <li>DecisionForest parseOutput(Job, PredictionCallback) : in order to convert the job outputs into a + * DecisionForest and its corresponding oob predictions</li> + * </ul> + * + */ +@Deprecated +public abstract class Builder { + + private static final Logger log = LoggerFactory.getLogger(Builder.class); + + private final TreeBuilder treeBuilder; + private final Path dataPath; + private final Path datasetPath; + private final Long seed; + private final Configuration conf; + private String outputDirName = "output"; + + protected Builder(TreeBuilder treeBuilder, Path dataPath, Path datasetPath, Long seed, Configuration conf) { + this.treeBuilder = treeBuilder; + this.dataPath = dataPath; + this.datasetPath = datasetPath; + this.seed = seed; + this.conf = new Configuration(conf); + } + + protected Path getDataPath() { + return dataPath; + } + + /** + * Return the value of "mapred.map.tasks". + * + * @param conf + * configuration + * @return number of map tasks + */ + public static int getNumMaps(Configuration conf) { + return conf.getInt("mapred.map.tasks", -1); + } + + /** + * Used only for DEBUG purposes. if false, the mappers doesn't output anything, so the builder has nothing + * to process + * + * @param conf + * configuration + * @return true if the builder has to return output. false otherwise + */ + protected static boolean isOutput(Configuration conf) { + return conf.getBoolean("debug.mahout.rf.output", true); + } + + /** + * Returns the random seed + * + * @param conf + * configuration + * @return null if no seed is available + */ + public static Long getRandomSeed(Configuration conf) { + String seed = conf.get("mahout.rf.random.seed"); + if (seed == null) { + return null; + } + + return Long.valueOf(seed); + } + + /** + * Sets the random seed value + * + * @param conf + * configuration + * @param seed + * random seed + */ + private static void setRandomSeed(Configuration conf, long seed) { + conf.setLong("mahout.rf.random.seed", seed); + } + + public static TreeBuilder getTreeBuilder(Configuration conf) { + String string = conf.get("mahout.rf.treebuilder"); + if (string == null) { + return null; + } + + return StringUtils.fromString(string); + } + + private static void setTreeBuilder(Configuration conf, TreeBuilder treeBuilder) { + conf.set("mahout.rf.treebuilder", StringUtils.toString(treeBuilder)); + } + + /** + * Get the number of trees for the map-reduce job. + * + * @param conf + * configuration + * @return number of trees to build + */ + public static int getNbTrees(Configuration conf) { + return conf.getInt("mahout.rf.nbtrees", -1); + } + + /** + * Set the number of trees to grow for the map-reduce job + * + * @param conf + * configuration + * @param nbTrees + * number of trees to build + * @throws IllegalArgumentException + * if (nbTrees <= 0) + */ + public static void setNbTrees(Configuration conf, int nbTrees) { + Preconditions.checkArgument(nbTrees > 0, "nbTrees should be greater than 0"); + + conf.setInt("mahout.rf.nbtrees", nbTrees); + } + + /** + * Sets the Output directory name, will be creating in the working directory + * + * @param name + * output dir. name + */ + public void setOutputDirName(String name) { + outputDirName = name; + } + + /** + * Output Directory name + * + * @param conf + * configuration + * @return output dir. path (%WORKING_DIRECTORY%/OUTPUT_DIR_NAME%) + * @throws IOException + * if we cannot get the default FileSystem + */ + protected Path getOutputPath(Configuration conf) throws IOException { + // the output directory is accessed only by this class, so use the default + // file system + FileSystem fs = FileSystem.get(conf); + return new Path(fs.getWorkingDirectory(), outputDirName); + } + + /** + * Helper method. Get a path from the DistributedCache + * + * @param conf + * configuration + * @param index + * index of the path in the DistributedCache files + * @return path from the DistributedCache + * @throws IOException + * if no path is found + */ + public static Path getDistributedCacheFile(Configuration conf, int index) throws IOException { + Path[] files = HadoopUtil.getCachedFiles(conf); + + if (files.length <= index) { + throw new IOException("path not found in the DistributedCache"); + } + + return files[index]; + } + + /** + * Helper method. Load a Dataset stored in the DistributedCache + * + * @param conf + * configuration + * @return loaded Dataset + * @throws IOException + * if we cannot retrieve the Dataset path from the DistributedCache, or the Dataset could not be + * loaded + */ + public static Dataset loadDataset(Configuration conf) throws IOException { + Path datasetPath = getDistributedCacheFile(conf, 0); + + return Dataset.load(conf, datasetPath); + } + + /** + * Used by the inheriting classes to configure the job + * + * + * @param job + * Hadoop's Job + * @throws IOException + * if anything goes wrong while configuring the job + */ + protected abstract void configureJob(Job job) throws IOException; + + /** + * Sequential implementation should override this method to simulate the job execution + * + * @param job + * Hadoop's job + * @return true is the job succeeded + */ + protected boolean runJob(Job job) throws ClassNotFoundException, IOException, InterruptedException { + return job.waitForCompletion(true); + } + + /** + * Parse the output files to extract the trees and pass the predictions to the callback + * + * @param job + * Hadoop's job + * @return Built DecisionForest + * @throws IOException + * if anything goes wrong while parsing the output + */ + protected abstract DecisionForest parseOutput(Job job) throws IOException; + + public DecisionForest build(int nbTrees) + throws IOException, ClassNotFoundException, InterruptedException { + // int numTrees = getNbTrees(conf); + + Path outputPath = getOutputPath(conf); + FileSystem fs = outputPath.getFileSystem(conf); + + // check the output + if (fs.exists(outputPath)) { + throw new IOException("Output path already exists : " + outputPath); + } + + if (seed != null) { + setRandomSeed(conf, seed); + } + setNbTrees(conf, nbTrees); + setTreeBuilder(conf, treeBuilder); + + // put the dataset into the DistributedCache + DistributedCache.addCacheFile(datasetPath.toUri(), conf); + + Job job = new Job(conf, "decision forest builder"); + + log.debug("Configuring the job..."); + configureJob(job); + + log.debug("Running the job..."); + if (!runJob(job)) { + log.error("Job failed!"); + return null; + } + + if (isOutput(conf)) { + log.debug("Parsing the output..."); + DecisionForest forest = parseOutput(job); + HadoopUtil.delete(conf, outputPath); + return forest; + } + + return null; + } + + /** + * sort the splits into order based on size, so that the biggest go first.<br> + * This is the same code used by Hadoop's JobClient. + * + * @param splits + * input splits + */ + public static void sortSplits(InputSplit[] splits) { + Arrays.sort(splits, new Comparator<InputSplit>() { + @Override + public int compare(InputSplit a, InputSplit b) { + try { + long left = a.getLength(); + long right = b.getLength(); + if (left == right) { + return 0; + } else if (left < right) { + return 1; + } else { + return -1; + } + } catch (IOException ie) { + throw new IllegalStateException("Problem getting input split size", ie); + } catch (InterruptedException ie) { + throw new IllegalStateException("Problem getting input split size", ie); + } + } + }); + } + +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/Classifier.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/Classifier.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/Classifier.java new file mode 100644 index 0000000..1a35cfe --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/Classifier.java @@ -0,0 +1,238 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.mapreduce; + +import com.google.common.io.Closeables; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.filecache.DistributedCache; +import org.apache.hadoop.fs.FSDataOutputStream; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.DoubleWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.JobContext; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; +import org.apache.hadoop.mapreduce.lib.input.FileSplit; +import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.mahout.classifier.df.DFUtils; +import org.apache.mahout.classifier.df.DecisionForest; +import org.apache.mahout.classifier.df.data.DataConverter; +import org.apache.mahout.classifier.df.data.Dataset; +import org.apache.mahout.classifier.df.data.Instance; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +/** + * Mapreduce implementation that classifies the Input data using a previousely built decision forest + */ +@Deprecated +public class Classifier { + + private static final Logger log = LoggerFactory.getLogger(Classifier.class); + + private final Path forestPath; + private final Path inputPath; + private final Path datasetPath; + private final Configuration conf; + private final Path outputPath; // path that will containt the final output of the classifier + private final Path mappersOutputPath; // mappers will output here + private double[][] results; + + public double[][] getResults() { + return results; + } + + public Classifier(Path forestPath, + Path inputPath, + Path datasetPath, + Path outputPath, + Configuration conf) { + this.forestPath = forestPath; + this.inputPath = inputPath; + this.datasetPath = datasetPath; + this.outputPath = outputPath; + this.conf = conf; + + mappersOutputPath = new Path(outputPath, "mappers"); + } + + private void configureJob(Job job) throws IOException { + + job.setJarByClass(Classifier.class); + + FileInputFormat.setInputPaths(job, inputPath); + FileOutputFormat.setOutputPath(job, mappersOutputPath); + + job.setOutputKeyClass(DoubleWritable.class); + job.setOutputValueClass(Text.class); + + job.setMapperClass(CMapper.class); + job.setNumReduceTasks(0); // no reducers + + job.setInputFormatClass(CTextInputFormat.class); + job.setOutputFormatClass(SequenceFileOutputFormat.class); + + } + + public void run() throws IOException, ClassNotFoundException, InterruptedException { + FileSystem fs = FileSystem.get(conf); + + // check the output + if (fs.exists(outputPath)) { + throw new IOException("Output path already exists : " + outputPath); + } + + log.info("Adding the dataset to the DistributedCache"); + // put the dataset into the DistributedCache + DistributedCache.addCacheFile(datasetPath.toUri(), conf); + + log.info("Adding the decision forest to the DistributedCache"); + DistributedCache.addCacheFile(forestPath.toUri(), conf); + + Job job = new Job(conf, "decision forest classifier"); + + log.info("Configuring the job..."); + configureJob(job); + + log.info("Running the job..."); + if (!job.waitForCompletion(true)) { + throw new IllegalStateException("Job failed!"); + } + + parseOutput(job); + + HadoopUtil.delete(conf, mappersOutputPath); + } + + /** + * Extract the prediction for each mapper and write them in the corresponding output file. + * The name of the output file is based on the name of the corresponding input file. + * Will compute the ConfusionMatrix if necessary. + */ + private void parseOutput(JobContext job) throws IOException { + Configuration conf = job.getConfiguration(); + FileSystem fs = mappersOutputPath.getFileSystem(conf); + + Path[] outfiles = DFUtils.listOutputFiles(fs, mappersOutputPath); + + // read all the output + List<double[]> resList = new ArrayList<>(); + for (Path path : outfiles) { + FSDataOutputStream ofile = null; + try { + for (Pair<DoubleWritable,Text> record : new SequenceFileIterable<DoubleWritable,Text>(path, true, conf)) { + double key = record.getFirst().get(); + String value = record.getSecond().toString(); + if (ofile == null) { + // this is the first value, it contains the name of the input file + ofile = fs.create(new Path(outputPath, value).suffix(".out")); + } else { + // The key contains the correct label of the data. The value contains a prediction + ofile.writeChars(value); // write the prediction + ofile.writeChar('\n'); + + resList.add(new double[]{key, Double.valueOf(value)}); + } + } + } finally { + Closeables.close(ofile, false); + } + } + results = new double[resList.size()][2]; + resList.toArray(results); + } + + /** + * TextInputFormat that does not split the input files. This ensures that each input file is processed by one single + * mapper. + */ + private static class CTextInputFormat extends TextInputFormat { + @Override + protected boolean isSplitable(JobContext jobContext, Path path) { + return false; + } + } + + public static class CMapper extends Mapper<LongWritable, Text, DoubleWritable, Text> { + + /** used to convert input values to data instances */ + private DataConverter converter; + private DecisionForest forest; + private final Random rng = RandomUtils.getRandom(); + private boolean first = true; + private final Text lvalue = new Text(); + private Dataset dataset; + private final DoubleWritable lkey = new DoubleWritable(); + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + super.setup(context); //To change body of overridden methods use File | Settings | File Templates. + + Configuration conf = context.getConfiguration(); + + Path[] files = HadoopUtil.getCachedFiles(conf); + + if (files.length < 2) { + throw new IOException("not enough paths in the DistributedCache"); + } + dataset = Dataset.load(conf, files[0]); + converter = new DataConverter(dataset); + + forest = DecisionForest.load(conf, files[1]); + if (forest == null) { + throw new InterruptedException("DecisionForest not found!"); + } + } + + @Override + protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException { + if (first) { + FileSplit split = (FileSplit) context.getInputSplit(); + Path path = split.getPath(); // current split path + lvalue.set(path.getName()); + lkey.set(key.get()); + context.write(lkey, lvalue); + + first = false; + } + + String line = value.toString(); + if (!line.isEmpty()) { + Instance instance = converter.convert(line); + double prediction = forest.classify(dataset, rng, instance); + lkey.set(dataset.getLabel(instance)); + lvalue.set(Double.toString(prediction)); + context.write(lkey, lvalue); + } + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/MapredMapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/MapredMapper.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/MapredMapper.java new file mode 100644 index 0000000..4d0f3f1 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/MapredMapper.java @@ -0,0 +1,75 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.mapreduce; + +import com.google.common.base.Preconditions; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.classifier.df.builder.TreeBuilder; +import org.apache.mahout.classifier.df.data.Dataset; + +import java.io.IOException; + +/** + * Base class for Mapred mappers. Loads common parameters from the job + */ +@Deprecated +public class MapredMapper<KEYIN,VALUEIN,KEYOUT,VALUEOUT> extends Mapper<KEYIN,VALUEIN,KEYOUT,VALUEOUT> { + + private boolean noOutput; + + private TreeBuilder treeBuilder; + + private Dataset dataset; + + /** + * + * @return whether the mapper does estimate and output predictions + */ + protected boolean isOutput() { + return !noOutput; + } + + protected TreeBuilder getTreeBuilder() { + return treeBuilder; + } + + protected Dataset getDataset() { + return dataset; + } + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + super.setup(context); + + Configuration conf = context.getConfiguration(); + + configure(!Builder.isOutput(conf), Builder.getTreeBuilder(conf), Builder + .loadDataset(conf)); + } + + /** + * Useful for testing + */ + protected void configure(boolean noOutput, TreeBuilder treeBuilder, Dataset dataset) { + Preconditions.checkArgument(treeBuilder != null, "TreeBuilder not found in the Job parameters"); + this.noOutput = noOutput; + this.treeBuilder = treeBuilder; + this.dataset = dataset; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/MapredOutput.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/MapredOutput.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/MapredOutput.java new file mode 100644 index 0000000..56cabb2 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/MapredOutput.java @@ -0,0 +1,120 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.mapreduce; + +import org.apache.hadoop.io.Writable; +import org.apache.mahout.classifier.df.DFUtils; +import org.apache.mahout.classifier.df.node.Node; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.Arrays; + +/** + * Used by various implementation to return the results of a build.<br> + * Contains a grown tree and and its oob predictions. + */ +@Deprecated +public class MapredOutput implements Writable, Cloneable { + + private Node tree; + + private int[] predictions; + + public MapredOutput() { + } + + public MapredOutput(Node tree, int[] predictions) { + this.tree = tree; + this.predictions = predictions; + } + + public MapredOutput(Node tree) { + this(tree, null); + } + + public Node getTree() { + return tree; + } + + int[] getPredictions() { + return predictions; + } + + @Override + public void readFields(DataInput in) throws IOException { + boolean readTree = in.readBoolean(); + if (readTree) { + tree = Node.read(in); + } + + boolean readPredictions = in.readBoolean(); + if (readPredictions) { + predictions = DFUtils.readIntArray(in); + } + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeBoolean(tree != null); + if (tree != null) { + tree.write(out); + } + + out.writeBoolean(predictions != null); + if (predictions != null) { + DFUtils.writeArray(out, predictions); + } + } + + @Override + public MapredOutput clone() { + return new MapredOutput(tree, predictions); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof MapredOutput)) { + return false; + } + + MapredOutput mo = (MapredOutput) obj; + + return ((tree == null && mo.getTree() == null) || (tree != null && tree.equals(mo.getTree()))) + && Arrays.equals(predictions, mo.getPredictions()); + } + + @Override + public int hashCode() { + int hashCode = tree == null ? 1 : tree.hashCode(); + for (int prediction : predictions) { + hashCode = 31 * hashCode + prediction; + } + return hashCode; + } + + @Override + public String toString() { + return "{" + tree + " | " + Arrays.toString(predictions) + '}'; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemBuilder.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemBuilder.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemBuilder.java new file mode 100644 index 0000000..86d4404 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemBuilder.java @@ -0,0 +1,114 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.mapreduce.inmem; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.filecache.DistributedCache; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.mahout.classifier.df.DFUtils; +import org.apache.mahout.classifier.df.DecisionForest; +import org.apache.mahout.classifier.df.builder.TreeBuilder; +import org.apache.mahout.classifier.df.mapreduce.Builder; +import org.apache.mahout.classifier.df.mapreduce.MapredOutput; +import org.apache.mahout.classifier.df.node.Node; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; + +/** + * MapReduce implementation where each mapper loads a full copy of the data in-memory. The forest trees are + * splitted across all the mappers + */ +@Deprecated +public class InMemBuilder extends Builder { + + public InMemBuilder(TreeBuilder treeBuilder, Path dataPath, Path datasetPath, Long seed, Configuration conf) { + super(treeBuilder, dataPath, datasetPath, seed, conf); + } + + public InMemBuilder(TreeBuilder treeBuilder, Path dataPath, Path datasetPath) { + this(treeBuilder, dataPath, datasetPath, null, new Configuration()); + } + + @Override + protected void configureJob(Job job) throws IOException { + Configuration conf = job.getConfiguration(); + + job.setJarByClass(InMemBuilder.class); + + FileOutputFormat.setOutputPath(job, getOutputPath(conf)); + + // put the data in the DistributedCache + DistributedCache.addCacheFile(getDataPath().toUri(), conf); + + job.setOutputKeyClass(IntWritable.class); + job.setOutputValueClass(MapredOutput.class); + + job.setMapperClass(InMemMapper.class); + job.setNumReduceTasks(0); // no reducers + + job.setInputFormatClass(InMemInputFormat.class); + job.setOutputFormatClass(SequenceFileOutputFormat.class); + + } + + @Override + protected DecisionForest parseOutput(Job job) throws IOException { + Configuration conf = job.getConfiguration(); + + Map<Integer,MapredOutput> output = new HashMap<>(); + + Path outputPath = getOutputPath(conf); + FileSystem fs = outputPath.getFileSystem(conf); + + Path[] outfiles = DFUtils.listOutputFiles(fs, outputPath); + + // import the InMemOutputs + for (Path path : outfiles) { + for (Pair<IntWritable,MapredOutput> record : new SequenceFileIterable<IntWritable,MapredOutput>(path, conf)) { + output.put(record.getFirst().get(), record.getSecond()); + } + } + + return processOutput(output); + } + + /** + * Process the output, extracting the trees + */ + private static DecisionForest processOutput(Map<Integer,MapredOutput> output) { + List<Node> trees = new ArrayList<>(); + + for (Map.Entry<Integer,MapredOutput> entry : output.entrySet()) { + MapredOutput value = entry.getValue(); + trees.add(value.getTree()); + } + + return new DecisionForest(trees); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputFormat.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputFormat.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputFormat.java new file mode 100644 index 0000000..c3b2fa3 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputFormat.java @@ -0,0 +1,284 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.mapreduce.inmem; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Random; + +import com.google.common.base.Preconditions; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.NullWritable; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.mapreduce.InputFormat; +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.hadoop.mapreduce.JobContext; +import org.apache.hadoop.mapreduce.RecordReader; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.mahout.classifier.df.mapreduce.Builder; +import org.apache.mahout.common.RandomUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Custom InputFormat that generates InputSplits given the desired number of trees.<br> + * each input split contains a subset of the trees.<br> + * The number of splits is equal to the number of requested splits + */ +@Deprecated +public class InMemInputFormat extends InputFormat<IntWritable,NullWritable> { + + private static final Logger log = LoggerFactory.getLogger(InMemInputSplit.class); + + private Random rng; + + private Long seed; + + private boolean isSingleSeed; + + /** + * Used for DEBUG purposes only. if true and a seed is available, all the mappers use the same seed, thus + * all the mapper should take the same time to build their trees. + */ + private static boolean isSingleSeed(Configuration conf) { + return conf.getBoolean("debug.mahout.rf.single.seed", false); + } + + @Override + public RecordReader<IntWritable,NullWritable> createRecordReader(InputSplit split, TaskAttemptContext context) + throws IOException, InterruptedException { + Preconditions.checkArgument(split instanceof InMemInputSplit); + return new InMemRecordReader((InMemInputSplit) split); + } + + @Override + public List<InputSplit> getSplits(JobContext context) throws IOException, InterruptedException { + Configuration conf = context.getConfiguration(); + int numSplits = conf.getInt("mapred.map.tasks", -1); + + return getSplits(conf, numSplits); + } + + public List<InputSplit> getSplits(Configuration conf, int numSplits) { + int nbTrees = Builder.getNbTrees(conf); + int splitSize = nbTrees / numSplits; + + seed = Builder.getRandomSeed(conf); + isSingleSeed = isSingleSeed(conf); + + if (rng != null && seed != null) { + log.warn("getSplits() was called more than once and the 'seed' is set, " + + "this can lead to no-repeatable behavior"); + } + + rng = seed == null || isSingleSeed ? null : RandomUtils.getRandom(seed); + + int id = 0; + + List<InputSplit> splits = new ArrayList<>(numSplits); + + for (int index = 0; index < numSplits - 1; index++) { + splits.add(new InMemInputSplit(id, splitSize, nextSeed())); + id += splitSize; + } + + // take care of the remainder + splits.add(new InMemInputSplit(id, nbTrees - id, nextSeed())); + + return splits; + } + + /** + * @return the seed for the next InputSplit + */ + private Long nextSeed() { + if (seed == null) { + return null; + } else if (isSingleSeed) { + return seed; + } else { + return rng.nextLong(); + } + } + + public static class InMemRecordReader extends RecordReader<IntWritable,NullWritable> { + + private final InMemInputSplit split; + private int pos; + private IntWritable key; + private NullWritable value; + + public InMemRecordReader(InMemInputSplit split) { + this.split = split; + } + + @Override + public float getProgress() throws IOException { + return pos == 0 ? 0.0f : (float) (pos - 1) / split.nbTrees; + } + + @Override + public IntWritable getCurrentKey() throws IOException, InterruptedException { + return key; + } + + @Override + public NullWritable getCurrentValue() throws IOException, InterruptedException { + return value; + } + + @Override + public void initialize(InputSplit arg0, TaskAttemptContext arg1) throws IOException, InterruptedException { + key = new IntWritable(); + value = NullWritable.get(); + } + + @Override + public boolean nextKeyValue() throws IOException, InterruptedException { + if (pos < split.nbTrees) { + key.set(split.firstId + pos); + pos++; + return true; + } else { + return false; + } + } + + @Override + public void close() throws IOException { + } + + } + + /** + * Custom InputSplit that indicates how many trees are built by each mapper + */ + public static class InMemInputSplit extends InputSplit implements Writable { + + private static final String[] NO_LOCATIONS = new String[0]; + + /** Id of the first tree of this split */ + private int firstId; + + private int nbTrees; + + private Long seed; + + public InMemInputSplit() { } + + public InMemInputSplit(int firstId, int nbTrees, Long seed) { + this.firstId = firstId; + this.nbTrees = nbTrees; + this.seed = seed; + } + + /** + * @return the Id of the first tree of this split + */ + public int getFirstId() { + return firstId; + } + + /** + * @return the number of trees + */ + public int getNbTrees() { + return nbTrees; + } + + /** + * @return the random seed or null if no seed is available + */ + public Long getSeed() { + return seed; + } + + @Override + public long getLength() throws IOException { + return nbTrees; + } + + @Override + public String[] getLocations() throws IOException { + return NO_LOCATIONS; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof InMemInputSplit)) { + return false; + } + + InMemInputSplit split = (InMemInputSplit) obj; + + if (firstId != split.firstId || nbTrees != split.nbTrees) { + return false; + } + if (seed == null) { + return split.seed == null; + } else { + return seed.equals(split.seed); + } + + } + + @Override + public int hashCode() { + return firstId + nbTrees + (seed == null ? 0 : seed.intValue()); + } + + @Override + public String toString() { + return String.format(Locale.ENGLISH, "[firstId:%d, nbTrees:%d, seed:%d]", firstId, nbTrees, seed); + } + + @Override + public void readFields(DataInput in) throws IOException { + firstId = in.readInt(); + nbTrees = in.readInt(); + boolean isSeed = in.readBoolean(); + seed = isSeed ? in.readLong() : null; + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeInt(firstId); + out.writeInt(nbTrees); + out.writeBoolean(seed != null); + if (seed != null) { + out.writeLong(seed); + } + } + + public static InMemInputSplit read(DataInput in) throws IOException { + InMemInputSplit split = new InMemInputSplit(); + split.readFields(in); + return split; + } + + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemMapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemMapper.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemMapper.java new file mode 100644 index 0000000..2fc67ba --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemMapper.java @@ -0,0 +1,106 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.mapreduce.inmem; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.NullWritable; +import org.apache.mahout.classifier.df.Bagging; +import org.apache.mahout.classifier.df.data.Data; +import org.apache.mahout.classifier.df.data.DataLoader; +import org.apache.mahout.classifier.df.data.Dataset; +import org.apache.mahout.classifier.df.mapreduce.Builder; +import org.apache.mahout.classifier.df.mapreduce.MapredMapper; +import org.apache.mahout.classifier.df.mapreduce.MapredOutput; +import org.apache.mahout.classifier.df.mapreduce.inmem.InMemInputFormat.InMemInputSplit; +import org.apache.mahout.classifier.df.node.Node; +import org.apache.mahout.common.RandomUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Random; + +/** + * In-memory mapper that grows the trees using a full copy of the data loaded in-memory. The number of trees + * to grow is determined by the current InMemInputSplit. + */ +@Deprecated +public class InMemMapper extends MapredMapper<IntWritable,NullWritable,IntWritable,MapredOutput> { + + private static final Logger log = LoggerFactory.getLogger(InMemMapper.class); + + private Bagging bagging; + + private Random rng; + + /** + * Load the training data + */ + private static Data loadData(Configuration conf, Dataset dataset) throws IOException { + Path dataPath = Builder.getDistributedCacheFile(conf, 1); + FileSystem fs = FileSystem.get(dataPath.toUri(), conf); + return DataLoader.loadData(dataset, fs, dataPath); + } + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + super.setup(context); + + Configuration conf = context.getConfiguration(); + + log.info("Loading the data..."); + Data data = loadData(conf, getDataset()); + log.info("Data loaded : {} instances", data.size()); + + bagging = new Bagging(getTreeBuilder(), data); + } + + @Override + protected void map(IntWritable key, + NullWritable value, + Context context) throws IOException, InterruptedException { + map(key, context); + } + + void map(IntWritable key, Context context) throws IOException, InterruptedException { + + initRandom((InMemInputSplit) context.getInputSplit()); + + log.debug("Building..."); + Node tree = bagging.build(rng); + + if (isOutput()) { + log.debug("Outputing..."); + MapredOutput mrOut = new MapredOutput(tree); + + context.write(key, mrOut); + } + } + + void initRandom(InMemInputSplit split) { + if (rng == null) { // first execution of this mapper + Long seed = split.getSeed(); + log.debug("Initialising rng with seed : {}", seed); + rng = seed == null ? RandomUtils.getRandom() : RandomUtils.getRandom(seed); + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/package-info.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/package-info.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/package-info.java new file mode 100644 index 0000000..61e65e8 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/package-info.java @@ -0,0 +1,22 @@ +/** + * <h2>In-memory mapreduce implementation of Random Decision Forests</h2> + * + * <p>Each mapper is responsible for growing a number of trees with a whole copy of the dataset loaded in memory, + * it uses the reference implementation's code to build each tree and estimate the oob error.</p> + * + * <p>The dataset is distributed to the slave nodes using the {@link org.apache.hadoop.filecache.DistributedCache}. + * A custom {@link org.apache.hadoop.mapreduce.InputFormat} + * ({@link org.apache.mahout.classifier.df.mapreduce.inmem.InMemInputFormat}) is configured with the + * desired number of trees and generates a number of {@link org.apache.hadoop.mapreduce.InputSplit}s + * equal to the configured number of maps.</p> + * + * <p>There is no need for reducers, each map outputs (the trees it built and, for each tree, the labels the + * tree predicted for each out-of-bag instance. This step has to be done in the mapper because only there we + * know which instances are o-o-b.</p> + * + * <p>The Forest builder ({@link org.apache.mahout.classifier.df.mapreduce.inmem.InMemBuilder}) is responsible + * for configuring and launching the job. + * At the end of the job it parses the output files and builds the corresponding + * {@link org.apache.mahout.classifier.df.DecisionForest}.</p> + */ +package org.apache.mahout.classifier.df.mapreduce.inmem; http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/PartialBuilder.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/PartialBuilder.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/PartialBuilder.java new file mode 100644 index 0000000..9236af3 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/PartialBuilder.java @@ -0,0 +1,158 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.mapreduce.partial; + +import com.google.common.base.Preconditions; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.JobContext; +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; +import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.mahout.classifier.df.DFUtils; +import org.apache.mahout.classifier.df.DecisionForest; +import org.apache.mahout.classifier.df.builder.TreeBuilder; +import org.apache.mahout.classifier.df.mapreduce.Builder; +import org.apache.mahout.classifier.df.mapreduce.MapredOutput; +import org.apache.mahout.classifier.df.node.Node; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +/** + * Builds a random forest using partial data. Each mapper uses only the data given by its InputSplit + */ +@Deprecated +public class PartialBuilder extends Builder { + + private static final Logger log = LoggerFactory.getLogger(PartialBuilder.class); + + public PartialBuilder(TreeBuilder treeBuilder, Path dataPath, Path datasetPath, Long seed) { + this(treeBuilder, dataPath, datasetPath, seed, new Configuration()); + } + + public PartialBuilder(TreeBuilder treeBuilder, + Path dataPath, + Path datasetPath, + Long seed, + Configuration conf) { + super(treeBuilder, dataPath, datasetPath, seed, conf); + } + + @Override + protected void configureJob(Job job) throws IOException { + Configuration conf = job.getConfiguration(); + + job.setJarByClass(PartialBuilder.class); + + FileInputFormat.setInputPaths(job, getDataPath()); + FileOutputFormat.setOutputPath(job, getOutputPath(conf)); + + job.setOutputKeyClass(TreeID.class); + job.setOutputValueClass(MapredOutput.class); + + job.setMapperClass(Step1Mapper.class); + job.setNumReduceTasks(0); // no reducers + + job.setInputFormatClass(TextInputFormat.class); + job.setOutputFormatClass(SequenceFileOutputFormat.class); + + // For this implementation to work, mapred.map.tasks needs to be set to the actual + // number of mappers Hadoop will use: + TextInputFormat inputFormat = new TextInputFormat(); + List<?> splits = inputFormat.getSplits(job); + if (splits == null || splits.isEmpty()) { + log.warn("Unable to compute number of splits?"); + } else { + int numSplits = splits.size(); + log.info("Setting mapred.map.tasks = {}", numSplits); + conf.setInt("mapred.map.tasks", numSplits); + } + } + + @Override + protected DecisionForest parseOutput(Job job) throws IOException { + Configuration conf = job.getConfiguration(); + + int numTrees = Builder.getNbTrees(conf); + + Path outputPath = getOutputPath(conf); + + TreeID[] keys = new TreeID[numTrees]; + Node[] trees = new Node[numTrees]; + + processOutput(job, outputPath, keys, trees); + + return new DecisionForest(Arrays.asList(trees)); + } + + /** + * Processes the output from the output path.<br> + * + * @param outputPath + * directory that contains the output of the job + * @param keys + * can be null + * @param trees + * can be null + * @throws java.io.IOException + */ + protected static void processOutput(JobContext job, + Path outputPath, + TreeID[] keys, + Node[] trees) throws IOException { + Preconditions.checkArgument(keys == null && trees == null || keys != null && trees != null, + "if keys is null, trees should also be null"); + Preconditions.checkArgument(keys == null || keys.length == trees.length, "keys.length != trees.length"); + + Configuration conf = job.getConfiguration(); + + FileSystem fs = outputPath.getFileSystem(conf); + + Path[] outfiles = DFUtils.listOutputFiles(fs, outputPath); + + // read all the outputs + int index = 0; + for (Path path : outfiles) { + for (Pair<TreeID,MapredOutput> record : new SequenceFileIterable<TreeID, MapredOutput>(path, conf)) { + TreeID key = record.getFirst(); + MapredOutput value = record.getSecond(); + if (keys != null) { + keys[index] = key; + } + if (trees != null) { + trees[index] = value.getTree(); + } + index++; + } + } + + // make sure we got all the keys/values + if (keys != null && index != keys.length) { + throw new IllegalStateException("Some key/values are missing from the output"); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1Mapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1Mapper.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1Mapper.java new file mode 100644 index 0000000..9474236 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1Mapper.java @@ -0,0 +1,168 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.mapreduce.partial; + +import com.google.common.base.Preconditions; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Text; +import org.apache.mahout.classifier.df.Bagging; +import org.apache.mahout.classifier.df.data.Data; +import org.apache.mahout.classifier.df.data.DataConverter; +import org.apache.mahout.classifier.df.data.Instance; +import org.apache.mahout.classifier.df.mapreduce.Builder; +import org.apache.mahout.classifier.df.mapreduce.MapredMapper; +import org.apache.mahout.classifier.df.mapreduce.MapredOutput; +import org.apache.mahout.classifier.df.node.Node; +import org.apache.mahout.common.RandomUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +/** + * First step of the Partial Data Builder. Builds the trees using the data available in the InputSplit. + * Predict the oob classes for each tree in its growing partition (input split). + */ +@Deprecated +public class Step1Mapper extends MapredMapper<LongWritable,Text,TreeID,MapredOutput> { + + private static final Logger log = LoggerFactory.getLogger(Step1Mapper.class); + + /** used to convert input values to data instances */ + private DataConverter converter; + + private Random rng; + + /** number of trees to be built by this mapper */ + private int nbTrees; + + /** id of the first tree */ + private int firstTreeId; + + /** mapper's partition */ + private int partition; + + /** will contain all instances if this mapper's split */ + private final List<Instance> instances = new ArrayList<>(); + + public int getFirstTreeId() { + return firstTreeId; + } + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + super.setup(context); + Configuration conf = context.getConfiguration(); + + configure(Builder.getRandomSeed(conf), conf.getInt("mapred.task.partition", -1), + Builder.getNumMaps(conf), Builder.getNbTrees(conf)); + } + + /** + * Useful when testing + * + * @param partition + * current mapper inputSplit partition + * @param numMapTasks + * number of running map tasks + * @param numTrees + * total number of trees in the forest + */ + protected void configure(Long seed, int partition, int numMapTasks, int numTrees) { + converter = new DataConverter(getDataset()); + + // prepare random-numders generator + log.debug("seed : {}", seed); + if (seed == null) { + rng = RandomUtils.getRandom(); + } else { + rng = RandomUtils.getRandom(seed); + } + + // mapper's partition + Preconditions.checkArgument(partition >= 0, "Wrong partition ID: " + partition + ". Partition must be >= 0!"); + this.partition = partition; + + // compute number of trees to build + nbTrees = nbTrees(numMapTasks, numTrees, partition); + + // compute first tree id + firstTreeId = 0; + for (int p = 0; p < partition; p++) { + firstTreeId += nbTrees(numMapTasks, numTrees, p); + } + + log.debug("partition : {}", partition); + log.debug("nbTrees : {}", nbTrees); + log.debug("firstTreeId : {}", firstTreeId); + } + + /** + * Compute the number of trees for a given partition. The first partitions may be longer + * than the rest because of the remainder. + * + * @param numMaps + * total number of maps (partitions) + * @param numTrees + * total number of trees to build + * @param partition + * partition to compute the number of trees for + */ + public static int nbTrees(int numMaps, int numTrees, int partition) { + int treesPerMapper = numTrees / numMaps; + int remainder = numTrees - numMaps * treesPerMapper; + return treesPerMapper + (partition < remainder ? 1 : 0); + } + + @Override + protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException { + instances.add(converter.convert(value.toString())); + } + + @Override + protected void cleanup(Context context) throws IOException, InterruptedException { + // prepare the data + log.debug("partition: {} numInstances: {}", partition, instances.size()); + + Data data = new Data(getDataset(), instances); + Bagging bagging = new Bagging(getTreeBuilder(), data); + + TreeID key = new TreeID(); + + log.debug("Building {} trees", nbTrees); + for (int treeId = 0; treeId < nbTrees; treeId++) { + log.debug("Building tree number : {}", treeId); + + Node tree = bagging.build(rng); + + key.set(partition, firstTreeId + treeId); + + if (isOutput()) { + MapredOutput emOut = new MapredOutput(tree); + context.write(key, emOut); + } + + context.progress(); + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/TreeID.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/TreeID.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/TreeID.java new file mode 100644 index 0000000..c296061 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/TreeID.java @@ -0,0 +1,58 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.mapreduce.partial; + +import com.google.common.base.Preconditions; +import org.apache.hadoop.io.LongWritable; + +/** + * Indicates both the tree and the data partition used to grow the tree + */ +@Deprecated +public class TreeID extends LongWritable implements Cloneable { + + public static final int MAX_TREEID = 100000; + + public TreeID() { } + + public TreeID(int partition, int treeId) { + Preconditions.checkArgument(partition >= 0, "Wrong partition: " + partition + ". Partition must be >= 0!"); + Preconditions.checkArgument(treeId >= 0, "Wrong treeId: " + treeId + ". TreeId must be >= 0!"); + set(partition, treeId); + } + + public void set(int partition, int treeId) { + set((long) partition * MAX_TREEID + treeId); + } + + /** + * Data partition (InputSplit's index) that was used to grow the tree + */ + public int partition() { + return (int) (get() / MAX_TREEID); + } + + public int treeId() { + return (int) (get() % MAX_TREEID); + } + + @Override + public TreeID clone() { + return new TreeID(partition(), treeId()); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/package-info.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/package-info.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/package-info.java new file mode 100644 index 0000000..e621c91 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/package-info.java @@ -0,0 +1,16 @@ +/** + * <h2>Partial-data mapreduce implementation of Random Decision Forests</h2> + * + * <p>The builder splits the data, using a FileInputSplit, among the mappers. + * Building the forest and estimating the oob error takes two job steps.</p> + * + * <p>In the first step, each mapper is responsible for growing a number of trees with its partition's, + * loading the data instances in its {@code map()} function, then building the trees in the {@code close()} method. It + * uses the reference implementation's code to build each tree and estimate the oob error.</p> + * + * <p>The second step is needed when estimating the oob error. Each mapper loads all the trees that does not + * belong to its own partition (were not built using the partition's data) and uses them to classify the + * partition's data instances. The data instances are loaded in the {@code map()} method and the classification + * is performed in the {@code close()} method.</p> + */ +package org.apache.mahout.classifier.df.mapreduce.partial; http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/node/CategoricalNode.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/node/CategoricalNode.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/node/CategoricalNode.java new file mode 100644 index 0000000..1f91842 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/node/CategoricalNode.java @@ -0,0 +1,134 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.node; + +import org.apache.commons.lang3.ArrayUtils; +import org.apache.mahout.classifier.df.DFUtils; +import org.apache.mahout.classifier.df.data.Instance; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.Arrays; +@Deprecated +public class CategoricalNode extends Node { + + private int attr; + private double[] values; + private Node[] childs; + + public CategoricalNode() { + } + + public CategoricalNode(int attr, double[] values, Node[] childs) { + this.attr = attr; + this.values = values; + this.childs = childs; + } + + @Override + public double classify(Instance instance) { + int index = ArrayUtils.indexOf(values, instance.get(attr)); + if (index == -1) { + // value not available, we cannot predict + return Double.NaN; + } + return childs[index].classify(instance); + } + + @Override + public long maxDepth() { + long max = 0; + + for (Node child : childs) { + long depth = child.maxDepth(); + if (depth > max) { + max = depth; + } + } + + return 1 + max; + } + + @Override + public long nbNodes() { + long nbNodes = 1; + + for (Node child : childs) { + nbNodes += child.nbNodes(); + } + + return nbNodes; + } + + @Override + protected Type getType() { + return Type.CATEGORICAL; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof CategoricalNode)) { + return false; + } + + CategoricalNode node = (CategoricalNode) obj; + + return attr == node.attr && Arrays.equals(values, node.values) && Arrays.equals(childs, node.childs); + } + + @Override + public int hashCode() { + int hashCode = attr; + for (double value : values) { + hashCode = 31 * hashCode + (int) Double.doubleToLongBits(value); + } + for (Node node : childs) { + hashCode = 31 * hashCode + node.hashCode(); + } + return hashCode; + } + + @Override + protected String getString() { + StringBuilder buffer = new StringBuilder(); + + for (Node child : childs) { + buffer.append(child).append(','); + } + + return buffer.toString(); + } + + @Override + public void readFields(DataInput in) throws IOException { + attr = in.readInt(); + values = DFUtils.readDoubleArray(in); + childs = DFUtils.readNodeArray(in); + } + + @Override + protected void writeNode(DataOutput out) throws IOException { + out.writeInt(attr); + DFUtils.writeArray(out, values); + DFUtils.writeArray(out, childs); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/node/Leaf.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/node/Leaf.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/node/Leaf.java new file mode 100644 index 0000000..3360bb5 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/node/Leaf.java @@ -0,0 +1,95 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.node; + +import org.apache.mahout.classifier.df.data.Instance; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +/** + * Represents a Leaf node + */ +@Deprecated +public class Leaf extends Node { + private static final double EPSILON = 1.0e-6; + + private double label; + + Leaf() { } + + public Leaf(double label) { + this.label = label; + } + + @Override + public double classify(Instance instance) { + return label; + } + + @Override + public long maxDepth() { + return 1; + } + + @Override + public long nbNodes() { + return 1; + } + + @Override + protected Type getType() { + return Type.LEAF; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof Leaf)) { + return false; + } + + Leaf leaf = (Leaf) obj; + + return Math.abs(label - leaf.label) < EPSILON; + } + + @Override + public int hashCode() { + long bits = Double.doubleToLongBits(label); + return (int)(bits ^ (bits >>> 32)); + } + + @Override + protected String getString() { + return ""; + } + + @Override + public void readFields(DataInput in) throws IOException { + label = in.readDouble(); + } + + @Override + protected void writeNode(DataOutput out) throws IOException { + out.writeDouble(label); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/node/Node.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/node/Node.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/node/Node.java new file mode 100644 index 0000000..73d516d --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/node/Node.java @@ -0,0 +1,96 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.node; + +import org.apache.hadoop.io.Writable; +import org.apache.mahout.classifier.df.data.Instance; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +/** + * Represents an abstract node of a decision tree + */ +@Deprecated +public abstract class Node implements Writable { + + protected enum Type { + LEAF, + NUMERICAL, + CATEGORICAL + } + + /** + * predicts the label for the instance + * + * @return -1 if the label cannot be predicted + */ + public abstract double classify(Instance instance); + + /** + * @return the total number of nodes of the tree + */ + public abstract long nbNodes(); + + /** + * @return the maximum depth of the tree + */ + public abstract long maxDepth(); + + protected abstract Type getType(); + + public static Node read(DataInput in) throws IOException { + Type type = Type.values()[in.readInt()]; + Node node; + + switch (type) { + case LEAF: + node = new Leaf(); + break; + case NUMERICAL: + node = new NumericalNode(); + break; + case CATEGORICAL: + node = new CategoricalNode(); + break; + default: + throw new IllegalStateException("This implementation is not currently supported"); + } + + node.readFields(in); + + return node; + } + + @Override + public final String toString() { + return getType() + ":" + getString() + ';'; + } + + protected abstract String getString(); + + @Override + public final void write(DataOutput out) throws IOException { + out.writeInt(getType().ordinal()); + writeNode(out); + } + + protected abstract void writeNode(DataOutput out) throws IOException; + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/node/NumericalNode.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/node/NumericalNode.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/node/NumericalNode.java new file mode 100644 index 0000000..aa02089 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/node/NumericalNode.java @@ -0,0 +1,115 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.node; + +import org.apache.mahout.classifier.df.data.Instance; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +/** + * Represents a node that splits using a numerical attribute + */ +@Deprecated +public class NumericalNode extends Node { + /** numerical attribute to split for */ + private int attr; + + /** split value */ + private double split; + + /** child node when attribute's value < split value */ + private Node loChild; + + /** child node when attribute's value >= split value */ + private Node hiChild; + + public NumericalNode() { } + + public NumericalNode(int attr, double split, Node loChild, Node hiChild) { + this.attr = attr; + this.split = split; + this.loChild = loChild; + this.hiChild = hiChild; + } + + @Override + public double classify(Instance instance) { + if (instance.get(attr) < split) { + return loChild.classify(instance); + } else { + return hiChild.classify(instance); + } + } + + @Override + public long maxDepth() { + return 1 + Math.max(loChild.maxDepth(), hiChild.maxDepth()); + } + + @Override + public long nbNodes() { + return 1 + loChild.nbNodes() + hiChild.nbNodes(); + } + + @Override + protected Type getType() { + return Type.NUMERICAL; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof NumericalNode)) { + return false; + } + + NumericalNode node = (NumericalNode) obj; + + return attr == node.attr && split == node.split && loChild.equals(node.loChild) && hiChild.equals(node.hiChild); + } + + @Override + public int hashCode() { + return attr + (int) Double.doubleToLongBits(split) + loChild.hashCode() + hiChild.hashCode(); + } + + @Override + protected String getString() { + return loChild.toString() + ',' + hiChild.toString(); + } + + @Override + public void readFields(DataInput in) throws IOException { + attr = in.readInt(); + split = in.readDouble(); + loChild = Node.read(in); + hiChild = Node.read(in); + } + + @Override + protected void writeNode(DataOutput out) throws IOException { + out.writeInt(attr); + out.writeDouble(split); + loChild.write(out); + hiChild.write(out); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/ref/SequentialBuilder.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/ref/SequentialBuilder.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/ref/SequentialBuilder.java new file mode 100644 index 0000000..7ef907e --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/ref/SequentialBuilder.java @@ -0,0 +1,78 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.ref; + +import org.apache.mahout.classifier.df.Bagging; +import org.apache.mahout.classifier.df.DecisionForest; +import org.apache.mahout.classifier.df.builder.TreeBuilder; +import org.apache.mahout.classifier.df.data.Data; +import org.apache.mahout.classifier.df.node.Node; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +/** + * Builds a Random Decision Forest using a given TreeBuilder to grow the trees + */ +@Deprecated +public class SequentialBuilder { + + private static final Logger log = LoggerFactory.getLogger(SequentialBuilder.class); + + private final Random rng; + + private final Bagging bagging; + + /** + * Constructor + * + * @param rng + * random-numbers generator + * @param treeBuilder + * tree builder + * @param data + * training data + */ + public SequentialBuilder(Random rng, TreeBuilder treeBuilder, Data data) { + this.rng = rng; + bagging = new Bagging(treeBuilder, data); + } + + public DecisionForest build(int nbTrees) { + List<Node> trees = new ArrayList<>(); + + for (int treeId = 0; treeId < nbTrees; treeId++) { + trees.add(bagging.build(rng)); + logProgress(((float) treeId + 1) / nbTrees); + } + + return new DecisionForest(trees); + } + + private static void logProgress(float progress) { + int percent = (int) (progress * 100); + if (percent % 10 == 0) { + log.info("Building {}%", percent); + } + + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/DefaultIgSplit.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/DefaultIgSplit.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/DefaultIgSplit.java new file mode 100644 index 0000000..3f1cfdf --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/DefaultIgSplit.java @@ -0,0 +1,118 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.split; + +import org.apache.mahout.classifier.df.data.Data; +import org.apache.mahout.classifier.df.data.conditions.Condition; + +import java.util.Arrays; + +/** + * Default, not optimized, implementation of IgSplit + */ +@Deprecated +public class DefaultIgSplit extends IgSplit { + + /** used by entropy() */ + private int[] counts; + + @Override + public Split computeSplit(Data data, int attr) { + if (data.getDataset().isNumerical(attr)) { + double[] values = data.values(attr); + double bestIg = -1; + double bestSplit = 0.0; + + for (double value : values) { + double ig = numericalIg(data, attr, value); + if (ig > bestIg) { + bestIg = ig; + bestSplit = value; + } + } + + return new Split(attr, bestIg, bestSplit); + } else { + double ig = categoricalIg(data, attr); + + return new Split(attr, ig); + } + } + + /** + * Computes the Information Gain for a CATEGORICAL attribute + */ + double categoricalIg(Data data, int attr) { + double[] values = data.values(attr); + double hy = entropy(data); // H(Y) + double hyx = 0.0; // H(Y|X) + double invDataSize = 1.0 / data.size(); + + for (double value : values) { + Data subset = data.subset(Condition.equals(attr, value)); + hyx += subset.size() * invDataSize * entropy(subset); + } + + return hy - hyx; + } + + /** + * Computes the Information Gain for a NUMERICAL attribute given a splitting value + */ + double numericalIg(Data data, int attr, double split) { + double hy = entropy(data); + double invDataSize = 1.0 / data.size(); + + // LO subset + Data subset = data.subset(Condition.lesser(attr, split)); + hy -= subset.size() * invDataSize * entropy(subset); + + // HI subset + subset = data.subset(Condition.greaterOrEquals(attr, split)); + hy -= subset.size() * invDataSize * entropy(subset); + + return hy; + } + + /** + * Computes the Entropy + */ + protected double entropy(Data data) { + double invDataSize = 1.0 / data.size(); + + if (counts == null) { + counts = new int[data.getDataset().nblabels()]; + } + + Arrays.fill(counts, 0); + data.countLabels(counts); + + double entropy = 0.0; + for (int label = 0; label < data.getDataset().nblabels(); label++) { + int count = counts[label]; + if (count == 0) { + continue; // otherwise we get a NaN + } + double p = count * invDataSize; + entropy += -p * Math.log(p) / LOG2; + } + + return entropy; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/IgSplit.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/IgSplit.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/IgSplit.java new file mode 100644 index 0000000..aff94e1 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/IgSplit.java @@ -0,0 +1,35 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.split; + +import org.apache.mahout.classifier.df.data.Data; + +/** + * Computes the best split using the Information Gain measure + */ +@Deprecated +public abstract class IgSplit { + + static final double LOG2 = Math.log(2.0); + + /** + * Computes the best split for the given attribute + */ + public abstract Split computeSplit(Data data, int attr); + +}
