http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java new file mode 100644 index 0000000..9f85aab --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java @@ -0,0 +1,170 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes; + +import java.io.IOException; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FSDataInputStream; +import org.apache.hadoop.fs.FSDataOutputStream; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.SparseRowMatrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; + +import com.google.common.base.Preconditions; + +/** NaiveBayesModel holds the weight matrix, the feature and label sums and the weight normalizer vectors.*/ +public class NaiveBayesModel { + + private final Vector weightsPerLabel; + private final Vector perlabelThetaNormalizer; + private final Vector weightsPerFeature; + private final Matrix weightsPerLabelAndFeature; + private final float alphaI; + private final double numFeatures; + private final double totalWeightSum; + private final boolean isComplementary; + + public final static String COMPLEMENTARY_MODEL = "COMPLEMENTARY_MODEL"; + + public NaiveBayesModel(Matrix weightMatrix, Vector weightsPerFeature, Vector weightsPerLabel, Vector thetaNormalizer, + float alphaI, boolean isComplementary) { + this.weightsPerLabelAndFeature = weightMatrix; + this.weightsPerFeature = weightsPerFeature; + this.weightsPerLabel = weightsPerLabel; + this.perlabelThetaNormalizer = thetaNormalizer; + this.numFeatures = weightsPerFeature.getNumNondefaultElements(); + this.totalWeightSum = weightsPerLabel.zSum(); + this.alphaI = alphaI; + this.isComplementary=isComplementary; + } + + public double labelWeight(int label) { + return weightsPerLabel.getQuick(label); + } + + public double thetaNormalizer(int label) { + return perlabelThetaNormalizer.get(label); + } + + public double featureWeight(int feature) { + return weightsPerFeature.getQuick(feature); + } + + public double weight(int label, int feature) { + return weightsPerLabelAndFeature.getQuick(label, feature); + } + + public float alphaI() { + return alphaI; + } + + public double numFeatures() { + return numFeatures; + } + + public double totalWeightSum() { + return totalWeightSum; + } + + public int numLabels() { + return weightsPerLabel.size(); + } + + public Vector createScoringVector() { + return weightsPerLabel.like(); + } + + public boolean isComplemtary(){ + return isComplementary; + } + + public static NaiveBayesModel materialize(Path output, Configuration conf) throws IOException { + FileSystem fs = output.getFileSystem(conf); + + Vector weightsPerLabel; + Vector perLabelThetaNormalizer = null; + Vector weightsPerFeature; + Matrix weightsPerLabelAndFeature; + float alphaI; + boolean isComplementary; + + try (FSDataInputStream in = fs.open(new Path(output, "naiveBayesModel.bin"))) { + alphaI = in.readFloat(); + isComplementary = in.readBoolean(); + weightsPerFeature = VectorWritable.readVector(in); + weightsPerLabel = new DenseVector(VectorWritable.readVector(in)); + if (isComplementary){ + perLabelThetaNormalizer = new DenseVector(VectorWritable.readVector(in)); + } + weightsPerLabelAndFeature = new SparseRowMatrix(weightsPerLabel.size(), weightsPerFeature.size()); + for (int label = 0; label < weightsPerLabelAndFeature.numRows(); label++) { + weightsPerLabelAndFeature.assignRow(label, VectorWritable.readVector(in)); + } + } + + NaiveBayesModel model = new NaiveBayesModel(weightsPerLabelAndFeature, weightsPerFeature, weightsPerLabel, + perLabelThetaNormalizer, alphaI, isComplementary); + model.validate(); + return model; + } + + public void serialize(Path output, Configuration conf) throws IOException { + FileSystem fs = output.getFileSystem(conf); + try (FSDataOutputStream out = fs.create(new Path(output, "naiveBayesModel.bin"))) { + out.writeFloat(alphaI); + out.writeBoolean(isComplementary); + VectorWritable.writeVector(out, weightsPerFeature); + VectorWritable.writeVector(out, weightsPerLabel); + if (isComplementary){ + VectorWritable.writeVector(out, perlabelThetaNormalizer); + } + for (int row = 0; row < weightsPerLabelAndFeature.numRows(); row++) { + VectorWritable.writeVector(out, weightsPerLabelAndFeature.viewRow(row)); + } + } + } + + public void validate() { + Preconditions.checkState(alphaI > 0, "alphaI has to be greater than 0!"); + Preconditions.checkArgument(numFeatures > 0, "the vocab count has to be greater than 0!"); + Preconditions.checkArgument(totalWeightSum > 0, "the totalWeightSum has to be greater than 0!"); + Preconditions.checkNotNull(weightsPerLabel, "the number of labels has to be defined!"); + Preconditions.checkArgument(weightsPerLabel.getNumNondefaultElements() > 0, + "the number of labels has to be greater than 0!"); + Preconditions.checkNotNull(weightsPerFeature, "the feature sums have to be defined"); + Preconditions.checkArgument(weightsPerFeature.getNumNondefaultElements() > 0, + "the feature sums have to be greater than 0!"); + if (isComplementary){ + Preconditions.checkArgument(perlabelThetaNormalizer != null, "the theta normalizers have to be defined"); + Preconditions.checkArgument(perlabelThetaNormalizer.getNumNondefaultElements() > 0, + "the number of theta normalizers has to be greater than 0!"); + Preconditions.checkArgument(Math.signum(perlabelThetaNormalizer.minValue()) + == Math.signum(perlabelThetaNormalizer.maxValue()), + "Theta normalizers do not all have the same sign"); + Preconditions.checkArgument(perlabelThetaNormalizer.getNumNonZeroElements() + == perlabelThetaNormalizer.size(), + "Theta normalizers can not have zero value."); + } + + } +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java new file mode 100644 index 0000000..e4ce8aa --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java @@ -0,0 +1,40 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes; + + +/** Implementation of the Naive Bayes Classifier Algorithm */ +public class StandardNaiveBayesClassifier extends AbstractNaiveBayesClassifier { + + public StandardNaiveBayesClassifier(NaiveBayesModel model) { + super(model); + } + + @Override + public double getScoreForLabelFeature(int label, int feature) { + NaiveBayesModel model = getModel(); + // Standard Naive Bayes does not use weight normalization + return computeWeight(model.weight(label, feature), model.labelWeight(label), model.alphaI(), model.numFeatures()); + } + + public static double computeWeight(double featureLabelWeight, double labelWeight, double alphaI, double numFeatures) { + double numerator = featureLabelWeight + alphaI; + double denominator = labelWeight + alphaI * numFeatures; + return Math.log(numerator / denominator); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java new file mode 100644 index 0000000..37a3b71 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java @@ -0,0 +1,76 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes.test; + +import com.google.common.base.Preconditions; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.classifier.naivebayes.AbstractNaiveBayesClassifier; +import org.apache.mahout.classifier.naivebayes.ComplementaryNaiveBayesClassifier; +import org.apache.mahout.classifier.naivebayes.NaiveBayesModel; +import org.apache.mahout.classifier.naivebayes.StandardNaiveBayesClassifier; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; + +import java.io.IOException; +import java.util.regex.Pattern; + +/** + * Run the input through the model and see if it matches. + * <p/> + * The output value is the generated label, the Pair is the expected label and true if they match: + */ +public class BayesTestMapper extends Mapper<Text, VectorWritable, Text, VectorWritable> { + + private static final Pattern SLASH = Pattern.compile("/"); + + private AbstractNaiveBayesClassifier classifier; + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + super.setup(context); + Configuration conf = context.getConfiguration(); + Path modelPath = HadoopUtil.getSingleCachedFile(conf); + NaiveBayesModel model = NaiveBayesModel.materialize(modelPath, conf); + boolean isComplementary = Boolean.parseBoolean(conf.get(TestNaiveBayesDriver.COMPLEMENTARY)); + + // ensure that if we are testing in complementary mode, the model has been + // trained complementary. a complementarty model will work for standard classification + // a standard model will not work for complementary classification + if (isComplementary) { + Preconditions.checkArgument((model.isComplemtary()), + "Complementary mode in model is different than test mode"); + } + + if (isComplementary) { + classifier = new ComplementaryNaiveBayesClassifier(model); + } else { + classifier = new StandardNaiveBayesClassifier(model); + } + } + + @Override + protected void map(Text key, VectorWritable value, Context context) throws IOException, InterruptedException { + Vector result = classifier.classifyFull(value.get()); + //the key is the expected value + context.write(new Text(SLASH.split(key.toString())[1]), new VectorWritable(result)); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java new file mode 100644 index 0000000..d9eedcf --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java @@ -0,0 +1,176 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes.test; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.regex.Pattern; + +import com.google.common.base.Preconditions; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.classifier.ClassifierResult; +import org.apache.mahout.classifier.ResultAnalyzer; +import org.apache.mahout.classifier.naivebayes.AbstractNaiveBayesClassifier; +import org.apache.mahout.classifier.naivebayes.BayesUtils; +import org.apache.mahout.classifier.naivebayes.ComplementaryNaiveBayesClassifier; +import org.apache.mahout.classifier.naivebayes.NaiveBayesModel; +import org.apache.mahout.classifier.naivebayes.StandardNaiveBayesClassifier; +import org.apache.mahout.common.AbstractJob; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.commandline.DefaultOptionCreator; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Test the (Complementary) Naive Bayes model that was built during training + * by running the iterating the test set and comparing it to the model + */ +public class TestNaiveBayesDriver extends AbstractJob { + + private static final Logger log = LoggerFactory.getLogger(TestNaiveBayesDriver.class); + + public static final String COMPLEMENTARY = "class"; //b for bayes, c for complementary + private static final Pattern SLASH = Pattern.compile("/"); + + public static void main(String[] args) throws Exception { + ToolRunner.run(new Configuration(), new TestNaiveBayesDriver(), args); + } + + @Override + public int run(String[] args) throws Exception { + addInputOption(); + addOutputOption(); + addOption(addOption(DefaultOptionCreator.overwriteOption().create())); + addOption("model", "m", "The path to the model built during training", true); + addOption(buildOption("testComplementary", "c", "test complementary?", false, false, String.valueOf(false))); + addOption(buildOption("runSequential", "seq", "run sequential?", false, false, String.valueOf(false))); + addOption("labelIndex", "l", "The path to the location of the label index", true); + Map<String, List<String>> parsedArgs = parseArguments(args); + if (parsedArgs == null) { + return -1; + } + if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) { + HadoopUtil.delete(getConf(), getOutputPath()); + } + + boolean sequential = hasOption("runSequential"); + boolean succeeded; + if (sequential) { + runSequential(); + } else { + succeeded = runMapReduce(); + if (!succeeded) { + return -1; + } + } + + //load the labels + Map<Integer, String> labelMap = BayesUtils.readLabelIndex(getConf(), new Path(getOption("labelIndex"))); + + //loop over the results and create the confusion matrix + SequenceFileDirIterable<Text, VectorWritable> dirIterable = + new SequenceFileDirIterable<>(getOutputPath(), PathType.LIST, PathFilters.partFilter(), getConf()); + ResultAnalyzer analyzer = new ResultAnalyzer(labelMap.values(), "DEFAULT"); + analyzeResults(labelMap, dirIterable, analyzer); + + log.info("{} Results: {}", hasOption("testComplementary") ? "Complementary" : "Standard NB", analyzer); + return 0; + } + + private void runSequential() throws IOException { + boolean complementary = hasOption("testComplementary"); + FileSystem fs = FileSystem.get(getConf()); + NaiveBayesModel model = NaiveBayesModel.materialize(new Path(getOption("model")), getConf()); + + // Ensure that if we are testing in complementary mode, the model has been + // trained complementary. a complementarty model will work for standard classification + // a standard model will not work for complementary classification + if (complementary){ + Preconditions.checkArgument((model.isComplemtary()), + "Complementary mode in model is different from test mode"); + } + + AbstractNaiveBayesClassifier classifier; + if (complementary) { + classifier = new ComplementaryNaiveBayesClassifier(model); + } else { + classifier = new StandardNaiveBayesClassifier(model); + } + + try (SequenceFile.Writer writer = + SequenceFile.createWriter(fs, getConf(), new Path(getOutputPath(), "part-r-00000"), + Text.class, VectorWritable.class)) { + SequenceFileDirIterable<Text, VectorWritable> dirIterable = + new SequenceFileDirIterable<>(getInputPath(), PathType.LIST, PathFilters.partFilter(), getConf()); + // loop through the part-r-* files in getInputPath() and get classification scores for all entries + for (Pair<Text, VectorWritable> pair : dirIterable) { + writer.append(new Text(SLASH.split(pair.getFirst().toString())[1]), + new VectorWritable(classifier.classifyFull(pair.getSecond().get()))); + } + } + } + + private boolean runMapReduce() throws IOException, + InterruptedException, ClassNotFoundException { + Path model = new Path(getOption("model")); + HadoopUtil.cacheFiles(model, getConf()); + //the output key is the expected value, the output value are the scores for all the labels + Job testJob = prepareJob(getInputPath(), getOutputPath(), SequenceFileInputFormat.class, BayesTestMapper.class, + Text.class, VectorWritable.class, SequenceFileOutputFormat.class); + //testJob.getConfiguration().set(LABEL_KEY, getOption("--labels")); + + + boolean complementary = hasOption("testComplementary"); + testJob.getConfiguration().set(COMPLEMENTARY, String.valueOf(complementary)); + return testJob.waitForCompletion(true); + } + + private static void analyzeResults(Map<Integer, String> labelMap, + SequenceFileDirIterable<Text, VectorWritable> dirIterable, + ResultAnalyzer analyzer) { + for (Pair<Text, VectorWritable> pair : dirIterable) { + int bestIdx = Integer.MIN_VALUE; + double bestScore = Long.MIN_VALUE; + for (Vector.Element element : pair.getSecond().get().all()) { + if (element.get() > bestScore) { + bestScore = element.get(); + bestIdx = element.index(); + } + } + if (bestIdx != Integer.MIN_VALUE) { + ClassifierResult classifierResult = new ClassifierResult(labelMap.get(bestIdx), bestScore); + analyzer.addInstance(pair.getFirst().toString(), classifierResult); + } + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java new file mode 100644 index 0000000..2b8ee1e --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java @@ -0,0 +1,83 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes.training; + +import com.google.common.base.Preconditions; +import org.apache.mahout.classifier.naivebayes.ComplementaryNaiveBayesClassifier; +import org.apache.mahout.math.Vector; + +public class ComplementaryThetaTrainer { + + private final Vector weightsPerFeature; + private final Vector weightsPerLabel; + private final Vector perLabelThetaNormalizer; + private final double alphaI; + private final double totalWeightSum; + private final double numFeatures; + + public ComplementaryThetaTrainer(Vector weightsPerFeature, Vector weightsPerLabel, double alphaI) { + Preconditions.checkNotNull(weightsPerFeature); + Preconditions.checkNotNull(weightsPerLabel); + this.weightsPerFeature = weightsPerFeature; + this.weightsPerLabel = weightsPerLabel; + this.alphaI = alphaI; + perLabelThetaNormalizer = weightsPerLabel.like(); + totalWeightSum = weightsPerLabel.zSum(); + numFeatures = weightsPerFeature.getNumNondefaultElements(); + } + + public void train(int label, Vector perLabelWeight) { + double labelWeight = labelWeight(label); + // sum weights for each label including those with zero word counts + for(int i = 0; i < perLabelWeight.size(); i++){ + Vector.Element perLabelWeightElement = perLabelWeight.getElement(i); + updatePerLabelThetaNormalizer(label, + ComplementaryNaiveBayesClassifier.computeWeight(featureWeight(perLabelWeightElement.index()), + perLabelWeightElement.get(), totalWeightSum(), labelWeight, alphaI(), numFeatures())); + } + } + + protected double alphaI() { + return alphaI; + } + + protected double numFeatures() { + return numFeatures; + } + + protected double labelWeight(int label) { + return weightsPerLabel.get(label); + } + + protected double totalWeightSum() { + return totalWeightSum; + } + + protected double featureWeight(int feature) { + return weightsPerFeature.get(feature); + } + + // http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - Section 3.2, Weight Magnitude Errors + protected void updatePerLabelThetaNormalizer(int label, double weight) { + perLabelThetaNormalizer.set(label, perLabelThetaNormalizer.get(label) + Math.abs(weight)); + } + + public Vector retrievePerLabelThetaNormalizer() { + return perLabelThetaNormalizer.clone(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java new file mode 100644 index 0000000..4df869e --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java @@ -0,0 +1,53 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes.training; + +import java.io.IOException; +import java.util.regex.Pattern; + +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.classifier.naivebayes.BayesUtils; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.map.OpenObjectIntHashMap; + +public class IndexInstancesMapper extends Mapper<Text, VectorWritable, IntWritable, VectorWritable> { + + private static final Pattern SLASH = Pattern.compile("/"); + + enum Counter { SKIPPED_INSTANCES } + + private OpenObjectIntHashMap<String> labelIndex; + + @Override + protected void setup(Context ctx) throws IOException, InterruptedException { + super.setup(ctx); + labelIndex = BayesUtils.readIndexFromCache(ctx.getConfiguration()); + } + + @Override + protected void map(Text labelText, VectorWritable instance, Context ctx) throws IOException, InterruptedException { + String label = SLASH.split(labelText.toString())[1]; + if (labelIndex.containsKey(label)) { + ctx.write(new IntWritable(labelIndex.get(label)), instance); + } else { + ctx.getCounter(Counter.SKIPPED_INSTANCES).increment(1); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java new file mode 100644 index 0000000..ff2ea40 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java @@ -0,0 +1,61 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes.training; + +import java.io.IOException; +import java.util.Map; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.classifier.naivebayes.BayesUtils; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; + +public class ThetaMapper extends Mapper<IntWritable, VectorWritable, Text, VectorWritable> { + + public static final String ALPHA_I = ThetaMapper.class.getName() + ".alphaI"; + static final String TRAIN_COMPLEMENTARY = ThetaMapper.class.getName() + ".trainComplementary"; + + private ComplementaryThetaTrainer trainer; + + @Override + protected void setup(Context ctx) throws IOException, InterruptedException { + super.setup(ctx); + Configuration conf = ctx.getConfiguration(); + + float alphaI = conf.getFloat(ALPHA_I, 1.0f); + Map<String, Vector> scores = BayesUtils.readScoresFromCache(conf); + + trainer = new ComplementaryThetaTrainer(scores.get(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE), + scores.get(TrainNaiveBayesJob.WEIGHTS_PER_LABEL), alphaI); + } + + @Override + protected void map(IntWritable key, VectorWritable value, Context ctx) throws IOException, InterruptedException { + trainer.train(key.get(), value.get()); + } + + @Override + protected void cleanup(Context ctx) throws IOException, InterruptedException { + ctx.write(new Text(TrainNaiveBayesJob.LABEL_THETA_NORMALIZER), + new VectorWritable(trainer.retrievePerLabelThetaNormalizer())); + super.cleanup(ctx); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java new file mode 100644 index 0000000..cd18d28 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java @@ -0,0 +1,177 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes.training; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.classifier.naivebayes.BayesUtils; +import org.apache.mahout.classifier.naivebayes.NaiveBayesModel; +import org.apache.mahout.common.AbstractJob; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.commandline.DefaultOptionCreator; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable; +import org.apache.mahout.common.mapreduce.VectorSumReducer; +import org.apache.mahout.math.VectorWritable; + +import com.google.common.base.Splitter; + +/** Trains a Naive Bayes Classifier (parameters for both Naive Bayes and Complementary Naive Bayes) */ +public final class TrainNaiveBayesJob extends AbstractJob { + private static final String TRAIN_COMPLEMENTARY = "trainComplementary"; + private static final String ALPHA_I = "alphaI"; + private static final String LABEL_INDEX = "labelIndex"; + public static final String WEIGHTS_PER_FEATURE = "__SPF"; + public static final String WEIGHTS_PER_LABEL = "__SPL"; + public static final String LABEL_THETA_NORMALIZER = "_LTN"; + public static final String SUMMED_OBSERVATIONS = "summedObservations"; + public static final String WEIGHTS = "weights"; + public static final String THETAS = "thetas"; + + public static void main(String[] args) throws Exception { + ToolRunner.run(new Configuration(), new TrainNaiveBayesJob(), args); + } + + @Override + public int run(String[] args) throws Exception { + + addInputOption(); + addOutputOption(); + + addOption(ALPHA_I, "a", "smoothing parameter", String.valueOf(1.0f)); + addOption(buildOption(TRAIN_COMPLEMENTARY, "c", "train complementary?", false, false, String.valueOf(false))); + addOption(LABEL_INDEX, "li", "The path to store the label index in", false); + addOption(DefaultOptionCreator.overwriteOption().create()); + + Map<String, List<String>> parsedArgs = parseArguments(args); + if (parsedArgs == null) { + return -1; + } + if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) { + HadoopUtil.delete(getConf(), getOutputPath()); + HadoopUtil.delete(getConf(), getTempPath()); + } + Path labPath; + String labPathStr = getOption(LABEL_INDEX); + if (labPathStr != null) { + labPath = new Path(labPathStr); + } else { + labPath = getTempPath(LABEL_INDEX); + } + long labelSize = createLabelIndex(labPath); + float alphaI = Float.parseFloat(getOption(ALPHA_I)); + boolean trainComplementary = hasOption(TRAIN_COMPLEMENTARY); + + HadoopUtil.setSerializations(getConf()); + HadoopUtil.cacheFiles(labPath, getConf()); + + // Add up all the vectors with the same labels, while mapping the labels into our index + Job indexInstances = prepareJob(getInputPath(), + getTempPath(SUMMED_OBSERVATIONS), + SequenceFileInputFormat.class, + IndexInstancesMapper.class, + IntWritable.class, + VectorWritable.class, + VectorSumReducer.class, + IntWritable.class, + VectorWritable.class, + SequenceFileOutputFormat.class); + indexInstances.setCombinerClass(VectorSumReducer.class); + boolean succeeded = indexInstances.waitForCompletion(true); + if (!succeeded) { + return -1; + } + // Sum up all the weights from the previous step, per label and per feature + Job weightSummer = prepareJob(getTempPath(SUMMED_OBSERVATIONS), + getTempPath(WEIGHTS), + SequenceFileInputFormat.class, + WeightsMapper.class, + Text.class, + VectorWritable.class, + VectorSumReducer.class, + Text.class, + VectorWritable.class, + SequenceFileOutputFormat.class); + weightSummer.getConfiguration().set(WeightsMapper.NUM_LABELS, String.valueOf(labelSize)); + weightSummer.setCombinerClass(VectorSumReducer.class); + succeeded = weightSummer.waitForCompletion(true); + if (!succeeded) { + return -1; + } + + // Put the per label and per feature vectors into the cache + HadoopUtil.cacheFiles(getTempPath(WEIGHTS), getConf()); + + if (trainComplementary){ + // Calculate the per label theta normalizers, write out to LABEL_THETA_NORMALIZER vector + // see http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - Section 3.2, Weight Magnitude Errors + Job thetaSummer = prepareJob(getTempPath(SUMMED_OBSERVATIONS), + getTempPath(THETAS), + SequenceFileInputFormat.class, + ThetaMapper.class, + Text.class, + VectorWritable.class, + VectorSumReducer.class, + Text.class, + VectorWritable.class, + SequenceFileOutputFormat.class); + thetaSummer.setCombinerClass(VectorSumReducer.class); + thetaSummer.getConfiguration().setFloat(ThetaMapper.ALPHA_I, alphaI); + thetaSummer.getConfiguration().setBoolean(ThetaMapper.TRAIN_COMPLEMENTARY, trainComplementary); + succeeded = thetaSummer.waitForCompletion(true); + if (!succeeded) { + return -1; + } + } + + // Put the per label theta normalizers into the cache + HadoopUtil.cacheFiles(getTempPath(THETAS), getConf()); + + // Validate our model and then write it out to the official output + getConf().setFloat(ThetaMapper.ALPHA_I, alphaI); + getConf().setBoolean(NaiveBayesModel.COMPLEMENTARY_MODEL, trainComplementary); + NaiveBayesModel naiveBayesModel = BayesUtils.readModelFromDir(getTempPath(), getConf()); + naiveBayesModel.validate(); + naiveBayesModel.serialize(getOutputPath(), getConf()); + + return 0; + } + + private long createLabelIndex(Path labPath) throws IOException { + long labelSize = 0; + Iterable<Pair<Text,IntWritable>> iterable = + new SequenceFileDirIterable<>(getInputPath(), + PathType.LIST, + PathFilters.logsCRCFilter(), + getConf()); + labelSize = BayesUtils.writeLabelIndex(getConf(), labPath, iterable); + return labelSize; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java new file mode 100644 index 0000000..5563057 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java @@ -0,0 +1,68 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes.training; + +import java.io.IOException; + +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.RandomAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.function.Functions; + +import com.google.common.base.Preconditions; + +public class WeightsMapper extends Mapper<IntWritable, VectorWritable, Text, VectorWritable> { + + static final String NUM_LABELS = WeightsMapper.class.getName() + ".numLabels"; + + private Vector weightsPerFeature; + private Vector weightsPerLabel; + + @Override + protected void setup(Context ctx) throws IOException, InterruptedException { + super.setup(ctx); + int numLabels = Integer.parseInt(ctx.getConfiguration().get(NUM_LABELS)); + Preconditions.checkArgument(numLabels > 0, "Wrong numLabels: " + numLabels + ". Must be > 0!"); + weightsPerLabel = new DenseVector(numLabels); + } + + @Override + protected void map(IntWritable index, VectorWritable value, Context ctx) throws IOException, InterruptedException { + Vector instance = value.get(); + if (weightsPerFeature == null) { + weightsPerFeature = new RandomAccessSparseVector(instance.size(), instance.getNumNondefaultElements()); + } + + int label = index.get(); + weightsPerFeature.assign(instance, Functions.PLUS); + weightsPerLabel.set(label, weightsPerLabel.get(label) + instance.zSum()); + } + + @Override + protected void cleanup(Context ctx) throws IOException, InterruptedException { + if (weightsPerFeature != null) { + ctx.write(new Text(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE), new VectorWritable(weightsPerFeature)); + ctx.write(new Text(TrainNaiveBayesJob.WEIGHTS_PER_LABEL), new VectorWritable(weightsPerLabel)); + } + super.cleanup(ctx); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/BaumWelchTrainer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/BaumWelchTrainer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/BaumWelchTrainer.java new file mode 100644 index 0000000..6d4e2b0 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/BaumWelchTrainer.java @@ -0,0 +1,161 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import java.io.DataOutputStream; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Date; +import java.util.List; +import java.util.Scanner; + +import org.apache.commons.cli2.CommandLine; +import org.apache.commons.cli2.Group; +import org.apache.commons.cli2.Option; +import org.apache.commons.cli2.OptionException; +import org.apache.commons.cli2.builder.ArgumentBuilder; +import org.apache.commons.cli2.builder.DefaultOptionBuilder; +import org.apache.commons.cli2.builder.GroupBuilder; +import org.apache.commons.cli2.commandline.Parser; +import org.apache.mahout.common.CommandLineUtil; +import org.apache.mahout.common.commandline.DefaultOptionCreator; + +/** + * A class for EM training of HMM from console + */ +public final class BaumWelchTrainer { + + private BaumWelchTrainer() { + } + + public static void main(String[] args) throws IOException { + DefaultOptionBuilder optionBuilder = new DefaultOptionBuilder(); + ArgumentBuilder argumentBuilder = new ArgumentBuilder(); + + Option inputOption = DefaultOptionCreator.inputOption().create(); + + Option outputOption = DefaultOptionCreator.outputOption().create(); + + Option stateNumberOption = optionBuilder.withLongName("nrOfHiddenStates"). + withDescription("Number of hidden states"). + withShortName("nh").withArgument(argumentBuilder.withMaximum(1).withMinimum(1). + withName("number").create()).withRequired(true).create(); + + Option observedStateNumberOption = optionBuilder.withLongName("nrOfObservedStates"). + withDescription("Number of observed states"). + withShortName("no").withArgument(argumentBuilder.withMaximum(1).withMinimum(1). + withName("number").create()).withRequired(true).create(); + + Option epsilonOption = optionBuilder.withLongName("epsilon"). + withDescription("Convergence threshold"). + withShortName("e").withArgument(argumentBuilder.withMaximum(1).withMinimum(1). + withName("number").create()).withRequired(true).create(); + + Option iterationsOption = optionBuilder.withLongName("max-iterations"). + withDescription("Maximum iterations number"). + withShortName("m").withArgument(argumentBuilder.withMaximum(1).withMinimum(1). + withName("number").create()).withRequired(true).create(); + + Group optionGroup = new GroupBuilder().withOption(inputOption). + withOption(outputOption).withOption(stateNumberOption).withOption(observedStateNumberOption). + withOption(epsilonOption).withOption(iterationsOption). + withName("Options").create(); + + try { + Parser parser = new Parser(); + parser.setGroup(optionGroup); + CommandLine commandLine = parser.parse(args); + + String input = (String) commandLine.getValue(inputOption); + String output = (String) commandLine.getValue(outputOption); + + int nrOfHiddenStates = Integer.parseInt((String) commandLine.getValue(stateNumberOption)); + int nrOfObservedStates = Integer.parseInt((String) commandLine.getValue(observedStateNumberOption)); + + double epsilon = Double.parseDouble((String) commandLine.getValue(epsilonOption)); + int maxIterations = Integer.parseInt((String) commandLine.getValue(iterationsOption)); + + //constructing random-generated HMM + HmmModel model = new HmmModel(nrOfHiddenStates, nrOfObservedStates, new Date().getTime()); + List<Integer> observations = new ArrayList<>(); + + //reading observations + try (Scanner scanner = new Scanner(new FileInputStream(input), "UTF-8")) { + while (scanner.hasNextInt()) { + observations.add(scanner.nextInt()); + } + } + + int[] observationsArray = new int[observations.size()]; + for (int i = 0; i < observations.size(); ++i) { + observationsArray[i] = observations.get(i); + } + + //training + HmmModel trainedModel = HmmTrainer.trainBaumWelch(model, + observationsArray, epsilon, maxIterations, true); + + //serializing trained model + try (DataOutputStream stream = new DataOutputStream(new FileOutputStream(output))){ + LossyHmmSerializer.serialize(trainedModel, stream); + } + + //printing tranied model + System.out.println("Initial probabilities: "); + for (int i = 0; i < trainedModel.getNrOfHiddenStates(); ++i) { + System.out.print(i + " "); + } + System.out.println(); + for (int i = 0; i < trainedModel.getNrOfHiddenStates(); ++i) { + System.out.print(trainedModel.getInitialProbabilities().get(i) + " "); + } + System.out.println(); + + System.out.println("Transition matrix:"); + System.out.print(" "); + for (int i = 0; i < trainedModel.getNrOfHiddenStates(); ++i) { + System.out.print(i + " "); + } + System.out.println(); + for (int i = 0; i < trainedModel.getNrOfHiddenStates(); ++i) { + System.out.print(i + " "); + for (int j = 0; j < trainedModel.getNrOfHiddenStates(); ++j) { + System.out.print(trainedModel.getTransitionMatrix().get(i, j) + " "); + } + System.out.println(); + } + System.out.println("Emission matrix: "); + System.out.print(" "); + for (int i = 0; i < trainedModel.getNrOfOutputStates(); ++i) { + System.out.print(i + " "); + } + System.out.println(); + for (int i = 0; i < trainedModel.getNrOfHiddenStates(); ++i) { + System.out.print(i + " "); + for (int j = 0; j < trainedModel.getNrOfOutputStates(); ++j) { + System.out.print(trainedModel.getEmissionMatrix().get(i, j) + " "); + } + System.out.println(); + } + } catch (OptionException e) { + CommandLineUtil.printHelp(optionGroup); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmAlgorithms.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmAlgorithms.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmAlgorithms.java new file mode 100644 index 0000000..c1d328e --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmAlgorithms.java @@ -0,0 +1,306 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; + +/** + * Class containing implementations of the three major HMM algorithms: forward, + * backward and Viterbi + */ +public final class HmmAlgorithms { + + + /** + * No public constructors for utility classes. + */ + private HmmAlgorithms() { + // nothing to do here really + } + + /** + * External function to compute a matrix of alpha factors + * + * @param model model to run forward algorithm for. + * @param observations observation sequence to train on. + * @param scaled Should log-scaled beta factors be computed? + * @return matrix of alpha factors. + */ + public static Matrix forwardAlgorithm(HmmModel model, int[] observations, boolean scaled) { + Matrix alpha = new DenseMatrix(observations.length, model.getNrOfHiddenStates()); + forwardAlgorithm(alpha, model, observations, scaled); + + return alpha; + } + + /** + * Internal function to compute the alpha factors + * + * @param alpha matrix to store alpha factors in. + * @param model model to use for alpha factor computation. + * @param observations observation sequence seen. + * @param scaled set to true if log-scaled beta factors should be computed. + */ + static void forwardAlgorithm(Matrix alpha, HmmModel model, int[] observations, boolean scaled) { + + // fetch references to the model parameters + Vector ip = model.getInitialProbabilities(); + Matrix b = model.getEmissionMatrix(); + Matrix a = model.getTransitionMatrix(); + + if (scaled) { // compute log scaled alpha values + // Initialization + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + alpha.setQuick(0, i, Math.log(ip.getQuick(i) * b.getQuick(i, observations[0]))); + } + + // Induction + for (int t = 1; t < observations.length; t++) { + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + double sum = Double.NEGATIVE_INFINITY; // log(0) + for (int j = 0; j < model.getNrOfHiddenStates(); j++) { + double tmp = alpha.getQuick(t - 1, j) + Math.log(a.getQuick(j, i)); + if (tmp > Double.NEGATIVE_INFINITY) { + // make sure we handle log(0) correctly + sum = tmp + Math.log1p(Math.exp(sum - tmp)); + } + } + alpha.setQuick(t, i, sum + Math.log(b.getQuick(i, observations[t]))); + } + } + } else { + + // Initialization + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + alpha.setQuick(0, i, ip.getQuick(i) * b.getQuick(i, observations[0])); + } + + // Induction + for (int t = 1; t < observations.length; t++) { + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + double sum = 0.0; + for (int j = 0; j < model.getNrOfHiddenStates(); j++) { + sum += alpha.getQuick(t - 1, j) * a.getQuick(j, i); + } + alpha.setQuick(t, i, sum * b.getQuick(i, observations[t])); + } + } + } + } + + /** + * External function to compute a matrix of beta factors + * + * @param model model to use for estimation. + * @param observations observation sequence seen. + * @param scaled Set to true if log-scaled beta factors should be computed. + * @return beta factors based on the model and observation sequence. + */ + public static Matrix backwardAlgorithm(HmmModel model, int[] observations, boolean scaled) { + // initialize the matrix + Matrix beta = new DenseMatrix(observations.length, model.getNrOfHiddenStates()); + // compute the beta factors + backwardAlgorithm(beta, model, observations, scaled); + + return beta; + } + + /** + * Internal function to compute the beta factors + * + * @param beta Matrix to store resulting factors in. + * @param model model to use for factor estimation. + * @param observations sequence of observations to estimate. + * @param scaled set to true to compute log-scaled parameters. + */ + static void backwardAlgorithm(Matrix beta, HmmModel model, int[] observations, boolean scaled) { + // fetch references to the model parameters + Matrix b = model.getEmissionMatrix(); + Matrix a = model.getTransitionMatrix(); + + if (scaled) { // compute log-scaled factors + // initialization + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + beta.setQuick(observations.length - 1, i, 0); + } + + // induction + for (int t = observations.length - 2; t >= 0; t--) { + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + double sum = Double.NEGATIVE_INFINITY; // log(0) + for (int j = 0; j < model.getNrOfHiddenStates(); j++) { + double tmp = beta.getQuick(t + 1, j) + Math.log(a.getQuick(i, j)) + + Math.log(b.getQuick(j, observations[t + 1])); + if (tmp > Double.NEGATIVE_INFINITY) { + // handle log(0) + sum = tmp + Math.log1p(Math.exp(sum - tmp)); + } + } + beta.setQuick(t, i, sum); + } + } + } else { + // initialization + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + beta.setQuick(observations.length - 1, i, 1); + } + // induction + for (int t = observations.length - 2; t >= 0; t--) { + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + double sum = 0; + for (int j = 0; j < model.getNrOfHiddenStates(); j++) { + sum += beta.getQuick(t + 1, j) * a.getQuick(i, j) * b.getQuick(j, observations[t + 1]); + } + beta.setQuick(t, i, sum); + } + } + } + } + + /** + * Viterbi algorithm to compute the most likely hidden sequence for a given + * model and observed sequence + * + * @param model HmmModel for which the Viterbi path should be computed + * @param observations Sequence of observations + * @param scaled Use log-scaled computations, this requires higher computational + * effort but is numerically more stable for large observation + * sequences + * @return nrOfObservations 1D int array containing the most likely hidden + * sequence + */ + public static int[] viterbiAlgorithm(HmmModel model, int[] observations, boolean scaled) { + + // probability that the most probable hidden states ends at state i at + // time t + double[][] delta = new double[observations.length][model + .getNrOfHiddenStates()]; + + // previous hidden state in the most probable state leading up to state + // i at time t + int[][] phi = new int[observations.length - 1][model.getNrOfHiddenStates()]; + + // initialize the return array + int[] sequence = new int[observations.length]; + + viterbiAlgorithm(sequence, delta, phi, model, observations, scaled); + + return sequence; + } + + /** + * Internal version of the viterbi algorithm, allowing to reuse existing + * arrays instead of allocating new ones + * + * @param sequence NrOfObservations 1D int array for storing the viterbi sequence + * @param delta NrOfObservations x NrHiddenStates 2D double array for storing the + * delta factors + * @param phi NrOfObservations-1 x NrHiddenStates 2D int array for storing the + * phi values + * @param model HmmModel for which the viterbi path should be computed + * @param observations Sequence of observations + * @param scaled Use log-scaled computations, this requires higher computational + * effort but is numerically more stable for large observation + * sequences + */ + static void viterbiAlgorithm(int[] sequence, double[][] delta, int[][] phi, HmmModel model, int[] observations, + boolean scaled) { + // fetch references to the model parameters + Vector ip = model.getInitialProbabilities(); + Matrix b = model.getEmissionMatrix(); + Matrix a = model.getTransitionMatrix(); + + // Initialization + if (scaled) { + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + delta[0][i] = Math.log(ip.getQuick(i) * b.getQuick(i, observations[0])); + } + } else { + + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + delta[0][i] = ip.getQuick(i) * b.getQuick(i, observations[0]); + } + } + + // Induction + // iterate over the time + if (scaled) { + for (int t = 1; t < observations.length; t++) { + // iterate over the hidden states + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + // find the maximum probability and most likely state + // leading up + // to this + int maxState = 0; + double maxProb = delta[t - 1][0] + Math.log(a.getQuick(0, i)); + for (int j = 1; j < model.getNrOfHiddenStates(); j++) { + double prob = delta[t - 1][j] + Math.log(a.getQuick(j, i)); + if (prob > maxProb) { + maxProb = prob; + maxState = j; + } + } + delta[t][i] = maxProb + Math.log(b.getQuick(i, observations[t])); + phi[t - 1][i] = maxState; + } + } + } else { + for (int t = 1; t < observations.length; t++) { + // iterate over the hidden states + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + // find the maximum probability and most likely state + // leading up + // to this + int maxState = 0; + double maxProb = delta[t - 1][0] * a.getQuick(0, i); + for (int j = 1; j < model.getNrOfHiddenStates(); j++) { + double prob = delta[t - 1][j] * a.getQuick(j, i); + if (prob > maxProb) { + maxProb = prob; + maxState = j; + } + } + delta[t][i] = maxProb * b.getQuick(i, observations[t]); + phi[t - 1][i] = maxState; + } + } + } + + // find the most likely end state for initialization + double maxProb; + if (scaled) { + maxProb = Double.NEGATIVE_INFINITY; + } else { + maxProb = 0.0; + } + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + if (delta[observations.length - 1][i] > maxProb) { + maxProb = delta[observations.length - 1][i]; + sequence[observations.length - 1] = i; + } + } + + // now backtrack to find the most likely hidden sequence + for (int t = observations.length - 2; t >= 0; t--) { + sequence[t] = phi[t][sequence[t + 1]]; + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmEvaluator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmEvaluator.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmEvaluator.java new file mode 100644 index 0000000..6e2def6 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmEvaluator.java @@ -0,0 +1,194 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import java.util.Random; + +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; + +/** + * The HMMEvaluator class offers several methods to evaluate an HMM Model. The + * following use-cases are covered: 1) Generate a sequence of output states from + * a given model (prediction). 2) Compute the likelihood that a given model + * generated a given sequence of output states (model likelihood). 3) Compute + * the most likely hidden sequence for a given model and a given observed + * sequence (decoding). + */ +public final class HmmEvaluator { + + /** + * No constructor for utility classes. + */ + private HmmEvaluator() {} + + /** + * Predict a sequence of steps output states for the given HMM model + * + * @param model The Hidden Markov model used to generate the output sequence + * @param steps Size of the generated output sequence + * @return integer array containing a sequence of steps output state IDs, + * generated by the specified model + */ + public static int[] predict(HmmModel model, int steps) { + return predict(model, steps, RandomUtils.getRandom()); + } + + /** + * Predict a sequence of steps output states for the given HMM model + * + * @param model The Hidden Markov model used to generate the output sequence + * @param steps Size of the generated output sequence + * @param seed seed to use for the RNG + * @return integer array containing a sequence of steps output state IDs, + * generated by the specified model + */ + public static int[] predict(HmmModel model, int steps, long seed) { + return predict(model, steps, RandomUtils.getRandom(seed)); + } + /** + * Predict a sequence of steps output states for the given HMM model using the + * given seed for probabilistic experiments + * + * @param model The Hidden Markov model used to generate the output sequence + * @param steps Size of the generated output sequence + * @param rand RNG to use + * @return integer array containing a sequence of steps output state IDs, + * generated by the specified model + */ + private static int[] predict(HmmModel model, int steps, Random rand) { + // fetch the cumulative distributions + Vector cip = HmmUtils.getCumulativeInitialProbabilities(model); + Matrix ctm = HmmUtils.getCumulativeTransitionMatrix(model); + Matrix com = HmmUtils.getCumulativeOutputMatrix(model); + // allocate the result IntArrayList + int[] result = new int[steps]; + // choose the initial state + int hiddenState = 0; + + double randnr = rand.nextDouble(); + while (cip.get(hiddenState) < randnr) { + hiddenState++; + } + + // now draw steps output states according to the cumulative + // distributions + for (int step = 0; step < steps; ++step) { + // choose output state to given hidden state + randnr = rand.nextDouble(); + int outputState = 0; + while (com.get(hiddenState, outputState) < randnr) { + outputState++; + } + result[step] = outputState; + // choose the next hidden state + randnr = rand.nextDouble(); + int nextHiddenState = 0; + while (ctm.get(hiddenState, nextHiddenState) < randnr) { + nextHiddenState++; + } + hiddenState = nextHiddenState; + } + return result; + } + + /** + * Returns the likelihood that a given output sequence was produced by the + * given model. Internally, this function calls the forward algorithm to + * compute the alpha values and then uses the overloaded function to compute + * the actual model likelihood. + * + * @param model Model to base the likelihood on. + * @param outputSequence Sequence to compute likelihood for. + * @param scaled Use log-scaled parameters for computation. This is computationally + * more expensive, but offers better numerically stability in case of + * long output sequences + * @return Likelihood that the given model produced the given sequence + */ + public static double modelLikelihood(HmmModel model, int[] outputSequence, boolean scaled) { + return modelLikelihood(HmmAlgorithms.forwardAlgorithm(model, outputSequence, scaled), scaled); + } + + /** + * Computes the likelihood that a given output sequence was computed by a + * given model using the alpha values computed by the forward algorithm. + * // TODO I am a bit confused here - where is the output sequence referenced in the comment above in the code? + * @param alpha Matrix of alpha values + * @param scaled Set to true if the alpha values are log-scaled. + * @return model likelihood. + */ + public static double modelLikelihood(Matrix alpha, boolean scaled) { + double likelihood = 0; + if (scaled) { + for (int i = 0; i < alpha.numCols(); ++i) { + likelihood += Math.exp(alpha.getQuick(alpha.numRows() - 1, i)); + } + } else { + for (int i = 0; i < alpha.numCols(); ++i) { + likelihood += alpha.getQuick(alpha.numRows() - 1, i); + } + } + return likelihood; + } + + /** + * Computes the likelihood that a given output sequence was computed by a + * given model. + * + * @param model model to compute sequence likelihood for. + * @param outputSequence sequence to base computation on. + * @param beta beta parameters. + * @param scaled set to true if betas are log-scaled. + * @return likelihood of the outputSequence given the model. + */ + public static double modelLikelihood(HmmModel model, int[] outputSequence, Matrix beta, boolean scaled) { + double likelihood = 0; + // fetch the emission probabilities + Matrix e = model.getEmissionMatrix(); + Vector pi = model.getInitialProbabilities(); + int firstOutput = outputSequence[0]; + if (scaled) { + for (int i = 0; i < model.getNrOfHiddenStates(); ++i) { + likelihood += pi.getQuick(i) * Math.exp(beta.getQuick(0, i)) * e.getQuick(i, firstOutput); + } + } else { + for (int i = 0; i < model.getNrOfHiddenStates(); ++i) { + likelihood += pi.getQuick(i) * beta.getQuick(0, i) * e.getQuick(i, firstOutput); + } + } + return likelihood; + } + + /** + * Returns the most likely sequence of hidden states for the given model and + * observation + * + * @param model model to use for decoding. + * @param observations integer Array containing a sequence of observed state IDs + * @param scaled Use log-scaled computations, this requires higher computational + * effort but is numerically more stable for large observation + * sequences + * @return integer array containing the most likely sequence of hidden state + * IDs + */ + public static int[] decode(HmmModel model, int[] observations, boolean scaled) { + return HmmAlgorithms.viterbiAlgorithm(model, observations, scaled); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmModel.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmModel.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmModel.java new file mode 100644 index 0000000..bc24884 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmModel.java @@ -0,0 +1,383 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import java.util.Map; +import java.util.Random; + +import com.google.common.collect.BiMap; +import com.google.common.collect.HashBiMap; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; + +/** + * Main class defining a Hidden Markov Model + */ +public class HmmModel implements Cloneable { + + /** Bi-directional Map for storing the observed state names */ + private BiMap<String,Integer> outputStateNames; + + /** Bi-Directional Map for storing the hidden state names */ + private BiMap<String,Integer> hiddenStateNames; + + /* Number of hidden states */ + private int nrOfHiddenStates; + + /** Number of output states */ + private int nrOfOutputStates; + + /** + * Transition matrix containing the transition probabilities between hidden + * states. TransitionMatrix(i,j) is the probability that we change from hidden + * state i to hidden state j In general: P(h(t+1)=h_j | h(t) = h_i) = + * transitionMatrix(i,j) Since we have to make sure that each hidden state can + * be "left", the following normalization condition has to hold: + * sum(transitionMatrix(i,j),j=1..hiddenStates) = 1 + */ + private Matrix transitionMatrix; + + /** + * Output matrix containing the probabilities that we observe a given output + * state given a hidden state. outputMatrix(i,j) is the probability that we + * observe output state j if we are in hidden state i Formally: P(o(t)=o_j | + * h(t)=h_i) = outputMatrix(i,j) Since we always have an observation for each + * hidden state, the following normalization condition has to hold: + * sum(outputMatrix(i,j),j=1..outputStates) = 1 + */ + private Matrix emissionMatrix; + + /** + * Vector containing the initial hidden state probabilities. That is + * P(h(0)=h_i) = initialProbabilities(i). Since we are dealing with + * probabilities the following normalization condition has to hold: + * sum(initialProbabilities(i),i=1..hiddenStates) = 1 + */ + private Vector initialProbabilities; + + + /** + * Get a copy of this model + */ + @Override + public HmmModel clone() { + HmmModel model = new HmmModel(transitionMatrix.clone(), emissionMatrix.clone(), initialProbabilities.clone()); + if (hiddenStateNames != null) { + model.hiddenStateNames = HashBiMap.create(hiddenStateNames); + } + if (outputStateNames != null) { + model.outputStateNames = HashBiMap.create(outputStateNames); + } + return model; + } + + /** + * Assign the content of another HMM model to this one + * + * @param model The HmmModel that will be assigned to this one + */ + public void assign(HmmModel model) { + this.nrOfHiddenStates = model.nrOfHiddenStates; + this.nrOfOutputStates = model.nrOfOutputStates; + this.hiddenStateNames = model.hiddenStateNames; + this.outputStateNames = model.outputStateNames; + // for now clone the matrix/vectors + this.initialProbabilities = model.initialProbabilities.clone(); + this.emissionMatrix = model.emissionMatrix.clone(); + this.transitionMatrix = model.transitionMatrix.clone(); + } + + /** + * Construct a valid random Hidden-Markov parameter set with the given number + * of hidden and output states using a given seed. + * + * @param nrOfHiddenStates Number of hidden states + * @param nrOfOutputStates Number of output states + * @param seed Seed for the random initialization, if set to 0 the current time + * is used + */ + public HmmModel(int nrOfHiddenStates, int nrOfOutputStates, long seed) { + this.nrOfHiddenStates = nrOfHiddenStates; + this.nrOfOutputStates = nrOfOutputStates; + this.transitionMatrix = new DenseMatrix(nrOfHiddenStates, nrOfHiddenStates); + this.emissionMatrix = new DenseMatrix(nrOfHiddenStates, nrOfOutputStates); + this.initialProbabilities = new DenseVector(nrOfHiddenStates); + // initialize a random, valid parameter set + initRandomParameters(seed); + } + + /** + * Construct a valid random Hidden-Markov parameter set with the given number + * of hidden and output states. + * + * @param nrOfHiddenStates Number of hidden states + * @param nrOfOutputStates Number of output states + */ + public HmmModel(int nrOfHiddenStates, int nrOfOutputStates) { + this(nrOfHiddenStates, nrOfOutputStates, 0); + } + + /** + * Generates a Hidden Markov model using the specified parameters + * + * @param transitionMatrix transition probabilities. + * @param emissionMatrix emission probabilities. + * @param initialProbabilities initial start probabilities. + * @throws IllegalArgumentException If the given parameter set is invalid + */ + public HmmModel(Matrix transitionMatrix, Matrix emissionMatrix, Vector initialProbabilities) { + this.nrOfHiddenStates = initialProbabilities.size(); + this.nrOfOutputStates = emissionMatrix.numCols(); + this.transitionMatrix = transitionMatrix; + this.emissionMatrix = emissionMatrix; + this.initialProbabilities = initialProbabilities; + } + + /** + * Initialize a valid random set of HMM parameters + * + * @param seed seed to use for Random initialization. Use 0 to use Java-built-in-version. + */ + private void initRandomParameters(long seed) { + Random rand; + // initialize the random number generator + if (seed == 0) { + rand = RandomUtils.getRandom(); + } else { + rand = RandomUtils.getRandom(seed); + } + // initialize the initial Probabilities + double sum = 0; // used for normalization + for (int i = 0; i < nrOfHiddenStates; i++) { + double nextRand = rand.nextDouble(); + initialProbabilities.set(i, nextRand); + sum += nextRand; + } + // "normalize" the vector to generate probabilities + initialProbabilities = initialProbabilities.divide(sum); + + // initialize the transition matrix + double[] values = new double[nrOfHiddenStates]; + for (int i = 0; i < nrOfHiddenStates; i++) { + sum = 0; + for (int j = 0; j < nrOfHiddenStates; j++) { + values[j] = rand.nextDouble(); + sum += values[j]; + } + // normalize the random values to obtain probabilities + for (int j = 0; j < nrOfHiddenStates; j++) { + values[j] /= sum; + } + // set this row of the transition matrix + transitionMatrix.set(i, values); + } + + // initialize the output matrix + values = new double[nrOfOutputStates]; + for (int i = 0; i < nrOfHiddenStates; i++) { + sum = 0; + for (int j = 0; j < nrOfOutputStates; j++) { + values[j] = rand.nextDouble(); + sum += values[j]; + } + // normalize the random values to obtain probabilities + for (int j = 0; j < nrOfOutputStates; j++) { + values[j] /= sum; + } + // set this row of the output matrix + emissionMatrix.set(i, values); + } + } + + /** + * Getter Method for the number of hidden states + * + * @return Number of hidden states + */ + public int getNrOfHiddenStates() { + return nrOfHiddenStates; + } + + /** + * Getter Method for the number of output states + * + * @return Number of output states + */ + public int getNrOfOutputStates() { + return nrOfOutputStates; + } + + /** + * Getter function to get the hidden state transition matrix + * + * @return returns the model's transition matrix. + */ + public Matrix getTransitionMatrix() { + return transitionMatrix; + } + + /** + * Getter function to get the output state probability matrix + * + * @return returns the models emission matrix. + */ + public Matrix getEmissionMatrix() { + return emissionMatrix; + } + + /** + * Getter function to return the vector of initial hidden state probabilities + * + * @return returns the model's init probabilities. + */ + public Vector getInitialProbabilities() { + return initialProbabilities; + } + + /** + * Getter method for the hidden state Names map + * + * @return hidden state names. + */ + public Map<String, Integer> getHiddenStateNames() { + return hiddenStateNames; + } + + /** + * Register an array of hidden state Names. We assume that the state name at + * position i has the ID i + * + * @param stateNames names of hidden states. + */ + public void registerHiddenStateNames(String[] stateNames) { + if (stateNames != null) { + hiddenStateNames = HashBiMap.create(); + for (int i = 0; i < stateNames.length; ++i) { + hiddenStateNames.put(stateNames[i], i); + } + } + } + + /** + * Register a map of hidden state Names/state IDs + * + * @param stateNames <String,Integer> Map that assigns each state name an integer ID + */ + public void registerHiddenStateNames(Map<String, Integer> stateNames) { + if (stateNames != null) { + hiddenStateNames = HashBiMap.create(stateNames); + } + } + + /** + * Lookup the name for the given hidden state ID + * + * @param id Integer id of the hidden state + * @return String containing the name for the given ID, null if this ID is not + * known or no hidden state names were specified + */ + public String getHiddenStateName(int id) { + if (hiddenStateNames == null) { + return null; + } + return hiddenStateNames.inverse().get(id); + } + + /** + * Lookup the ID for the given hidden state name + * + * @param name Name of the hidden state + * @return int containing the ID for the given name, -1 if this name is not + * known or no hidden state names were specified + */ + public int getHiddenStateID(String name) { + if (hiddenStateNames == null) { + return -1; + } + Integer tmp = hiddenStateNames.get(name); + return tmp == null ? -1 : tmp; + } + + /** + * Getter method for the output state Names map + * + * @return names of output states. + */ + public Map<String, Integer> getOutputStateNames() { + return outputStateNames; + } + + /** + * Register an array of hidden state Names. We assume that the state name at + * position i has the ID i + * + * @param stateNames state names to register. + */ + public void registerOutputStateNames(String[] stateNames) { + if (stateNames != null) { + outputStateNames = HashBiMap.create(); + for (int i = 0; i < stateNames.length; ++i) { + outputStateNames.put(stateNames[i], i); + } + } + } + + /** + * Register a map of hidden state Names/state IDs + * + * @param stateNames <String,Integer> Map that assigns each state name an integer ID + */ + public void registerOutputStateNames(Map<String, Integer> stateNames) { + if (stateNames != null) { + outputStateNames = HashBiMap.create(stateNames); + } + } + + /** + * Lookup the name for the given output state id + * + * @param id Integer id of the output state + * @return String containing the name for the given id, null if this id is not + * known or no output state names were specified + */ + public String getOutputStateName(int id) { + if (outputStateNames == null) { + return null; + } + return outputStateNames.inverse().get(id); + } + + /** + * Lookup the ID for the given output state name + * + * @param name Name of the output state + * @return int containing the ID for the given name, -1 if this name is not + * known or no output state names were specified + */ + public int getOutputStateID(String name) { + if (outputStateNames == null) { + return -1; + } + Integer tmp = outputStateNames.get(name); + return tmp == null ? -1 : tmp; + } + +}
