http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputFormat.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputFormat.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputFormat.java new file mode 100644 index 0000000..c3b2fa3 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputFormat.java @@ -0,0 +1,284 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.mapreduce.inmem; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Random; + +import com.google.common.base.Preconditions; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.NullWritable; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.mapreduce.InputFormat; +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.hadoop.mapreduce.JobContext; +import org.apache.hadoop.mapreduce.RecordReader; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.mahout.classifier.df.mapreduce.Builder; +import org.apache.mahout.common.RandomUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Custom InputFormat that generates InputSplits given the desired number of trees.<br> + * each input split contains a subset of the trees.<br> + * The number of splits is equal to the number of requested splits + */ +@Deprecated +public class InMemInputFormat extends InputFormat<IntWritable,NullWritable> { + + private static final Logger log = LoggerFactory.getLogger(InMemInputSplit.class); + + private Random rng; + + private Long seed; + + private boolean isSingleSeed; + + /** + * Used for DEBUG purposes only. if true and a seed is available, all the mappers use the same seed, thus + * all the mapper should take the same time to build their trees. + */ + private static boolean isSingleSeed(Configuration conf) { + return conf.getBoolean("debug.mahout.rf.single.seed", false); + } + + @Override + public RecordReader<IntWritable,NullWritable> createRecordReader(InputSplit split, TaskAttemptContext context) + throws IOException, InterruptedException { + Preconditions.checkArgument(split instanceof InMemInputSplit); + return new InMemRecordReader((InMemInputSplit) split); + } + + @Override + public List<InputSplit> getSplits(JobContext context) throws IOException, InterruptedException { + Configuration conf = context.getConfiguration(); + int numSplits = conf.getInt("mapred.map.tasks", -1); + + return getSplits(conf, numSplits); + } + + public List<InputSplit> getSplits(Configuration conf, int numSplits) { + int nbTrees = Builder.getNbTrees(conf); + int splitSize = nbTrees / numSplits; + + seed = Builder.getRandomSeed(conf); + isSingleSeed = isSingleSeed(conf); + + if (rng != null && seed != null) { + log.warn("getSplits() was called more than once and the 'seed' is set, " + + "this can lead to no-repeatable behavior"); + } + + rng = seed == null || isSingleSeed ? null : RandomUtils.getRandom(seed); + + int id = 0; + + List<InputSplit> splits = new ArrayList<>(numSplits); + + for (int index = 0; index < numSplits - 1; index++) { + splits.add(new InMemInputSplit(id, splitSize, nextSeed())); + id += splitSize; + } + + // take care of the remainder + splits.add(new InMemInputSplit(id, nbTrees - id, nextSeed())); + + return splits; + } + + /** + * @return the seed for the next InputSplit + */ + private Long nextSeed() { + if (seed == null) { + return null; + } else if (isSingleSeed) { + return seed; + } else { + return rng.nextLong(); + } + } + + public static class InMemRecordReader extends RecordReader<IntWritable,NullWritable> { + + private final InMemInputSplit split; + private int pos; + private IntWritable key; + private NullWritable value; + + public InMemRecordReader(InMemInputSplit split) { + this.split = split; + } + + @Override + public float getProgress() throws IOException { + return pos == 0 ? 0.0f : (float) (pos - 1) / split.nbTrees; + } + + @Override + public IntWritable getCurrentKey() throws IOException, InterruptedException { + return key; + } + + @Override + public NullWritable getCurrentValue() throws IOException, InterruptedException { + return value; + } + + @Override + public void initialize(InputSplit arg0, TaskAttemptContext arg1) throws IOException, InterruptedException { + key = new IntWritable(); + value = NullWritable.get(); + } + + @Override + public boolean nextKeyValue() throws IOException, InterruptedException { + if (pos < split.nbTrees) { + key.set(split.firstId + pos); + pos++; + return true; + } else { + return false; + } + } + + @Override + public void close() throws IOException { + } + + } + + /** + * Custom InputSplit that indicates how many trees are built by each mapper + */ + public static class InMemInputSplit extends InputSplit implements Writable { + + private static final String[] NO_LOCATIONS = new String[0]; + + /** Id of the first tree of this split */ + private int firstId; + + private int nbTrees; + + private Long seed; + + public InMemInputSplit() { } + + public InMemInputSplit(int firstId, int nbTrees, Long seed) { + this.firstId = firstId; + this.nbTrees = nbTrees; + this.seed = seed; + } + + /** + * @return the Id of the first tree of this split + */ + public int getFirstId() { + return firstId; + } + + /** + * @return the number of trees + */ + public int getNbTrees() { + return nbTrees; + } + + /** + * @return the random seed or null if no seed is available + */ + public Long getSeed() { + return seed; + } + + @Override + public long getLength() throws IOException { + return nbTrees; + } + + @Override + public String[] getLocations() throws IOException { + return NO_LOCATIONS; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof InMemInputSplit)) { + return false; + } + + InMemInputSplit split = (InMemInputSplit) obj; + + if (firstId != split.firstId || nbTrees != split.nbTrees) { + return false; + } + if (seed == null) { + return split.seed == null; + } else { + return seed.equals(split.seed); + } + + } + + @Override + public int hashCode() { + return firstId + nbTrees + (seed == null ? 0 : seed.intValue()); + } + + @Override + public String toString() { + return String.format(Locale.ENGLISH, "[firstId:%d, nbTrees:%d, seed:%d]", firstId, nbTrees, seed); + } + + @Override + public void readFields(DataInput in) throws IOException { + firstId = in.readInt(); + nbTrees = in.readInt(); + boolean isSeed = in.readBoolean(); + seed = isSeed ? in.readLong() : null; + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeInt(firstId); + out.writeInt(nbTrees); + out.writeBoolean(seed != null); + if (seed != null) { + out.writeLong(seed); + } + } + + public static InMemInputSplit read(DataInput in) throws IOException { + InMemInputSplit split = new InMemInputSplit(); + split.readFields(in); + return split; + } + + } + +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemMapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemMapper.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemMapper.java new file mode 100644 index 0000000..2fc67ba --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemMapper.java @@ -0,0 +1,106 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.mapreduce.inmem; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.NullWritable; +import org.apache.mahout.classifier.df.Bagging; +import org.apache.mahout.classifier.df.data.Data; +import org.apache.mahout.classifier.df.data.DataLoader; +import org.apache.mahout.classifier.df.data.Dataset; +import org.apache.mahout.classifier.df.mapreduce.Builder; +import org.apache.mahout.classifier.df.mapreduce.MapredMapper; +import org.apache.mahout.classifier.df.mapreduce.MapredOutput; +import org.apache.mahout.classifier.df.mapreduce.inmem.InMemInputFormat.InMemInputSplit; +import org.apache.mahout.classifier.df.node.Node; +import org.apache.mahout.common.RandomUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Random; + +/** + * In-memory mapper that grows the trees using a full copy of the data loaded in-memory. The number of trees + * to grow is determined by the current InMemInputSplit. + */ +@Deprecated +public class InMemMapper extends MapredMapper<IntWritable,NullWritable,IntWritable,MapredOutput> { + + private static final Logger log = LoggerFactory.getLogger(InMemMapper.class); + + private Bagging bagging; + + private Random rng; + + /** + * Load the training data + */ + private static Data loadData(Configuration conf, Dataset dataset) throws IOException { + Path dataPath = Builder.getDistributedCacheFile(conf, 1); + FileSystem fs = FileSystem.get(dataPath.toUri(), conf); + return DataLoader.loadData(dataset, fs, dataPath); + } + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + super.setup(context); + + Configuration conf = context.getConfiguration(); + + log.info("Loading the data..."); + Data data = loadData(conf, getDataset()); + log.info("Data loaded : {} instances", data.size()); + + bagging = new Bagging(getTreeBuilder(), data); + } + + @Override + protected void map(IntWritable key, + NullWritable value, + Context context) throws IOException, InterruptedException { + map(key, context); + } + + void map(IntWritable key, Context context) throws IOException, InterruptedException { + + initRandom((InMemInputSplit) context.getInputSplit()); + + log.debug("Building..."); + Node tree = bagging.build(rng); + + if (isOutput()) { + log.debug("Outputing..."); + MapredOutput mrOut = new MapredOutput(tree); + + context.write(key, mrOut); + } + } + + void initRandom(InMemInputSplit split) { + if (rng == null) { // first execution of this mapper + Long seed = split.getSeed(); + log.debug("Initialising rng with seed : {}", seed); + rng = seed == null ? RandomUtils.getRandom() : RandomUtils.getRandom(seed); + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/package-info.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/package-info.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/package-info.java new file mode 100644 index 0000000..61e65e8 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/package-info.java @@ -0,0 +1,22 @@ +/** + * <h2>In-memory mapreduce implementation of Random Decision Forests</h2> + * + * <p>Each mapper is responsible for growing a number of trees with a whole copy of the dataset loaded in memory, + * it uses the reference implementation's code to build each tree and estimate the oob error.</p> + * + * <p>The dataset is distributed to the slave nodes using the {@link org.apache.hadoop.filecache.DistributedCache}. + * A custom {@link org.apache.hadoop.mapreduce.InputFormat} + * ({@link org.apache.mahout.classifier.df.mapreduce.inmem.InMemInputFormat}) is configured with the + * desired number of trees and generates a number of {@link org.apache.hadoop.mapreduce.InputSplit}s + * equal to the configured number of maps.</p> + * + * <p>There is no need for reducers, each map outputs (the trees it built and, for each tree, the labels the + * tree predicted for each out-of-bag instance. This step has to be done in the mapper because only there we + * know which instances are o-o-b.</p> + * + * <p>The Forest builder ({@link org.apache.mahout.classifier.df.mapreduce.inmem.InMemBuilder}) is responsible + * for configuring and launching the job. + * At the end of the job it parses the output files and builds the corresponding + * {@link org.apache.mahout.classifier.df.DecisionForest}.</p> + */ +package org.apache.mahout.classifier.df.mapreduce.inmem; http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/PartialBuilder.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/PartialBuilder.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/PartialBuilder.java new file mode 100644 index 0000000..9236af3 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/PartialBuilder.java @@ -0,0 +1,158 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.mapreduce.partial; + +import com.google.common.base.Preconditions; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.JobContext; +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; +import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.mahout.classifier.df.DFUtils; +import org.apache.mahout.classifier.df.DecisionForest; +import org.apache.mahout.classifier.df.builder.TreeBuilder; +import org.apache.mahout.classifier.df.mapreduce.Builder; +import org.apache.mahout.classifier.df.mapreduce.MapredOutput; +import org.apache.mahout.classifier.df.node.Node; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +/** + * Builds a random forest using partial data. Each mapper uses only the data given by its InputSplit + */ +@Deprecated +public class PartialBuilder extends Builder { + + private static final Logger log = LoggerFactory.getLogger(PartialBuilder.class); + + public PartialBuilder(TreeBuilder treeBuilder, Path dataPath, Path datasetPath, Long seed) { + this(treeBuilder, dataPath, datasetPath, seed, new Configuration()); + } + + public PartialBuilder(TreeBuilder treeBuilder, + Path dataPath, + Path datasetPath, + Long seed, + Configuration conf) { + super(treeBuilder, dataPath, datasetPath, seed, conf); + } + + @Override + protected void configureJob(Job job) throws IOException { + Configuration conf = job.getConfiguration(); + + job.setJarByClass(PartialBuilder.class); + + FileInputFormat.setInputPaths(job, getDataPath()); + FileOutputFormat.setOutputPath(job, getOutputPath(conf)); + + job.setOutputKeyClass(TreeID.class); + job.setOutputValueClass(MapredOutput.class); + + job.setMapperClass(Step1Mapper.class); + job.setNumReduceTasks(0); // no reducers + + job.setInputFormatClass(TextInputFormat.class); + job.setOutputFormatClass(SequenceFileOutputFormat.class); + + // For this implementation to work, mapred.map.tasks needs to be set to the actual + // number of mappers Hadoop will use: + TextInputFormat inputFormat = new TextInputFormat(); + List<?> splits = inputFormat.getSplits(job); + if (splits == null || splits.isEmpty()) { + log.warn("Unable to compute number of splits?"); + } else { + int numSplits = splits.size(); + log.info("Setting mapred.map.tasks = {}", numSplits); + conf.setInt("mapred.map.tasks", numSplits); + } + } + + @Override + protected DecisionForest parseOutput(Job job) throws IOException { + Configuration conf = job.getConfiguration(); + + int numTrees = Builder.getNbTrees(conf); + + Path outputPath = getOutputPath(conf); + + TreeID[] keys = new TreeID[numTrees]; + Node[] trees = new Node[numTrees]; + + processOutput(job, outputPath, keys, trees); + + return new DecisionForest(Arrays.asList(trees)); + } + + /** + * Processes the output from the output path.<br> + * + * @param outputPath + * directory that contains the output of the job + * @param keys + * can be null + * @param trees + * can be null + * @throws java.io.IOException + */ + protected static void processOutput(JobContext job, + Path outputPath, + TreeID[] keys, + Node[] trees) throws IOException { + Preconditions.checkArgument(keys == null && trees == null || keys != null && trees != null, + "if keys is null, trees should also be null"); + Preconditions.checkArgument(keys == null || keys.length == trees.length, "keys.length != trees.length"); + + Configuration conf = job.getConfiguration(); + + FileSystem fs = outputPath.getFileSystem(conf); + + Path[] outfiles = DFUtils.listOutputFiles(fs, outputPath); + + // read all the outputs + int index = 0; + for (Path path : outfiles) { + for (Pair<TreeID,MapredOutput> record : new SequenceFileIterable<TreeID, MapredOutput>(path, conf)) { + TreeID key = record.getFirst(); + MapredOutput value = record.getSecond(); + if (keys != null) { + keys[index] = key; + } + if (trees != null) { + trees[index] = value.getTree(); + } + index++; + } + } + + // make sure we got all the keys/values + if (keys != null && index != keys.length) { + throw new IllegalStateException("Some key/values are missing from the output"); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1Mapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1Mapper.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1Mapper.java new file mode 100644 index 0000000..9474236 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1Mapper.java @@ -0,0 +1,168 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.mapreduce.partial; + +import com.google.common.base.Preconditions; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Text; +import org.apache.mahout.classifier.df.Bagging; +import org.apache.mahout.classifier.df.data.Data; +import org.apache.mahout.classifier.df.data.DataConverter; +import org.apache.mahout.classifier.df.data.Instance; +import org.apache.mahout.classifier.df.mapreduce.Builder; +import org.apache.mahout.classifier.df.mapreduce.MapredMapper; +import org.apache.mahout.classifier.df.mapreduce.MapredOutput; +import org.apache.mahout.classifier.df.node.Node; +import org.apache.mahout.common.RandomUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +/** + * First step of the Partial Data Builder. Builds the trees using the data available in the InputSplit. + * Predict the oob classes for each tree in its growing partition (input split). + */ +@Deprecated +public class Step1Mapper extends MapredMapper<LongWritable,Text,TreeID,MapredOutput> { + + private static final Logger log = LoggerFactory.getLogger(Step1Mapper.class); + + /** used to convert input values to data instances */ + private DataConverter converter; + + private Random rng; + + /** number of trees to be built by this mapper */ + private int nbTrees; + + /** id of the first tree */ + private int firstTreeId; + + /** mapper's partition */ + private int partition; + + /** will contain all instances if this mapper's split */ + private final List<Instance> instances = new ArrayList<>(); + + public int getFirstTreeId() { + return firstTreeId; + } + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + super.setup(context); + Configuration conf = context.getConfiguration(); + + configure(Builder.getRandomSeed(conf), conf.getInt("mapred.task.partition", -1), + Builder.getNumMaps(conf), Builder.getNbTrees(conf)); + } + + /** + * Useful when testing + * + * @param partition + * current mapper inputSplit partition + * @param numMapTasks + * number of running map tasks + * @param numTrees + * total number of trees in the forest + */ + protected void configure(Long seed, int partition, int numMapTasks, int numTrees) { + converter = new DataConverter(getDataset()); + + // prepare random-numders generator + log.debug("seed : {}", seed); + if (seed == null) { + rng = RandomUtils.getRandom(); + } else { + rng = RandomUtils.getRandom(seed); + } + + // mapper's partition + Preconditions.checkArgument(partition >= 0, "Wrong partition ID: " + partition + ". Partition must be >= 0!"); + this.partition = partition; + + // compute number of trees to build + nbTrees = nbTrees(numMapTasks, numTrees, partition); + + // compute first tree id + firstTreeId = 0; + for (int p = 0; p < partition; p++) { + firstTreeId += nbTrees(numMapTasks, numTrees, p); + } + + log.debug("partition : {}", partition); + log.debug("nbTrees : {}", nbTrees); + log.debug("firstTreeId : {}", firstTreeId); + } + + /** + * Compute the number of trees for a given partition. The first partitions may be longer + * than the rest because of the remainder. + * + * @param numMaps + * total number of maps (partitions) + * @param numTrees + * total number of trees to build + * @param partition + * partition to compute the number of trees for + */ + public static int nbTrees(int numMaps, int numTrees, int partition) { + int treesPerMapper = numTrees / numMaps; + int remainder = numTrees - numMaps * treesPerMapper; + return treesPerMapper + (partition < remainder ? 1 : 0); + } + + @Override + protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException { + instances.add(converter.convert(value.toString())); + } + + @Override + protected void cleanup(Context context) throws IOException, InterruptedException { + // prepare the data + log.debug("partition: {} numInstances: {}", partition, instances.size()); + + Data data = new Data(getDataset(), instances); + Bagging bagging = new Bagging(getTreeBuilder(), data); + + TreeID key = new TreeID(); + + log.debug("Building {} trees", nbTrees); + for (int treeId = 0; treeId < nbTrees; treeId++) { + log.debug("Building tree number : {}", treeId); + + Node tree = bagging.build(rng); + + key.set(partition, firstTreeId + treeId); + + if (isOutput()) { + MapredOutput emOut = new MapredOutput(tree); + context.write(key, emOut); + } + + context.progress(); + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/TreeID.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/TreeID.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/TreeID.java new file mode 100644 index 0000000..c296061 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/TreeID.java @@ -0,0 +1,58 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.mapreduce.partial; + +import com.google.common.base.Preconditions; +import org.apache.hadoop.io.LongWritable; + +/** + * Indicates both the tree and the data partition used to grow the tree + */ +@Deprecated +public class TreeID extends LongWritable implements Cloneable { + + public static final int MAX_TREEID = 100000; + + public TreeID() { } + + public TreeID(int partition, int treeId) { + Preconditions.checkArgument(partition >= 0, "Wrong partition: " + partition + ". Partition must be >= 0!"); + Preconditions.checkArgument(treeId >= 0, "Wrong treeId: " + treeId + ". TreeId must be >= 0!"); + set(partition, treeId); + } + + public void set(int partition, int treeId) { + set((long) partition * MAX_TREEID + treeId); + } + + /** + * Data partition (InputSplit's index) that was used to grow the tree + */ + public int partition() { + return (int) (get() / MAX_TREEID); + } + + public int treeId() { + return (int) (get() % MAX_TREEID); + } + + @Override + public TreeID clone() { + return new TreeID(partition(), treeId()); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/package-info.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/package-info.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/package-info.java new file mode 100644 index 0000000..e621c91 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/package-info.java @@ -0,0 +1,16 @@ +/** + * <h2>Partial-data mapreduce implementation of Random Decision Forests</h2> + * + * <p>The builder splits the data, using a FileInputSplit, among the mappers. + * Building the forest and estimating the oob error takes two job steps.</p> + * + * <p>In the first step, each mapper is responsible for growing a number of trees with its partition's, + * loading the data instances in its {@code map()} function, then building the trees in the {@code close()} method. It + * uses the reference implementation's code to build each tree and estimate the oob error.</p> + * + * <p>The second step is needed when estimating the oob error. Each mapper loads all the trees that does not + * belong to its own partition (were not built using the partition's data) and uses them to classify the + * partition's data instances. The data instances are loaded in the {@code map()} method and the classification + * is performed in the {@code close()} method.</p> + */ +package org.apache.mahout.classifier.df.mapreduce.partial; http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/node/CategoricalNode.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/node/CategoricalNode.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/node/CategoricalNode.java new file mode 100644 index 0000000..1f91842 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/node/CategoricalNode.java @@ -0,0 +1,134 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.node; + +import org.apache.commons.lang3.ArrayUtils; +import org.apache.mahout.classifier.df.DFUtils; +import org.apache.mahout.classifier.df.data.Instance; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.Arrays; +@Deprecated +public class CategoricalNode extends Node { + + private int attr; + private double[] values; + private Node[] childs; + + public CategoricalNode() { + } + + public CategoricalNode(int attr, double[] values, Node[] childs) { + this.attr = attr; + this.values = values; + this.childs = childs; + } + + @Override + public double classify(Instance instance) { + int index = ArrayUtils.indexOf(values, instance.get(attr)); + if (index == -1) { + // value not available, we cannot predict + return Double.NaN; + } + return childs[index].classify(instance); + } + + @Override + public long maxDepth() { + long max = 0; + + for (Node child : childs) { + long depth = child.maxDepth(); + if (depth > max) { + max = depth; + } + } + + return 1 + max; + } + + @Override + public long nbNodes() { + long nbNodes = 1; + + for (Node child : childs) { + nbNodes += child.nbNodes(); + } + + return nbNodes; + } + + @Override + protected Type getType() { + return Type.CATEGORICAL; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof CategoricalNode)) { + return false; + } + + CategoricalNode node = (CategoricalNode) obj; + + return attr == node.attr && Arrays.equals(values, node.values) && Arrays.equals(childs, node.childs); + } + + @Override + public int hashCode() { + int hashCode = attr; + for (double value : values) { + hashCode = 31 * hashCode + (int) Double.doubleToLongBits(value); + } + for (Node node : childs) { + hashCode = 31 * hashCode + node.hashCode(); + } + return hashCode; + } + + @Override + protected String getString() { + StringBuilder buffer = new StringBuilder(); + + for (Node child : childs) { + buffer.append(child).append(','); + } + + return buffer.toString(); + } + + @Override + public void readFields(DataInput in) throws IOException { + attr = in.readInt(); + values = DFUtils.readDoubleArray(in); + childs = DFUtils.readNodeArray(in); + } + + @Override + protected void writeNode(DataOutput out) throws IOException { + out.writeInt(attr); + DFUtils.writeArray(out, values); + DFUtils.writeArray(out, childs); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/node/Leaf.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/node/Leaf.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/node/Leaf.java new file mode 100644 index 0000000..3360bb5 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/node/Leaf.java @@ -0,0 +1,95 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.node; + +import org.apache.mahout.classifier.df.data.Instance; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +/** + * Represents a Leaf node + */ +@Deprecated +public class Leaf extends Node { + private static final double EPSILON = 1.0e-6; + + private double label; + + Leaf() { } + + public Leaf(double label) { + this.label = label; + } + + @Override + public double classify(Instance instance) { + return label; + } + + @Override + public long maxDepth() { + return 1; + } + + @Override + public long nbNodes() { + return 1; + } + + @Override + protected Type getType() { + return Type.LEAF; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof Leaf)) { + return false; + } + + Leaf leaf = (Leaf) obj; + + return Math.abs(label - leaf.label) < EPSILON; + } + + @Override + public int hashCode() { + long bits = Double.doubleToLongBits(label); + return (int)(bits ^ (bits >>> 32)); + } + + @Override + protected String getString() { + return ""; + } + + @Override + public void readFields(DataInput in) throws IOException { + label = in.readDouble(); + } + + @Override + protected void writeNode(DataOutput out) throws IOException { + out.writeDouble(label); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/node/Node.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/node/Node.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/node/Node.java new file mode 100644 index 0000000..73d516d --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/node/Node.java @@ -0,0 +1,96 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.node; + +import org.apache.hadoop.io.Writable; +import org.apache.mahout.classifier.df.data.Instance; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +/** + * Represents an abstract node of a decision tree + */ +@Deprecated +public abstract class Node implements Writable { + + protected enum Type { + LEAF, + NUMERICAL, + CATEGORICAL + } + + /** + * predicts the label for the instance + * + * @return -1 if the label cannot be predicted + */ + public abstract double classify(Instance instance); + + /** + * @return the total number of nodes of the tree + */ + public abstract long nbNodes(); + + /** + * @return the maximum depth of the tree + */ + public abstract long maxDepth(); + + protected abstract Type getType(); + + public static Node read(DataInput in) throws IOException { + Type type = Type.values()[in.readInt()]; + Node node; + + switch (type) { + case LEAF: + node = new Leaf(); + break; + case NUMERICAL: + node = new NumericalNode(); + break; + case CATEGORICAL: + node = new CategoricalNode(); + break; + default: + throw new IllegalStateException("This implementation is not currently supported"); + } + + node.readFields(in); + + return node; + } + + @Override + public final String toString() { + return getType() + ":" + getString() + ';'; + } + + protected abstract String getString(); + + @Override + public final void write(DataOutput out) throws IOException { + out.writeInt(getType().ordinal()); + writeNode(out); + } + + protected abstract void writeNode(DataOutput out) throws IOException; + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/node/NumericalNode.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/node/NumericalNode.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/node/NumericalNode.java new file mode 100644 index 0000000..aa02089 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/node/NumericalNode.java @@ -0,0 +1,115 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.node; + +import org.apache.mahout.classifier.df.data.Instance; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +/** + * Represents a node that splits using a numerical attribute + */ +@Deprecated +public class NumericalNode extends Node { + /** numerical attribute to split for */ + private int attr; + + /** split value */ + private double split; + + /** child node when attribute's value < split value */ + private Node loChild; + + /** child node when attribute's value >= split value */ + private Node hiChild; + + public NumericalNode() { } + + public NumericalNode(int attr, double split, Node loChild, Node hiChild) { + this.attr = attr; + this.split = split; + this.loChild = loChild; + this.hiChild = hiChild; + } + + @Override + public double classify(Instance instance) { + if (instance.get(attr) < split) { + return loChild.classify(instance); + } else { + return hiChild.classify(instance); + } + } + + @Override + public long maxDepth() { + return 1 + Math.max(loChild.maxDepth(), hiChild.maxDepth()); + } + + @Override + public long nbNodes() { + return 1 + loChild.nbNodes() + hiChild.nbNodes(); + } + + @Override + protected Type getType() { + return Type.NUMERICAL; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof NumericalNode)) { + return false; + } + + NumericalNode node = (NumericalNode) obj; + + return attr == node.attr && split == node.split && loChild.equals(node.loChild) && hiChild.equals(node.hiChild); + } + + @Override + public int hashCode() { + return attr + (int) Double.doubleToLongBits(split) + loChild.hashCode() + hiChild.hashCode(); + } + + @Override + protected String getString() { + return loChild.toString() + ',' + hiChild.toString(); + } + + @Override + public void readFields(DataInput in) throws IOException { + attr = in.readInt(); + split = in.readDouble(); + loChild = Node.read(in); + hiChild = Node.read(in); + } + + @Override + protected void writeNode(DataOutput out) throws IOException { + out.writeInt(attr); + out.writeDouble(split); + loChild.write(out); + hiChild.write(out); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/ref/SequentialBuilder.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/ref/SequentialBuilder.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/ref/SequentialBuilder.java new file mode 100644 index 0000000..7ef907e --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/ref/SequentialBuilder.java @@ -0,0 +1,78 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.ref; + +import org.apache.mahout.classifier.df.Bagging; +import org.apache.mahout.classifier.df.DecisionForest; +import org.apache.mahout.classifier.df.builder.TreeBuilder; +import org.apache.mahout.classifier.df.data.Data; +import org.apache.mahout.classifier.df.node.Node; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +/** + * Builds a Random Decision Forest using a given TreeBuilder to grow the trees + */ +@Deprecated +public class SequentialBuilder { + + private static final Logger log = LoggerFactory.getLogger(SequentialBuilder.class); + + private final Random rng; + + private final Bagging bagging; + + /** + * Constructor + * + * @param rng + * random-numbers generator + * @param treeBuilder + * tree builder + * @param data + * training data + */ + public SequentialBuilder(Random rng, TreeBuilder treeBuilder, Data data) { + this.rng = rng; + bagging = new Bagging(treeBuilder, data); + } + + public DecisionForest build(int nbTrees) { + List<Node> trees = new ArrayList<>(); + + for (int treeId = 0; treeId < nbTrees; treeId++) { + trees.add(bagging.build(rng)); + logProgress(((float) treeId + 1) / nbTrees); + } + + return new DecisionForest(trees); + } + + private static void logProgress(float progress) { + int percent = (int) (progress * 100); + if (percent % 10 == 0) { + log.info("Building {}%", percent); + } + + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/DefaultIgSplit.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/DefaultIgSplit.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/DefaultIgSplit.java new file mode 100644 index 0000000..3f1cfdf --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/DefaultIgSplit.java @@ -0,0 +1,118 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.split; + +import org.apache.mahout.classifier.df.data.Data; +import org.apache.mahout.classifier.df.data.conditions.Condition; + +import java.util.Arrays; + +/** + * Default, not optimized, implementation of IgSplit + */ +@Deprecated +public class DefaultIgSplit extends IgSplit { + + /** used by entropy() */ + private int[] counts; + + @Override + public Split computeSplit(Data data, int attr) { + if (data.getDataset().isNumerical(attr)) { + double[] values = data.values(attr); + double bestIg = -1; + double bestSplit = 0.0; + + for (double value : values) { + double ig = numericalIg(data, attr, value); + if (ig > bestIg) { + bestIg = ig; + bestSplit = value; + } + } + + return new Split(attr, bestIg, bestSplit); + } else { + double ig = categoricalIg(data, attr); + + return new Split(attr, ig); + } + } + + /** + * Computes the Information Gain for a CATEGORICAL attribute + */ + double categoricalIg(Data data, int attr) { + double[] values = data.values(attr); + double hy = entropy(data); // H(Y) + double hyx = 0.0; // H(Y|X) + double invDataSize = 1.0 / data.size(); + + for (double value : values) { + Data subset = data.subset(Condition.equals(attr, value)); + hyx += subset.size() * invDataSize * entropy(subset); + } + + return hy - hyx; + } + + /** + * Computes the Information Gain for a NUMERICAL attribute given a splitting value + */ + double numericalIg(Data data, int attr, double split) { + double hy = entropy(data); + double invDataSize = 1.0 / data.size(); + + // LO subset + Data subset = data.subset(Condition.lesser(attr, split)); + hy -= subset.size() * invDataSize * entropy(subset); + + // HI subset + subset = data.subset(Condition.greaterOrEquals(attr, split)); + hy -= subset.size() * invDataSize * entropy(subset); + + return hy; + } + + /** + * Computes the Entropy + */ + protected double entropy(Data data) { + double invDataSize = 1.0 / data.size(); + + if (counts == null) { + counts = new int[data.getDataset().nblabels()]; + } + + Arrays.fill(counts, 0); + data.countLabels(counts); + + double entropy = 0.0; + for (int label = 0; label < data.getDataset().nblabels(); label++) { + int count = counts[label]; + if (count == 0) { + continue; // otherwise we get a NaN + } + double p = count * invDataSize; + entropy += -p * Math.log(p) / LOG2; + } + + return entropy; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/IgSplit.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/IgSplit.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/IgSplit.java new file mode 100644 index 0000000..aff94e1 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/IgSplit.java @@ -0,0 +1,35 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.split; + +import org.apache.mahout.classifier.df.data.Data; + +/** + * Computes the best split using the Information Gain measure + */ +@Deprecated +public abstract class IgSplit { + + static final double LOG2 = Math.log(2.0); + + /** + * Computes the best split for the given attribute + */ + public abstract Split computeSplit(Data data, int attr); + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/OptIgSplit.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/OptIgSplit.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/OptIgSplit.java new file mode 100644 index 0000000..56b1a04 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/OptIgSplit.java @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.split; + +import org.apache.commons.math3.stat.descriptive.rank.Percentile; +import org.apache.mahout.classifier.df.data.Data; +import org.apache.mahout.classifier.df.data.DataUtils; +import org.apache.mahout.classifier.df.data.Dataset; +import org.apache.mahout.classifier.df.data.Instance; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.TreeSet; + +/** + * <p>Optimized implementation of IgSplit. + * This class can be used when the criterion variable is the categorical attribute.</p> + * + * <p>This code was changed in MAHOUT-1419 to deal in sampled splits among numeric + * features to fix a performance problem. To generate some synthetic data that exercises + * the issue, try for example generating 4 features of Normal(0,1) values with a random + * boolean 0/1 categorical feature. In Scala:</p> + * + * {@code + * val r = new scala.util.Random() + * val pw = new java.io.PrintWriter("random.csv") + * (1 to 10000000).foreach(e => + * pw.println(r.nextDouble() + "," + + * r.nextDouble() + "," + + * r.nextDouble() + "," + + * r.nextDouble() + "," + + * (if (r.nextBoolean()) 1 else 0)) + * ) + * pw.close() + * } + */ +@Deprecated +public class OptIgSplit extends IgSplit { + + private static final int MAX_NUMERIC_SPLITS = 16; + + @Override + public Split computeSplit(Data data, int attr) { + if (data.getDataset().isNumerical(attr)) { + return numericalSplit(data, attr); + } else { + return categoricalSplit(data, attr); + } + } + + /** + * Computes the split for a CATEGORICAL attribute + */ + private static Split categoricalSplit(Data data, int attr) { + double[] values = data.values(attr).clone(); + + double[] splitPoints = chooseCategoricalSplitPoints(values); + + int numLabels = data.getDataset().nblabels(); + int[][] counts = new int[splitPoints.length][numLabels]; + int[] countAll = new int[numLabels]; + + computeFrequencies(data, attr, splitPoints, counts, countAll); + + int size = data.size(); + double hy = entropy(countAll, size); // H(Y) + double hyx = 0.0; // H(Y|X) + double invDataSize = 1.0 / size; + + for (int index = 0; index < splitPoints.length; index++) { + size = DataUtils.sum(counts[index]); + hyx += size * invDataSize * entropy(counts[index], size); + } + + double ig = hy - hyx; + return new Split(attr, ig); + } + + static void computeFrequencies(Data data, + int attr, + double[] splitPoints, + int[][] counts, + int[] countAll) { + Dataset dataset = data.getDataset(); + + for (int index = 0; index < data.size(); index++) { + Instance instance = data.get(index); + int label = (int) dataset.getLabel(instance); + double value = instance.get(attr); + int split = 0; + while (split < splitPoints.length && value > splitPoints[split]) { + split++; + } + if (split < splitPoints.length) { + counts[split][label]++; + } // Otherwise it's in the last split, which we don't need to count + countAll[label]++; + } + } + + /** + * Computes the best split for a NUMERICAL attribute + */ + static Split numericalSplit(Data data, int attr) { + double[] values = data.values(attr).clone(); + Arrays.sort(values); + + double[] splitPoints = chooseNumericSplitPoints(values); + + int numLabels = data.getDataset().nblabels(); + int[][] counts = new int[splitPoints.length][numLabels]; + int[] countAll = new int[numLabels]; + int[] countLess = new int[numLabels]; + + computeFrequencies(data, attr, splitPoints, counts, countAll); + + int size = data.size(); + double hy = entropy(countAll, size); + double invDataSize = 1.0 / size; + + int best = -1; + double bestIg = -1.0; + + // try each possible split value + for (int index = 0; index < splitPoints.length; index++) { + double ig = hy; + + DataUtils.add(countLess, counts[index]); + DataUtils.dec(countAll, counts[index]); + + // instance with attribute value < values[index] + size = DataUtils.sum(countLess); + ig -= size * invDataSize * entropy(countLess, size); + // instance with attribute value >= values[index] + size = DataUtils.sum(countAll); + ig -= size * invDataSize * entropy(countAll, size); + + if (ig > bestIg) { + bestIg = ig; + best = index; + } + } + + if (best == -1) { + throw new IllegalStateException("no best split found !"); + } + return new Split(attr, bestIg, splitPoints[best]); + } + + /** + * @return an array of values to split the numeric feature's values on when + * building candidate splits. When input size is <= MAX_NUMERIC_SPLITS + 1, it will + * return the averages between success values as split points. When larger, it will + * return MAX_NUMERIC_SPLITS approximate percentiles through the data. + */ + private static double[] chooseNumericSplitPoints(double[] values) { + if (values.length <= 1) { + return values; + } + if (values.length <= MAX_NUMERIC_SPLITS + 1) { + double[] splitPoints = new double[values.length - 1]; + for (int i = 1; i < values.length; i++) { + splitPoints[i-1] = (values[i] + values[i-1]) / 2.0; + } + return splitPoints; + } + Percentile distribution = new Percentile(); + distribution.setData(values); + double[] percentiles = new double[MAX_NUMERIC_SPLITS]; + for (int i = 0 ; i < percentiles.length; i++) { + double p = 100.0 * ((i + 1.0) / (MAX_NUMERIC_SPLITS + 1.0)); + percentiles[i] = distribution.evaluate(p); + } + return percentiles; + } + + private static double[] chooseCategoricalSplitPoints(double[] values) { + // There is no great reason to believe that categorical value order matters, + // but the original code worked this way, and it's not terrible in the absence + // of more sophisticated analysis + Collection<Double> uniqueOrderedCategories = new TreeSet<>(); + for (double v : values) { + uniqueOrderedCategories.add(v); + } + double[] uniqueValues = new double[uniqueOrderedCategories.size()]; + Iterator<Double> it = uniqueOrderedCategories.iterator(); + for (int i = 0; i < uniqueValues.length; i++) { + uniqueValues[i] = it.next(); + } + return uniqueValues; + } + + /** + * Computes the Entropy + * + * @param counts counts[i] = numInstances with label i + * @param dataSize numInstances + */ + private static double entropy(int[] counts, int dataSize) { + if (dataSize == 0) { + return 0.0; + } + + double entropy = 0.0; + + for (int count : counts) { + if (count > 0) { + double p = count / (double) dataSize; + entropy -= p * Math.log(p); + } + } + + return entropy / LOG2; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/RegressionSplit.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/RegressionSplit.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/RegressionSplit.java new file mode 100644 index 0000000..38695a3 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/RegressionSplit.java @@ -0,0 +1,177 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.split; + +import org.apache.mahout.cf.taste.impl.common.FullRunningAverage; +import org.apache.mahout.classifier.df.data.Data; +import org.apache.mahout.classifier.df.data.Instance; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.Comparator; + +/** + * Regression problem implementation of IgSplit. This class can be used when the criterion variable is the numerical + * attribute. + */ +@Deprecated +public class RegressionSplit extends IgSplit { + + /** + * Comparator for Instance sort + */ + private static class InstanceComparator implements Comparator<Instance>, Serializable { + private final int attr; + + InstanceComparator(int attr) { + this.attr = attr; + } + + @Override + public int compare(Instance arg0, Instance arg1) { + return Double.compare(arg0.get(attr), arg1.get(attr)); + } + } + + @Override + public Split computeSplit(Data data, int attr) { + if (data.getDataset().isNumerical(attr)) { + return numericalSplit(data, attr); + } else { + return categoricalSplit(data, attr); + } + } + + /** + * Computes the split for a CATEGORICAL attribute + */ + private static Split categoricalSplit(Data data, int attr) { + FullRunningAverage[] ra = new FullRunningAverage[data.getDataset().nbValues(attr)]; + double[] sk = new double[data.getDataset().nbValues(attr)]; + for (int i = 0; i < ra.length; i++) { + ra[i] = new FullRunningAverage(); + } + FullRunningAverage totalRa = new FullRunningAverage(); + double totalSk = 0.0; + + for (int i = 0; i < data.size(); i++) { + // computes the variance + Instance instance = data.get(i); + int value = (int) instance.get(attr); + double xk = data.getDataset().getLabel(instance); + if (ra[value].getCount() == 0) { + ra[value].addDatum(xk); + sk[value] = 0.0; + } else { + double mk = ra[value].getAverage(); + ra[value].addDatum(xk); + sk[value] += (xk - mk) * (xk - ra[value].getAverage()); + } + + // total variance + if (i == 0) { + totalRa.addDatum(xk); + totalSk = 0.0; + } else { + double mk = totalRa.getAverage(); + totalRa.addDatum(xk); + totalSk += (xk - mk) * (xk - totalRa.getAverage()); + } + } + + // computes the variance gain + double ig = totalSk; + for (double aSk : sk) { + ig -= aSk; + } + + return new Split(attr, ig); + } + + /** + * Computes the best split for a NUMERICAL attribute + */ + private static Split numericalSplit(Data data, int attr) { + FullRunningAverage[] ra = new FullRunningAverage[2]; + for (int i = 0; i < ra.length; i++) { + ra[i] = new FullRunningAverage(); + } + + // Instance sort + Instance[] instances = new Instance[data.size()]; + for (int i = 0; i < data.size(); i++) { + instances[i] = data.get(i); + } + Arrays.sort(instances, new InstanceComparator(attr)); + + double[] sk = new double[2]; + for (Instance instance : instances) { + double xk = data.getDataset().getLabel(instance); + if (ra[1].getCount() == 0) { + ra[1].addDatum(xk); + sk[1] = 0.0; + } else { + double mk = ra[1].getAverage(); + ra[1].addDatum(xk); + sk[1] += (xk - mk) * (xk - ra[1].getAverage()); + } + } + double totalSk = sk[1]; + + // find the best split point + double split = Double.NaN; + double preSplit = Double.NaN; + double bestVal = Double.MAX_VALUE; + double bestSk = 0.0; + + // computes total variance + for (Instance instance : instances) { + double xk = data.getDataset().getLabel(instance); + + if (instance.get(attr) > preSplit) { + double curVal = sk[0] / ra[0].getCount() + sk[1] / ra[1].getCount(); + if (curVal < bestVal) { + bestVal = curVal; + bestSk = sk[0] + sk[1]; + split = (instance.get(attr) + preSplit) / 2.0; + } + } + + // computes the variance + if (ra[0].getCount() == 0) { + ra[0].addDatum(xk); + sk[0] = 0.0; + } else { + double mk = ra[0].getAverage(); + ra[0].addDatum(xk); + sk[0] += (xk - mk) * (xk - ra[0].getAverage()); + } + + double mk = ra[1].getAverage(); + ra[1].removeDatum(xk); + sk[1] -= (xk - mk) * (xk - ra[1].getAverage()); + + preSplit = instance.get(attr); + } + + // computes the variance gain + double ig = totalSk - bestSk; + + return new Split(attr, ig, split); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/Split.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/Split.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/Split.java new file mode 100644 index 0000000..2a6a322 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/split/Split.java @@ -0,0 +1,68 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.split; + +import java.util.Locale; + +/** + * Contains enough information to identify each split + */ +@Deprecated +public final class Split { + + private final int attr; + private final double ig; + private final double split; + + public Split(int attr, double ig, double split) { + this.attr = attr; + this.ig = ig; + this.split = split; + } + + public Split(int attr, double ig) { + this(attr, ig, Double.NaN); + } + + /** + * @return attribute to split for + */ + public int getAttr() { + return attr; + } + + /** + * @return Information Gain of the split + */ + public double getIg() { + return ig; + } + + /** + * @return split value for NUMERICAL attributes + */ + public double getSplit() { + return split; + } + + @Override + public String toString() { + return String.format(Locale.ENGLISH, "attr: %d, ig: %f, split: %f", attr, ig, split); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/Describe.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/Describe.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/Describe.java new file mode 100644 index 0000000..f29faed --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/Describe.java @@ -0,0 +1,166 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.tools; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +import org.apache.commons.cli2.CommandLine; +import org.apache.commons.cli2.Group; +import org.apache.commons.cli2.Option; +import org.apache.commons.cli2.OptionException; +import org.apache.commons.cli2.builder.ArgumentBuilder; +import org.apache.commons.cli2.builder.DefaultOptionBuilder; +import org.apache.commons.cli2.builder.GroupBuilder; +import org.apache.commons.cli2.commandline.Parser; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.util.Tool; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.classifier.df.DFUtils; +import org.apache.mahout.classifier.df.data.DataLoader; +import org.apache.mahout.classifier.df.data.Dataset; +import org.apache.mahout.classifier.df.data.DescriptorException; +import org.apache.mahout.classifier.df.data.DescriptorUtils; +import org.apache.mahout.common.CommandLineUtil; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Generates a file descriptor for a given dataset + */ +public final class Describe implements Tool { + + private static final Logger log = LoggerFactory.getLogger(Describe.class); + + private Describe() {} + + public static int main(String[] args) throws Exception { + return ToolRunner.run(new Describe(), args); + } + + @Override + public int run(String[] args) throws Exception { + DefaultOptionBuilder obuilder = new DefaultOptionBuilder(); + ArgumentBuilder abuilder = new ArgumentBuilder(); + GroupBuilder gbuilder = new GroupBuilder(); + + Option pathOpt = obuilder.withLongName("path").withShortName("p").withRequired(true).withArgument( + abuilder.withName("path").withMinimum(1).withMaximum(1).create()).withDescription("Data path").create(); + + Option descriptorOpt = obuilder.withLongName("descriptor").withShortName("d").withRequired(true) + .withArgument(abuilder.withName("descriptor").withMinimum(1).create()).withDescription( + "data descriptor").create(); + + Option descPathOpt = obuilder.withLongName("file").withShortName("f").withRequired(true).withArgument( + abuilder.withName("file").withMinimum(1).withMaximum(1).create()).withDescription( + "Path to generated descriptor file").create(); + + Option regOpt = obuilder.withLongName("regression").withDescription("Regression Problem").withShortName("r") + .create(); + + Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h") + .create(); + + Group group = gbuilder.withName("Options").withOption(pathOpt).withOption(descPathOpt).withOption( + descriptorOpt).withOption(regOpt).withOption(helpOpt).create(); + + try { + Parser parser = new Parser(); + parser.setGroup(group); + CommandLine cmdLine = parser.parse(args); + + if (cmdLine.hasOption(helpOpt)) { + CommandLineUtil.printHelp(group); + return -1; + } + + String dataPath = cmdLine.getValue(pathOpt).toString(); + String descPath = cmdLine.getValue(descPathOpt).toString(); + List<String> descriptor = convert(cmdLine.getValues(descriptorOpt)); + boolean regression = cmdLine.hasOption(regOpt); + + log.debug("Data path : {}", dataPath); + log.debug("Descriptor path : {}", descPath); + log.debug("Descriptor : {}", descriptor); + log.debug("Regression : {}", regression); + + runTool(dataPath, descriptor, descPath, regression); + } catch (OptionException e) { + log.warn(e.toString()); + CommandLineUtil.printHelp(group); + } + return 0; + } + + private void runTool(String dataPath, Iterable<String> description, String filePath, boolean regression) + throws DescriptorException, IOException { + log.info("Generating the descriptor..."); + String descriptor = DescriptorUtils.generateDescriptor(description); + + Path fPath = validateOutput(filePath); + + log.info("generating the dataset..."); + Dataset dataset = generateDataset(descriptor, dataPath, regression); + + log.info("storing the dataset description"); + String json = dataset.toJSON(); + DFUtils.storeString(conf, fPath, json); + } + + private Dataset generateDataset(String descriptor, String dataPath, boolean regression) throws IOException, + DescriptorException { + Path path = new Path(dataPath); + FileSystem fs = path.getFileSystem(conf); + + return DataLoader.generateDataset(descriptor, regression, fs, path); + } + + private Path validateOutput(String filePath) throws IOException { + Path path = new Path(filePath); + FileSystem fs = path.getFileSystem(conf); + if (fs.exists(path)) { + throw new IllegalStateException("Descriptor's file already exists"); + } + + return path; + } + + private static List<String> convert(Collection<?> values) { + List<String> list = new ArrayList<>(values.size()); + for (Object value : values) { + list.add(value.toString()); + } + return list; + } + + private Configuration conf; + + @Override + public void setConf(Configuration entries) { + this.conf = entries; + } + + @Override + public Configuration getConf() { + return conf; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/ForestVisualizer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/ForestVisualizer.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/ForestVisualizer.java new file mode 100644 index 0000000..b421c4e --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/ForestVisualizer.java @@ -0,0 +1,158 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.tools; + +import java.io.IOException; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.Collection; +import java.util.List; + +import org.apache.commons.cli2.CommandLine; +import org.apache.commons.cli2.Group; +import org.apache.commons.cli2.Option; +import org.apache.commons.cli2.builder.ArgumentBuilder; +import org.apache.commons.cli2.builder.DefaultOptionBuilder; +import org.apache.commons.cli2.builder.GroupBuilder; +import org.apache.commons.cli2.commandline.Parser; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.mahout.classifier.df.DecisionForest; +import org.apache.mahout.classifier.df.data.Dataset; +import org.apache.mahout.classifier.df.node.Node; +import org.apache.mahout.common.CommandLineUtil; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This tool is to visualize the Decision Forest + */ +@Deprecated +public final class ForestVisualizer { + + private static final Logger log = LoggerFactory.getLogger(ForestVisualizer.class); + + private ForestVisualizer() { + } + + public static String toString(DecisionForest forest, Dataset dataset, String[] attrNames) { + + List<Node> trees; + try { + Method getTrees = forest.getClass().getDeclaredMethod("getTrees"); + getTrees.setAccessible(true); + trees = (List<Node>) getTrees.invoke(forest); + } catch (IllegalAccessException e) { + throw new IllegalStateException(e); + } catch (InvocationTargetException e) { + throw new IllegalStateException(e); + } catch (NoSuchMethodException e) { + throw new IllegalStateException(e); + } + + int cnt = 1; + StringBuilder buff = new StringBuilder(); + for (Node tree : trees) { + buff.append("Tree[").append(cnt).append("]:"); + buff.append(TreeVisualizer.toString(tree, dataset, attrNames)); + buff.append('\n'); + cnt++; + } + return buff.toString(); + } + + /** + * Decision Forest to String + * @param forestPath + * path to the Decision Forest + * @param datasetPath + * dataset path + * @param attrNames + * attribute names + */ + public static String toString(String forestPath, String datasetPath, String[] attrNames) throws IOException { + Configuration conf = new Configuration(); + DecisionForest forest = DecisionForest.load(conf, new Path(forestPath)); + Dataset dataset = Dataset.load(conf, new Path(datasetPath)); + return toString(forest, dataset, attrNames); + } + + /** + * Print Decision Forest + * @param forestPath + * path to the Decision Forest + * @param datasetPath + * dataset path + * @param attrNames + * attribute names + */ + public static void print(String forestPath, String datasetPath, String[] attrNames) throws IOException { + System.out.println(toString(forestPath, datasetPath, attrNames)); + } + + public static void main(String[] args) { + DefaultOptionBuilder obuilder = new DefaultOptionBuilder(); + ArgumentBuilder abuilder = new ArgumentBuilder(); + GroupBuilder gbuilder = new GroupBuilder(); + + Option datasetOpt = obuilder.withLongName("dataset").withShortName("ds").withRequired(true) + .withArgument(abuilder.withName("dataset").withMinimum(1).withMaximum(1).create()) + .withDescription("Dataset path").create(); + + Option modelOpt = obuilder.withLongName("model").withShortName("m").withRequired(true) + .withArgument(abuilder.withName("path").withMinimum(1).withMaximum(1).create()) + .withDescription("Path to the Decision Forest").create(); + + Option attrNamesOpt = obuilder.withLongName("names").withShortName("n").withRequired(false) + .withArgument(abuilder.withName("names").withMinimum(1).create()) + .withDescription("Optional, Attribute names").create(); + + Option helpOpt = obuilder.withLongName("help").withShortName("h") + .withDescription("Print out help").create(); + + Group group = gbuilder.withName("Options").withOption(datasetOpt).withOption(modelOpt) + .withOption(attrNamesOpt).withOption(helpOpt).create(); + + try { + Parser parser = new Parser(); + parser.setGroup(group); + CommandLine cmdLine = parser.parse(args); + + if (cmdLine.hasOption("help")) { + CommandLineUtil.printHelp(group); + return; + } + + String datasetName = cmdLine.getValue(datasetOpt).toString(); + String modelName = cmdLine.getValue(modelOpt).toString(); + String[] attrNames = null; + if (cmdLine.hasOption(attrNamesOpt)) { + Collection<String> names = (Collection<String>) cmdLine.getValues(attrNamesOpt); + if (!names.isEmpty()) { + attrNames = new String[names.size()]; + names.toArray(attrNames); + } + } + + print(modelName, datasetName, attrNames); + } catch (Exception e) { + log.error("Exception", e); + CommandLineUtil.printHelp(group); + } + } +}
