http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/Frequencies.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/Frequencies.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/Frequencies.java new file mode 100644 index 0000000..c37af4e --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/Frequencies.java @@ -0,0 +1,122 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.tools; + +import org.apache.commons.cli2.CommandLine; +import org.apache.commons.cli2.Group; +import org.apache.commons.cli2.Option; +import org.apache.commons.cli2.OptionException; +import org.apache.commons.cli2.builder.ArgumentBuilder; +import org.apache.commons.cli2.builder.DefaultOptionBuilder; +import org.apache.commons.cli2.builder.GroupBuilder; +import org.apache.commons.cli2.commandline.Parser; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.conf.Configured; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.util.Tool; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.common.CommandLineUtil; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Arrays; + +/** + * Compute the frequency distribution of the "class label"<br> + * This class can be used when the criterion variable is the categorical attribute. + */ +@Deprecated +public final class Frequencies extends Configured implements Tool { + + private static final Logger log = LoggerFactory.getLogger(Frequencies.class); + + private Frequencies() { } + + @Override + public int run(String[] args) throws IOException, ClassNotFoundException, InterruptedException { + + DefaultOptionBuilder obuilder = new DefaultOptionBuilder(); + ArgumentBuilder abuilder = new ArgumentBuilder(); + GroupBuilder gbuilder = new GroupBuilder(); + + Option dataOpt = obuilder.withLongName("data").withShortName("d").withRequired(true).withArgument( + abuilder.withName("path").withMinimum(1).withMaximum(1).create()).withDescription("Data path").create(); + + Option datasetOpt = obuilder.withLongName("dataset").withShortName("ds").withRequired(true).withArgument( + abuilder.withName("path").withMinimum(1).create()).withDescription("dataset path").create(); + + Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h") + .create(); + + Group group = gbuilder.withName("Options").withOption(dataOpt).withOption(datasetOpt).withOption(helpOpt) + .create(); + + try { + Parser parser = new Parser(); + parser.setGroup(group); + CommandLine cmdLine = parser.parse(args); + + if (cmdLine.hasOption(helpOpt)) { + CommandLineUtil.printHelp(group); + return 0; + } + + String dataPath = cmdLine.getValue(dataOpt).toString(); + String datasetPath = cmdLine.getValue(datasetOpt).toString(); + + log.debug("Data path : {}", dataPath); + log.debug("Dataset path : {}", datasetPath); + + runTool(dataPath, datasetPath); + } catch (OptionException e) { + log.warn(e.toString(), e); + CommandLineUtil.printHelp(group); + } + + return 0; + } + + private void runTool(String data, String dataset) throws IOException, + ClassNotFoundException, + InterruptedException { + + FileSystem fs = FileSystem.get(getConf()); + Path workingDir = fs.getWorkingDirectory(); + + Path dataPath = new Path(data); + Path datasetPath = new Path(dataset); + + log.info("Computing the frequencies..."); + FrequenciesJob job = new FrequenciesJob(new Path(workingDir, "output"), dataPath, datasetPath); + + int[][] counts = job.run(getConf()); + + // outputing the frequencies + log.info("counts[partition][class]"); + for (int[] count : counts) { + log.info(Arrays.toString(count)); + } + } + + public static void main(String[] args) throws Exception { + ToolRunner.run(new Configuration(), new Frequencies(), args); + } + +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java new file mode 100644 index 0000000..9d7e2ff --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java @@ -0,0 +1,297 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.tools; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.filecache.DistributedCache; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.JobContext; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; +import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.mahout.classifier.df.DFUtils; +import org.apache.mahout.classifier.df.data.DataConverter; +import org.apache.mahout.classifier.df.data.Dataset; +import org.apache.mahout.classifier.df.data.Instance; +import org.apache.mahout.classifier.df.mapreduce.Builder; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.net.URI; +import java.util.Arrays; + +/** + * Temporary class used to compute the frequency distribution of the "class attribute".<br> + * This class can be used when the criterion variable is the categorical attribute. + */ +@Deprecated +public class FrequenciesJob { + + private static final Logger log = LoggerFactory.getLogger(FrequenciesJob.class); + + /** directory that will hold this job's output */ + private final Path outputPath; + + /** file that contains the serialized dataset */ + private final Path datasetPath; + + /** directory that contains the data used in the first step */ + private final Path dataPath; + + /** + * @param base + * base directory + * @param dataPath + * data used in the first step + */ + public FrequenciesJob(Path base, Path dataPath, Path datasetPath) { + this.outputPath = new Path(base, "frequencies.output"); + this.dataPath = dataPath; + this.datasetPath = datasetPath; + } + + /** + * @return counts[partition][label] = num tuples from 'partition' with class == label + */ + public int[][] run(Configuration conf) throws IOException, ClassNotFoundException, InterruptedException { + + // check the output + FileSystem fs = outputPath.getFileSystem(conf); + if (fs.exists(outputPath)) { + throw new IOException("Output path already exists : " + outputPath); + } + + // put the dataset into the DistributedCache + URI[] files = {datasetPath.toUri()}; + DistributedCache.setCacheFiles(files, conf); + + Job job = new Job(conf); + job.setJarByClass(FrequenciesJob.class); + + FileInputFormat.setInputPaths(job, dataPath); + FileOutputFormat.setOutputPath(job, outputPath); + + job.setMapOutputKeyClass(LongWritable.class); + job.setMapOutputValueClass(IntWritable.class); + job.setOutputKeyClass(LongWritable.class); + job.setOutputValueClass(Frequencies.class); + + job.setMapperClass(FrequenciesMapper.class); + job.setReducerClass(FrequenciesReducer.class); + + job.setInputFormatClass(TextInputFormat.class); + job.setOutputFormatClass(SequenceFileOutputFormat.class); + + // run the job + boolean succeeded = job.waitForCompletion(true); + if (!succeeded) { + throw new IllegalStateException("Job failed!"); + } + + int[][] counts = parseOutput(job); + + HadoopUtil.delete(conf, outputPath); + + return counts; + } + + /** + * Extracts the output and processes it + * + * @return counts[partition][label] = num tuples from 'partition' with class == label + */ + int[][] parseOutput(JobContext job) throws IOException { + Configuration conf = job.getConfiguration(); + + int numMaps = conf.getInt("mapred.map.tasks", -1); + log.info("mapred.map.tasks = {}", numMaps); + + FileSystem fs = outputPath.getFileSystem(conf); + + Path[] outfiles = DFUtils.listOutputFiles(fs, outputPath); + + Frequencies[] values = new Frequencies[numMaps]; + + // read all the outputs + int index = 0; + for (Path path : outfiles) { + for (Frequencies value : new SequenceFileValueIterable<Frequencies>(path, conf)) { + values[index++] = value; + } + } + + if (index < numMaps) { + throw new IllegalStateException("number of output Frequencies (" + index + + ") is lesser than the number of mappers!"); + } + + // sort the frequencies using the firstIds + Arrays.sort(values); + return Frequencies.extractCounts(values); + } + + /** + * Outputs the first key and the label of each tuple + * + */ + private static class FrequenciesMapper extends Mapper<LongWritable,Text,LongWritable,IntWritable> { + + private LongWritable firstId; + + private DataConverter converter; + private Dataset dataset; + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + Configuration conf = context.getConfiguration(); + + dataset = Builder.loadDataset(conf); + setup(dataset); + } + + /** + * Useful when testing + */ + void setup(Dataset dataset) { + converter = new DataConverter(dataset); + } + + @Override + protected void map(LongWritable key, Text value, Context context) throws IOException, + InterruptedException { + if (firstId == null) { + firstId = new LongWritable(key.get()); + } + + Instance instance = converter.convert(value.toString()); + + context.write(firstId, new IntWritable((int) dataset.getLabel(instance))); + } + + } + + private static class FrequenciesReducer extends Reducer<LongWritable,IntWritable,LongWritable,Frequencies> { + + private int nblabels; + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + Configuration conf = context.getConfiguration(); + Dataset dataset = Builder.loadDataset(conf); + setup(dataset.nblabels()); + } + + /** + * Useful when testing + */ + void setup(int nblabels) { + this.nblabels = nblabels; + } + + @Override + protected void reduce(LongWritable key, Iterable<IntWritable> values, Context context) + throws IOException, InterruptedException { + int[] counts = new int[nblabels]; + for (IntWritable value : values) { + counts[value.get()]++; + } + + context.write(key, new Frequencies(key.get(), counts)); + } + } + + /** + * Output of the job + * + */ + private static class Frequencies implements Writable, Comparable<Frequencies>, Cloneable { + + /** first key of the partition used to sort the partitions */ + private long firstId; + + /** counts[c] = num tuples from the partition with label == c */ + private int[] counts; + + Frequencies() { } + + Frequencies(long firstId, int[] counts) { + this.firstId = firstId; + this.counts = Arrays.copyOf(counts, counts.length); + } + + @Override + public void readFields(DataInput in) throws IOException { + firstId = in.readLong(); + counts = DFUtils.readIntArray(in); + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeLong(firstId); + DFUtils.writeArray(out, counts); + } + + @Override + public boolean equals(Object other) { + return other instanceof Frequencies && firstId == ((Frequencies) other).firstId; + } + + @Override + public int hashCode() { + return (int) firstId; + } + + @Override + protected Frequencies clone() { + return new Frequencies(firstId, counts); + } + + @Override + public int compareTo(Frequencies obj) { + if (firstId < obj.firstId) { + return -1; + } else if (firstId > obj.firstId) { + return 1; + } else { + return 0; + } + } + + public static int[][] extractCounts(Frequencies[] partitions) { + int[][] counts = new int[partitions.length][]; + for (int p = 0; p < partitions.length; p++) { + counts[p] = partitions[p].counts; + } + return counts; + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java new file mode 100644 index 0000000..a2a3458 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java @@ -0,0 +1,264 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.tools; + +import java.lang.reflect.Field; +import java.text.DecimalFormat; +import java.util.HashMap; +import java.util.Map; + +import org.apache.commons.lang3.ArrayUtils; +import org.apache.mahout.classifier.df.data.Data; +import org.apache.mahout.classifier.df.data.Dataset; +import org.apache.mahout.classifier.df.data.Instance; +import org.apache.mahout.classifier.df.node.CategoricalNode; +import org.apache.mahout.classifier.df.node.Leaf; +import org.apache.mahout.classifier.df.node.Node; +import org.apache.mahout.classifier.df.node.NumericalNode; + +/** + * This tool is to visualize the Decision tree + */ +@Deprecated +public final class TreeVisualizer { + + private TreeVisualizer() {} + + private static String doubleToString(double value) { + DecimalFormat df = new DecimalFormat("0.##"); + return df.format(value); + } + + private static String toStringNode(Node node, Dataset dataset, + String[] attrNames, Map<String,Field> fields, int layer) { + + StringBuilder buff = new StringBuilder(); + + try { + if (node instanceof CategoricalNode) { + CategoricalNode cnode = (CategoricalNode) node; + int attr = (Integer) fields.get("CategoricalNode.attr").get(cnode); + double[] values = (double[]) fields.get("CategoricalNode.values").get(cnode); + Node[] childs = (Node[]) fields.get("CategoricalNode.childs").get(cnode); + String[][] attrValues = (String[][]) fields.get("Dataset.values").get(dataset); + for (int i = 0; i < attrValues[attr].length; i++) { + int index = ArrayUtils.indexOf(values, i); + if (index < 0) { + continue; + } + buff.append('\n'); + for (int j = 0; j < layer; j++) { + buff.append("| "); + } + buff.append(attrNames == null ? attr : attrNames[attr]).append(" = ") + .append(attrValues[attr][i]); + buff.append(toStringNode(childs[index], dataset, attrNames, fields, layer + 1)); + } + } else if (node instanceof NumericalNode) { + NumericalNode nnode = (NumericalNode) node; + int attr = (Integer) fields.get("NumericalNode.attr").get(nnode); + double split = (Double) fields.get("NumericalNode.split").get(nnode); + Node loChild = (Node) fields.get("NumericalNode.loChild").get(nnode); + Node hiChild = (Node) fields.get("NumericalNode.hiChild").get(nnode); + buff.append('\n'); + for (int j = 0; j < layer; j++) { + buff.append("| "); + } + buff.append(attrNames == null ? attr : attrNames[attr]).append(" < ") + .append(doubleToString(split)); + buff.append(toStringNode(loChild, dataset, attrNames, fields, layer + 1)); + buff.append('\n'); + for (int j = 0; j < layer; j++) { + buff.append("| "); + } + buff.append(attrNames == null ? attr : attrNames[attr]).append(" >= ") + .append(doubleToString(split)); + buff.append(toStringNode(hiChild, dataset, attrNames, fields, layer + 1)); + } else if (node instanceof Leaf) { + Leaf leaf = (Leaf) node; + double label = (Double) fields.get("Leaf.label").get(leaf); + if (dataset.isNumerical(dataset.getLabelId())) { + buff.append(" : ").append(doubleToString(label)); + } else { + buff.append(" : ").append(dataset.getLabelString(label)); + } + } + } catch (IllegalAccessException iae) { + throw new IllegalStateException(iae); + } + + return buff.toString(); + } + + private static Map<String,Field> getReflectMap() { + Map<String,Field> fields = new HashMap<>(); + + try { + Field m = CategoricalNode.class.getDeclaredField("attr"); + m.setAccessible(true); + fields.put("CategoricalNode.attr", m); + m = CategoricalNode.class.getDeclaredField("values"); + m.setAccessible(true); + fields.put("CategoricalNode.values", m); + m = CategoricalNode.class.getDeclaredField("childs"); + m.setAccessible(true); + fields.put("CategoricalNode.childs", m); + m = NumericalNode.class.getDeclaredField("attr"); + m.setAccessible(true); + fields.put("NumericalNode.attr", m); + m = NumericalNode.class.getDeclaredField("split"); + m.setAccessible(true); + fields.put("NumericalNode.split", m); + m = NumericalNode.class.getDeclaredField("loChild"); + m.setAccessible(true); + fields.put("NumericalNode.loChild", m); + m = NumericalNode.class.getDeclaredField("hiChild"); + m.setAccessible(true); + fields.put("NumericalNode.hiChild", m); + m = Leaf.class.getDeclaredField("label"); + m.setAccessible(true); + fields.put("Leaf.label", m); + m = Dataset.class.getDeclaredField("values"); + m.setAccessible(true); + fields.put("Dataset.values", m); + } catch (NoSuchFieldException nsfe) { + throw new IllegalStateException(nsfe); + } + + return fields; + } + + /** + * Decision tree to String + * + * @param tree + * Node of tree + * @param attrNames + * attribute names + */ + public static String toString(Node tree, Dataset dataset, String[] attrNames) { + return toStringNode(tree, dataset, attrNames, getReflectMap(), 0); + } + + /** + * Print Decision tree + * + * @param tree + * Node of tree + * @param attrNames + * attribute names + */ + public static void print(Node tree, Dataset dataset, String[] attrNames) { + System.out.println(toString(tree, dataset, attrNames)); + } + + private static String toStringPredict(Node node, Instance instance, + Dataset dataset, String[] attrNames, Map<String,Field> fields) { + StringBuilder buff = new StringBuilder(); + + try { + if (node instanceof CategoricalNode) { + CategoricalNode cnode = (CategoricalNode) node; + int attr = (Integer) fields.get("CategoricalNode.attr").get(cnode); + double[] values = (double[]) fields.get("CategoricalNode.values").get( + cnode); + Node[] childs = (Node[]) fields.get("CategoricalNode.childs") + .get(cnode); + String[][] attrValues = (String[][]) fields.get("Dataset.values").get( + dataset); + + int index = ArrayUtils.indexOf(values, instance.get(attr)); + if (index >= 0) { + buff.append(attrNames == null ? attr : attrNames[attr]).append(" = ") + .append(attrValues[attr][(int) instance.get(attr)]); + buff.append(" -> "); + buff.append(toStringPredict(childs[index], instance, dataset, + attrNames, fields)); + } + } else if (node instanceof NumericalNode) { + NumericalNode nnode = (NumericalNode) node; + int attr = (Integer) fields.get("NumericalNode.attr").get(nnode); + double split = (Double) fields.get("NumericalNode.split").get(nnode); + Node loChild = (Node) fields.get("NumericalNode.loChild").get(nnode); + Node hiChild = (Node) fields.get("NumericalNode.hiChild").get(nnode); + + if (instance.get(attr) < split) { + buff.append('(').append(attrNames == null ? attr : attrNames[attr]) + .append(" = ").append(doubleToString(instance.get(attr))) + .append(") < ").append(doubleToString(split)); + buff.append(" -> "); + buff.append(toStringPredict(loChild, instance, dataset, attrNames, + fields)); + } else { + buff.append('(').append(attrNames == null ? attr : attrNames[attr]) + .append(" = ").append(doubleToString(instance.get(attr))) + .append(") >= ").append(doubleToString(split)); + buff.append(" -> "); + buff.append(toStringPredict(hiChild, instance, dataset, attrNames, + fields)); + } + } else if (node instanceof Leaf) { + Leaf leaf = (Leaf) node; + double label = (Double) fields.get("Leaf.label").get(leaf); + if (dataset.isNumerical(dataset.getLabelId())) { + buff.append(doubleToString(label)); + } else { + buff.append(dataset.getLabelString(label)); + } + } + } catch (IllegalAccessException iae) { + throw new IllegalStateException(iae); + } + + return buff.toString(); + } + + /** + * Predict trace to String + * + * @param tree + * Node of tree + * @param attrNames + * attribute names + */ + public static String[] predictTrace(Node tree, Data data, String[] attrNames) { + Map<String,Field> reflectMap = getReflectMap(); + String[] prediction = new String[data.size()]; + for (int i = 0; i < data.size(); i++) { + prediction[i] = toStringPredict(tree, data.get(i), data.getDataset(), + attrNames, reflectMap); + } + return prediction; + } + + /** + * Print predict trace + * + * @param tree + * Node of tree + * @param attrNames + * attribute names + */ + public static void predictTracePrint(Node tree, Data data, String[] attrNames) { + Map<String,Field> reflectMap = getReflectMap(); + for (int i = 0; i < data.size(); i++) { + System.out.println(toStringPredict(tree, data.get(i), data.getDataset(), + attrNames, reflectMap)); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java new file mode 100644 index 0000000..e1b55ab --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java @@ -0,0 +1,212 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.tools; + +import java.io.File; +import java.io.IOException; +import java.util.Locale; +import java.util.Random; +import java.util.Scanner; + +import com.google.common.base.Preconditions; +import com.google.common.io.Closeables; +import org.apache.commons.cli2.CommandLine; +import org.apache.commons.cli2.Group; +import org.apache.commons.cli2.Option; +import org.apache.commons.cli2.OptionException; +import org.apache.commons.cli2.builder.ArgumentBuilder; +import org.apache.commons.cli2.builder.DefaultOptionBuilder; +import org.apache.commons.cli2.builder.GroupBuilder; +import org.apache.commons.cli2.commandline.Parser; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FSDataInputStream; +import org.apache.hadoop.fs.FSDataOutputStream; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.FileUtil; +import org.apache.hadoop.fs.Path; +import org.apache.mahout.classifier.df.data.DataConverter; +import org.apache.mahout.classifier.df.data.Dataset; +import org.apache.mahout.classifier.df.data.Instance; +import org.apache.mahout.common.CommandLineUtil; +import org.apache.mahout.common.RandomUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This tool is used to uniformly distribute the class of all the tuples of the dataset over a given number of + * partitions.<br> + * This class can be used when the criterion variable is the categorical attribute. + */ +@Deprecated +public final class UDistrib { + + private static final Logger log = LoggerFactory.getLogger(UDistrib.class); + + private UDistrib() {} + + /** + * Launch the uniform distribution tool. Requires the following command line arguments:<br> + * + * data : data path dataset : dataset path numpartitions : num partitions output : output path + * + * @throws java.io.IOException + */ + public static void main(String[] args) throws IOException { + + DefaultOptionBuilder obuilder = new DefaultOptionBuilder(); + ArgumentBuilder abuilder = new ArgumentBuilder(); + GroupBuilder gbuilder = new GroupBuilder(); + + Option dataOpt = obuilder.withLongName("data").withShortName("d").withRequired(true).withArgument( + abuilder.withName("data").withMinimum(1).withMaximum(1).create()).withDescription("Data path").create(); + + Option datasetOpt = obuilder.withLongName("dataset").withShortName("ds").withRequired(true).withArgument( + abuilder.withName("dataset").withMinimum(1).create()).withDescription("Dataset path").create(); + + Option outputOpt = obuilder.withLongName("output").withShortName("o").withRequired(true).withArgument( + abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription( + "Path to generated files").create(); + + Option partitionsOpt = obuilder.withLongName("numpartitions").withShortName("p").withRequired(true) + .withArgument(abuilder.withName("numparts").withMinimum(1).withMinimum(1).create()).withDescription( + "Number of partitions to create").create(); + Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h") + .create(); + + Group group = gbuilder.withName("Options").withOption(dataOpt).withOption(outputOpt).withOption( + datasetOpt).withOption(partitionsOpt).withOption(helpOpt).create(); + + try { + Parser parser = new Parser(); + parser.setGroup(group); + CommandLine cmdLine = parser.parse(args); + + if (cmdLine.hasOption(helpOpt)) { + CommandLineUtil.printHelp(group); + return; + } + + String data = cmdLine.getValue(dataOpt).toString(); + String dataset = cmdLine.getValue(datasetOpt).toString(); + int numPartitions = Integer.parseInt(cmdLine.getValue(partitionsOpt).toString()); + String output = cmdLine.getValue(outputOpt).toString(); + + runTool(data, dataset, output, numPartitions); + } catch (OptionException e) { + log.warn(e.toString(), e); + CommandLineUtil.printHelp(group); + } + + } + + private static void runTool(String dataStr, String datasetStr, String output, int numPartitions) throws IOException { + + Preconditions.checkArgument(numPartitions > 0, "numPartitions <= 0"); + + // make sure the output file does not exist + Path outputPath = new Path(output); + Configuration conf = new Configuration(); + FileSystem fs = outputPath.getFileSystem(conf); + + Preconditions.checkArgument(!fs.exists(outputPath), "Output path already exists"); + + // create a new file corresponding to each partition + // Path workingDir = fs.getWorkingDirectory(); + // FileSystem wfs = workingDir.getFileSystem(conf); + // File parentFile = new File(workingDir.toString()); + // File tempFile = FileUtil.createLocalTempFile(parentFile, "Parts", true); + // File tempFile = File.createTempFile("df.tools.UDistrib",""); + // tempFile.deleteOnExit(); + File tempFile = FileUtil.createLocalTempFile(new File(""), "df.tools.UDistrib", true); + Path partsPath = new Path(tempFile.toString()); + FileSystem pfs = partsPath.getFileSystem(conf); + + Path[] partPaths = new Path[numPartitions]; + FSDataOutputStream[] files = new FSDataOutputStream[numPartitions]; + for (int p = 0; p < numPartitions; p++) { + partPaths[p] = new Path(partsPath, String.format(Locale.ENGLISH, "part.%03d", p)); + files[p] = pfs.create(partPaths[p]); + } + + Path datasetPath = new Path(datasetStr); + Dataset dataset = Dataset.load(conf, datasetPath); + + // currents[label] = next partition file where to place the tuple + int[] currents = new int[dataset.nblabels()]; + + // currents is initialized randomly in the range [0, numpartitions[ + Random random = RandomUtils.getRandom(); + for (int c = 0; c < currents.length; c++) { + currents[c] = random.nextInt(numPartitions); + } + + // foreach tuple of the data + Path dataPath = new Path(dataStr); + FileSystem ifs = dataPath.getFileSystem(conf); + FSDataInputStream input = ifs.open(dataPath); + Scanner scanner = new Scanner(input, "UTF-8"); + DataConverter converter = new DataConverter(dataset); + + int id = 0; + while (scanner.hasNextLine()) { + if (id % 1000 == 0) { + log.info("progress : {}", id); + } + + String line = scanner.nextLine(); + if (line.isEmpty()) { + continue; // skip empty lines + } + + // write the tuple in files[tuple.label] + Instance instance = converter.convert(line); + int label = (int) dataset.getLabel(instance); + files[currents[label]].writeBytes(line); + files[currents[label]].writeChar('\n'); + + // update currents + currents[label]++; + if (currents[label] == numPartitions) { + currents[label] = 0; + } + } + + // close all the files. + scanner.close(); + for (FSDataOutputStream file : files) { + Closeables.close(file, false); + } + + // merge all output files + FileUtil.copyMerge(pfs, partsPath, fs, outputPath, true, conf, null); + /* + * FSDataOutputStream joined = fs.create(new Path(outputPath, "uniform.data")); for (int p = 0; p < + * numPartitions; p++) {log.info("Joining part : {}", p); FSDataInputStream partStream = + * fs.open(partPaths[p]); + * + * IOUtils.copyBytes(partStream, joined, conf, false); + * + * partStream.close(); } + * + * joined.close(); + * + * fs.delete(partsPath, true); + */ + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java new file mode 100644 index 0000000..049f9bf --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.evaluation; + +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.list.DoubleArrayList; + +import com.google.common.base.Preconditions; + +import java.util.Random; + +/** + * Computes AUC and a few other accuracy statistics without storing huge amounts of data. This is + * done by keeping uniform samples of the positive and negative scores. Then, when AUC is to be + * computed, the remaining scores are sorted and a rank-sum statistic is used to compute the AUC. + * Since AUC is invariant with respect to down-sampling of either positives or negatives, this is + * close to correct and is exactly correct if maxBufferSize or fewer positive and negative scores + * are examined. + */ +public class Auc { + + private int maxBufferSize = 10000; + private final DoubleArrayList[] scores = {new DoubleArrayList(), new DoubleArrayList()}; + private final Random rand; + private int samples; + private final double threshold; + private final Matrix confusion; + private final DenseMatrix entropy; + + private boolean probabilityScore = true; + + private boolean hasScore; + + /** + * Allocates a new data-structure for accumulating information about AUC and a few other accuracy + * measures. + * @param threshold The threshold to use in computing the confusion matrix. + */ + public Auc(double threshold) { + confusion = new DenseMatrix(2, 2); + entropy = new DenseMatrix(2, 2); + this.rand = RandomUtils.getRandom(); + this.threshold = threshold; + } + + public Auc() { + this(0.5); + } + + /** + * Adds a score to the AUC buffers. + * + * @param trueValue Whether this score is for a true-positive or a true-negative example. + * @param score The score for this example. + */ + public void add(int trueValue, double score) { + Preconditions.checkArgument(trueValue == 0 || trueValue == 1, "True value must be 0 or 1"); + hasScore = true; + + int predictedClass = score > threshold ? 1 : 0; + confusion.set(trueValue, predictedClass, confusion.get(trueValue, predictedClass) + 1); + + samples++; + if (isProbabilityScore()) { + double limited = Math.max(1.0e-20, Math.min(score, 1 - 1.0e-20)); + double v0 = entropy.get(trueValue, 0); + entropy.set(trueValue, 0, (Math.log1p(-limited) - v0) / samples + v0); + + double v1 = entropy.get(trueValue, 1); + entropy.set(trueValue, 1, (Math.log(limited) - v1) / samples + v1); + } + + // add to buffers + DoubleArrayList buf = scores[trueValue]; + if (buf.size() >= maxBufferSize) { + // but if too many points are seen, we insert into a random + // place and discard the predecessor. The random place could + // be anywhere, possibly not even in the buffer. + // this is a special case of Knuth's permutation algorithm + // but since we don't ever shuffle the first maxBufferSize + // samples, the result isn't just a fair sample of the prefixes + // of all permutations. The CONTENTs of the result, however, + // will be a fair and uniform sample of maxBufferSize elements + // chosen from all elements without replacement + int index = rand.nextInt(samples); + if (index < buf.size()) { + buf.set(index, score); + } + } else { + // for small buffers, we collect all points without permuting + // since we sort the data later, permuting now would just be + // pedantic + buf.add(score); + } + } + + public void add(int trueValue, int predictedClass) { + hasScore = false; + Preconditions.checkArgument(trueValue == 0 || trueValue == 1, "True value must be 0 or 1"); + confusion.set(trueValue, predictedClass, confusion.get(trueValue, predictedClass) + 1); + } + + /** + * Computes the AUC of points seen so far. This can be moderately expensive since it requires + * that all points that have been retained be sorted. + * + * @return The value of the Area Under the receiver operating Curve. + */ + public double auc() { + Preconditions.checkArgument(hasScore, "Can't compute AUC for classifier without a score"); + scores[0].sort(); + scores[1].sort(); + + double n0 = scores[0].size(); + double n1 = scores[1].size(); + + if (n0 == 0 || n1 == 0) { + return 0.5; + } + + // scan the data + int i0 = 0; + int i1 = 0; + int rank = 1; + double rankSum = 0; + while (i0 < n0 && i1 < n1) { + + double v0 = scores[0].get(i0); + double v1 = scores[1].get(i1); + + if (v0 < v1) { + i0++; + rank++; + } else if (v1 < v0) { + i1++; + rankSum += rank; + rank++; + } else { + // ties have to be handled delicately + double tieScore = v0; + + // how many negatives are tied? + int k0 = 0; + while (i0 < n0 && scores[0].get(i0) == tieScore) { + k0++; + i0++; + } + + // and how many positives + int k1 = 0; + while (i1 < n1 && scores[1].get(i1) == tieScore) { + k1++; + i1++; + } + + // we found k0 + k1 tied values which have + // ranks in the half open interval [rank, rank + k0 + k1) + // the average rank is assigned to all + rankSum += (rank + (k0 + k1 - 1) / 2.0) * k1; + rank += k0 + k1; + } + } + + if (i1 < n1) { + rankSum += (rank + (n1 - i1 - 1) / 2.0) * (n1 - i1); + rank += (int) (n1 - i1); + } + + return (rankSum / n1 - (n1 + 1) / 2) / n0; + } + + /** + * Returns the confusion matrix for the classifier supposing that we were to use a particular + * threshold. + * @return The confusion matrix. + */ + public Matrix confusion() { + return confusion; + } + + /** + * Returns a matrix related to the confusion matrix and to the log-likelihood. For a + * pretty accurate classifier, N + entropy is nearly the same as the confusion matrix + * because log(1-eps) \approx -eps if eps is small. + * + * For lower accuracy classifiers, this measure will give us a better picture of how + * things work our. + * + * Also, by definition, log-likelihood = sum(diag(entropy)) + * @return Returns a cell by cell break-down of the log-likelihood + */ + public Matrix entropy() { + if (!hasScore) { + // find a constant score that would optimize log-likelihood, but use a dash of Bayesian + // conservatism to avoid dividing by zero or taking log(0) + double p = (0.5 + confusion.get(1, 1)) / (1 + confusion.get(0, 0) + confusion.get(1, 1)); + entropy.set(0, 0, confusion.get(0, 0) * Math.log1p(-p)); + entropy.set(0, 1, confusion.get(0, 1) * Math.log(p)); + entropy.set(1, 0, confusion.get(1, 0) * Math.log1p(-p)); + entropy.set(1, 1, confusion.get(1, 1) * Math.log(p)); + } + return entropy; + } + + public void setMaxBufferSize(int maxBufferSize) { + this.maxBufferSize = maxBufferSize; + } + + public boolean isProbabilityScore() { + return probabilityScore; + } + + public void setProbabilityScore(boolean probabilityScore) { + this.probabilityScore = probabilityScore; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java new file mode 100644 index 0000000..f0794b3 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java @@ -0,0 +1,82 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes; + +import org.apache.mahout.classifier.AbstractVectorClassifier; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.Vector.Element; + +/** + * Class implementing the Naive Bayes Classifier Algorithm. Note that this class + * supports {@link #classifyFull}, but not {@code classify} or + * {@code classifyScalar}. The reason that these two methods are not + * supported is because the scores computed by a NaiveBayesClassifier do not + * represent probabilities. + */ +public abstract class AbstractNaiveBayesClassifier extends AbstractVectorClassifier { + + private final NaiveBayesModel model; + + protected AbstractNaiveBayesClassifier(NaiveBayesModel model) { + this.model = model; + } + + protected NaiveBayesModel getModel() { + return model; + } + + protected abstract double getScoreForLabelFeature(int label, int feature); + + protected double getScoreForLabelInstance(int label, Vector instance) { + double result = 0.0; + for (Element e : instance.nonZeroes()) { + result += e.get() * getScoreForLabelFeature(label, e.index()); + } + return result; + } + + @Override + public int numCategories() { + return model.numLabels(); + } + + @Override + public Vector classifyFull(Vector instance) { + return classifyFull(model.createScoringVector(), instance); + } + + @Override + public Vector classifyFull(Vector r, Vector instance) { + for (int label = 0; label < model.numLabels(); label++) { + r.setQuick(label, getScoreForLabelInstance(label, instance)); + } + return r; + } + + /** Unsupported method. This implementation simply throws an {@link UnsupportedOperationException}. */ + @Override + public double classifyScalar(Vector instance) { + throw new UnsupportedOperationException("Not supported in Naive Bayes"); + } + + /** Unsupported method. This implementation simply throws an {@link UnsupportedOperationException}. */ + @Override + public Vector classify(Vector instance) { + throw new UnsupportedOperationException("probabilites not supported in Naive Bayes"); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java new file mode 100644 index 0000000..4db8b17 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java @@ -0,0 +1,161 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes; + +import com.google.common.base.Preconditions; +import java.io.IOException; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.regex.Pattern; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.io.Writable; +import org.apache.mahout.classifier.naivebayes.training.ThetaMapper; +import org.apache.mahout.classifier.naivebayes.training.TrainNaiveBayesJob; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.SparseMatrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.map.OpenObjectIntHashMap; + +public final class BayesUtils { + + private static final Pattern SLASH = Pattern.compile("/"); + + private BayesUtils() {} + + public static NaiveBayesModel readModelFromDir(Path base, Configuration conf) { + + float alphaI = conf.getFloat(ThetaMapper.ALPHA_I, 1.0f); + boolean isComplementary = conf.getBoolean(NaiveBayesModel.COMPLEMENTARY_MODEL, true); + + // read feature sums and label sums + Vector scoresPerLabel = null; + Vector scoresPerFeature = null; + for (Pair<Text,VectorWritable> record : new SequenceFileDirIterable<Text, VectorWritable>( + new Path(base, TrainNaiveBayesJob.WEIGHTS), PathType.LIST, PathFilters.partFilter(), conf)) { + String key = record.getFirst().toString(); + VectorWritable value = record.getSecond(); + if (key.equals(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE)) { + scoresPerFeature = value.get(); + } else if (key.equals(TrainNaiveBayesJob.WEIGHTS_PER_LABEL)) { + scoresPerLabel = value.get(); + } + } + + Preconditions.checkNotNull(scoresPerFeature); + Preconditions.checkNotNull(scoresPerLabel); + + Matrix scoresPerLabelAndFeature = new SparseMatrix(scoresPerLabel.size(), scoresPerFeature.size()); + for (Pair<IntWritable,VectorWritable> entry : new SequenceFileDirIterable<IntWritable,VectorWritable>( + new Path(base, TrainNaiveBayesJob.SUMMED_OBSERVATIONS), PathType.LIST, PathFilters.partFilter(), conf)) { + scoresPerLabelAndFeature.assignRow(entry.getFirst().get(), entry.getSecond().get()); + } + + // perLabelThetaNormalizer is only used by the complementary model, we do not instantiate it for the standard model + Vector perLabelThetaNormalizer = null; + if (isComplementary) { + perLabelThetaNormalizer=scoresPerLabel.like(); + for (Pair<Text,VectorWritable> entry : new SequenceFileDirIterable<Text,VectorWritable>( + new Path(base, TrainNaiveBayesJob.THETAS), PathType.LIST, PathFilters.partFilter(), conf)) { + if (entry.getFirst().toString().equals(TrainNaiveBayesJob.LABEL_THETA_NORMALIZER)) { + perLabelThetaNormalizer = entry.getSecond().get(); + } + } + Preconditions.checkNotNull(perLabelThetaNormalizer); + } + + return new NaiveBayesModel(scoresPerLabelAndFeature, scoresPerFeature, scoresPerLabel, perLabelThetaNormalizer, + alphaI, isComplementary); + } + + /** Write the list of labels into a map file */ + public static int writeLabelIndex(Configuration conf, Iterable<String> labels, Path indexPath) + throws IOException { + FileSystem fs = FileSystem.get(indexPath.toUri(), conf); + int i = 0; + try (SequenceFile.Writer writer = + SequenceFile.createWriter(fs.getConf(), SequenceFile.Writer.file(indexPath), + SequenceFile.Writer.keyClass(Text.class), SequenceFile.Writer.valueClass(IntWritable.class))) { + for (String label : labels) { + writer.append(new Text(label), new IntWritable(i++)); + } + } + return i; + } + + public static int writeLabelIndex(Configuration conf, Path indexPath, + Iterable<Pair<Text,IntWritable>> labels) throws IOException { + FileSystem fs = FileSystem.get(indexPath.toUri(), conf); + Collection<String> seen = new HashSet<>(); + int i = 0; + try (SequenceFile.Writer writer = + SequenceFile.createWriter(fs.getConf(), SequenceFile.Writer.file(indexPath), + SequenceFile.Writer.keyClass(Text.class), SequenceFile.Writer.valueClass(IntWritable.class))){ + for (Object label : labels) { + String theLabel = SLASH.split(((Pair<?, ?>) label).getFirst().toString())[1]; + if (!seen.contains(theLabel)) { + writer.append(new Text(theLabel), new IntWritable(i++)); + seen.add(theLabel); + } + } + } + return i; + } + + public static Map<Integer, String> readLabelIndex(Configuration conf, Path indexPath) { + Map<Integer, String> labelMap = new HashMap<>(); + for (Pair<Text, IntWritable> pair : new SequenceFileIterable<Text, IntWritable>(indexPath, true, conf)) { + labelMap.put(pair.getSecond().get(), pair.getFirst().toString()); + } + return labelMap; + } + + public static OpenObjectIntHashMap<String> readIndexFromCache(Configuration conf) throws IOException { + OpenObjectIntHashMap<String> index = new OpenObjectIntHashMap<>(); + for (Pair<Writable,IntWritable> entry + : new SequenceFileIterable<Writable,IntWritable>(HadoopUtil.getSingleCachedFile(conf), conf)) { + index.put(entry.getFirst().toString(), entry.getSecond().get()); + } + return index; + } + + public static Map<String,Vector> readScoresFromCache(Configuration conf) throws IOException { + Map<String,Vector> sumVectors = new HashMap<>(); + for (Pair<Text,VectorWritable> entry + : new SequenceFileDirIterable<Text,VectorWritable>(HadoopUtil.getSingleCachedFile(conf), + PathType.LIST, PathFilters.partFilter(), conf)) { + sumVectors.put(entry.getFirst().toString(), entry.getSecond().get()); + } + return sumVectors; + } + + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java new file mode 100644 index 0000000..18bd3d6 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java @@ -0,0 +1,43 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes; + + +/** Implementation of the Naive Bayes Classifier Algorithm */ +public class ComplementaryNaiveBayesClassifier extends AbstractNaiveBayesClassifier { + public ComplementaryNaiveBayesClassifier(NaiveBayesModel model) { + super(model); + } + + @Override + public double getScoreForLabelFeature(int label, int feature) { + NaiveBayesModel model = getModel(); + double weight = computeWeight(model.featureWeight(feature), model.weight(label, feature), + model.totalWeightSum(), model.labelWeight(label), model.alphaI(), model.numFeatures()); + // see http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - Section 3.2, Weight Magnitude Errors + return weight / model.thetaNormalizer(label); + } + + // see http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - Section 3.1, Skewed Data bias + public static double computeWeight(double featureWeight, double featureLabelWeight, + double totalWeight, double labelWeight, double alphaI, double numFeatures) { + double numerator = featureWeight - featureLabelWeight + alphaI; + double denominator = totalWeight - labelWeight + alphaI * numFeatures; + return -Math.log(numerator / denominator); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java new file mode 100644 index 0000000..9f85aab --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java @@ -0,0 +1,170 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes; + +import java.io.IOException; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FSDataInputStream; +import org.apache.hadoop.fs.FSDataOutputStream; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.SparseRowMatrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; + +import com.google.common.base.Preconditions; + +/** NaiveBayesModel holds the weight matrix, the feature and label sums and the weight normalizer vectors.*/ +public class NaiveBayesModel { + + private final Vector weightsPerLabel; + private final Vector perlabelThetaNormalizer; + private final Vector weightsPerFeature; + private final Matrix weightsPerLabelAndFeature; + private final float alphaI; + private final double numFeatures; + private final double totalWeightSum; + private final boolean isComplementary; + + public final static String COMPLEMENTARY_MODEL = "COMPLEMENTARY_MODEL"; + + public NaiveBayesModel(Matrix weightMatrix, Vector weightsPerFeature, Vector weightsPerLabel, Vector thetaNormalizer, + float alphaI, boolean isComplementary) { + this.weightsPerLabelAndFeature = weightMatrix; + this.weightsPerFeature = weightsPerFeature; + this.weightsPerLabel = weightsPerLabel; + this.perlabelThetaNormalizer = thetaNormalizer; + this.numFeatures = weightsPerFeature.getNumNondefaultElements(); + this.totalWeightSum = weightsPerLabel.zSum(); + this.alphaI = alphaI; + this.isComplementary=isComplementary; + } + + public double labelWeight(int label) { + return weightsPerLabel.getQuick(label); + } + + public double thetaNormalizer(int label) { + return perlabelThetaNormalizer.get(label); + } + + public double featureWeight(int feature) { + return weightsPerFeature.getQuick(feature); + } + + public double weight(int label, int feature) { + return weightsPerLabelAndFeature.getQuick(label, feature); + } + + public float alphaI() { + return alphaI; + } + + public double numFeatures() { + return numFeatures; + } + + public double totalWeightSum() { + return totalWeightSum; + } + + public int numLabels() { + return weightsPerLabel.size(); + } + + public Vector createScoringVector() { + return weightsPerLabel.like(); + } + + public boolean isComplemtary(){ + return isComplementary; + } + + public static NaiveBayesModel materialize(Path output, Configuration conf) throws IOException { + FileSystem fs = output.getFileSystem(conf); + + Vector weightsPerLabel; + Vector perLabelThetaNormalizer = null; + Vector weightsPerFeature; + Matrix weightsPerLabelAndFeature; + float alphaI; + boolean isComplementary; + + try (FSDataInputStream in = fs.open(new Path(output, "naiveBayesModel.bin"))) { + alphaI = in.readFloat(); + isComplementary = in.readBoolean(); + weightsPerFeature = VectorWritable.readVector(in); + weightsPerLabel = new DenseVector(VectorWritable.readVector(in)); + if (isComplementary){ + perLabelThetaNormalizer = new DenseVector(VectorWritable.readVector(in)); + } + weightsPerLabelAndFeature = new SparseRowMatrix(weightsPerLabel.size(), weightsPerFeature.size()); + for (int label = 0; label < weightsPerLabelAndFeature.numRows(); label++) { + weightsPerLabelAndFeature.assignRow(label, VectorWritable.readVector(in)); + } + } + + NaiveBayesModel model = new NaiveBayesModel(weightsPerLabelAndFeature, weightsPerFeature, weightsPerLabel, + perLabelThetaNormalizer, alphaI, isComplementary); + model.validate(); + return model; + } + + public void serialize(Path output, Configuration conf) throws IOException { + FileSystem fs = output.getFileSystem(conf); + try (FSDataOutputStream out = fs.create(new Path(output, "naiveBayesModel.bin"))) { + out.writeFloat(alphaI); + out.writeBoolean(isComplementary); + VectorWritable.writeVector(out, weightsPerFeature); + VectorWritable.writeVector(out, weightsPerLabel); + if (isComplementary){ + VectorWritable.writeVector(out, perlabelThetaNormalizer); + } + for (int row = 0; row < weightsPerLabelAndFeature.numRows(); row++) { + VectorWritable.writeVector(out, weightsPerLabelAndFeature.viewRow(row)); + } + } + } + + public void validate() { + Preconditions.checkState(alphaI > 0, "alphaI has to be greater than 0!"); + Preconditions.checkArgument(numFeatures > 0, "the vocab count has to be greater than 0!"); + Preconditions.checkArgument(totalWeightSum > 0, "the totalWeightSum has to be greater than 0!"); + Preconditions.checkNotNull(weightsPerLabel, "the number of labels has to be defined!"); + Preconditions.checkArgument(weightsPerLabel.getNumNondefaultElements() > 0, + "the number of labels has to be greater than 0!"); + Preconditions.checkNotNull(weightsPerFeature, "the feature sums have to be defined"); + Preconditions.checkArgument(weightsPerFeature.getNumNondefaultElements() > 0, + "the feature sums have to be greater than 0!"); + if (isComplementary){ + Preconditions.checkArgument(perlabelThetaNormalizer != null, "the theta normalizers have to be defined"); + Preconditions.checkArgument(perlabelThetaNormalizer.getNumNondefaultElements() > 0, + "the number of theta normalizers has to be greater than 0!"); + Preconditions.checkArgument(Math.signum(perlabelThetaNormalizer.minValue()) + == Math.signum(perlabelThetaNormalizer.maxValue()), + "Theta normalizers do not all have the same sign"); + Preconditions.checkArgument(perlabelThetaNormalizer.getNumNonZeroElements() + == perlabelThetaNormalizer.size(), + "Theta normalizers can not have zero value."); + } + + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java new file mode 100644 index 0000000..e4ce8aa --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java @@ -0,0 +1,40 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes; + + +/** Implementation of the Naive Bayes Classifier Algorithm */ +public class StandardNaiveBayesClassifier extends AbstractNaiveBayesClassifier { + + public StandardNaiveBayesClassifier(NaiveBayesModel model) { + super(model); + } + + @Override + public double getScoreForLabelFeature(int label, int feature) { + NaiveBayesModel model = getModel(); + // Standard Naive Bayes does not use weight normalization + return computeWeight(model.weight(label, feature), model.labelWeight(label), model.alphaI(), model.numFeatures()); + } + + public static double computeWeight(double featureLabelWeight, double labelWeight, double alphaI, double numFeatures) { + double numerator = featureLabelWeight + alphaI; + double denominator = labelWeight + alphaI * numFeatures; + return Math.log(numerator / denominator); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java new file mode 100644 index 0000000..37a3b71 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java @@ -0,0 +1,76 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes.test; + +import com.google.common.base.Preconditions; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.classifier.naivebayes.AbstractNaiveBayesClassifier; +import org.apache.mahout.classifier.naivebayes.ComplementaryNaiveBayesClassifier; +import org.apache.mahout.classifier.naivebayes.NaiveBayesModel; +import org.apache.mahout.classifier.naivebayes.StandardNaiveBayesClassifier; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; + +import java.io.IOException; +import java.util.regex.Pattern; + +/** + * Run the input through the model and see if it matches. + * <p/> + * The output value is the generated label, the Pair is the expected label and true if they match: + */ +public class BayesTestMapper extends Mapper<Text, VectorWritable, Text, VectorWritable> { + + private static final Pattern SLASH = Pattern.compile("/"); + + private AbstractNaiveBayesClassifier classifier; + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + super.setup(context); + Configuration conf = context.getConfiguration(); + Path modelPath = HadoopUtil.getSingleCachedFile(conf); + NaiveBayesModel model = NaiveBayesModel.materialize(modelPath, conf); + boolean isComplementary = Boolean.parseBoolean(conf.get(TestNaiveBayesDriver.COMPLEMENTARY)); + + // ensure that if we are testing in complementary mode, the model has been + // trained complementary. a complementarty model will work for standard classification + // a standard model will not work for complementary classification + if (isComplementary) { + Preconditions.checkArgument((model.isComplemtary()), + "Complementary mode in model is different than test mode"); + } + + if (isComplementary) { + classifier = new ComplementaryNaiveBayesClassifier(model); + } else { + classifier = new StandardNaiveBayesClassifier(model); + } + } + + @Override + protected void map(Text key, VectorWritable value, Context context) throws IOException, InterruptedException { + Vector result = classifier.classifyFull(value.get()); + //the key is the expected value + context.write(new Text(SLASH.split(key.toString())[1]), new VectorWritable(result)); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java new file mode 100644 index 0000000..d9eedcf --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java @@ -0,0 +1,176 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes.test; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.regex.Pattern; + +import com.google.common.base.Preconditions; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.classifier.ClassifierResult; +import org.apache.mahout.classifier.ResultAnalyzer; +import org.apache.mahout.classifier.naivebayes.AbstractNaiveBayesClassifier; +import org.apache.mahout.classifier.naivebayes.BayesUtils; +import org.apache.mahout.classifier.naivebayes.ComplementaryNaiveBayesClassifier; +import org.apache.mahout.classifier.naivebayes.NaiveBayesModel; +import org.apache.mahout.classifier.naivebayes.StandardNaiveBayesClassifier; +import org.apache.mahout.common.AbstractJob; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.commandline.DefaultOptionCreator; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Test the (Complementary) Naive Bayes model that was built during training + * by running the iterating the test set and comparing it to the model + */ +public class TestNaiveBayesDriver extends AbstractJob { + + private static final Logger log = LoggerFactory.getLogger(TestNaiveBayesDriver.class); + + public static final String COMPLEMENTARY = "class"; //b for bayes, c for complementary + private static final Pattern SLASH = Pattern.compile("/"); + + public static void main(String[] args) throws Exception { + ToolRunner.run(new Configuration(), new TestNaiveBayesDriver(), args); + } + + @Override + public int run(String[] args) throws Exception { + addInputOption(); + addOutputOption(); + addOption(addOption(DefaultOptionCreator.overwriteOption().create())); + addOption("model", "m", "The path to the model built during training", true); + addOption(buildOption("testComplementary", "c", "test complementary?", false, false, String.valueOf(false))); + addOption(buildOption("runSequential", "seq", "run sequential?", false, false, String.valueOf(false))); + addOption("labelIndex", "l", "The path to the location of the label index", true); + Map<String, List<String>> parsedArgs = parseArguments(args); + if (parsedArgs == null) { + return -1; + } + if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) { + HadoopUtil.delete(getConf(), getOutputPath()); + } + + boolean sequential = hasOption("runSequential"); + boolean succeeded; + if (sequential) { + runSequential(); + } else { + succeeded = runMapReduce(); + if (!succeeded) { + return -1; + } + } + + //load the labels + Map<Integer, String> labelMap = BayesUtils.readLabelIndex(getConf(), new Path(getOption("labelIndex"))); + + //loop over the results and create the confusion matrix + SequenceFileDirIterable<Text, VectorWritable> dirIterable = + new SequenceFileDirIterable<>(getOutputPath(), PathType.LIST, PathFilters.partFilter(), getConf()); + ResultAnalyzer analyzer = new ResultAnalyzer(labelMap.values(), "DEFAULT"); + analyzeResults(labelMap, dirIterable, analyzer); + + log.info("{} Results: {}", hasOption("testComplementary") ? "Complementary" : "Standard NB", analyzer); + return 0; + } + + private void runSequential() throws IOException { + boolean complementary = hasOption("testComplementary"); + FileSystem fs = FileSystem.get(getConf()); + NaiveBayesModel model = NaiveBayesModel.materialize(new Path(getOption("model")), getConf()); + + // Ensure that if we are testing in complementary mode, the model has been + // trained complementary. a complementarty model will work for standard classification + // a standard model will not work for complementary classification + if (complementary){ + Preconditions.checkArgument((model.isComplemtary()), + "Complementary mode in model is different from test mode"); + } + + AbstractNaiveBayesClassifier classifier; + if (complementary) { + classifier = new ComplementaryNaiveBayesClassifier(model); + } else { + classifier = new StandardNaiveBayesClassifier(model); + } + + try (SequenceFile.Writer writer = + SequenceFile.createWriter(fs, getConf(), new Path(getOutputPath(), "part-r-00000"), + Text.class, VectorWritable.class)) { + SequenceFileDirIterable<Text, VectorWritable> dirIterable = + new SequenceFileDirIterable<>(getInputPath(), PathType.LIST, PathFilters.partFilter(), getConf()); + // loop through the part-r-* files in getInputPath() and get classification scores for all entries + for (Pair<Text, VectorWritable> pair : dirIterable) { + writer.append(new Text(SLASH.split(pair.getFirst().toString())[1]), + new VectorWritable(classifier.classifyFull(pair.getSecond().get()))); + } + } + } + + private boolean runMapReduce() throws IOException, + InterruptedException, ClassNotFoundException { + Path model = new Path(getOption("model")); + HadoopUtil.cacheFiles(model, getConf()); + //the output key is the expected value, the output value are the scores for all the labels + Job testJob = prepareJob(getInputPath(), getOutputPath(), SequenceFileInputFormat.class, BayesTestMapper.class, + Text.class, VectorWritable.class, SequenceFileOutputFormat.class); + //testJob.getConfiguration().set(LABEL_KEY, getOption("--labels")); + + + boolean complementary = hasOption("testComplementary"); + testJob.getConfiguration().set(COMPLEMENTARY, String.valueOf(complementary)); + return testJob.waitForCompletion(true); + } + + private static void analyzeResults(Map<Integer, String> labelMap, + SequenceFileDirIterable<Text, VectorWritable> dirIterable, + ResultAnalyzer analyzer) { + for (Pair<Text, VectorWritable> pair : dirIterable) { + int bestIdx = Integer.MIN_VALUE; + double bestScore = Long.MIN_VALUE; + for (Vector.Element element : pair.getSecond().get().all()) { + if (element.get() > bestScore) { + bestScore = element.get(); + bestIdx = element.index(); + } + } + if (bestIdx != Integer.MIN_VALUE) { + ClassifierResult classifierResult = new ClassifierResult(labelMap.get(bestIdx), bestScore); + analyzer.addInstance(pair.getFirst().toString(), classifierResult); + } + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java new file mode 100644 index 0000000..2b8ee1e --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java @@ -0,0 +1,83 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes.training; + +import com.google.common.base.Preconditions; +import org.apache.mahout.classifier.naivebayes.ComplementaryNaiveBayesClassifier; +import org.apache.mahout.math.Vector; + +public class ComplementaryThetaTrainer { + + private final Vector weightsPerFeature; + private final Vector weightsPerLabel; + private final Vector perLabelThetaNormalizer; + private final double alphaI; + private final double totalWeightSum; + private final double numFeatures; + + public ComplementaryThetaTrainer(Vector weightsPerFeature, Vector weightsPerLabel, double alphaI) { + Preconditions.checkNotNull(weightsPerFeature); + Preconditions.checkNotNull(weightsPerLabel); + this.weightsPerFeature = weightsPerFeature; + this.weightsPerLabel = weightsPerLabel; + this.alphaI = alphaI; + perLabelThetaNormalizer = weightsPerLabel.like(); + totalWeightSum = weightsPerLabel.zSum(); + numFeatures = weightsPerFeature.getNumNondefaultElements(); + } + + public void train(int label, Vector perLabelWeight) { + double labelWeight = labelWeight(label); + // sum weights for each label including those with zero word counts + for(int i = 0; i < perLabelWeight.size(); i++){ + Vector.Element perLabelWeightElement = perLabelWeight.getElement(i); + updatePerLabelThetaNormalizer(label, + ComplementaryNaiveBayesClassifier.computeWeight(featureWeight(perLabelWeightElement.index()), + perLabelWeightElement.get(), totalWeightSum(), labelWeight, alphaI(), numFeatures())); + } + } + + protected double alphaI() { + return alphaI; + } + + protected double numFeatures() { + return numFeatures; + } + + protected double labelWeight(int label) { + return weightsPerLabel.get(label); + } + + protected double totalWeightSum() { + return totalWeightSum; + } + + protected double featureWeight(int feature) { + return weightsPerFeature.get(feature); + } + + // http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - Section 3.2, Weight Magnitude Errors + protected void updatePerLabelThetaNormalizer(int label, double weight) { + perLabelThetaNormalizer.set(label, perLabelThetaNormalizer.get(label) + Math.abs(weight)); + } + + public Vector retrievePerLabelThetaNormalizer() { + return perLabelThetaNormalizer.clone(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java new file mode 100644 index 0000000..4df869e --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java @@ -0,0 +1,53 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes.training; + +import java.io.IOException; +import java.util.regex.Pattern; + +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.classifier.naivebayes.BayesUtils; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.map.OpenObjectIntHashMap; + +public class IndexInstancesMapper extends Mapper<Text, VectorWritable, IntWritable, VectorWritable> { + + private static final Pattern SLASH = Pattern.compile("/"); + + enum Counter { SKIPPED_INSTANCES } + + private OpenObjectIntHashMap<String> labelIndex; + + @Override + protected void setup(Context ctx) throws IOException, InterruptedException { + super.setup(ctx); + labelIndex = BayesUtils.readIndexFromCache(ctx.getConfiguration()); + } + + @Override + protected void map(Text labelText, VectorWritable instance, Context ctx) throws IOException, InterruptedException { + String label = SLASH.split(labelText.toString())[1]; + if (labelIndex.containsKey(label)) { + ctx.write(new IntWritable(labelIndex.get(label)), instance); + } else { + ctx.getCounter(Counter.SKIPPED_INSTANCES).increment(1); + } + } +}
