http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/mlp/TrainMultilayerPerceptron.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/mlp/TrainMultilayerPerceptron.java b/mr/src/main/java/org/apache/mahout/classifier/mlp/TrainMultilayerPerceptron.java new file mode 100644 index 0000000..0f88a70 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/mlp/TrainMultilayerPerceptron.java @@ -0,0 +1,332 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.classifier.mlp; + +import java.io.BufferedReader; +import java.io.InputStreamReader; +import java.util.List; +import java.util.Map; + +import org.apache.commons.cli2.CommandLine; +import org.apache.commons.cli2.Group; +import org.apache.commons.cli2.Option; +import org.apache.commons.cli2.builder.ArgumentBuilder; +import org.apache.commons.cli2.builder.DefaultOptionBuilder; +import org.apache.commons.cli2.builder.GroupBuilder; +import org.apache.commons.cli2.commandline.Parser; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.mahout.math.Arrays; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.io.Closeables; + +/** Train a {@link MultilayerPerceptron}. */ +public final class TrainMultilayerPerceptron { + + private static final Logger log = LoggerFactory.getLogger(TrainMultilayerPerceptron.class); + + /** The parameters used by MLP. */ + static class Parameters { + double learningRate; + double momemtumWeight; + double regularizationWeight; + + String inputFilePath; + boolean skipHeader; + Map<String, Integer> labelsIndex = Maps.newHashMap(); + + String modelFilePath; + boolean updateModel; + List<Integer> layerSizeList = Lists.newArrayList(); + String squashingFunctionName; + } + + /* + private double learningRate; + private double momemtumWeight; + private double regularizationWeight; + + private String inputFilePath; + private boolean skipHeader; + private Map<String, Integer> labelsIndex = Maps.newHashMap(); + + private String modelFilePath; + private boolean updateModel; + private List<Integer> layerSizeList = Lists.newArrayList(); + private String squashingFunctionName;*/ + + public static void main(String[] args) throws Exception { + Parameters parameters = new Parameters(); + + if (parseArgs(args, parameters)) { + log.info("Validate model..."); + // check whether the model already exists + Path modelPath = new Path(parameters.modelFilePath); + FileSystem modelFs = modelPath.getFileSystem(new Configuration()); + MultilayerPerceptron mlp; + + if (modelFs.exists(modelPath) && parameters.updateModel) { + // incrementally update existing model + log.info("Build model from existing model..."); + mlp = new MultilayerPerceptron(parameters.modelFilePath); + } else { + if (modelFs.exists(modelPath)) { + modelFs.delete(modelPath, true); // delete the existing file + } + log.info("Build model from scratch..."); + mlp = new MultilayerPerceptron(); + for (int i = 0; i < parameters.layerSizeList.size(); ++i) { + if (i != parameters.layerSizeList.size() - 1) { + mlp.addLayer(parameters.layerSizeList.get(i), false, parameters.squashingFunctionName); + } else { + mlp.addLayer(parameters.layerSizeList.get(i), true, parameters.squashingFunctionName); + } + mlp.setCostFunction("Minus_Squared"); + mlp.setLearningRate(parameters.learningRate) + .setMomentumWeight(parameters.momemtumWeight) + .setRegularizationWeight(parameters.regularizationWeight); + } + mlp.setModelPath(parameters.modelFilePath); + } + + // set the parameters + mlp.setLearningRate(parameters.learningRate) + .setMomentumWeight(parameters.momemtumWeight) + .setRegularizationWeight(parameters.regularizationWeight); + + // train by the training data + Path trainingDataPath = new Path(parameters.inputFilePath); + FileSystem dataFs = trainingDataPath.getFileSystem(new Configuration()); + + Preconditions.checkArgument(dataFs.exists(trainingDataPath), "Training dataset %s cannot be found!", + parameters.inputFilePath); + + log.info("Read data and train model..."); + BufferedReader reader = null; + + try { + reader = new BufferedReader(new InputStreamReader(dataFs.open(trainingDataPath))); + String line; + + // read training data line by line + if (parameters.skipHeader) { + reader.readLine(); + } + + int labelDimension = parameters.labelsIndex.size(); + while ((line = reader.readLine()) != null) { + String[] token = line.split(","); + String label = token[token.length - 1]; + int labelIndex = parameters.labelsIndex.get(label); + + double[] instances = new double[token.length - 1 + labelDimension]; + for (int i = 0; i < token.length - 1; ++i) { + instances[i] = Double.parseDouble(token[i]); + } + for (int i = 0; i < labelDimension; ++i) { + instances[token.length - 1 + i] = 0; + } + // set the corresponding dimension + instances[token.length - 1 + labelIndex] = 1; + + Vector instance = new DenseVector(instances).viewPart(0, instances.length); + mlp.trainOnline(instance); + } + + // write model back + log.info("Write trained model to {}", parameters.modelFilePath); + mlp.writeModelToFile(); + mlp.close(); + } finally { + Closeables.close(reader, true); + } + } + } + + /** + * Parse the input arguments. + * + * @param args The input arguments + * @param parameters The parameters parsed. + * @return Whether the input arguments are valid. + * @throws Exception + */ + private static boolean parseArgs(String[] args, Parameters parameters) throws Exception { + // build the options + log.info("Validate and parse arguments..."); + DefaultOptionBuilder optionBuilder = new DefaultOptionBuilder(); + GroupBuilder groupBuilder = new GroupBuilder(); + ArgumentBuilder argumentBuilder = new ArgumentBuilder(); + + // whether skip the first row of the input file + Option skipHeaderOption = optionBuilder.withLongName("skipHeader") + .withShortName("sh").create(); + + Group skipHeaderGroup = groupBuilder.withOption(skipHeaderOption).create(); + + Option inputOption = optionBuilder + .withLongName("input") + .withShortName("i") + .withRequired(true) + .withChildren(skipHeaderGroup) + .withArgument(argumentBuilder.withName("path").withMinimum(1).withMaximum(1) + .create()).withDescription("the file path of training dataset") + .create(); + + Option labelsOption = optionBuilder + .withLongName("labels") + .withShortName("labels") + .withRequired(true) + .withArgument(argumentBuilder.withName("label-name").withMinimum(2).create()) + .withDescription("label names").create(); + + Option updateOption = optionBuilder + .withLongName("update") + .withShortName("u") + .withDescription("whether to incrementally update model if the model exists") + .create(); + + Group modelUpdateGroup = groupBuilder.withOption(updateOption).create(); + + Option modelOption = optionBuilder + .withLongName("model") + .withShortName("mo") + .withRequired(true) + .withArgument(argumentBuilder.withName("model-path").withMinimum(1).withMaximum(1).create()) + .withDescription("the path to store the trained model") + .withChildren(modelUpdateGroup).create(); + + Option layerSizeOption = optionBuilder + .withLongName("layerSize") + .withShortName("ls") + .withRequired(true) + .withArgument(argumentBuilder.withName("size of layer").withMinimum(2).withMaximum(5).create()) + .withDescription("the size of each layer").create(); + + Option squashingFunctionOption = optionBuilder + .withLongName("squashingFunction") + .withShortName("sf") + .withArgument(argumentBuilder.withName("squashing function").withMinimum(1).withMaximum(1) + .withDefault("Sigmoid").create()) + .withDescription("the name of squashing function (currently only supports Sigmoid)") + .create(); + + Option learningRateOption = optionBuilder + .withLongName("learningRate") + .withShortName("l") + .withArgument(argumentBuilder.withName("learning rate").withMaximum(1) + .withMinimum(1).withDefault(NeuralNetwork.DEFAULT_LEARNING_RATE).create()) + .withDescription("learning rate").create(); + + Option momemtumOption = optionBuilder + .withLongName("momemtumWeight") + .withShortName("m") + .withArgument(argumentBuilder.withName("momemtum weight").withMaximum(1) + .withMinimum(1).withDefault(NeuralNetwork.DEFAULT_MOMENTUM_WEIGHT).create()) + .withDescription("momemtum weight").create(); + + Option regularizationOption = optionBuilder + .withLongName("regularizationWeight") + .withShortName("r") + .withArgument(argumentBuilder.withName("regularization weight").withMaximum(1) + .withMinimum(1).withDefault(NeuralNetwork.DEFAULT_REGULARIZATION_WEIGHT).create()) + .withDescription("regularization weight").create(); + + // parse the input + Parser parser = new Parser(); + Group normalOptions = groupBuilder.withOption(inputOption) + .withOption(skipHeaderOption).withOption(updateOption) + .withOption(labelsOption).withOption(modelOption) + .withOption(layerSizeOption).withOption(squashingFunctionOption) + .withOption(learningRateOption).withOption(momemtumOption) + .withOption(regularizationOption).create(); + + parser.setGroup(normalOptions); + + CommandLine commandLine = parser.parseAndHelp(args); + if (commandLine == null) { + return false; + } + + parameters.learningRate = getDouble(commandLine, learningRateOption); + parameters.momemtumWeight = getDouble(commandLine, momemtumOption); + parameters.regularizationWeight = getDouble(commandLine, regularizationOption); + + parameters.inputFilePath = getString(commandLine, inputOption); + parameters.skipHeader = commandLine.hasOption(skipHeaderOption); + + List<String> labelsList = getStringList(commandLine, labelsOption); + int currentIndex = 0; + for (String label : labelsList) { + parameters.labelsIndex.put(label, currentIndex++); + } + + parameters.modelFilePath = getString(commandLine, modelOption); + parameters.updateModel = commandLine.hasOption(updateOption); + + parameters.layerSizeList = getIntegerList(commandLine, layerSizeOption); + + parameters.squashingFunctionName = getString(commandLine, squashingFunctionOption); + + System.out.printf("Input: %s, Model: %s, Update: %s, Layer size: %s, Squashing function: %s, Learning rate: %f," + + " Momemtum weight: %f, Regularization Weight: %f\n", parameters.inputFilePath, parameters.modelFilePath, + parameters.updateModel, Arrays.toString(parameters.layerSizeList.toArray()), + parameters.squashingFunctionName, parameters.learningRate, parameters.momemtumWeight, + parameters.regularizationWeight); + + return true; + } + + static Double getDouble(CommandLine commandLine, Option option) { + Object val = commandLine.getValue(option); + if (val != null) { + return Double.parseDouble(val.toString()); + } + return null; + } + + static String getString(CommandLine commandLine, Option option) { + Object val = commandLine.getValue(option); + if (val != null) { + return val.toString(); + } + return null; + } + + static List<Integer> getIntegerList(CommandLine commandLine, Option option) { + List<String> list = commandLine.getValues(option); + List<Integer> valList = Lists.newArrayList(); + for (String str : list) { + valList.add(Integer.parseInt(str)); + } + return valList; + } + + static List<String> getStringList(CommandLine commandLine, Option option) { + return commandLine.getValues(option); + } + +} \ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java new file mode 100644 index 0000000..f0794b3 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java @@ -0,0 +1,82 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes; + +import org.apache.mahout.classifier.AbstractVectorClassifier; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.Vector.Element; + +/** + * Class implementing the Naive Bayes Classifier Algorithm. Note that this class + * supports {@link #classifyFull}, but not {@code classify} or + * {@code classifyScalar}. The reason that these two methods are not + * supported is because the scores computed by a NaiveBayesClassifier do not + * represent probabilities. + */ +public abstract class AbstractNaiveBayesClassifier extends AbstractVectorClassifier { + + private final NaiveBayesModel model; + + protected AbstractNaiveBayesClassifier(NaiveBayesModel model) { + this.model = model; + } + + protected NaiveBayesModel getModel() { + return model; + } + + protected abstract double getScoreForLabelFeature(int label, int feature); + + protected double getScoreForLabelInstance(int label, Vector instance) { + double result = 0.0; + for (Element e : instance.nonZeroes()) { + result += e.get() * getScoreForLabelFeature(label, e.index()); + } + return result; + } + + @Override + public int numCategories() { + return model.numLabels(); + } + + @Override + public Vector classifyFull(Vector instance) { + return classifyFull(model.createScoringVector(), instance); + } + + @Override + public Vector classifyFull(Vector r, Vector instance) { + for (int label = 0; label < model.numLabels(); label++) { + r.setQuick(label, getScoreForLabelInstance(label, instance)); + } + return r; + } + + /** Unsupported method. This implementation simply throws an {@link UnsupportedOperationException}. */ + @Override + public double classifyScalar(Vector instance) { + throw new UnsupportedOperationException("Not supported in Naive Bayes"); + } + + /** Unsupported method. This implementation simply throws an {@link UnsupportedOperationException}. */ + @Override + public Vector classify(Vector instance) { + throw new UnsupportedOperationException("probabilites not supported in Naive Bayes"); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java new file mode 100644 index 0000000..1e5171c --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java @@ -0,0 +1,167 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes; + +import java.io.IOException; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import java.util.regex.Pattern; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.io.Writable; +import org.apache.mahout.classifier.naivebayes.training.ThetaMapper; +import org.apache.mahout.classifier.naivebayes.training.TrainNaiveBayesJob; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.SparseMatrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.map.OpenObjectIntHashMap; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import com.google.common.io.Closeables; + +public final class BayesUtils { + + private static final Pattern SLASH = Pattern.compile("/"); + + private BayesUtils() {} + + public static NaiveBayesModel readModelFromDir(Path base, Configuration conf) { + + float alphaI = conf.getFloat(ThetaMapper.ALPHA_I, 1.0f); + boolean isComplementary = conf.getBoolean(NaiveBayesModel.COMPLEMENTARY_MODEL, true); + + // read feature sums and label sums + Vector scoresPerLabel = null; + Vector scoresPerFeature = null; + for (Pair<Text,VectorWritable> record : new SequenceFileDirIterable<Text, VectorWritable>( + new Path(base, TrainNaiveBayesJob.WEIGHTS), PathType.LIST, PathFilters.partFilter(), conf)) { + String key = record.getFirst().toString(); + VectorWritable value = record.getSecond(); + if (key.equals(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE)) { + scoresPerFeature = value.get(); + } else if (key.equals(TrainNaiveBayesJob.WEIGHTS_PER_LABEL)) { + scoresPerLabel = value.get(); + } + } + + Preconditions.checkNotNull(scoresPerFeature); + Preconditions.checkNotNull(scoresPerLabel); + + Matrix scoresPerLabelAndFeature = new SparseMatrix(scoresPerLabel.size(), scoresPerFeature.size()); + for (Pair<IntWritable,VectorWritable> entry : new SequenceFileDirIterable<IntWritable,VectorWritable>( + new Path(base, TrainNaiveBayesJob.SUMMED_OBSERVATIONS), PathType.LIST, PathFilters.partFilter(), conf)) { + scoresPerLabelAndFeature.assignRow(entry.getFirst().get(), entry.getSecond().get()); + } + + // perLabelThetaNormalizer is only used by the complementary model, we do not instantiate it for the standard model + Vector perLabelThetaNormalizer = null; + if (isComplementary) { + perLabelThetaNormalizer=scoresPerLabel.like(); + for (Pair<Text,VectorWritable> entry : new SequenceFileDirIterable<Text,VectorWritable>( + new Path(base, TrainNaiveBayesJob.THETAS), PathType.LIST, PathFilters.partFilter(), conf)) { + if (entry.getFirst().toString().equals(TrainNaiveBayesJob.LABEL_THETA_NORMALIZER)) { + perLabelThetaNormalizer = entry.getSecond().get(); + } + } + Preconditions.checkNotNull(perLabelThetaNormalizer); + } + + return new NaiveBayesModel(scoresPerLabelAndFeature, scoresPerFeature, scoresPerLabel, perLabelThetaNormalizer, + alphaI, isComplementary); + } + + /** Write the list of labels into a map file */ + public static int writeLabelIndex(Configuration conf, Iterable<String> labels, Path indexPath) + throws IOException { + FileSystem fs = FileSystem.get(indexPath.toUri(), conf); + SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, indexPath, Text.class, IntWritable.class); + int i = 0; + try { + for (String label : labels) { + writer.append(new Text(label), new IntWritable(i++)); + } + } finally { + Closeables.close(writer, false); + } + return i; + } + + public static int writeLabelIndex(Configuration conf, Path indexPath, + Iterable<Pair<Text,IntWritable>> labels) throws IOException { + FileSystem fs = FileSystem.get(indexPath.toUri(), conf); + SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, indexPath, Text.class, IntWritable.class); + Collection<String> seen = Sets.newHashSet(); + int i = 0; + try { + for (Object label : labels) { + String theLabel = SLASH.split(((Pair<?, ?>) label).getFirst().toString())[1]; + if (!seen.contains(theLabel)) { + writer.append(new Text(theLabel), new IntWritable(i++)); + seen.add(theLabel); + } + } + } finally { + Closeables.close(writer, false); + } + return i; + } + + public static Map<Integer, String> readLabelIndex(Configuration conf, Path indexPath) { + Map<Integer, String> labelMap = new HashMap<>(); + for (Pair<Text, IntWritable> pair : new SequenceFileIterable<Text, IntWritable>(indexPath, true, conf)) { + labelMap.put(pair.getSecond().get(), pair.getFirst().toString()); + } + return labelMap; + } + + public static OpenObjectIntHashMap<String> readIndexFromCache(Configuration conf) throws IOException { + OpenObjectIntHashMap<String> index = new OpenObjectIntHashMap<>(); + for (Pair<Writable,IntWritable> entry + : new SequenceFileIterable<Writable,IntWritable>(HadoopUtil.getSingleCachedFile(conf), conf)) { + index.put(entry.getFirst().toString(), entry.getSecond().get()); + } + return index; + } + + public static Map<String,Vector> readScoresFromCache(Configuration conf) throws IOException { + Map<String,Vector> sumVectors = Maps.newHashMap(); + for (Pair<Text,VectorWritable> entry + : new SequenceFileDirIterable<Text,VectorWritable>(HadoopUtil.getSingleCachedFile(conf), + PathType.LIST, PathFilters.partFilter(), conf)) { + sumVectors.put(entry.getFirst().toString(), entry.getSecond().get()); + } + return sumVectors; + } + + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java new file mode 100644 index 0000000..18bd3d6 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java @@ -0,0 +1,43 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes; + + +/** Implementation of the Naive Bayes Classifier Algorithm */ +public class ComplementaryNaiveBayesClassifier extends AbstractNaiveBayesClassifier { + public ComplementaryNaiveBayesClassifier(NaiveBayesModel model) { + super(model); + } + + @Override + public double getScoreForLabelFeature(int label, int feature) { + NaiveBayesModel model = getModel(); + double weight = computeWeight(model.featureWeight(feature), model.weight(label, feature), + model.totalWeightSum(), model.labelWeight(label), model.alphaI(), model.numFeatures()); + // see http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - Section 3.2, Weight Magnitude Errors + return weight / model.thetaNormalizer(label); + } + + // see http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - Section 3.1, Skewed Data bias + public static double computeWeight(double featureWeight, double featureLabelWeight, + double totalWeight, double labelWeight, double alphaI, double numFeatures) { + double numerator = featureWeight - featureLabelWeight + alphaI; + double denominator = totalWeight - labelWeight + alphaI * numFeatures; + return -Math.log(numerator / denominator); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java new file mode 100644 index 0000000..f180e8b --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java @@ -0,0 +1,176 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes; + +import java.io.IOException; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FSDataInputStream; +import org.apache.hadoop.fs.FSDataOutputStream; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.SparseRowMatrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; + +import com.google.common.base.Preconditions; +import com.google.common.io.Closeables; + +/** NaiveBayesModel holds the weight matrix, the feature and label sums and the weight normalizer vectors.*/ +public class NaiveBayesModel { + + private final Vector weightsPerLabel; + private final Vector perlabelThetaNormalizer; + private final Vector weightsPerFeature; + private final Matrix weightsPerLabelAndFeature; + private final float alphaI; + private final double numFeatures; + private final double totalWeightSum; + private final boolean isComplementary; + + public final static String COMPLEMENTARY_MODEL = "COMPLEMENTARY_MODEL"; + + public NaiveBayesModel(Matrix weightMatrix, Vector weightsPerFeature, Vector weightsPerLabel, Vector thetaNormalizer, + float alphaI, boolean isComplementary) { + this.weightsPerLabelAndFeature = weightMatrix; + this.weightsPerFeature = weightsPerFeature; + this.weightsPerLabel = weightsPerLabel; + this.perlabelThetaNormalizer = thetaNormalizer; + this.numFeatures = weightsPerFeature.getNumNondefaultElements(); + this.totalWeightSum = weightsPerLabel.zSum(); + this.alphaI = alphaI; + this.isComplementary=isComplementary; + } + + public double labelWeight(int label) { + return weightsPerLabel.getQuick(label); + } + + public double thetaNormalizer(int label) { + return perlabelThetaNormalizer.get(label); + } + + public double featureWeight(int feature) { + return weightsPerFeature.getQuick(feature); + } + + public double weight(int label, int feature) { + return weightsPerLabelAndFeature.getQuick(label, feature); + } + + public float alphaI() { + return alphaI; + } + + public double numFeatures() { + return numFeatures; + } + + public double totalWeightSum() { + return totalWeightSum; + } + + public int numLabels() { + return weightsPerLabel.size(); + } + + public Vector createScoringVector() { + return weightsPerLabel.like(); + } + + public boolean isComplemtary(){ + return isComplementary; + } + + public static NaiveBayesModel materialize(Path output, Configuration conf) throws IOException { + FileSystem fs = output.getFileSystem(conf); + + Vector weightsPerLabel = null; + Vector perLabelThetaNormalizer = null; + Vector weightsPerFeature = null; + Matrix weightsPerLabelAndFeature; + float alphaI; + boolean isComplementary; + + FSDataInputStream in = fs.open(new Path(output, "naiveBayesModel.bin")); + try { + alphaI = in.readFloat(); + isComplementary = in.readBoolean(); + weightsPerFeature = VectorWritable.readVector(in); + weightsPerLabel = new DenseVector(VectorWritable.readVector(in)); + if (isComplementary){ + perLabelThetaNormalizer = new DenseVector(VectorWritable.readVector(in)); + } + weightsPerLabelAndFeature = new SparseRowMatrix(weightsPerLabel.size(), weightsPerFeature.size()); + for (int label = 0; label < weightsPerLabelAndFeature.numRows(); label++) { + weightsPerLabelAndFeature.assignRow(label, VectorWritable.readVector(in)); + } + } finally { + Closeables.close(in, true); + } + NaiveBayesModel model = new NaiveBayesModel(weightsPerLabelAndFeature, weightsPerFeature, weightsPerLabel, + perLabelThetaNormalizer, alphaI, isComplementary); + model.validate(); + return model; + } + + public void serialize(Path output, Configuration conf) throws IOException { + FileSystem fs = output.getFileSystem(conf); + FSDataOutputStream out = fs.create(new Path(output, "naiveBayesModel.bin")); + try { + out.writeFloat(alphaI); + out.writeBoolean(isComplementary); + VectorWritable.writeVector(out, weightsPerFeature); + VectorWritable.writeVector(out, weightsPerLabel); + if (isComplementary){ + VectorWritable.writeVector(out, perlabelThetaNormalizer); + } + for (int row = 0; row < weightsPerLabelAndFeature.numRows(); row++) { + VectorWritable.writeVector(out, weightsPerLabelAndFeature.viewRow(row)); + } + } finally { + Closeables.close(out, false); + } + } + + public void validate() { + Preconditions.checkState(alphaI > 0, "alphaI has to be greater than 0!"); + Preconditions.checkArgument(numFeatures > 0, "the vocab count has to be greater than 0!"); + Preconditions.checkArgument(totalWeightSum > 0, "the totalWeightSum has to be greater than 0!"); + Preconditions.checkNotNull(weightsPerLabel, "the number of labels has to be defined!"); + Preconditions.checkArgument(weightsPerLabel.getNumNondefaultElements() > 0, + "the number of labels has to be greater than 0!"); + Preconditions.checkNotNull(weightsPerFeature, "the feature sums have to be defined"); + Preconditions.checkArgument(weightsPerFeature.getNumNondefaultElements() > 0, + "the feature sums have to be greater than 0!"); + if (isComplementary){ + Preconditions.checkArgument(perlabelThetaNormalizer != null, "the theta normalizers have to be defined"); + Preconditions.checkArgument(perlabelThetaNormalizer.getNumNondefaultElements() > 0, + "the number of theta normalizers has to be greater than 0!"); + Preconditions.checkArgument(Math.signum(perlabelThetaNormalizer.minValue()) + == Math.signum(perlabelThetaNormalizer.maxValue()), + "Theta normalizers do not all have the same sign"); + Preconditions.checkArgument(perlabelThetaNormalizer.getNumNonZeroElements() + == perlabelThetaNormalizer.size(), + "Theta normalizers can not have zero value."); + } + + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java new file mode 100644 index 0000000..e4ce8aa --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java @@ -0,0 +1,40 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes; + + +/** Implementation of the Naive Bayes Classifier Algorithm */ +public class StandardNaiveBayesClassifier extends AbstractNaiveBayesClassifier { + + public StandardNaiveBayesClassifier(NaiveBayesModel model) { + super(model); + } + + @Override + public double getScoreForLabelFeature(int label, int feature) { + NaiveBayesModel model = getModel(); + // Standard Naive Bayes does not use weight normalization + return computeWeight(model.weight(label, feature), model.labelWeight(label), model.alphaI(), model.numFeatures()); + } + + public static double computeWeight(double featureLabelWeight, double labelWeight, double alphaI, double numFeatures) { + double numerator = featureLabelWeight + alphaI; + double denominator = labelWeight + alphaI * numFeatures; + return Math.log(numerator / denominator); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java new file mode 100644 index 0000000..37a3b71 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java @@ -0,0 +1,76 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes.test; + +import com.google.common.base.Preconditions; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.classifier.naivebayes.AbstractNaiveBayesClassifier; +import org.apache.mahout.classifier.naivebayes.ComplementaryNaiveBayesClassifier; +import org.apache.mahout.classifier.naivebayes.NaiveBayesModel; +import org.apache.mahout.classifier.naivebayes.StandardNaiveBayesClassifier; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; + +import java.io.IOException; +import java.util.regex.Pattern; + +/** + * Run the input through the model and see if it matches. + * <p/> + * The output value is the generated label, the Pair is the expected label and true if they match: + */ +public class BayesTestMapper extends Mapper<Text, VectorWritable, Text, VectorWritable> { + + private static final Pattern SLASH = Pattern.compile("/"); + + private AbstractNaiveBayesClassifier classifier; + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + super.setup(context); + Configuration conf = context.getConfiguration(); + Path modelPath = HadoopUtil.getSingleCachedFile(conf); + NaiveBayesModel model = NaiveBayesModel.materialize(modelPath, conf); + boolean isComplementary = Boolean.parseBoolean(conf.get(TestNaiveBayesDriver.COMPLEMENTARY)); + + // ensure that if we are testing in complementary mode, the model has been + // trained complementary. a complementarty model will work for standard classification + // a standard model will not work for complementary classification + if (isComplementary) { + Preconditions.checkArgument((model.isComplemtary()), + "Complementary mode in model is different than test mode"); + } + + if (isComplementary) { + classifier = new ComplementaryNaiveBayesClassifier(model); + } else { + classifier = new StandardNaiveBayesClassifier(model); + } + } + + @Override + protected void map(Text key, VectorWritable value, Context context) throws IOException, InterruptedException { + Vector result = classifier.classifyFull(value.get()); + //the key is the expected value + context.write(new Text(SLASH.split(key.toString())[1]), new VectorWritable(result)); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java new file mode 100644 index 0000000..8fd422f --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java @@ -0,0 +1,179 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes.test; + +import com.google.common.base.Preconditions; +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.regex.Pattern; + +import com.google.common.io.Closeables; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.classifier.ClassifierResult; +import org.apache.mahout.classifier.ResultAnalyzer; +import org.apache.mahout.classifier.naivebayes.AbstractNaiveBayesClassifier; +import org.apache.mahout.classifier.naivebayes.BayesUtils; +import org.apache.mahout.classifier.naivebayes.ComplementaryNaiveBayesClassifier; +import org.apache.mahout.classifier.naivebayes.NaiveBayesModel; +import org.apache.mahout.classifier.naivebayes.StandardNaiveBayesClassifier; +import org.apache.mahout.common.AbstractJob; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.commandline.DefaultOptionCreator; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Test the (Complementary) Naive Bayes model that was built during training + * by running the iterating the test set and comparing it to the model + */ +public class TestNaiveBayesDriver extends AbstractJob { + + private static final Logger log = LoggerFactory.getLogger(TestNaiveBayesDriver.class); + + public static final String COMPLEMENTARY = "class"; //b for bayes, c for complementary + private static final Pattern SLASH = Pattern.compile("/"); + + public static void main(String[] args) throws Exception { + ToolRunner.run(new Configuration(), new TestNaiveBayesDriver(), args); + } + + @Override + public int run(String[] args) throws Exception { + addInputOption(); + addOutputOption(); + addOption(addOption(DefaultOptionCreator.overwriteOption().create())); + addOption("model", "m", "The path to the model built during training", true); + addOption(buildOption("testComplementary", "c", "test complementary?", false, false, String.valueOf(false))); + addOption(buildOption("runSequential", "seq", "run sequential?", false, false, String.valueOf(false))); + addOption("labelIndex", "l", "The path to the location of the label index", true); + Map<String, List<String>> parsedArgs = parseArguments(args); + if (parsedArgs == null) { + return -1; + } + if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) { + HadoopUtil.delete(getConf(), getOutputPath()); + } + + boolean sequential = hasOption("runSequential"); + boolean succeeded; + if (sequential) { + runSequential(); + } else { + succeeded = runMapReduce(); + if (!succeeded) { + return -1; + } + } + + //load the labels + Map<Integer, String> labelMap = BayesUtils.readLabelIndex(getConf(), new Path(getOption("labelIndex"))); + + //loop over the results and create the confusion matrix + SequenceFileDirIterable<Text, VectorWritable> dirIterable = + new SequenceFileDirIterable<>(getOutputPath(), PathType.LIST, PathFilters.partFilter(), getConf()); + ResultAnalyzer analyzer = new ResultAnalyzer(labelMap.values(), "DEFAULT"); + analyzeResults(labelMap, dirIterable, analyzer); + + log.info("{} Results: {}", hasOption("testComplementary") ? "Complementary" : "Standard NB", analyzer); + return 0; + } + + private void runSequential() throws IOException { + boolean complementary = hasOption("testComplementary"); + FileSystem fs = FileSystem.get(getConf()); + NaiveBayesModel model = NaiveBayesModel.materialize(new Path(getOption("model")), getConf()); + + // Ensure that if we are testing in complementary mode, the model has been + // trained complementary. a complementarty model will work for standard classification + // a standard model will not work for complementary classification + if (complementary){ + Preconditions.checkArgument((model.isComplemtary()), + "Complementary mode in model is different from test mode"); + } + + AbstractNaiveBayesClassifier classifier; + if (complementary) { + classifier = new ComplementaryNaiveBayesClassifier(model); + } else { + classifier = new StandardNaiveBayesClassifier(model); + } + SequenceFile.Writer writer = SequenceFile.createWriter(fs, getConf(), new Path(getOutputPath(), "part-r-00000"), + Text.class, VectorWritable.class); + + try { + SequenceFileDirIterable<Text, VectorWritable> dirIterable = + new SequenceFileDirIterable<>(getInputPath(), PathType.LIST, PathFilters.partFilter(), getConf()); + // loop through the part-r-* files in getInputPath() and get classification scores for all entries + for (Pair<Text, VectorWritable> pair : dirIterable) { + writer.append(new Text(SLASH.split(pair.getFirst().toString())[1]), + new VectorWritable(classifier.classifyFull(pair.getSecond().get()))); + } + } finally { + Closeables.close(writer, false); + } + } + + private boolean runMapReduce() throws IOException, + InterruptedException, ClassNotFoundException { + Path model = new Path(getOption("model")); + HadoopUtil.cacheFiles(model, getConf()); + //the output key is the expected value, the output value are the scores for all the labels + Job testJob = prepareJob(getInputPath(), getOutputPath(), SequenceFileInputFormat.class, BayesTestMapper.class, + Text.class, VectorWritable.class, SequenceFileOutputFormat.class); + //testJob.getConfiguration().set(LABEL_KEY, getOption("--labels")); + + + boolean complementary = hasOption("testComplementary"); + testJob.getConfiguration().set(COMPLEMENTARY, String.valueOf(complementary)); + return testJob.waitForCompletion(true); + } + + private static void analyzeResults(Map<Integer, String> labelMap, + SequenceFileDirIterable<Text, VectorWritable> dirIterable, + ResultAnalyzer analyzer) { + for (Pair<Text, VectorWritable> pair : dirIterable) { + int bestIdx = Integer.MIN_VALUE; + double bestScore = Long.MIN_VALUE; + for (Vector.Element element : pair.getSecond().get().all()) { + if (element.get() > bestScore) { + bestScore = element.get(); + bestIdx = element.index(); + } + } + if (bestIdx != Integer.MIN_VALUE) { + ClassifierResult classifierResult = new ClassifierResult(labelMap.get(bestIdx), bestScore); + analyzer.addInstance(pair.getFirst().toString(), classifierResult); + } + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java new file mode 100644 index 0000000..2b8ee1e --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java @@ -0,0 +1,83 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes.training; + +import com.google.common.base.Preconditions; +import org.apache.mahout.classifier.naivebayes.ComplementaryNaiveBayesClassifier; +import org.apache.mahout.math.Vector; + +public class ComplementaryThetaTrainer { + + private final Vector weightsPerFeature; + private final Vector weightsPerLabel; + private final Vector perLabelThetaNormalizer; + private final double alphaI; + private final double totalWeightSum; + private final double numFeatures; + + public ComplementaryThetaTrainer(Vector weightsPerFeature, Vector weightsPerLabel, double alphaI) { + Preconditions.checkNotNull(weightsPerFeature); + Preconditions.checkNotNull(weightsPerLabel); + this.weightsPerFeature = weightsPerFeature; + this.weightsPerLabel = weightsPerLabel; + this.alphaI = alphaI; + perLabelThetaNormalizer = weightsPerLabel.like(); + totalWeightSum = weightsPerLabel.zSum(); + numFeatures = weightsPerFeature.getNumNondefaultElements(); + } + + public void train(int label, Vector perLabelWeight) { + double labelWeight = labelWeight(label); + // sum weights for each label including those with zero word counts + for(int i = 0; i < perLabelWeight.size(); i++){ + Vector.Element perLabelWeightElement = perLabelWeight.getElement(i); + updatePerLabelThetaNormalizer(label, + ComplementaryNaiveBayesClassifier.computeWeight(featureWeight(perLabelWeightElement.index()), + perLabelWeightElement.get(), totalWeightSum(), labelWeight, alphaI(), numFeatures())); + } + } + + protected double alphaI() { + return alphaI; + } + + protected double numFeatures() { + return numFeatures; + } + + protected double labelWeight(int label) { + return weightsPerLabel.get(label); + } + + protected double totalWeightSum() { + return totalWeightSum; + } + + protected double featureWeight(int feature) { + return weightsPerFeature.get(feature); + } + + // http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - Section 3.2, Weight Magnitude Errors + protected void updatePerLabelThetaNormalizer(int label, double weight) { + perLabelThetaNormalizer.set(label, perLabelThetaNormalizer.get(label) + Math.abs(weight)); + } + + public Vector retrievePerLabelThetaNormalizer() { + return perLabelThetaNormalizer.clone(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java new file mode 100644 index 0000000..40ca2e9 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java @@ -0,0 +1,53 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes.training; + +import java.io.IOException; +import java.util.regex.Pattern; + +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.classifier.naivebayes.BayesUtils; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.map.OpenObjectIntHashMap; + +public class IndexInstancesMapper extends Mapper<Text, VectorWritable, IntWritable, VectorWritable> { + + private static final Pattern SLASH = Pattern.compile("/"); + + public enum Counter { SKIPPED_INSTANCES } + + private OpenObjectIntHashMap<String> labelIndex; + + @Override + protected void setup(Context ctx) throws IOException, InterruptedException { + super.setup(ctx); + labelIndex = BayesUtils.readIndexFromCache(ctx.getConfiguration()); + } + + @Override + protected void map(Text labelText, VectorWritable instance, Context ctx) throws IOException, InterruptedException { + String label = SLASH.split(labelText.toString())[1]; + if (labelIndex.containsKey(label)) { + ctx.write(new IntWritable(labelIndex.get(label)), instance); + } else { + ctx.getCounter(Counter.SKIPPED_INSTANCES).increment(1); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java new file mode 100644 index 0000000..ff2ea40 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java @@ -0,0 +1,61 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes.training; + +import java.io.IOException; +import java.util.Map; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.classifier.naivebayes.BayesUtils; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; + +public class ThetaMapper extends Mapper<IntWritable, VectorWritable, Text, VectorWritable> { + + public static final String ALPHA_I = ThetaMapper.class.getName() + ".alphaI"; + static final String TRAIN_COMPLEMENTARY = ThetaMapper.class.getName() + ".trainComplementary"; + + private ComplementaryThetaTrainer trainer; + + @Override + protected void setup(Context ctx) throws IOException, InterruptedException { + super.setup(ctx); + Configuration conf = ctx.getConfiguration(); + + float alphaI = conf.getFloat(ALPHA_I, 1.0f); + Map<String, Vector> scores = BayesUtils.readScoresFromCache(conf); + + trainer = new ComplementaryThetaTrainer(scores.get(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE), + scores.get(TrainNaiveBayesJob.WEIGHTS_PER_LABEL), alphaI); + } + + @Override + protected void map(IntWritable key, VectorWritable value, Context ctx) throws IOException, InterruptedException { + trainer.train(key.get(), value.get()); + } + + @Override + protected void cleanup(Context ctx) throws IOException, InterruptedException { + ctx.write(new Text(TrainNaiveBayesJob.LABEL_THETA_NORMALIZER), + new VectorWritable(trainer.retrievePerLabelThetaNormalizer())); + super.cleanup(ctx); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java new file mode 100644 index 0000000..ac1c4c9 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java @@ -0,0 +1,186 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes.training; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.classifier.naivebayes.BayesUtils; +import org.apache.mahout.classifier.naivebayes.NaiveBayesModel; +import org.apache.mahout.common.AbstractJob; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.commandline.DefaultOptionCreator; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable; +import org.apache.mahout.common.mapreduce.VectorSumReducer; +import org.apache.mahout.math.VectorWritable; + +import com.google.common.base.Splitter; + +/** Trains a Naive Bayes Classifier (parameters for both Naive Bayes and Complementary Naive Bayes) */ +public final class TrainNaiveBayesJob extends AbstractJob { + private static final String TRAIN_COMPLEMENTARY = "trainComplementary"; + private static final String ALPHA_I = "alphaI"; + private static final String LABEL_INDEX = "labelIndex"; + private static final String EXTRACT_LABELS = "extractLabels"; + private static final String LABELS = "labels"; + public static final String WEIGHTS_PER_FEATURE = "__SPF"; + public static final String WEIGHTS_PER_LABEL = "__SPL"; + public static final String LABEL_THETA_NORMALIZER = "_LTN"; + + public static final String SUMMED_OBSERVATIONS = "summedObservations"; + public static final String WEIGHTS = "weights"; + public static final String THETAS = "thetas"; + + public static void main(String[] args) throws Exception { + ToolRunner.run(new Configuration(), new TrainNaiveBayesJob(), args); + } + + @Override + public int run(String[] args) throws Exception { + + addInputOption(); + addOutputOption(); + addOption(LABELS, "l", "comma-separated list of labels to include in training", false); + + addOption(buildOption(EXTRACT_LABELS, "el", "Extract the labels from the input", false, false, "")); + addOption(ALPHA_I, "a", "smoothing parameter", String.valueOf(1.0f)); + addOption(buildOption(TRAIN_COMPLEMENTARY, "c", "train complementary?", false, false, String.valueOf(false))); + addOption(LABEL_INDEX, "li", "The path to store the label index in", false); + addOption(DefaultOptionCreator.overwriteOption().create()); + Map<String, List<String>> parsedArgs = parseArguments(args); + if (parsedArgs == null) { + return -1; + } + if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) { + HadoopUtil.delete(getConf(), getOutputPath()); + HadoopUtil.delete(getConf(), getTempPath()); + } + Path labPath; + String labPathStr = getOption(LABEL_INDEX); + if (labPathStr != null) { + labPath = new Path(labPathStr); + } else { + labPath = getTempPath(LABEL_INDEX); + } + long labelSize = createLabelIndex(labPath); + float alphaI = Float.parseFloat(getOption(ALPHA_I)); + boolean trainComplementary = hasOption(TRAIN_COMPLEMENTARY); + + HadoopUtil.setSerializations(getConf()); + HadoopUtil.cacheFiles(labPath, getConf()); + + // Add up all the vectors with the same labels, while mapping the labels into our index + Job indexInstances = prepareJob(getInputPath(), + getTempPath(SUMMED_OBSERVATIONS), + SequenceFileInputFormat.class, + IndexInstancesMapper.class, + IntWritable.class, + VectorWritable.class, + VectorSumReducer.class, + IntWritable.class, + VectorWritable.class, + SequenceFileOutputFormat.class); + indexInstances.setCombinerClass(VectorSumReducer.class); + boolean succeeded = indexInstances.waitForCompletion(true); + if (!succeeded) { + return -1; + } + // Sum up all the weights from the previous step, per label and per feature + Job weightSummer = prepareJob(getTempPath(SUMMED_OBSERVATIONS), + getTempPath(WEIGHTS), + SequenceFileInputFormat.class, + WeightsMapper.class, + Text.class, + VectorWritable.class, + VectorSumReducer.class, + Text.class, + VectorWritable.class, + SequenceFileOutputFormat.class); + weightSummer.getConfiguration().set(WeightsMapper.NUM_LABELS, String.valueOf(labelSize)); + weightSummer.setCombinerClass(VectorSumReducer.class); + succeeded = weightSummer.waitForCompletion(true); + if (!succeeded) { + return -1; + } + + // Put the per label and per feature vectors into the cache + HadoopUtil.cacheFiles(getTempPath(WEIGHTS), getConf()); + + if (trainComplementary){ + // Calculate the per label theta normalizers, write out to LABEL_THETA_NORMALIZER vector + // see http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - Section 3.2, Weight Magnitude Errors + Job thetaSummer = prepareJob(getTempPath(SUMMED_OBSERVATIONS), + getTempPath(THETAS), + SequenceFileInputFormat.class, + ThetaMapper.class, + Text.class, + VectorWritable.class, + VectorSumReducer.class, + Text.class, + VectorWritable.class, + SequenceFileOutputFormat.class); + thetaSummer.setCombinerClass(VectorSumReducer.class); + thetaSummer.getConfiguration().setFloat(ThetaMapper.ALPHA_I, alphaI); + thetaSummer.getConfiguration().setBoolean(ThetaMapper.TRAIN_COMPLEMENTARY, trainComplementary); + succeeded = thetaSummer.waitForCompletion(true); + if (!succeeded) { + return -1; + } + } + + // Put the per label theta normalizers into the cache + HadoopUtil.cacheFiles(getTempPath(THETAS), getConf()); + + // Validate our model and then write it out to the official output + getConf().setFloat(ThetaMapper.ALPHA_I, alphaI); + getConf().setBoolean(NaiveBayesModel.COMPLEMENTARY_MODEL, trainComplementary); + NaiveBayesModel naiveBayesModel = BayesUtils.readModelFromDir(getTempPath(), getConf()); + naiveBayesModel.validate(); + naiveBayesModel.serialize(getOutputPath(), getConf()); + + return 0; + } + + private long createLabelIndex(Path labPath) throws IOException { + long labelSize = 0; + if (hasOption(LABELS)) { + Iterable<String> labels = Splitter.on(",").split(getOption(LABELS)); + labelSize = BayesUtils.writeLabelIndex(getConf(), labels, labPath); + } else if (hasOption(EXTRACT_LABELS)) { + Iterable<Pair<Text,IntWritable>> iterable = + new SequenceFileDirIterable<Text, IntWritable>(getInputPath(), + PathType.LIST, + PathFilters.logsCRCFilter(), + getConf()); + labelSize = BayesUtils.writeLabelIndex(getConf(), labPath, iterable); + } + return labelSize; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java new file mode 100644 index 0000000..5563057 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java @@ -0,0 +1,68 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes.training; + +import java.io.IOException; + +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.RandomAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.function.Functions; + +import com.google.common.base.Preconditions; + +public class WeightsMapper extends Mapper<IntWritable, VectorWritable, Text, VectorWritable> { + + static final String NUM_LABELS = WeightsMapper.class.getName() + ".numLabels"; + + private Vector weightsPerFeature; + private Vector weightsPerLabel; + + @Override + protected void setup(Context ctx) throws IOException, InterruptedException { + super.setup(ctx); + int numLabels = Integer.parseInt(ctx.getConfiguration().get(NUM_LABELS)); + Preconditions.checkArgument(numLabels > 0, "Wrong numLabels: " + numLabels + ". Must be > 0!"); + weightsPerLabel = new DenseVector(numLabels); + } + + @Override + protected void map(IntWritable index, VectorWritable value, Context ctx) throws IOException, InterruptedException { + Vector instance = value.get(); + if (weightsPerFeature == null) { + weightsPerFeature = new RandomAccessSparseVector(instance.size(), instance.getNumNondefaultElements()); + } + + int label = index.get(); + weightsPerFeature.assign(instance, Functions.PLUS); + weightsPerLabel.set(label, weightsPerLabel.get(label) + instance.zSum()); + } + + @Override + protected void cleanup(Context ctx) throws IOException, InterruptedException { + if (weightsPerFeature != null) { + ctx.write(new Text(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE), new VectorWritable(weightsPerFeature)); + ctx.write(new Text(TrainNaiveBayesJob.WEIGHTS_PER_LABEL), new VectorWritable(weightsPerLabel)); + } + super.cleanup(ctx); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/BaumWelchTrainer.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/BaumWelchTrainer.java b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/BaumWelchTrainer.java new file mode 100644 index 0000000..942a101 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/BaumWelchTrainer.java @@ -0,0 +1,165 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import java.io.DataOutputStream; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.util.Date; +import java.util.List; +import java.util.Scanner; + +import com.google.common.collect.Lists; +import com.google.common.io.Closeables; +import org.apache.commons.cli2.CommandLine; +import org.apache.commons.cli2.Group; +import org.apache.commons.cli2.Option; +import org.apache.commons.cli2.OptionException; +import org.apache.commons.cli2.builder.ArgumentBuilder; +import org.apache.commons.cli2.builder.DefaultOptionBuilder; +import org.apache.commons.cli2.builder.GroupBuilder; +import org.apache.commons.cli2.commandline.Parser; +import org.apache.mahout.common.CommandLineUtil; +import org.apache.mahout.common.commandline.DefaultOptionCreator; + +/** + * A class for EM training of HMM from console + */ +public final class BaumWelchTrainer { + + private BaumWelchTrainer() { + } + + public static void main(String[] args) throws IOException { + DefaultOptionBuilder optionBuilder = new DefaultOptionBuilder(); + ArgumentBuilder argumentBuilder = new ArgumentBuilder(); + + Option inputOption = DefaultOptionCreator.inputOption().create(); + + Option outputOption = DefaultOptionCreator.outputOption().create(); + + Option stateNumberOption = optionBuilder.withLongName("nrOfHiddenStates"). + withDescription("Number of hidden states"). + withShortName("nh").withArgument(argumentBuilder.withMaximum(1).withMinimum(1). + withName("number").create()).withRequired(true).create(); + + Option observedStateNumberOption = optionBuilder.withLongName("nrOfObservedStates"). + withDescription("Number of observed states"). + withShortName("no").withArgument(argumentBuilder.withMaximum(1).withMinimum(1). + withName("number").create()).withRequired(true).create(); + + Option epsilonOption = optionBuilder.withLongName("epsilon"). + withDescription("Convergence threshold"). + withShortName("e").withArgument(argumentBuilder.withMaximum(1).withMinimum(1). + withName("number").create()).withRequired(true).create(); + + Option iterationsOption = optionBuilder.withLongName("max-iterations"). + withDescription("Maximum iterations number"). + withShortName("m").withArgument(argumentBuilder.withMaximum(1).withMinimum(1). + withName("number").create()).withRequired(true).create(); + + Group optionGroup = new GroupBuilder().withOption(inputOption). + withOption(outputOption).withOption(stateNumberOption).withOption(observedStateNumberOption). + withOption(epsilonOption).withOption(iterationsOption). + withName("Options").create(); + + try { + Parser parser = new Parser(); + parser.setGroup(optionGroup); + CommandLine commandLine = parser.parse(args); + + String input = (String) commandLine.getValue(inputOption); + String output = (String) commandLine.getValue(outputOption); + + int nrOfHiddenStates = Integer.parseInt((String) commandLine.getValue(stateNumberOption)); + int nrOfObservedStates = Integer.parseInt((String) commandLine.getValue(observedStateNumberOption)); + + double epsilon = Double.parseDouble((String) commandLine.getValue(epsilonOption)); + int maxIterations = Integer.parseInt((String) commandLine.getValue(iterationsOption)); + + //constructing random-generated HMM + HmmModel model = new HmmModel(nrOfHiddenStates, nrOfObservedStates, new Date().getTime()); + List<Integer> observations = Lists.newArrayList(); + + //reading observations + try (Scanner scanner = new Scanner(new FileInputStream(input), "UTF-8")) { + while (scanner.hasNextInt()) { + observations.add(scanner.nextInt()); + } + } + + int[] observationsArray = new int[observations.size()]; + for (int i = 0; i < observations.size(); ++i) { + observationsArray[i] = observations.get(i); + } + + //training + HmmModel trainedModel = HmmTrainer.trainBaumWelch(model, + observationsArray, epsilon, maxIterations, true); + + //serializing trained model + DataOutputStream stream = new DataOutputStream(new FileOutputStream(output)); + try { + LossyHmmSerializer.serialize(trainedModel, stream); + } finally { + Closeables.close(stream, false); + } + + //printing tranied model + System.out.println("Initial probabilities: "); + for (int i = 0; i < trainedModel.getNrOfHiddenStates(); ++i) { + System.out.print(i + " "); + } + System.out.println(); + for (int i = 0; i < trainedModel.getNrOfHiddenStates(); ++i) { + System.out.print(trainedModel.getInitialProbabilities().get(i) + " "); + } + System.out.println(); + + System.out.println("Transition matrix:"); + System.out.print(" "); + for (int i = 0; i < trainedModel.getNrOfHiddenStates(); ++i) { + System.out.print(i + " "); + } + System.out.println(); + for (int i = 0; i < trainedModel.getNrOfHiddenStates(); ++i) { + System.out.print(i + " "); + for (int j = 0; j < trainedModel.getNrOfHiddenStates(); ++j) { + System.out.print(trainedModel.getTransitionMatrix().get(i, j) + " "); + } + System.out.println(); + } + System.out.println("Emission matrix: "); + System.out.print(" "); + for (int i = 0; i < trainedModel.getNrOfOutputStates(); ++i) { + System.out.print(i + " "); + } + System.out.println(); + for (int i = 0; i < trainedModel.getNrOfHiddenStates(); ++i) { + System.out.print(i + " "); + for (int j = 0; j < trainedModel.getNrOfOutputStates(); ++j) { + System.out.print(trainedModel.getEmissionMatrix().get(i, j) + " "); + } + System.out.println(); + } + } catch (OptionException e) { + CommandLineUtil.printHelp(optionGroup); + } + } +}
