http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/OptIgSplit.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/OptIgSplit.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/OptIgSplit.java new file mode 100644 index 0000000..56b1a04 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/OptIgSplit.java @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.split; + +import org.apache.commons.math3.stat.descriptive.rank.Percentile; +import org.apache.mahout.classifier.df.data.Data; +import org.apache.mahout.classifier.df.data.DataUtils; +import org.apache.mahout.classifier.df.data.Dataset; +import org.apache.mahout.classifier.df.data.Instance; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.TreeSet; + +/** + * <p>Optimized implementation of IgSplit. + * This class can be used when the criterion variable is the categorical attribute.</p> + * + * <p>This code was changed in MAHOUT-1419 to deal in sampled splits among numeric + * features to fix a performance problem. To generate some synthetic data that exercises + * the issue, try for example generating 4 features of Normal(0,1) values with a random + * boolean 0/1 categorical feature. In Scala:</p> + * + * {@code + * val r = new scala.util.Random() + * val pw = new java.io.PrintWriter("random.csv") + * (1 to 10000000).foreach(e => + * pw.println(r.nextDouble() + "," + + * r.nextDouble() + "," + + * r.nextDouble() + "," + + * r.nextDouble() + "," + + * (if (r.nextBoolean()) 1 else 0)) + * ) + * pw.close() + * } + */ +@Deprecated +public class OptIgSplit extends IgSplit { + + private static final int MAX_NUMERIC_SPLITS = 16; + + @Override + public Split computeSplit(Data data, int attr) { + if (data.getDataset().isNumerical(attr)) { + return numericalSplit(data, attr); + } else { + return categoricalSplit(data, attr); + } + } + + /** + * Computes the split for a CATEGORICAL attribute + */ + private static Split categoricalSplit(Data data, int attr) { + double[] values = data.values(attr).clone(); + + double[] splitPoints = chooseCategoricalSplitPoints(values); + + int numLabels = data.getDataset().nblabels(); + int[][] counts = new int[splitPoints.length][numLabels]; + int[] countAll = new int[numLabels]; + + computeFrequencies(data, attr, splitPoints, counts, countAll); + + int size = data.size(); + double hy = entropy(countAll, size); // H(Y) + double hyx = 0.0; // H(Y|X) + double invDataSize = 1.0 / size; + + for (int index = 0; index < splitPoints.length; index++) { + size = DataUtils.sum(counts[index]); + hyx += size * invDataSize * entropy(counts[index], size); + } + + double ig = hy - hyx; + return new Split(attr, ig); + } + + static void computeFrequencies(Data data, + int attr, + double[] splitPoints, + int[][] counts, + int[] countAll) { + Dataset dataset = data.getDataset(); + + for (int index = 0; index < data.size(); index++) { + Instance instance = data.get(index); + int label = (int) dataset.getLabel(instance); + double value = instance.get(attr); + int split = 0; + while (split < splitPoints.length && value > splitPoints[split]) { + split++; + } + if (split < splitPoints.length) { + counts[split][label]++; + } // Otherwise it's in the last split, which we don't need to count + countAll[label]++; + } + } + + /** + * Computes the best split for a NUMERICAL attribute + */ + static Split numericalSplit(Data data, int attr) { + double[] values = data.values(attr).clone(); + Arrays.sort(values); + + double[] splitPoints = chooseNumericSplitPoints(values); + + int numLabels = data.getDataset().nblabels(); + int[][] counts = new int[splitPoints.length][numLabels]; + int[] countAll = new int[numLabels]; + int[] countLess = new int[numLabels]; + + computeFrequencies(data, attr, splitPoints, counts, countAll); + + int size = data.size(); + double hy = entropy(countAll, size); + double invDataSize = 1.0 / size; + + int best = -1; + double bestIg = -1.0; + + // try each possible split value + for (int index = 0; index < splitPoints.length; index++) { + double ig = hy; + + DataUtils.add(countLess, counts[index]); + DataUtils.dec(countAll, counts[index]); + + // instance with attribute value < values[index] + size = DataUtils.sum(countLess); + ig -= size * invDataSize * entropy(countLess, size); + // instance with attribute value >= values[index] + size = DataUtils.sum(countAll); + ig -= size * invDataSize * entropy(countAll, size); + + if (ig > bestIg) { + bestIg = ig; + best = index; + } + } + + if (best == -1) { + throw new IllegalStateException("no best split found !"); + } + return new Split(attr, bestIg, splitPoints[best]); + } + + /** + * @return an array of values to split the numeric feature's values on when + * building candidate splits. When input size is <= MAX_NUMERIC_SPLITS + 1, it will + * return the averages between success values as split points. When larger, it will + * return MAX_NUMERIC_SPLITS approximate percentiles through the data. + */ + private static double[] chooseNumericSplitPoints(double[] values) { + if (values.length <= 1) { + return values; + } + if (values.length <= MAX_NUMERIC_SPLITS + 1) { + double[] splitPoints = new double[values.length - 1]; + for (int i = 1; i < values.length; i++) { + splitPoints[i-1] = (values[i] + values[i-1]) / 2.0; + } + return splitPoints; + } + Percentile distribution = new Percentile(); + distribution.setData(values); + double[] percentiles = new double[MAX_NUMERIC_SPLITS]; + for (int i = 0 ; i < percentiles.length; i++) { + double p = 100.0 * ((i + 1.0) / (MAX_NUMERIC_SPLITS + 1.0)); + percentiles[i] = distribution.evaluate(p); + } + return percentiles; + } + + private static double[] chooseCategoricalSplitPoints(double[] values) { + // There is no great reason to believe that categorical value order matters, + // but the original code worked this way, and it's not terrible in the absence + // of more sophisticated analysis + Collection<Double> uniqueOrderedCategories = new TreeSet<>(); + for (double v : values) { + uniqueOrderedCategories.add(v); + } + double[] uniqueValues = new double[uniqueOrderedCategories.size()]; + Iterator<Double> it = uniqueOrderedCategories.iterator(); + for (int i = 0; i < uniqueValues.length; i++) { + uniqueValues[i] = it.next(); + } + return uniqueValues; + } + + /** + * Computes the Entropy + * + * @param counts counts[i] = numInstances with label i + * @param dataSize numInstances + */ + private static double entropy(int[] counts, int dataSize) { + if (dataSize == 0) { + return 0.0; + } + + double entropy = 0.0; + + for (int count : counts) { + if (count > 0) { + double p = count / (double) dataSize; + entropy -= p * Math.log(p); + } + } + + return entropy / LOG2; + } + +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/RegressionSplit.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/RegressionSplit.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/RegressionSplit.java new file mode 100644 index 0000000..38695a3 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/RegressionSplit.java @@ -0,0 +1,177 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.split; + +import org.apache.mahout.cf.taste.impl.common.FullRunningAverage; +import org.apache.mahout.classifier.df.data.Data; +import org.apache.mahout.classifier.df.data.Instance; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.Comparator; + +/** + * Regression problem implementation of IgSplit. This class can be used when the criterion variable is the numerical + * attribute. + */ +@Deprecated +public class RegressionSplit extends IgSplit { + + /** + * Comparator for Instance sort + */ + private static class InstanceComparator implements Comparator<Instance>, Serializable { + private final int attr; + + InstanceComparator(int attr) { + this.attr = attr; + } + + @Override + public int compare(Instance arg0, Instance arg1) { + return Double.compare(arg0.get(attr), arg1.get(attr)); + } + } + + @Override + public Split computeSplit(Data data, int attr) { + if (data.getDataset().isNumerical(attr)) { + return numericalSplit(data, attr); + } else { + return categoricalSplit(data, attr); + } + } + + /** + * Computes the split for a CATEGORICAL attribute + */ + private static Split categoricalSplit(Data data, int attr) { + FullRunningAverage[] ra = new FullRunningAverage[data.getDataset().nbValues(attr)]; + double[] sk = new double[data.getDataset().nbValues(attr)]; + for (int i = 0; i < ra.length; i++) { + ra[i] = new FullRunningAverage(); + } + FullRunningAverage totalRa = new FullRunningAverage(); + double totalSk = 0.0; + + for (int i = 0; i < data.size(); i++) { + // computes the variance + Instance instance = data.get(i); + int value = (int) instance.get(attr); + double xk = data.getDataset().getLabel(instance); + if (ra[value].getCount() == 0) { + ra[value].addDatum(xk); + sk[value] = 0.0; + } else { + double mk = ra[value].getAverage(); + ra[value].addDatum(xk); + sk[value] += (xk - mk) * (xk - ra[value].getAverage()); + } + + // total variance + if (i == 0) { + totalRa.addDatum(xk); + totalSk = 0.0; + } else { + double mk = totalRa.getAverage(); + totalRa.addDatum(xk); + totalSk += (xk - mk) * (xk - totalRa.getAverage()); + } + } + + // computes the variance gain + double ig = totalSk; + for (double aSk : sk) { + ig -= aSk; + } + + return new Split(attr, ig); + } + + /** + * Computes the best split for a NUMERICAL attribute + */ + private static Split numericalSplit(Data data, int attr) { + FullRunningAverage[] ra = new FullRunningAverage[2]; + for (int i = 0; i < ra.length; i++) { + ra[i] = new FullRunningAverage(); + } + + // Instance sort + Instance[] instances = new Instance[data.size()]; + for (int i = 0; i < data.size(); i++) { + instances[i] = data.get(i); + } + Arrays.sort(instances, new InstanceComparator(attr)); + + double[] sk = new double[2]; + for (Instance instance : instances) { + double xk = data.getDataset().getLabel(instance); + if (ra[1].getCount() == 0) { + ra[1].addDatum(xk); + sk[1] = 0.0; + } else { + double mk = ra[1].getAverage(); + ra[1].addDatum(xk); + sk[1] += (xk - mk) * (xk - ra[1].getAverage()); + } + } + double totalSk = sk[1]; + + // find the best split point + double split = Double.NaN; + double preSplit = Double.NaN; + double bestVal = Double.MAX_VALUE; + double bestSk = 0.0; + + // computes total variance + for (Instance instance : instances) { + double xk = data.getDataset().getLabel(instance); + + if (instance.get(attr) > preSplit) { + double curVal = sk[0] / ra[0].getCount() + sk[1] / ra[1].getCount(); + if (curVal < bestVal) { + bestVal = curVal; + bestSk = sk[0] + sk[1]; + split = (instance.get(attr) + preSplit) / 2.0; + } + } + + // computes the variance + if (ra[0].getCount() == 0) { + ra[0].addDatum(xk); + sk[0] = 0.0; + } else { + double mk = ra[0].getAverage(); + ra[0].addDatum(xk); + sk[0] += (xk - mk) * (xk - ra[0].getAverage()); + } + + double mk = ra[1].getAverage(); + ra[1].removeDatum(xk); + sk[1] -= (xk - mk) * (xk - ra[1].getAverage()); + + preSplit = instance.get(attr); + } + + // computes the variance gain + double ig = totalSk - bestSk; + + return new Split(attr, ig, split); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/Split.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/Split.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/Split.java new file mode 100644 index 0000000..2a6a322 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/Split.java @@ -0,0 +1,68 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.split; + +import java.util.Locale; + +/** + * Contains enough information to identify each split + */ +@Deprecated +public final class Split { + + private final int attr; + private final double ig; + private final double split; + + public Split(int attr, double ig, double split) { + this.attr = attr; + this.ig = ig; + this.split = split; + } + + public Split(int attr, double ig) { + this(attr, ig, Double.NaN); + } + + /** + * @return attribute to split for + */ + public int getAttr() { + return attr; + } + + /** + * @return Information Gain of the split + */ + public double getIg() { + return ig; + } + + /** + * @return split value for NUMERICAL attributes + */ + public double getSplit() { + return split; + } + + @Override + public String toString() { + return String.format(Locale.ENGLISH, "attr: %d, ig: %f, split: %f", attr, ig, split); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/Describe.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/Describe.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/Describe.java new file mode 100644 index 0000000..f29faed --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/Describe.java @@ -0,0 +1,166 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.tools; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +import org.apache.commons.cli2.CommandLine; +import org.apache.commons.cli2.Group; +import org.apache.commons.cli2.Option; +import org.apache.commons.cli2.OptionException; +import org.apache.commons.cli2.builder.ArgumentBuilder; +import org.apache.commons.cli2.builder.DefaultOptionBuilder; +import org.apache.commons.cli2.builder.GroupBuilder; +import org.apache.commons.cli2.commandline.Parser; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.util.Tool; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.classifier.df.DFUtils; +import org.apache.mahout.classifier.df.data.DataLoader; +import org.apache.mahout.classifier.df.data.Dataset; +import org.apache.mahout.classifier.df.data.DescriptorException; +import org.apache.mahout.classifier.df.data.DescriptorUtils; +import org.apache.mahout.common.CommandLineUtil; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Generates a file descriptor for a given dataset + */ +public final class Describe implements Tool { + + private static final Logger log = LoggerFactory.getLogger(Describe.class); + + private Describe() {} + + public static int main(String[] args) throws Exception { + return ToolRunner.run(new Describe(), args); + } + + @Override + public int run(String[] args) throws Exception { + DefaultOptionBuilder obuilder = new DefaultOptionBuilder(); + ArgumentBuilder abuilder = new ArgumentBuilder(); + GroupBuilder gbuilder = new GroupBuilder(); + + Option pathOpt = obuilder.withLongName("path").withShortName("p").withRequired(true).withArgument( + abuilder.withName("path").withMinimum(1).withMaximum(1).create()).withDescription("Data path").create(); + + Option descriptorOpt = obuilder.withLongName("descriptor").withShortName("d").withRequired(true) + .withArgument(abuilder.withName("descriptor").withMinimum(1).create()).withDescription( + "data descriptor").create(); + + Option descPathOpt = obuilder.withLongName("file").withShortName("f").withRequired(true).withArgument( + abuilder.withName("file").withMinimum(1).withMaximum(1).create()).withDescription( + "Path to generated descriptor file").create(); + + Option regOpt = obuilder.withLongName("regression").withDescription("Regression Problem").withShortName("r") + .create(); + + Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h") + .create(); + + Group group = gbuilder.withName("Options").withOption(pathOpt).withOption(descPathOpt).withOption( + descriptorOpt).withOption(regOpt).withOption(helpOpt).create(); + + try { + Parser parser = new Parser(); + parser.setGroup(group); + CommandLine cmdLine = parser.parse(args); + + if (cmdLine.hasOption(helpOpt)) { + CommandLineUtil.printHelp(group); + return -1; + } + + String dataPath = cmdLine.getValue(pathOpt).toString(); + String descPath = cmdLine.getValue(descPathOpt).toString(); + List<String> descriptor = convert(cmdLine.getValues(descriptorOpt)); + boolean regression = cmdLine.hasOption(regOpt); + + log.debug("Data path : {}", dataPath); + log.debug("Descriptor path : {}", descPath); + log.debug("Descriptor : {}", descriptor); + log.debug("Regression : {}", regression); + + runTool(dataPath, descriptor, descPath, regression); + } catch (OptionException e) { + log.warn(e.toString()); + CommandLineUtil.printHelp(group); + } + return 0; + } + + private void runTool(String dataPath, Iterable<String> description, String filePath, boolean regression) + throws DescriptorException, IOException { + log.info("Generating the descriptor..."); + String descriptor = DescriptorUtils.generateDescriptor(description); + + Path fPath = validateOutput(filePath); + + log.info("generating the dataset..."); + Dataset dataset = generateDataset(descriptor, dataPath, regression); + + log.info("storing the dataset description"); + String json = dataset.toJSON(); + DFUtils.storeString(conf, fPath, json); + } + + private Dataset generateDataset(String descriptor, String dataPath, boolean regression) throws IOException, + DescriptorException { + Path path = new Path(dataPath); + FileSystem fs = path.getFileSystem(conf); + + return DataLoader.generateDataset(descriptor, regression, fs, path); + } + + private Path validateOutput(String filePath) throws IOException { + Path path = new Path(filePath); + FileSystem fs = path.getFileSystem(conf); + if (fs.exists(path)) { + throw new IllegalStateException("Descriptor's file already exists"); + } + + return path; + } + + private static List<String> convert(Collection<?> values) { + List<String> list = new ArrayList<>(values.size()); + for (Object value : values) { + list.add(value.toString()); + } + return list; + } + + private Configuration conf; + + @Override + public void setConf(Configuration entries) { + this.conf = entries; + } + + @Override + public Configuration getConf() { + return conf; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/ForestVisualizer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/ForestVisualizer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/ForestVisualizer.java new file mode 100644 index 0000000..b421c4e --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/ForestVisualizer.java @@ -0,0 +1,158 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.tools; + +import java.io.IOException; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.Collection; +import java.util.List; + +import org.apache.commons.cli2.CommandLine; +import org.apache.commons.cli2.Group; +import org.apache.commons.cli2.Option; +import org.apache.commons.cli2.builder.ArgumentBuilder; +import org.apache.commons.cli2.builder.DefaultOptionBuilder; +import org.apache.commons.cli2.builder.GroupBuilder; +import org.apache.commons.cli2.commandline.Parser; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.mahout.classifier.df.DecisionForest; +import org.apache.mahout.classifier.df.data.Dataset; +import org.apache.mahout.classifier.df.node.Node; +import org.apache.mahout.common.CommandLineUtil; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This tool is to visualize the Decision Forest + */ +@Deprecated +public final class ForestVisualizer { + + private static final Logger log = LoggerFactory.getLogger(ForestVisualizer.class); + + private ForestVisualizer() { + } + + public static String toString(DecisionForest forest, Dataset dataset, String[] attrNames) { + + List<Node> trees; + try { + Method getTrees = forest.getClass().getDeclaredMethod("getTrees"); + getTrees.setAccessible(true); + trees = (List<Node>) getTrees.invoke(forest); + } catch (IllegalAccessException e) { + throw new IllegalStateException(e); + } catch (InvocationTargetException e) { + throw new IllegalStateException(e); + } catch (NoSuchMethodException e) { + throw new IllegalStateException(e); + } + + int cnt = 1; + StringBuilder buff = new StringBuilder(); + for (Node tree : trees) { + buff.append("Tree[").append(cnt).append("]:"); + buff.append(TreeVisualizer.toString(tree, dataset, attrNames)); + buff.append('\n'); + cnt++; + } + return buff.toString(); + } + + /** + * Decision Forest to String + * @param forestPath + * path to the Decision Forest + * @param datasetPath + * dataset path + * @param attrNames + * attribute names + */ + public static String toString(String forestPath, String datasetPath, String[] attrNames) throws IOException { + Configuration conf = new Configuration(); + DecisionForest forest = DecisionForest.load(conf, new Path(forestPath)); + Dataset dataset = Dataset.load(conf, new Path(datasetPath)); + return toString(forest, dataset, attrNames); + } + + /** + * Print Decision Forest + * @param forestPath + * path to the Decision Forest + * @param datasetPath + * dataset path + * @param attrNames + * attribute names + */ + public static void print(String forestPath, String datasetPath, String[] attrNames) throws IOException { + System.out.println(toString(forestPath, datasetPath, attrNames)); + } + + public static void main(String[] args) { + DefaultOptionBuilder obuilder = new DefaultOptionBuilder(); + ArgumentBuilder abuilder = new ArgumentBuilder(); + GroupBuilder gbuilder = new GroupBuilder(); + + Option datasetOpt = obuilder.withLongName("dataset").withShortName("ds").withRequired(true) + .withArgument(abuilder.withName("dataset").withMinimum(1).withMaximum(1).create()) + .withDescription("Dataset path").create(); + + Option modelOpt = obuilder.withLongName("model").withShortName("m").withRequired(true) + .withArgument(abuilder.withName("path").withMinimum(1).withMaximum(1).create()) + .withDescription("Path to the Decision Forest").create(); + + Option attrNamesOpt = obuilder.withLongName("names").withShortName("n").withRequired(false) + .withArgument(abuilder.withName("names").withMinimum(1).create()) + .withDescription("Optional, Attribute names").create(); + + Option helpOpt = obuilder.withLongName("help").withShortName("h") + .withDescription("Print out help").create(); + + Group group = gbuilder.withName("Options").withOption(datasetOpt).withOption(modelOpt) + .withOption(attrNamesOpt).withOption(helpOpt).create(); + + try { + Parser parser = new Parser(); + parser.setGroup(group); + CommandLine cmdLine = parser.parse(args); + + if (cmdLine.hasOption("help")) { + CommandLineUtil.printHelp(group); + return; + } + + String datasetName = cmdLine.getValue(datasetOpt).toString(); + String modelName = cmdLine.getValue(modelOpt).toString(); + String[] attrNames = null; + if (cmdLine.hasOption(attrNamesOpt)) { + Collection<String> names = (Collection<String>) cmdLine.getValues(attrNamesOpt); + if (!names.isEmpty()) { + attrNames = new String[names.size()]; + names.toArray(attrNames); + } + } + + print(modelName, datasetName, attrNames); + } catch (Exception e) { + log.error("Exception", e); + CommandLineUtil.printHelp(group); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/Frequencies.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/Frequencies.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/Frequencies.java new file mode 100644 index 0000000..c37af4e --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/Frequencies.java @@ -0,0 +1,122 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.tools; + +import org.apache.commons.cli2.CommandLine; +import org.apache.commons.cli2.Group; +import org.apache.commons.cli2.Option; +import org.apache.commons.cli2.OptionException; +import org.apache.commons.cli2.builder.ArgumentBuilder; +import org.apache.commons.cli2.builder.DefaultOptionBuilder; +import org.apache.commons.cli2.builder.GroupBuilder; +import org.apache.commons.cli2.commandline.Parser; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.conf.Configured; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.util.Tool; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.common.CommandLineUtil; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Arrays; + +/** + * Compute the frequency distribution of the "class label"<br> + * This class can be used when the criterion variable is the categorical attribute. + */ +@Deprecated +public final class Frequencies extends Configured implements Tool { + + private static final Logger log = LoggerFactory.getLogger(Frequencies.class); + + private Frequencies() { } + + @Override + public int run(String[] args) throws IOException, ClassNotFoundException, InterruptedException { + + DefaultOptionBuilder obuilder = new DefaultOptionBuilder(); + ArgumentBuilder abuilder = new ArgumentBuilder(); + GroupBuilder gbuilder = new GroupBuilder(); + + Option dataOpt = obuilder.withLongName("data").withShortName("d").withRequired(true).withArgument( + abuilder.withName("path").withMinimum(1).withMaximum(1).create()).withDescription("Data path").create(); + + Option datasetOpt = obuilder.withLongName("dataset").withShortName("ds").withRequired(true).withArgument( + abuilder.withName("path").withMinimum(1).create()).withDescription("dataset path").create(); + + Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h") + .create(); + + Group group = gbuilder.withName("Options").withOption(dataOpt).withOption(datasetOpt).withOption(helpOpt) + .create(); + + try { + Parser parser = new Parser(); + parser.setGroup(group); + CommandLine cmdLine = parser.parse(args); + + if (cmdLine.hasOption(helpOpt)) { + CommandLineUtil.printHelp(group); + return 0; + } + + String dataPath = cmdLine.getValue(dataOpt).toString(); + String datasetPath = cmdLine.getValue(datasetOpt).toString(); + + log.debug("Data path : {}", dataPath); + log.debug("Dataset path : {}", datasetPath); + + runTool(dataPath, datasetPath); + } catch (OptionException e) { + log.warn(e.toString(), e); + CommandLineUtil.printHelp(group); + } + + return 0; + } + + private void runTool(String data, String dataset) throws IOException, + ClassNotFoundException, + InterruptedException { + + FileSystem fs = FileSystem.get(getConf()); + Path workingDir = fs.getWorkingDirectory(); + + Path dataPath = new Path(data); + Path datasetPath = new Path(dataset); + + log.info("Computing the frequencies..."); + FrequenciesJob job = new FrequenciesJob(new Path(workingDir, "output"), dataPath, datasetPath); + + int[][] counts = job.run(getConf()); + + // outputing the frequencies + log.info("counts[partition][class]"); + for (int[] count : counts) { + log.info(Arrays.toString(count)); + } + } + + public static void main(String[] args) throws Exception { + ToolRunner.run(new Configuration(), new Frequencies(), args); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java new file mode 100644 index 0000000..9d7e2ff --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java @@ -0,0 +1,297 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.tools; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.filecache.DistributedCache; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.JobContext; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; +import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.mahout.classifier.df.DFUtils; +import org.apache.mahout.classifier.df.data.DataConverter; +import org.apache.mahout.classifier.df.data.Dataset; +import org.apache.mahout.classifier.df.data.Instance; +import org.apache.mahout.classifier.df.mapreduce.Builder; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.net.URI; +import java.util.Arrays; + +/** + * Temporary class used to compute the frequency distribution of the "class attribute".<br> + * This class can be used when the criterion variable is the categorical attribute. + */ +@Deprecated +public class FrequenciesJob { + + private static final Logger log = LoggerFactory.getLogger(FrequenciesJob.class); + + /** directory that will hold this job's output */ + private final Path outputPath; + + /** file that contains the serialized dataset */ + private final Path datasetPath; + + /** directory that contains the data used in the first step */ + private final Path dataPath; + + /** + * @param base + * base directory + * @param dataPath + * data used in the first step + */ + public FrequenciesJob(Path base, Path dataPath, Path datasetPath) { + this.outputPath = new Path(base, "frequencies.output"); + this.dataPath = dataPath; + this.datasetPath = datasetPath; + } + + /** + * @return counts[partition][label] = num tuples from 'partition' with class == label + */ + public int[][] run(Configuration conf) throws IOException, ClassNotFoundException, InterruptedException { + + // check the output + FileSystem fs = outputPath.getFileSystem(conf); + if (fs.exists(outputPath)) { + throw new IOException("Output path already exists : " + outputPath); + } + + // put the dataset into the DistributedCache + URI[] files = {datasetPath.toUri()}; + DistributedCache.setCacheFiles(files, conf); + + Job job = new Job(conf); + job.setJarByClass(FrequenciesJob.class); + + FileInputFormat.setInputPaths(job, dataPath); + FileOutputFormat.setOutputPath(job, outputPath); + + job.setMapOutputKeyClass(LongWritable.class); + job.setMapOutputValueClass(IntWritable.class); + job.setOutputKeyClass(LongWritable.class); + job.setOutputValueClass(Frequencies.class); + + job.setMapperClass(FrequenciesMapper.class); + job.setReducerClass(FrequenciesReducer.class); + + job.setInputFormatClass(TextInputFormat.class); + job.setOutputFormatClass(SequenceFileOutputFormat.class); + + // run the job + boolean succeeded = job.waitForCompletion(true); + if (!succeeded) { + throw new IllegalStateException("Job failed!"); + } + + int[][] counts = parseOutput(job); + + HadoopUtil.delete(conf, outputPath); + + return counts; + } + + /** + * Extracts the output and processes it + * + * @return counts[partition][label] = num tuples from 'partition' with class == label + */ + int[][] parseOutput(JobContext job) throws IOException { + Configuration conf = job.getConfiguration(); + + int numMaps = conf.getInt("mapred.map.tasks", -1); + log.info("mapred.map.tasks = {}", numMaps); + + FileSystem fs = outputPath.getFileSystem(conf); + + Path[] outfiles = DFUtils.listOutputFiles(fs, outputPath); + + Frequencies[] values = new Frequencies[numMaps]; + + // read all the outputs + int index = 0; + for (Path path : outfiles) { + for (Frequencies value : new SequenceFileValueIterable<Frequencies>(path, conf)) { + values[index++] = value; + } + } + + if (index < numMaps) { + throw new IllegalStateException("number of output Frequencies (" + index + + ") is lesser than the number of mappers!"); + } + + // sort the frequencies using the firstIds + Arrays.sort(values); + return Frequencies.extractCounts(values); + } + + /** + * Outputs the first key and the label of each tuple + * + */ + private static class FrequenciesMapper extends Mapper<LongWritable,Text,LongWritable,IntWritable> { + + private LongWritable firstId; + + private DataConverter converter; + private Dataset dataset; + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + Configuration conf = context.getConfiguration(); + + dataset = Builder.loadDataset(conf); + setup(dataset); + } + + /** + * Useful when testing + */ + void setup(Dataset dataset) { + converter = new DataConverter(dataset); + } + + @Override + protected void map(LongWritable key, Text value, Context context) throws IOException, + InterruptedException { + if (firstId == null) { + firstId = new LongWritable(key.get()); + } + + Instance instance = converter.convert(value.toString()); + + context.write(firstId, new IntWritable((int) dataset.getLabel(instance))); + } + + } + + private static class FrequenciesReducer extends Reducer<LongWritable,IntWritable,LongWritable,Frequencies> { + + private int nblabels; + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + Configuration conf = context.getConfiguration(); + Dataset dataset = Builder.loadDataset(conf); + setup(dataset.nblabels()); + } + + /** + * Useful when testing + */ + void setup(int nblabels) { + this.nblabels = nblabels; + } + + @Override + protected void reduce(LongWritable key, Iterable<IntWritable> values, Context context) + throws IOException, InterruptedException { + int[] counts = new int[nblabels]; + for (IntWritable value : values) { + counts[value.get()]++; + } + + context.write(key, new Frequencies(key.get(), counts)); + } + } + + /** + * Output of the job + * + */ + private static class Frequencies implements Writable, Comparable<Frequencies>, Cloneable { + + /** first key of the partition used to sort the partitions */ + private long firstId; + + /** counts[c] = num tuples from the partition with label == c */ + private int[] counts; + + Frequencies() { } + + Frequencies(long firstId, int[] counts) { + this.firstId = firstId; + this.counts = Arrays.copyOf(counts, counts.length); + } + + @Override + public void readFields(DataInput in) throws IOException { + firstId = in.readLong(); + counts = DFUtils.readIntArray(in); + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeLong(firstId); + DFUtils.writeArray(out, counts); + } + + @Override + public boolean equals(Object other) { + return other instanceof Frequencies && firstId == ((Frequencies) other).firstId; + } + + @Override + public int hashCode() { + return (int) firstId; + } + + @Override + protected Frequencies clone() { + return new Frequencies(firstId, counts); + } + + @Override + public int compareTo(Frequencies obj) { + if (firstId < obj.firstId) { + return -1; + } else if (firstId > obj.firstId) { + return 1; + } else { + return 0; + } + } + + public static int[][] extractCounts(Frequencies[] partitions) { + int[][] counts = new int[partitions.length][]; + for (int p = 0; p < partitions.length; p++) { + counts[p] = partitions[p].counts; + } + return counts; + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java new file mode 100644 index 0000000..a2a3458 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java @@ -0,0 +1,264 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.tools; + +import java.lang.reflect.Field; +import java.text.DecimalFormat; +import java.util.HashMap; +import java.util.Map; + +import org.apache.commons.lang3.ArrayUtils; +import org.apache.mahout.classifier.df.data.Data; +import org.apache.mahout.classifier.df.data.Dataset; +import org.apache.mahout.classifier.df.data.Instance; +import org.apache.mahout.classifier.df.node.CategoricalNode; +import org.apache.mahout.classifier.df.node.Leaf; +import org.apache.mahout.classifier.df.node.Node; +import org.apache.mahout.classifier.df.node.NumericalNode; + +/** + * This tool is to visualize the Decision tree + */ +@Deprecated +public final class TreeVisualizer { + + private TreeVisualizer() {} + + private static String doubleToString(double value) { + DecimalFormat df = new DecimalFormat("0.##"); + return df.format(value); + } + + private static String toStringNode(Node node, Dataset dataset, + String[] attrNames, Map<String,Field> fields, int layer) { + + StringBuilder buff = new StringBuilder(); + + try { + if (node instanceof CategoricalNode) { + CategoricalNode cnode = (CategoricalNode) node; + int attr = (Integer) fields.get("CategoricalNode.attr").get(cnode); + double[] values = (double[]) fields.get("CategoricalNode.values").get(cnode); + Node[] childs = (Node[]) fields.get("CategoricalNode.childs").get(cnode); + String[][] attrValues = (String[][]) fields.get("Dataset.values").get(dataset); + for (int i = 0; i < attrValues[attr].length; i++) { + int index = ArrayUtils.indexOf(values, i); + if (index < 0) { + continue; + } + buff.append('\n'); + for (int j = 0; j < layer; j++) { + buff.append("| "); + } + buff.append(attrNames == null ? attr : attrNames[attr]).append(" = ") + .append(attrValues[attr][i]); + buff.append(toStringNode(childs[index], dataset, attrNames, fields, layer + 1)); + } + } else if (node instanceof NumericalNode) { + NumericalNode nnode = (NumericalNode) node; + int attr = (Integer) fields.get("NumericalNode.attr").get(nnode); + double split = (Double) fields.get("NumericalNode.split").get(nnode); + Node loChild = (Node) fields.get("NumericalNode.loChild").get(nnode); + Node hiChild = (Node) fields.get("NumericalNode.hiChild").get(nnode); + buff.append('\n'); + for (int j = 0; j < layer; j++) { + buff.append("| "); + } + buff.append(attrNames == null ? attr : attrNames[attr]).append(" < ") + .append(doubleToString(split)); + buff.append(toStringNode(loChild, dataset, attrNames, fields, layer + 1)); + buff.append('\n'); + for (int j = 0; j < layer; j++) { + buff.append("| "); + } + buff.append(attrNames == null ? attr : attrNames[attr]).append(" >= ") + .append(doubleToString(split)); + buff.append(toStringNode(hiChild, dataset, attrNames, fields, layer + 1)); + } else if (node instanceof Leaf) { + Leaf leaf = (Leaf) node; + double label = (Double) fields.get("Leaf.label").get(leaf); + if (dataset.isNumerical(dataset.getLabelId())) { + buff.append(" : ").append(doubleToString(label)); + } else { + buff.append(" : ").append(dataset.getLabelString(label)); + } + } + } catch (IllegalAccessException iae) { + throw new IllegalStateException(iae); + } + + return buff.toString(); + } + + private static Map<String,Field> getReflectMap() { + Map<String,Field> fields = new HashMap<>(); + + try { + Field m = CategoricalNode.class.getDeclaredField("attr"); + m.setAccessible(true); + fields.put("CategoricalNode.attr", m); + m = CategoricalNode.class.getDeclaredField("values"); + m.setAccessible(true); + fields.put("CategoricalNode.values", m); + m = CategoricalNode.class.getDeclaredField("childs"); + m.setAccessible(true); + fields.put("CategoricalNode.childs", m); + m = NumericalNode.class.getDeclaredField("attr"); + m.setAccessible(true); + fields.put("NumericalNode.attr", m); + m = NumericalNode.class.getDeclaredField("split"); + m.setAccessible(true); + fields.put("NumericalNode.split", m); + m = NumericalNode.class.getDeclaredField("loChild"); + m.setAccessible(true); + fields.put("NumericalNode.loChild", m); + m = NumericalNode.class.getDeclaredField("hiChild"); + m.setAccessible(true); + fields.put("NumericalNode.hiChild", m); + m = Leaf.class.getDeclaredField("label"); + m.setAccessible(true); + fields.put("Leaf.label", m); + m = Dataset.class.getDeclaredField("values"); + m.setAccessible(true); + fields.put("Dataset.values", m); + } catch (NoSuchFieldException nsfe) { + throw new IllegalStateException(nsfe); + } + + return fields; + } + + /** + * Decision tree to String + * + * @param tree + * Node of tree + * @param attrNames + * attribute names + */ + public static String toString(Node tree, Dataset dataset, String[] attrNames) { + return toStringNode(tree, dataset, attrNames, getReflectMap(), 0); + } + + /** + * Print Decision tree + * + * @param tree + * Node of tree + * @param attrNames + * attribute names + */ + public static void print(Node tree, Dataset dataset, String[] attrNames) { + System.out.println(toString(tree, dataset, attrNames)); + } + + private static String toStringPredict(Node node, Instance instance, + Dataset dataset, String[] attrNames, Map<String,Field> fields) { + StringBuilder buff = new StringBuilder(); + + try { + if (node instanceof CategoricalNode) { + CategoricalNode cnode = (CategoricalNode) node; + int attr = (Integer) fields.get("CategoricalNode.attr").get(cnode); + double[] values = (double[]) fields.get("CategoricalNode.values").get( + cnode); + Node[] childs = (Node[]) fields.get("CategoricalNode.childs") + .get(cnode); + String[][] attrValues = (String[][]) fields.get("Dataset.values").get( + dataset); + + int index = ArrayUtils.indexOf(values, instance.get(attr)); + if (index >= 0) { + buff.append(attrNames == null ? attr : attrNames[attr]).append(" = ") + .append(attrValues[attr][(int) instance.get(attr)]); + buff.append(" -> "); + buff.append(toStringPredict(childs[index], instance, dataset, + attrNames, fields)); + } + } else if (node instanceof NumericalNode) { + NumericalNode nnode = (NumericalNode) node; + int attr = (Integer) fields.get("NumericalNode.attr").get(nnode); + double split = (Double) fields.get("NumericalNode.split").get(nnode); + Node loChild = (Node) fields.get("NumericalNode.loChild").get(nnode); + Node hiChild = (Node) fields.get("NumericalNode.hiChild").get(nnode); + + if (instance.get(attr) < split) { + buff.append('(').append(attrNames == null ? attr : attrNames[attr]) + .append(" = ").append(doubleToString(instance.get(attr))) + .append(") < ").append(doubleToString(split)); + buff.append(" -> "); + buff.append(toStringPredict(loChild, instance, dataset, attrNames, + fields)); + } else { + buff.append('(').append(attrNames == null ? attr : attrNames[attr]) + .append(" = ").append(doubleToString(instance.get(attr))) + .append(") >= ").append(doubleToString(split)); + buff.append(" -> "); + buff.append(toStringPredict(hiChild, instance, dataset, attrNames, + fields)); + } + } else if (node instanceof Leaf) { + Leaf leaf = (Leaf) node; + double label = (Double) fields.get("Leaf.label").get(leaf); + if (dataset.isNumerical(dataset.getLabelId())) { + buff.append(doubleToString(label)); + } else { + buff.append(dataset.getLabelString(label)); + } + } + } catch (IllegalAccessException iae) { + throw new IllegalStateException(iae); + } + + return buff.toString(); + } + + /** + * Predict trace to String + * + * @param tree + * Node of tree + * @param attrNames + * attribute names + */ + public static String[] predictTrace(Node tree, Data data, String[] attrNames) { + Map<String,Field> reflectMap = getReflectMap(); + String[] prediction = new String[data.size()]; + for (int i = 0; i < data.size(); i++) { + prediction[i] = toStringPredict(tree, data.get(i), data.getDataset(), + attrNames, reflectMap); + } + return prediction; + } + + /** + * Print predict trace + * + * @param tree + * Node of tree + * @param attrNames + * attribute names + */ + public static void predictTracePrint(Node tree, Data data, String[] attrNames) { + Map<String,Field> reflectMap = getReflectMap(); + for (int i = 0; i < data.size(); i++) { + System.out.println(toStringPredict(tree, data.get(i), data.getDataset(), + attrNames, reflectMap)); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java new file mode 100644 index 0000000..e1b55ab --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java @@ -0,0 +1,212 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.tools; + +import java.io.File; +import java.io.IOException; +import java.util.Locale; +import java.util.Random; +import java.util.Scanner; + +import com.google.common.base.Preconditions; +import com.google.common.io.Closeables; +import org.apache.commons.cli2.CommandLine; +import org.apache.commons.cli2.Group; +import org.apache.commons.cli2.Option; +import org.apache.commons.cli2.OptionException; +import org.apache.commons.cli2.builder.ArgumentBuilder; +import org.apache.commons.cli2.builder.DefaultOptionBuilder; +import org.apache.commons.cli2.builder.GroupBuilder; +import org.apache.commons.cli2.commandline.Parser; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FSDataInputStream; +import org.apache.hadoop.fs.FSDataOutputStream; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.FileUtil; +import org.apache.hadoop.fs.Path; +import org.apache.mahout.classifier.df.data.DataConverter; +import org.apache.mahout.classifier.df.data.Dataset; +import org.apache.mahout.classifier.df.data.Instance; +import org.apache.mahout.common.CommandLineUtil; +import org.apache.mahout.common.RandomUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This tool is used to uniformly distribute the class of all the tuples of the dataset over a given number of + * partitions.<br> + * This class can be used when the criterion variable is the categorical attribute. + */ +@Deprecated +public final class UDistrib { + + private static final Logger log = LoggerFactory.getLogger(UDistrib.class); + + private UDistrib() {} + + /** + * Launch the uniform distribution tool. Requires the following command line arguments:<br> + * + * data : data path dataset : dataset path numpartitions : num partitions output : output path + * + * @throws java.io.IOException + */ + public static void main(String[] args) throws IOException { + + DefaultOptionBuilder obuilder = new DefaultOptionBuilder(); + ArgumentBuilder abuilder = new ArgumentBuilder(); + GroupBuilder gbuilder = new GroupBuilder(); + + Option dataOpt = obuilder.withLongName("data").withShortName("d").withRequired(true).withArgument( + abuilder.withName("data").withMinimum(1).withMaximum(1).create()).withDescription("Data path").create(); + + Option datasetOpt = obuilder.withLongName("dataset").withShortName("ds").withRequired(true).withArgument( + abuilder.withName("dataset").withMinimum(1).create()).withDescription("Dataset path").create(); + + Option outputOpt = obuilder.withLongName("output").withShortName("o").withRequired(true).withArgument( + abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription( + "Path to generated files").create(); + + Option partitionsOpt = obuilder.withLongName("numpartitions").withShortName("p").withRequired(true) + .withArgument(abuilder.withName("numparts").withMinimum(1).withMinimum(1).create()).withDescription( + "Number of partitions to create").create(); + Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h") + .create(); + + Group group = gbuilder.withName("Options").withOption(dataOpt).withOption(outputOpt).withOption( + datasetOpt).withOption(partitionsOpt).withOption(helpOpt).create(); + + try { + Parser parser = new Parser(); + parser.setGroup(group); + CommandLine cmdLine = parser.parse(args); + + if (cmdLine.hasOption(helpOpt)) { + CommandLineUtil.printHelp(group); + return; + } + + String data = cmdLine.getValue(dataOpt).toString(); + String dataset = cmdLine.getValue(datasetOpt).toString(); + int numPartitions = Integer.parseInt(cmdLine.getValue(partitionsOpt).toString()); + String output = cmdLine.getValue(outputOpt).toString(); + + runTool(data, dataset, output, numPartitions); + } catch (OptionException e) { + log.warn(e.toString(), e); + CommandLineUtil.printHelp(group); + } + + } + + private static void runTool(String dataStr, String datasetStr, String output, int numPartitions) throws IOException { + + Preconditions.checkArgument(numPartitions > 0, "numPartitions <= 0"); + + // make sure the output file does not exist + Path outputPath = new Path(output); + Configuration conf = new Configuration(); + FileSystem fs = outputPath.getFileSystem(conf); + + Preconditions.checkArgument(!fs.exists(outputPath), "Output path already exists"); + + // create a new file corresponding to each partition + // Path workingDir = fs.getWorkingDirectory(); + // FileSystem wfs = workingDir.getFileSystem(conf); + // File parentFile = new File(workingDir.toString()); + // File tempFile = FileUtil.createLocalTempFile(parentFile, "Parts", true); + // File tempFile = File.createTempFile("df.tools.UDistrib",""); + // tempFile.deleteOnExit(); + File tempFile = FileUtil.createLocalTempFile(new File(""), "df.tools.UDistrib", true); + Path partsPath = new Path(tempFile.toString()); + FileSystem pfs = partsPath.getFileSystem(conf); + + Path[] partPaths = new Path[numPartitions]; + FSDataOutputStream[] files = new FSDataOutputStream[numPartitions]; + for (int p = 0; p < numPartitions; p++) { + partPaths[p] = new Path(partsPath, String.format(Locale.ENGLISH, "part.%03d", p)); + files[p] = pfs.create(partPaths[p]); + } + + Path datasetPath = new Path(datasetStr); + Dataset dataset = Dataset.load(conf, datasetPath); + + // currents[label] = next partition file where to place the tuple + int[] currents = new int[dataset.nblabels()]; + + // currents is initialized randomly in the range [0, numpartitions[ + Random random = RandomUtils.getRandom(); + for (int c = 0; c < currents.length; c++) { + currents[c] = random.nextInt(numPartitions); + } + + // foreach tuple of the data + Path dataPath = new Path(dataStr); + FileSystem ifs = dataPath.getFileSystem(conf); + FSDataInputStream input = ifs.open(dataPath); + Scanner scanner = new Scanner(input, "UTF-8"); + DataConverter converter = new DataConverter(dataset); + + int id = 0; + while (scanner.hasNextLine()) { + if (id % 1000 == 0) { + log.info("progress : {}", id); + } + + String line = scanner.nextLine(); + if (line.isEmpty()) { + continue; // skip empty lines + } + + // write the tuple in files[tuple.label] + Instance instance = converter.convert(line); + int label = (int) dataset.getLabel(instance); + files[currents[label]].writeBytes(line); + files[currents[label]].writeChar('\n'); + + // update currents + currents[label]++; + if (currents[label] == numPartitions) { + currents[label] = 0; + } + } + + // close all the files. + scanner.close(); + for (FSDataOutputStream file : files) { + Closeables.close(file, false); + } + + // merge all output files + FileUtil.copyMerge(pfs, partsPath, fs, outputPath, true, conf, null); + /* + * FSDataOutputStream joined = fs.create(new Path(outputPath, "uniform.data")); for (int p = 0; p < + * numPartitions; p++) {log.info("Joining part : {}", p); FSDataInputStream partStream = + * fs.open(partPaths[p]); + * + * IOUtils.copyBytes(partStream, joined, conf, false); + * + * partStream.close(); } + * + * joined.close(); + * + * fs.delete(partsPath, true); + */ + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java new file mode 100644 index 0000000..049f9bf --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.evaluation; + +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.list.DoubleArrayList; + +import com.google.common.base.Preconditions; + +import java.util.Random; + +/** + * Computes AUC and a few other accuracy statistics without storing huge amounts of data. This is + * done by keeping uniform samples of the positive and negative scores. Then, when AUC is to be + * computed, the remaining scores are sorted and a rank-sum statistic is used to compute the AUC. + * Since AUC is invariant with respect to down-sampling of either positives or negatives, this is + * close to correct and is exactly correct if maxBufferSize or fewer positive and negative scores + * are examined. + */ +public class Auc { + + private int maxBufferSize = 10000; + private final DoubleArrayList[] scores = {new DoubleArrayList(), new DoubleArrayList()}; + private final Random rand; + private int samples; + private final double threshold; + private final Matrix confusion; + private final DenseMatrix entropy; + + private boolean probabilityScore = true; + + private boolean hasScore; + + /** + * Allocates a new data-structure for accumulating information about AUC and a few other accuracy + * measures. + * @param threshold The threshold to use in computing the confusion matrix. + */ + public Auc(double threshold) { + confusion = new DenseMatrix(2, 2); + entropy = new DenseMatrix(2, 2); + this.rand = RandomUtils.getRandom(); + this.threshold = threshold; + } + + public Auc() { + this(0.5); + } + + /** + * Adds a score to the AUC buffers. + * + * @param trueValue Whether this score is for a true-positive or a true-negative example. + * @param score The score for this example. + */ + public void add(int trueValue, double score) { + Preconditions.checkArgument(trueValue == 0 || trueValue == 1, "True value must be 0 or 1"); + hasScore = true; + + int predictedClass = score > threshold ? 1 : 0; + confusion.set(trueValue, predictedClass, confusion.get(trueValue, predictedClass) + 1); + + samples++; + if (isProbabilityScore()) { + double limited = Math.max(1.0e-20, Math.min(score, 1 - 1.0e-20)); + double v0 = entropy.get(trueValue, 0); + entropy.set(trueValue, 0, (Math.log1p(-limited) - v0) / samples + v0); + + double v1 = entropy.get(trueValue, 1); + entropy.set(trueValue, 1, (Math.log(limited) - v1) / samples + v1); + } + + // add to buffers + DoubleArrayList buf = scores[trueValue]; + if (buf.size() >= maxBufferSize) { + // but if too many points are seen, we insert into a random + // place and discard the predecessor. The random place could + // be anywhere, possibly not even in the buffer. + // this is a special case of Knuth's permutation algorithm + // but since we don't ever shuffle the first maxBufferSize + // samples, the result isn't just a fair sample of the prefixes + // of all permutations. The CONTENTs of the result, however, + // will be a fair and uniform sample of maxBufferSize elements + // chosen from all elements without replacement + int index = rand.nextInt(samples); + if (index < buf.size()) { + buf.set(index, score); + } + } else { + // for small buffers, we collect all points without permuting + // since we sort the data later, permuting now would just be + // pedantic + buf.add(score); + } + } + + public void add(int trueValue, int predictedClass) { + hasScore = false; + Preconditions.checkArgument(trueValue == 0 || trueValue == 1, "True value must be 0 or 1"); + confusion.set(trueValue, predictedClass, confusion.get(trueValue, predictedClass) + 1); + } + + /** + * Computes the AUC of points seen so far. This can be moderately expensive since it requires + * that all points that have been retained be sorted. + * + * @return The value of the Area Under the receiver operating Curve. + */ + public double auc() { + Preconditions.checkArgument(hasScore, "Can't compute AUC for classifier without a score"); + scores[0].sort(); + scores[1].sort(); + + double n0 = scores[0].size(); + double n1 = scores[1].size(); + + if (n0 == 0 || n1 == 0) { + return 0.5; + } + + // scan the data + int i0 = 0; + int i1 = 0; + int rank = 1; + double rankSum = 0; + while (i0 < n0 && i1 < n1) { + + double v0 = scores[0].get(i0); + double v1 = scores[1].get(i1); + + if (v0 < v1) { + i0++; + rank++; + } else if (v1 < v0) { + i1++; + rankSum += rank; + rank++; + } else { + // ties have to be handled delicately + double tieScore = v0; + + // how many negatives are tied? + int k0 = 0; + while (i0 < n0 && scores[0].get(i0) == tieScore) { + k0++; + i0++; + } + + // and how many positives + int k1 = 0; + while (i1 < n1 && scores[1].get(i1) == tieScore) { + k1++; + i1++; + } + + // we found k0 + k1 tied values which have + // ranks in the half open interval [rank, rank + k0 + k1) + // the average rank is assigned to all + rankSum += (rank + (k0 + k1 - 1) / 2.0) * k1; + rank += k0 + k1; + } + } + + if (i1 < n1) { + rankSum += (rank + (n1 - i1 - 1) / 2.0) * (n1 - i1); + rank += (int) (n1 - i1); + } + + return (rankSum / n1 - (n1 + 1) / 2) / n0; + } + + /** + * Returns the confusion matrix for the classifier supposing that we were to use a particular + * threshold. + * @return The confusion matrix. + */ + public Matrix confusion() { + return confusion; + } + + /** + * Returns a matrix related to the confusion matrix and to the log-likelihood. For a + * pretty accurate classifier, N + entropy is nearly the same as the confusion matrix + * because log(1-eps) \approx -eps if eps is small. + * + * For lower accuracy classifiers, this measure will give us a better picture of how + * things work our. + * + * Also, by definition, log-likelihood = sum(diag(entropy)) + * @return Returns a cell by cell break-down of the log-likelihood + */ + public Matrix entropy() { + if (!hasScore) { + // find a constant score that would optimize log-likelihood, but use a dash of Bayesian + // conservatism to avoid dividing by zero or taking log(0) + double p = (0.5 + confusion.get(1, 1)) / (1 + confusion.get(0, 0) + confusion.get(1, 1)); + entropy.set(0, 0, confusion.get(0, 0) * Math.log1p(-p)); + entropy.set(0, 1, confusion.get(0, 1) * Math.log(p)); + entropy.set(1, 0, confusion.get(1, 0) * Math.log1p(-p)); + entropy.set(1, 1, confusion.get(1, 1) * Math.log(p)); + } + return entropy; + } + + public void setMaxBufferSize(int maxBufferSize) { + this.maxBufferSize = maxBufferSize; + } + + public boolean isProbabilityScore() { + return probabilityScore; + } + + public void setProbabilityScore(boolean probabilityScore) { + this.probabilityScore = probabilityScore; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java new file mode 100644 index 0000000..f0794b3 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java @@ -0,0 +1,82 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes; + +import org.apache.mahout.classifier.AbstractVectorClassifier; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.Vector.Element; + +/** + * Class implementing the Naive Bayes Classifier Algorithm. Note that this class + * supports {@link #classifyFull}, but not {@code classify} or + * {@code classifyScalar}. The reason that these two methods are not + * supported is because the scores computed by a NaiveBayesClassifier do not + * represent probabilities. + */ +public abstract class AbstractNaiveBayesClassifier extends AbstractVectorClassifier { + + private final NaiveBayesModel model; + + protected AbstractNaiveBayesClassifier(NaiveBayesModel model) { + this.model = model; + } + + protected NaiveBayesModel getModel() { + return model; + } + + protected abstract double getScoreForLabelFeature(int label, int feature); + + protected double getScoreForLabelInstance(int label, Vector instance) { + double result = 0.0; + for (Element e : instance.nonZeroes()) { + result += e.get() * getScoreForLabelFeature(label, e.index()); + } + return result; + } + + @Override + public int numCategories() { + return model.numLabels(); + } + + @Override + public Vector classifyFull(Vector instance) { + return classifyFull(model.createScoringVector(), instance); + } + + @Override + public Vector classifyFull(Vector r, Vector instance) { + for (int label = 0; label < model.numLabels(); label++) { + r.setQuick(label, getScoreForLabelInstance(label, instance)); + } + return r; + } + + /** Unsupported method. This implementation simply throws an {@link UnsupportedOperationException}. */ + @Override + public double classifyScalar(Vector instance) { + throw new UnsupportedOperationException("Not supported in Naive Bayes"); + } + + /** Unsupported method. This implementation simply throws an {@link UnsupportedOperationException}. */ + @Override + public Vector classify(Vector instance) { + throw new UnsupportedOperationException("probabilites not supported in Naive Bayes"); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java new file mode 100644 index 0000000..4db8b17 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java @@ -0,0 +1,161 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes; + +import com.google.common.base.Preconditions; +import java.io.IOException; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.regex.Pattern; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.io.Writable; +import org.apache.mahout.classifier.naivebayes.training.ThetaMapper; +import org.apache.mahout.classifier.naivebayes.training.TrainNaiveBayesJob; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.SparseMatrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.map.OpenObjectIntHashMap; + +public final class BayesUtils { + + private static final Pattern SLASH = Pattern.compile("/"); + + private BayesUtils() {} + + public static NaiveBayesModel readModelFromDir(Path base, Configuration conf) { + + float alphaI = conf.getFloat(ThetaMapper.ALPHA_I, 1.0f); + boolean isComplementary = conf.getBoolean(NaiveBayesModel.COMPLEMENTARY_MODEL, true); + + // read feature sums and label sums + Vector scoresPerLabel = null; + Vector scoresPerFeature = null; + for (Pair<Text,VectorWritable> record : new SequenceFileDirIterable<Text, VectorWritable>( + new Path(base, TrainNaiveBayesJob.WEIGHTS), PathType.LIST, PathFilters.partFilter(), conf)) { + String key = record.getFirst().toString(); + VectorWritable value = record.getSecond(); + if (key.equals(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE)) { + scoresPerFeature = value.get(); + } else if (key.equals(TrainNaiveBayesJob.WEIGHTS_PER_LABEL)) { + scoresPerLabel = value.get(); + } + } + + Preconditions.checkNotNull(scoresPerFeature); + Preconditions.checkNotNull(scoresPerLabel); + + Matrix scoresPerLabelAndFeature = new SparseMatrix(scoresPerLabel.size(), scoresPerFeature.size()); + for (Pair<IntWritable,VectorWritable> entry : new SequenceFileDirIterable<IntWritable,VectorWritable>( + new Path(base, TrainNaiveBayesJob.SUMMED_OBSERVATIONS), PathType.LIST, PathFilters.partFilter(), conf)) { + scoresPerLabelAndFeature.assignRow(entry.getFirst().get(), entry.getSecond().get()); + } + + // perLabelThetaNormalizer is only used by the complementary model, we do not instantiate it for the standard model + Vector perLabelThetaNormalizer = null; + if (isComplementary) { + perLabelThetaNormalizer=scoresPerLabel.like(); + for (Pair<Text,VectorWritable> entry : new SequenceFileDirIterable<Text,VectorWritable>( + new Path(base, TrainNaiveBayesJob.THETAS), PathType.LIST, PathFilters.partFilter(), conf)) { + if (entry.getFirst().toString().equals(TrainNaiveBayesJob.LABEL_THETA_NORMALIZER)) { + perLabelThetaNormalizer = entry.getSecond().get(); + } + } + Preconditions.checkNotNull(perLabelThetaNormalizer); + } + + return new NaiveBayesModel(scoresPerLabelAndFeature, scoresPerFeature, scoresPerLabel, perLabelThetaNormalizer, + alphaI, isComplementary); + } + + /** Write the list of labels into a map file */ + public static int writeLabelIndex(Configuration conf, Iterable<String> labels, Path indexPath) + throws IOException { + FileSystem fs = FileSystem.get(indexPath.toUri(), conf); + int i = 0; + try (SequenceFile.Writer writer = + SequenceFile.createWriter(fs.getConf(), SequenceFile.Writer.file(indexPath), + SequenceFile.Writer.keyClass(Text.class), SequenceFile.Writer.valueClass(IntWritable.class))) { + for (String label : labels) { + writer.append(new Text(label), new IntWritable(i++)); + } + } + return i; + } + + public static int writeLabelIndex(Configuration conf, Path indexPath, + Iterable<Pair<Text,IntWritable>> labels) throws IOException { + FileSystem fs = FileSystem.get(indexPath.toUri(), conf); + Collection<String> seen = new HashSet<>(); + int i = 0; + try (SequenceFile.Writer writer = + SequenceFile.createWriter(fs.getConf(), SequenceFile.Writer.file(indexPath), + SequenceFile.Writer.keyClass(Text.class), SequenceFile.Writer.valueClass(IntWritable.class))){ + for (Object label : labels) { + String theLabel = SLASH.split(((Pair<?, ?>) label).getFirst().toString())[1]; + if (!seen.contains(theLabel)) { + writer.append(new Text(theLabel), new IntWritable(i++)); + seen.add(theLabel); + } + } + } + return i; + } + + public static Map<Integer, String> readLabelIndex(Configuration conf, Path indexPath) { + Map<Integer, String> labelMap = new HashMap<>(); + for (Pair<Text, IntWritable> pair : new SequenceFileIterable<Text, IntWritable>(indexPath, true, conf)) { + labelMap.put(pair.getSecond().get(), pair.getFirst().toString()); + } + return labelMap; + } + + public static OpenObjectIntHashMap<String> readIndexFromCache(Configuration conf) throws IOException { + OpenObjectIntHashMap<String> index = new OpenObjectIntHashMap<>(); + for (Pair<Writable,IntWritable> entry + : new SequenceFileIterable<Writable,IntWritable>(HadoopUtil.getSingleCachedFile(conf), conf)) { + index.put(entry.getFirst().toString(), entry.getSecond().get()); + } + return index; + } + + public static Map<String,Vector> readScoresFromCache(Configuration conf) throws IOException { + Map<String,Vector> sumVectors = new HashMap<>(); + for (Pair<Text,VectorWritable> entry + : new SequenceFileDirIterable<Text,VectorWritable>(HadoopUtil.getSingleCachedFile(conf), + PathType.LIST, PathFilters.partFilter(), conf)) { + sumVectors.put(entry.getFirst().toString(), entry.getSecond().get()); + } + return sumVectors; + } + + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java new file mode 100644 index 0000000..18bd3d6 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java @@ -0,0 +1,43 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes; + + +/** Implementation of the Naive Bayes Classifier Algorithm */ +public class ComplementaryNaiveBayesClassifier extends AbstractNaiveBayesClassifier { + public ComplementaryNaiveBayesClassifier(NaiveBayesModel model) { + super(model); + } + + @Override + public double getScoreForLabelFeature(int label, int feature) { + NaiveBayesModel model = getModel(); + double weight = computeWeight(model.featureWeight(feature), model.weight(label, feature), + model.totalWeightSum(), model.labelWeight(label), model.alphaI(), model.numFeatures()); + // see http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - Section 3.2, Weight Magnitude Errors + return weight / model.thetaNormalizer(label); + } + + // see http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - Section 3.1, Skewed Data bias + public static double computeWeight(double featureWeight, double featureLabelWeight, + double totalWeight, double labelWeight, double alphaI, double numFeatures) { + double numerator = featureWeight - featureLabelWeight + alphaI; + double denominator = totalWeight - labelWeight + alphaI * numFeatures; + return -Math.log(numerator / denominator); + } +}
