http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java new file mode 100644 index 0000000..ff2ea40 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java @@ -0,0 +1,61 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes.training; + +import java.io.IOException; +import java.util.Map; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.classifier.naivebayes.BayesUtils; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; + +public class ThetaMapper extends Mapper<IntWritable, VectorWritable, Text, VectorWritable> { + + public static final String ALPHA_I = ThetaMapper.class.getName() + ".alphaI"; + static final String TRAIN_COMPLEMENTARY = ThetaMapper.class.getName() + ".trainComplementary"; + + private ComplementaryThetaTrainer trainer; + + @Override + protected void setup(Context ctx) throws IOException, InterruptedException { + super.setup(ctx); + Configuration conf = ctx.getConfiguration(); + + float alphaI = conf.getFloat(ALPHA_I, 1.0f); + Map<String, Vector> scores = BayesUtils.readScoresFromCache(conf); + + trainer = new ComplementaryThetaTrainer(scores.get(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE), + scores.get(TrainNaiveBayesJob.WEIGHTS_PER_LABEL), alphaI); + } + + @Override + protected void map(IntWritable key, VectorWritable value, Context ctx) throws IOException, InterruptedException { + trainer.train(key.get(), value.get()); + } + + @Override + protected void cleanup(Context ctx) throws IOException, InterruptedException { + ctx.write(new Text(TrainNaiveBayesJob.LABEL_THETA_NORMALIZER), + new VectorWritable(trainer.retrievePerLabelThetaNormalizer())); + super.cleanup(ctx); + } +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java new file mode 100644 index 0000000..cd18d28 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java @@ -0,0 +1,177 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes.training; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.classifier.naivebayes.BayesUtils; +import org.apache.mahout.classifier.naivebayes.NaiveBayesModel; +import org.apache.mahout.common.AbstractJob; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.commandline.DefaultOptionCreator; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable; +import org.apache.mahout.common.mapreduce.VectorSumReducer; +import org.apache.mahout.math.VectorWritable; + +import com.google.common.base.Splitter; + +/** Trains a Naive Bayes Classifier (parameters for both Naive Bayes and Complementary Naive Bayes) */ +public final class TrainNaiveBayesJob extends AbstractJob { + private static final String TRAIN_COMPLEMENTARY = "trainComplementary"; + private static final String ALPHA_I = "alphaI"; + private static final String LABEL_INDEX = "labelIndex"; + public static final String WEIGHTS_PER_FEATURE = "__SPF"; + public static final String WEIGHTS_PER_LABEL = "__SPL"; + public static final String LABEL_THETA_NORMALIZER = "_LTN"; + public static final String SUMMED_OBSERVATIONS = "summedObservations"; + public static final String WEIGHTS = "weights"; + public static final String THETAS = "thetas"; + + public static void main(String[] args) throws Exception { + ToolRunner.run(new Configuration(), new TrainNaiveBayesJob(), args); + } + + @Override + public int run(String[] args) throws Exception { + + addInputOption(); + addOutputOption(); + + addOption(ALPHA_I, "a", "smoothing parameter", String.valueOf(1.0f)); + addOption(buildOption(TRAIN_COMPLEMENTARY, "c", "train complementary?", false, false, String.valueOf(false))); + addOption(LABEL_INDEX, "li", "The path to store the label index in", false); + addOption(DefaultOptionCreator.overwriteOption().create()); + + Map<String, List<String>> parsedArgs = parseArguments(args); + if (parsedArgs == null) { + return -1; + } + if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) { + HadoopUtil.delete(getConf(), getOutputPath()); + HadoopUtil.delete(getConf(), getTempPath()); + } + Path labPath; + String labPathStr = getOption(LABEL_INDEX); + if (labPathStr != null) { + labPath = new Path(labPathStr); + } else { + labPath = getTempPath(LABEL_INDEX); + } + long labelSize = createLabelIndex(labPath); + float alphaI = Float.parseFloat(getOption(ALPHA_I)); + boolean trainComplementary = hasOption(TRAIN_COMPLEMENTARY); + + HadoopUtil.setSerializations(getConf()); + HadoopUtil.cacheFiles(labPath, getConf()); + + // Add up all the vectors with the same labels, while mapping the labels into our index + Job indexInstances = prepareJob(getInputPath(), + getTempPath(SUMMED_OBSERVATIONS), + SequenceFileInputFormat.class, + IndexInstancesMapper.class, + IntWritable.class, + VectorWritable.class, + VectorSumReducer.class, + IntWritable.class, + VectorWritable.class, + SequenceFileOutputFormat.class); + indexInstances.setCombinerClass(VectorSumReducer.class); + boolean succeeded = indexInstances.waitForCompletion(true); + if (!succeeded) { + return -1; + } + // Sum up all the weights from the previous step, per label and per feature + Job weightSummer = prepareJob(getTempPath(SUMMED_OBSERVATIONS), + getTempPath(WEIGHTS), + SequenceFileInputFormat.class, + WeightsMapper.class, + Text.class, + VectorWritable.class, + VectorSumReducer.class, + Text.class, + VectorWritable.class, + SequenceFileOutputFormat.class); + weightSummer.getConfiguration().set(WeightsMapper.NUM_LABELS, String.valueOf(labelSize)); + weightSummer.setCombinerClass(VectorSumReducer.class); + succeeded = weightSummer.waitForCompletion(true); + if (!succeeded) { + return -1; + } + + // Put the per label and per feature vectors into the cache + HadoopUtil.cacheFiles(getTempPath(WEIGHTS), getConf()); + + if (trainComplementary){ + // Calculate the per label theta normalizers, write out to LABEL_THETA_NORMALIZER vector + // see http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - Section 3.2, Weight Magnitude Errors + Job thetaSummer = prepareJob(getTempPath(SUMMED_OBSERVATIONS), + getTempPath(THETAS), + SequenceFileInputFormat.class, + ThetaMapper.class, + Text.class, + VectorWritable.class, + VectorSumReducer.class, + Text.class, + VectorWritable.class, + SequenceFileOutputFormat.class); + thetaSummer.setCombinerClass(VectorSumReducer.class); + thetaSummer.getConfiguration().setFloat(ThetaMapper.ALPHA_I, alphaI); + thetaSummer.getConfiguration().setBoolean(ThetaMapper.TRAIN_COMPLEMENTARY, trainComplementary); + succeeded = thetaSummer.waitForCompletion(true); + if (!succeeded) { + return -1; + } + } + + // Put the per label theta normalizers into the cache + HadoopUtil.cacheFiles(getTempPath(THETAS), getConf()); + + // Validate our model and then write it out to the official output + getConf().setFloat(ThetaMapper.ALPHA_I, alphaI); + getConf().setBoolean(NaiveBayesModel.COMPLEMENTARY_MODEL, trainComplementary); + NaiveBayesModel naiveBayesModel = BayesUtils.readModelFromDir(getTempPath(), getConf()); + naiveBayesModel.validate(); + naiveBayesModel.serialize(getOutputPath(), getConf()); + + return 0; + } + + private long createLabelIndex(Path labPath) throws IOException { + long labelSize = 0; + Iterable<Pair<Text,IntWritable>> iterable = + new SequenceFileDirIterable<>(getInputPath(), + PathType.LIST, + PathFilters.logsCRCFilter(), + getConf()); + labelSize = BayesUtils.writeLabelIndex(getConf(), labPath, iterable); + return labelSize; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java new file mode 100644 index 0000000..5563057 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java @@ -0,0 +1,68 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes.training; + +import java.io.IOException; + +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.RandomAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.function.Functions; + +import com.google.common.base.Preconditions; + +public class WeightsMapper extends Mapper<IntWritable, VectorWritable, Text, VectorWritable> { + + static final String NUM_LABELS = WeightsMapper.class.getName() + ".numLabels"; + + private Vector weightsPerFeature; + private Vector weightsPerLabel; + + @Override + protected void setup(Context ctx) throws IOException, InterruptedException { + super.setup(ctx); + int numLabels = Integer.parseInt(ctx.getConfiguration().get(NUM_LABELS)); + Preconditions.checkArgument(numLabels > 0, "Wrong numLabels: " + numLabels + ". Must be > 0!"); + weightsPerLabel = new DenseVector(numLabels); + } + + @Override + protected void map(IntWritable index, VectorWritable value, Context ctx) throws IOException, InterruptedException { + Vector instance = value.get(); + if (weightsPerFeature == null) { + weightsPerFeature = new RandomAccessSparseVector(instance.size(), instance.getNumNondefaultElements()); + } + + int label = index.get(); + weightsPerFeature.assign(instance, Functions.PLUS); + weightsPerLabel.set(label, weightsPerLabel.get(label) + instance.zSum()); + } + + @Override + protected void cleanup(Context ctx) throws IOException, InterruptedException { + if (weightsPerFeature != null) { + ctx.write(new Text(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE), new VectorWritable(weightsPerFeature)); + ctx.write(new Text(TrainNaiveBayesJob.WEIGHTS_PER_LABEL), new VectorWritable(weightsPerLabel)); + } + super.cleanup(ctx); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/BaumWelchTrainer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/BaumWelchTrainer.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/BaumWelchTrainer.java new file mode 100644 index 0000000..6d4e2b0 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/BaumWelchTrainer.java @@ -0,0 +1,161 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import java.io.DataOutputStream; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Date; +import java.util.List; +import java.util.Scanner; + +import org.apache.commons.cli2.CommandLine; +import org.apache.commons.cli2.Group; +import org.apache.commons.cli2.Option; +import org.apache.commons.cli2.OptionException; +import org.apache.commons.cli2.builder.ArgumentBuilder; +import org.apache.commons.cli2.builder.DefaultOptionBuilder; +import org.apache.commons.cli2.builder.GroupBuilder; +import org.apache.commons.cli2.commandline.Parser; +import org.apache.mahout.common.CommandLineUtil; +import org.apache.mahout.common.commandline.DefaultOptionCreator; + +/** + * A class for EM training of HMM from console + */ +public final class BaumWelchTrainer { + + private BaumWelchTrainer() { + } + + public static void main(String[] args) throws IOException { + DefaultOptionBuilder optionBuilder = new DefaultOptionBuilder(); + ArgumentBuilder argumentBuilder = new ArgumentBuilder(); + + Option inputOption = DefaultOptionCreator.inputOption().create(); + + Option outputOption = DefaultOptionCreator.outputOption().create(); + + Option stateNumberOption = optionBuilder.withLongName("nrOfHiddenStates"). + withDescription("Number of hidden states"). + withShortName("nh").withArgument(argumentBuilder.withMaximum(1).withMinimum(1). + withName("number").create()).withRequired(true).create(); + + Option observedStateNumberOption = optionBuilder.withLongName("nrOfObservedStates"). + withDescription("Number of observed states"). + withShortName("no").withArgument(argumentBuilder.withMaximum(1).withMinimum(1). + withName("number").create()).withRequired(true).create(); + + Option epsilonOption = optionBuilder.withLongName("epsilon"). + withDescription("Convergence threshold"). + withShortName("e").withArgument(argumentBuilder.withMaximum(1).withMinimum(1). + withName("number").create()).withRequired(true).create(); + + Option iterationsOption = optionBuilder.withLongName("max-iterations"). + withDescription("Maximum iterations number"). + withShortName("m").withArgument(argumentBuilder.withMaximum(1).withMinimum(1). + withName("number").create()).withRequired(true).create(); + + Group optionGroup = new GroupBuilder().withOption(inputOption). + withOption(outputOption).withOption(stateNumberOption).withOption(observedStateNumberOption). + withOption(epsilonOption).withOption(iterationsOption). + withName("Options").create(); + + try { + Parser parser = new Parser(); + parser.setGroup(optionGroup); + CommandLine commandLine = parser.parse(args); + + String input = (String) commandLine.getValue(inputOption); + String output = (String) commandLine.getValue(outputOption); + + int nrOfHiddenStates = Integer.parseInt((String) commandLine.getValue(stateNumberOption)); + int nrOfObservedStates = Integer.parseInt((String) commandLine.getValue(observedStateNumberOption)); + + double epsilon = Double.parseDouble((String) commandLine.getValue(epsilonOption)); + int maxIterations = Integer.parseInt((String) commandLine.getValue(iterationsOption)); + + //constructing random-generated HMM + HmmModel model = new HmmModel(nrOfHiddenStates, nrOfObservedStates, new Date().getTime()); + List<Integer> observations = new ArrayList<>(); + + //reading observations + try (Scanner scanner = new Scanner(new FileInputStream(input), "UTF-8")) { + while (scanner.hasNextInt()) { + observations.add(scanner.nextInt()); + } + } + + int[] observationsArray = new int[observations.size()]; + for (int i = 0; i < observations.size(); ++i) { + observationsArray[i] = observations.get(i); + } + + //training + HmmModel trainedModel = HmmTrainer.trainBaumWelch(model, + observationsArray, epsilon, maxIterations, true); + + //serializing trained model + try (DataOutputStream stream = new DataOutputStream(new FileOutputStream(output))){ + LossyHmmSerializer.serialize(trainedModel, stream); + } + + //printing tranied model + System.out.println("Initial probabilities: "); + for (int i = 0; i < trainedModel.getNrOfHiddenStates(); ++i) { + System.out.print(i + " "); + } + System.out.println(); + for (int i = 0; i < trainedModel.getNrOfHiddenStates(); ++i) { + System.out.print(trainedModel.getInitialProbabilities().get(i) + " "); + } + System.out.println(); + + System.out.println("Transition matrix:"); + System.out.print(" "); + for (int i = 0; i < trainedModel.getNrOfHiddenStates(); ++i) { + System.out.print(i + " "); + } + System.out.println(); + for (int i = 0; i < trainedModel.getNrOfHiddenStates(); ++i) { + System.out.print(i + " "); + for (int j = 0; j < trainedModel.getNrOfHiddenStates(); ++j) { + System.out.print(trainedModel.getTransitionMatrix().get(i, j) + " "); + } + System.out.println(); + } + System.out.println("Emission matrix: "); + System.out.print(" "); + for (int i = 0; i < trainedModel.getNrOfOutputStates(); ++i) { + System.out.print(i + " "); + } + System.out.println(); + for (int i = 0; i < trainedModel.getNrOfHiddenStates(); ++i) { + System.out.print(i + " "); + for (int j = 0; j < trainedModel.getNrOfOutputStates(); ++j) { + System.out.print(trainedModel.getEmissionMatrix().get(i, j) + " "); + } + System.out.println(); + } + } catch (OptionException e) { + CommandLineUtil.printHelp(optionGroup); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmAlgorithms.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmAlgorithms.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmAlgorithms.java new file mode 100644 index 0000000..c1d328e --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmAlgorithms.java @@ -0,0 +1,306 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; + +/** + * Class containing implementations of the three major HMM algorithms: forward, + * backward and Viterbi + */ +public final class HmmAlgorithms { + + + /** + * No public constructors for utility classes. + */ + private HmmAlgorithms() { + // nothing to do here really + } + + /** + * External function to compute a matrix of alpha factors + * + * @param model model to run forward algorithm for. + * @param observations observation sequence to train on. + * @param scaled Should log-scaled beta factors be computed? + * @return matrix of alpha factors. + */ + public static Matrix forwardAlgorithm(HmmModel model, int[] observations, boolean scaled) { + Matrix alpha = new DenseMatrix(observations.length, model.getNrOfHiddenStates()); + forwardAlgorithm(alpha, model, observations, scaled); + + return alpha; + } + + /** + * Internal function to compute the alpha factors + * + * @param alpha matrix to store alpha factors in. + * @param model model to use for alpha factor computation. + * @param observations observation sequence seen. + * @param scaled set to true if log-scaled beta factors should be computed. + */ + static void forwardAlgorithm(Matrix alpha, HmmModel model, int[] observations, boolean scaled) { + + // fetch references to the model parameters + Vector ip = model.getInitialProbabilities(); + Matrix b = model.getEmissionMatrix(); + Matrix a = model.getTransitionMatrix(); + + if (scaled) { // compute log scaled alpha values + // Initialization + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + alpha.setQuick(0, i, Math.log(ip.getQuick(i) * b.getQuick(i, observations[0]))); + } + + // Induction + for (int t = 1; t < observations.length; t++) { + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + double sum = Double.NEGATIVE_INFINITY; // log(0) + for (int j = 0; j < model.getNrOfHiddenStates(); j++) { + double tmp = alpha.getQuick(t - 1, j) + Math.log(a.getQuick(j, i)); + if (tmp > Double.NEGATIVE_INFINITY) { + // make sure we handle log(0) correctly + sum = tmp + Math.log1p(Math.exp(sum - tmp)); + } + } + alpha.setQuick(t, i, sum + Math.log(b.getQuick(i, observations[t]))); + } + } + } else { + + // Initialization + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + alpha.setQuick(0, i, ip.getQuick(i) * b.getQuick(i, observations[0])); + } + + // Induction + for (int t = 1; t < observations.length; t++) { + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + double sum = 0.0; + for (int j = 0; j < model.getNrOfHiddenStates(); j++) { + sum += alpha.getQuick(t - 1, j) * a.getQuick(j, i); + } + alpha.setQuick(t, i, sum * b.getQuick(i, observations[t])); + } + } + } + } + + /** + * External function to compute a matrix of beta factors + * + * @param model model to use for estimation. + * @param observations observation sequence seen. + * @param scaled Set to true if log-scaled beta factors should be computed. + * @return beta factors based on the model and observation sequence. + */ + public static Matrix backwardAlgorithm(HmmModel model, int[] observations, boolean scaled) { + // initialize the matrix + Matrix beta = new DenseMatrix(observations.length, model.getNrOfHiddenStates()); + // compute the beta factors + backwardAlgorithm(beta, model, observations, scaled); + + return beta; + } + + /** + * Internal function to compute the beta factors + * + * @param beta Matrix to store resulting factors in. + * @param model model to use for factor estimation. + * @param observations sequence of observations to estimate. + * @param scaled set to true to compute log-scaled parameters. + */ + static void backwardAlgorithm(Matrix beta, HmmModel model, int[] observations, boolean scaled) { + // fetch references to the model parameters + Matrix b = model.getEmissionMatrix(); + Matrix a = model.getTransitionMatrix(); + + if (scaled) { // compute log-scaled factors + // initialization + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + beta.setQuick(observations.length - 1, i, 0); + } + + // induction + for (int t = observations.length - 2; t >= 0; t--) { + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + double sum = Double.NEGATIVE_INFINITY; // log(0) + for (int j = 0; j < model.getNrOfHiddenStates(); j++) { + double tmp = beta.getQuick(t + 1, j) + Math.log(a.getQuick(i, j)) + + Math.log(b.getQuick(j, observations[t + 1])); + if (tmp > Double.NEGATIVE_INFINITY) { + // handle log(0) + sum = tmp + Math.log1p(Math.exp(sum - tmp)); + } + } + beta.setQuick(t, i, sum); + } + } + } else { + // initialization + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + beta.setQuick(observations.length - 1, i, 1); + } + // induction + for (int t = observations.length - 2; t >= 0; t--) { + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + double sum = 0; + for (int j = 0; j < model.getNrOfHiddenStates(); j++) { + sum += beta.getQuick(t + 1, j) * a.getQuick(i, j) * b.getQuick(j, observations[t + 1]); + } + beta.setQuick(t, i, sum); + } + } + } + } + + /** + * Viterbi algorithm to compute the most likely hidden sequence for a given + * model and observed sequence + * + * @param model HmmModel for which the Viterbi path should be computed + * @param observations Sequence of observations + * @param scaled Use log-scaled computations, this requires higher computational + * effort but is numerically more stable for large observation + * sequences + * @return nrOfObservations 1D int array containing the most likely hidden + * sequence + */ + public static int[] viterbiAlgorithm(HmmModel model, int[] observations, boolean scaled) { + + // probability that the most probable hidden states ends at state i at + // time t + double[][] delta = new double[observations.length][model + .getNrOfHiddenStates()]; + + // previous hidden state in the most probable state leading up to state + // i at time t + int[][] phi = new int[observations.length - 1][model.getNrOfHiddenStates()]; + + // initialize the return array + int[] sequence = new int[observations.length]; + + viterbiAlgorithm(sequence, delta, phi, model, observations, scaled); + + return sequence; + } + + /** + * Internal version of the viterbi algorithm, allowing to reuse existing + * arrays instead of allocating new ones + * + * @param sequence NrOfObservations 1D int array for storing the viterbi sequence + * @param delta NrOfObservations x NrHiddenStates 2D double array for storing the + * delta factors + * @param phi NrOfObservations-1 x NrHiddenStates 2D int array for storing the + * phi values + * @param model HmmModel for which the viterbi path should be computed + * @param observations Sequence of observations + * @param scaled Use log-scaled computations, this requires higher computational + * effort but is numerically more stable for large observation + * sequences + */ + static void viterbiAlgorithm(int[] sequence, double[][] delta, int[][] phi, HmmModel model, int[] observations, + boolean scaled) { + // fetch references to the model parameters + Vector ip = model.getInitialProbabilities(); + Matrix b = model.getEmissionMatrix(); + Matrix a = model.getTransitionMatrix(); + + // Initialization + if (scaled) { + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + delta[0][i] = Math.log(ip.getQuick(i) * b.getQuick(i, observations[0])); + } + } else { + + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + delta[0][i] = ip.getQuick(i) * b.getQuick(i, observations[0]); + } + } + + // Induction + // iterate over the time + if (scaled) { + for (int t = 1; t < observations.length; t++) { + // iterate over the hidden states + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + // find the maximum probability and most likely state + // leading up + // to this + int maxState = 0; + double maxProb = delta[t - 1][0] + Math.log(a.getQuick(0, i)); + for (int j = 1; j < model.getNrOfHiddenStates(); j++) { + double prob = delta[t - 1][j] + Math.log(a.getQuick(j, i)); + if (prob > maxProb) { + maxProb = prob; + maxState = j; + } + } + delta[t][i] = maxProb + Math.log(b.getQuick(i, observations[t])); + phi[t - 1][i] = maxState; + } + } + } else { + for (int t = 1; t < observations.length; t++) { + // iterate over the hidden states + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + // find the maximum probability and most likely state + // leading up + // to this + int maxState = 0; + double maxProb = delta[t - 1][0] * a.getQuick(0, i); + for (int j = 1; j < model.getNrOfHiddenStates(); j++) { + double prob = delta[t - 1][j] * a.getQuick(j, i); + if (prob > maxProb) { + maxProb = prob; + maxState = j; + } + } + delta[t][i] = maxProb * b.getQuick(i, observations[t]); + phi[t - 1][i] = maxState; + } + } + } + + // find the most likely end state for initialization + double maxProb; + if (scaled) { + maxProb = Double.NEGATIVE_INFINITY; + } else { + maxProb = 0.0; + } + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + if (delta[observations.length - 1][i] > maxProb) { + maxProb = delta[observations.length - 1][i]; + sequence[observations.length - 1] = i; + } + } + + // now backtrack to find the most likely hidden sequence + for (int t = observations.length - 2; t >= 0; t--) { + sequence[t] = phi[t][sequence[t + 1]]; + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmEvaluator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmEvaluator.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmEvaluator.java new file mode 100644 index 0000000..6e2def6 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmEvaluator.java @@ -0,0 +1,194 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import java.util.Random; + +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; + +/** + * The HMMEvaluator class offers several methods to evaluate an HMM Model. The + * following use-cases are covered: 1) Generate a sequence of output states from + * a given model (prediction). 2) Compute the likelihood that a given model + * generated a given sequence of output states (model likelihood). 3) Compute + * the most likely hidden sequence for a given model and a given observed + * sequence (decoding). + */ +public final class HmmEvaluator { + + /** + * No constructor for utility classes. + */ + private HmmEvaluator() {} + + /** + * Predict a sequence of steps output states for the given HMM model + * + * @param model The Hidden Markov model used to generate the output sequence + * @param steps Size of the generated output sequence + * @return integer array containing a sequence of steps output state IDs, + * generated by the specified model + */ + public static int[] predict(HmmModel model, int steps) { + return predict(model, steps, RandomUtils.getRandom()); + } + + /** + * Predict a sequence of steps output states for the given HMM model + * + * @param model The Hidden Markov model used to generate the output sequence + * @param steps Size of the generated output sequence + * @param seed seed to use for the RNG + * @return integer array containing a sequence of steps output state IDs, + * generated by the specified model + */ + public static int[] predict(HmmModel model, int steps, long seed) { + return predict(model, steps, RandomUtils.getRandom(seed)); + } + /** + * Predict a sequence of steps output states for the given HMM model using the + * given seed for probabilistic experiments + * + * @param model The Hidden Markov model used to generate the output sequence + * @param steps Size of the generated output sequence + * @param rand RNG to use + * @return integer array containing a sequence of steps output state IDs, + * generated by the specified model + */ + private static int[] predict(HmmModel model, int steps, Random rand) { + // fetch the cumulative distributions + Vector cip = HmmUtils.getCumulativeInitialProbabilities(model); + Matrix ctm = HmmUtils.getCumulativeTransitionMatrix(model); + Matrix com = HmmUtils.getCumulativeOutputMatrix(model); + // allocate the result IntArrayList + int[] result = new int[steps]; + // choose the initial state + int hiddenState = 0; + + double randnr = rand.nextDouble(); + while (cip.get(hiddenState) < randnr) { + hiddenState++; + } + + // now draw steps output states according to the cumulative + // distributions + for (int step = 0; step < steps; ++step) { + // choose output state to given hidden state + randnr = rand.nextDouble(); + int outputState = 0; + while (com.get(hiddenState, outputState) < randnr) { + outputState++; + } + result[step] = outputState; + // choose the next hidden state + randnr = rand.nextDouble(); + int nextHiddenState = 0; + while (ctm.get(hiddenState, nextHiddenState) < randnr) { + nextHiddenState++; + } + hiddenState = nextHiddenState; + } + return result; + } + + /** + * Returns the likelihood that a given output sequence was produced by the + * given model. Internally, this function calls the forward algorithm to + * compute the alpha values and then uses the overloaded function to compute + * the actual model likelihood. + * + * @param model Model to base the likelihood on. + * @param outputSequence Sequence to compute likelihood for. + * @param scaled Use log-scaled parameters for computation. This is computationally + * more expensive, but offers better numerically stability in case of + * long output sequences + * @return Likelihood that the given model produced the given sequence + */ + public static double modelLikelihood(HmmModel model, int[] outputSequence, boolean scaled) { + return modelLikelihood(HmmAlgorithms.forwardAlgorithm(model, outputSequence, scaled), scaled); + } + + /** + * Computes the likelihood that a given output sequence was computed by a + * given model using the alpha values computed by the forward algorithm. + * // TODO I am a bit confused here - where is the output sequence referenced in the comment above in the code? + * @param alpha Matrix of alpha values + * @param scaled Set to true if the alpha values are log-scaled. + * @return model likelihood. + */ + public static double modelLikelihood(Matrix alpha, boolean scaled) { + double likelihood = 0; + if (scaled) { + for (int i = 0; i < alpha.numCols(); ++i) { + likelihood += Math.exp(alpha.getQuick(alpha.numRows() - 1, i)); + } + } else { + for (int i = 0; i < alpha.numCols(); ++i) { + likelihood += alpha.getQuick(alpha.numRows() - 1, i); + } + } + return likelihood; + } + + /** + * Computes the likelihood that a given output sequence was computed by a + * given model. + * + * @param model model to compute sequence likelihood for. + * @param outputSequence sequence to base computation on. + * @param beta beta parameters. + * @param scaled set to true if betas are log-scaled. + * @return likelihood of the outputSequence given the model. + */ + public static double modelLikelihood(HmmModel model, int[] outputSequence, Matrix beta, boolean scaled) { + double likelihood = 0; + // fetch the emission probabilities + Matrix e = model.getEmissionMatrix(); + Vector pi = model.getInitialProbabilities(); + int firstOutput = outputSequence[0]; + if (scaled) { + for (int i = 0; i < model.getNrOfHiddenStates(); ++i) { + likelihood += pi.getQuick(i) * Math.exp(beta.getQuick(0, i)) * e.getQuick(i, firstOutput); + } + } else { + for (int i = 0; i < model.getNrOfHiddenStates(); ++i) { + likelihood += pi.getQuick(i) * beta.getQuick(0, i) * e.getQuick(i, firstOutput); + } + } + return likelihood; + } + + /** + * Returns the most likely sequence of hidden states for the given model and + * observation + * + * @param model model to use for decoding. + * @param observations integer Array containing a sequence of observed state IDs + * @param scaled Use log-scaled computations, this requires higher computational + * effort but is numerically more stable for large observation + * sequences + * @return integer array containing the most likely sequence of hidden state + * IDs + */ + public static int[] decode(HmmModel model, int[] observations, boolean scaled) { + return HmmAlgorithms.viterbiAlgorithm(model, observations, scaled); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmModel.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmModel.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmModel.java new file mode 100644 index 0000000..bc24884 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmModel.java @@ -0,0 +1,383 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import java.util.Map; +import java.util.Random; + +import com.google.common.collect.BiMap; +import com.google.common.collect.HashBiMap; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; + +/** + * Main class defining a Hidden Markov Model + */ +public class HmmModel implements Cloneable { + + /** Bi-directional Map for storing the observed state names */ + private BiMap<String,Integer> outputStateNames; + + /** Bi-Directional Map for storing the hidden state names */ + private BiMap<String,Integer> hiddenStateNames; + + /* Number of hidden states */ + private int nrOfHiddenStates; + + /** Number of output states */ + private int nrOfOutputStates; + + /** + * Transition matrix containing the transition probabilities between hidden + * states. TransitionMatrix(i,j) is the probability that we change from hidden + * state i to hidden state j In general: P(h(t+1)=h_j | h(t) = h_i) = + * transitionMatrix(i,j) Since we have to make sure that each hidden state can + * be "left", the following normalization condition has to hold: + * sum(transitionMatrix(i,j),j=1..hiddenStates) = 1 + */ + private Matrix transitionMatrix; + + /** + * Output matrix containing the probabilities that we observe a given output + * state given a hidden state. outputMatrix(i,j) is the probability that we + * observe output state j if we are in hidden state i Formally: P(o(t)=o_j | + * h(t)=h_i) = outputMatrix(i,j) Since we always have an observation for each + * hidden state, the following normalization condition has to hold: + * sum(outputMatrix(i,j),j=1..outputStates) = 1 + */ + private Matrix emissionMatrix; + + /** + * Vector containing the initial hidden state probabilities. That is + * P(h(0)=h_i) = initialProbabilities(i). Since we are dealing with + * probabilities the following normalization condition has to hold: + * sum(initialProbabilities(i),i=1..hiddenStates) = 1 + */ + private Vector initialProbabilities; + + + /** + * Get a copy of this model + */ + @Override + public HmmModel clone() { + HmmModel model = new HmmModel(transitionMatrix.clone(), emissionMatrix.clone(), initialProbabilities.clone()); + if (hiddenStateNames != null) { + model.hiddenStateNames = HashBiMap.create(hiddenStateNames); + } + if (outputStateNames != null) { + model.outputStateNames = HashBiMap.create(outputStateNames); + } + return model; + } + + /** + * Assign the content of another HMM model to this one + * + * @param model The HmmModel that will be assigned to this one + */ + public void assign(HmmModel model) { + this.nrOfHiddenStates = model.nrOfHiddenStates; + this.nrOfOutputStates = model.nrOfOutputStates; + this.hiddenStateNames = model.hiddenStateNames; + this.outputStateNames = model.outputStateNames; + // for now clone the matrix/vectors + this.initialProbabilities = model.initialProbabilities.clone(); + this.emissionMatrix = model.emissionMatrix.clone(); + this.transitionMatrix = model.transitionMatrix.clone(); + } + + /** + * Construct a valid random Hidden-Markov parameter set with the given number + * of hidden and output states using a given seed. + * + * @param nrOfHiddenStates Number of hidden states + * @param nrOfOutputStates Number of output states + * @param seed Seed for the random initialization, if set to 0 the current time + * is used + */ + public HmmModel(int nrOfHiddenStates, int nrOfOutputStates, long seed) { + this.nrOfHiddenStates = nrOfHiddenStates; + this.nrOfOutputStates = nrOfOutputStates; + this.transitionMatrix = new DenseMatrix(nrOfHiddenStates, nrOfHiddenStates); + this.emissionMatrix = new DenseMatrix(nrOfHiddenStates, nrOfOutputStates); + this.initialProbabilities = new DenseVector(nrOfHiddenStates); + // initialize a random, valid parameter set + initRandomParameters(seed); + } + + /** + * Construct a valid random Hidden-Markov parameter set with the given number + * of hidden and output states. + * + * @param nrOfHiddenStates Number of hidden states + * @param nrOfOutputStates Number of output states + */ + public HmmModel(int nrOfHiddenStates, int nrOfOutputStates) { + this(nrOfHiddenStates, nrOfOutputStates, 0); + } + + /** + * Generates a Hidden Markov model using the specified parameters + * + * @param transitionMatrix transition probabilities. + * @param emissionMatrix emission probabilities. + * @param initialProbabilities initial start probabilities. + * @throws IllegalArgumentException If the given parameter set is invalid + */ + public HmmModel(Matrix transitionMatrix, Matrix emissionMatrix, Vector initialProbabilities) { + this.nrOfHiddenStates = initialProbabilities.size(); + this.nrOfOutputStates = emissionMatrix.numCols(); + this.transitionMatrix = transitionMatrix; + this.emissionMatrix = emissionMatrix; + this.initialProbabilities = initialProbabilities; + } + + /** + * Initialize a valid random set of HMM parameters + * + * @param seed seed to use for Random initialization. Use 0 to use Java-built-in-version. + */ + private void initRandomParameters(long seed) { + Random rand; + // initialize the random number generator + if (seed == 0) { + rand = RandomUtils.getRandom(); + } else { + rand = RandomUtils.getRandom(seed); + } + // initialize the initial Probabilities + double sum = 0; // used for normalization + for (int i = 0; i < nrOfHiddenStates; i++) { + double nextRand = rand.nextDouble(); + initialProbabilities.set(i, nextRand); + sum += nextRand; + } + // "normalize" the vector to generate probabilities + initialProbabilities = initialProbabilities.divide(sum); + + // initialize the transition matrix + double[] values = new double[nrOfHiddenStates]; + for (int i = 0; i < nrOfHiddenStates; i++) { + sum = 0; + for (int j = 0; j < nrOfHiddenStates; j++) { + values[j] = rand.nextDouble(); + sum += values[j]; + } + // normalize the random values to obtain probabilities + for (int j = 0; j < nrOfHiddenStates; j++) { + values[j] /= sum; + } + // set this row of the transition matrix + transitionMatrix.set(i, values); + } + + // initialize the output matrix + values = new double[nrOfOutputStates]; + for (int i = 0; i < nrOfHiddenStates; i++) { + sum = 0; + for (int j = 0; j < nrOfOutputStates; j++) { + values[j] = rand.nextDouble(); + sum += values[j]; + } + // normalize the random values to obtain probabilities + for (int j = 0; j < nrOfOutputStates; j++) { + values[j] /= sum; + } + // set this row of the output matrix + emissionMatrix.set(i, values); + } + } + + /** + * Getter Method for the number of hidden states + * + * @return Number of hidden states + */ + public int getNrOfHiddenStates() { + return nrOfHiddenStates; + } + + /** + * Getter Method for the number of output states + * + * @return Number of output states + */ + public int getNrOfOutputStates() { + return nrOfOutputStates; + } + + /** + * Getter function to get the hidden state transition matrix + * + * @return returns the model's transition matrix. + */ + public Matrix getTransitionMatrix() { + return transitionMatrix; + } + + /** + * Getter function to get the output state probability matrix + * + * @return returns the models emission matrix. + */ + public Matrix getEmissionMatrix() { + return emissionMatrix; + } + + /** + * Getter function to return the vector of initial hidden state probabilities + * + * @return returns the model's init probabilities. + */ + public Vector getInitialProbabilities() { + return initialProbabilities; + } + + /** + * Getter method for the hidden state Names map + * + * @return hidden state names. + */ + public Map<String, Integer> getHiddenStateNames() { + return hiddenStateNames; + } + + /** + * Register an array of hidden state Names. We assume that the state name at + * position i has the ID i + * + * @param stateNames names of hidden states. + */ + public void registerHiddenStateNames(String[] stateNames) { + if (stateNames != null) { + hiddenStateNames = HashBiMap.create(); + for (int i = 0; i < stateNames.length; ++i) { + hiddenStateNames.put(stateNames[i], i); + } + } + } + + /** + * Register a map of hidden state Names/state IDs + * + * @param stateNames <String,Integer> Map that assigns each state name an integer ID + */ + public void registerHiddenStateNames(Map<String, Integer> stateNames) { + if (stateNames != null) { + hiddenStateNames = HashBiMap.create(stateNames); + } + } + + /** + * Lookup the name for the given hidden state ID + * + * @param id Integer id of the hidden state + * @return String containing the name for the given ID, null if this ID is not + * known or no hidden state names were specified + */ + public String getHiddenStateName(int id) { + if (hiddenStateNames == null) { + return null; + } + return hiddenStateNames.inverse().get(id); + } + + /** + * Lookup the ID for the given hidden state name + * + * @param name Name of the hidden state + * @return int containing the ID for the given name, -1 if this name is not + * known or no hidden state names were specified + */ + public int getHiddenStateID(String name) { + if (hiddenStateNames == null) { + return -1; + } + Integer tmp = hiddenStateNames.get(name); + return tmp == null ? -1 : tmp; + } + + /** + * Getter method for the output state Names map + * + * @return names of output states. + */ + public Map<String, Integer> getOutputStateNames() { + return outputStateNames; + } + + /** + * Register an array of hidden state Names. We assume that the state name at + * position i has the ID i + * + * @param stateNames state names to register. + */ + public void registerOutputStateNames(String[] stateNames) { + if (stateNames != null) { + outputStateNames = HashBiMap.create(); + for (int i = 0; i < stateNames.length; ++i) { + outputStateNames.put(stateNames[i], i); + } + } + } + + /** + * Register a map of hidden state Names/state IDs + * + * @param stateNames <String,Integer> Map that assigns each state name an integer ID + */ + public void registerOutputStateNames(Map<String, Integer> stateNames) { + if (stateNames != null) { + outputStateNames = HashBiMap.create(stateNames); + } + } + + /** + * Lookup the name for the given output state id + * + * @param id Integer id of the output state + * @return String containing the name for the given id, null if this id is not + * known or no output state names were specified + */ + public String getOutputStateName(int id) { + if (outputStateNames == null) { + return null; + } + return outputStateNames.inverse().get(id); + } + + /** + * Lookup the ID for the given output state name + * + * @param name Name of the output state + * @return int containing the ID for the given name, -1 if this name is not + * known or no output state names were specified + */ + public int getOutputStateID(String name) { + if (outputStateNames == null) { + return -1; + } + Integer tmp = outputStateNames.get(name); + return tmp == null ? -1 : tmp; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmTrainer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmTrainer.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmTrainer.java new file mode 100644 index 0000000..a1cd3e0 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmTrainer.java @@ -0,0 +1,488 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import java.util.Collection; +import java.util.Iterator; + +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; + +/** + * Class containing several algorithms used to train a Hidden Markov Model. The + * three main algorithms are: supervised learning, unsupervised Viterbi and + * unsupervised Baum-Welch. + */ +public final class HmmTrainer { + + /** + * No public constructor for utility classes. + */ + private HmmTrainer() { + // nothing to do here really. + } + + /** + * Create an supervised initial estimate of an HMM Model based on a sequence + * of observed and hidden states. + * + * @param nrOfHiddenStates The total number of hidden states + * @param nrOfOutputStates The total number of output states + * @param observedSequence Integer array containing the observed sequence + * @param hiddenSequence Integer array containing the hidden sequence + * @param pseudoCount Value that is assigned to non-occurring transitions to avoid zero + * probabilities. + * @return An initial model using the estimated parameters + */ + public static HmmModel trainSupervised(int nrOfHiddenStates, int nrOfOutputStates, int[] observedSequence, + int[] hiddenSequence, double pseudoCount) { + // make sure the pseudo count is not zero + pseudoCount = pseudoCount == 0 ? Double.MIN_VALUE : pseudoCount; + + // initialize the parameters + DenseMatrix transitionMatrix = new DenseMatrix(nrOfHiddenStates, nrOfHiddenStates); + DenseMatrix emissionMatrix = new DenseMatrix(nrOfHiddenStates, nrOfOutputStates); + // assign a small initial probability that is larger than zero, so + // unseen states will not get a zero probability + transitionMatrix.assign(pseudoCount); + emissionMatrix.assign(pseudoCount); + // given no prior knowledge, we have to assume that all initial hidden + // states are equally likely + DenseVector initialProbabilities = new DenseVector(nrOfHiddenStates); + initialProbabilities.assign(1.0 / nrOfHiddenStates); + + // now loop over the sequences to count the number of transitions + countTransitions(transitionMatrix, emissionMatrix, observedSequence, + hiddenSequence); + + // make sure that probabilities are normalized + for (int i = 0; i < nrOfHiddenStates; i++) { + // compute sum of probabilities for current row of transition matrix + double sum = 0; + for (int j = 0; j < nrOfHiddenStates; j++) { + sum += transitionMatrix.getQuick(i, j); + } + // normalize current row of transition matrix + for (int j = 0; j < nrOfHiddenStates; j++) { + transitionMatrix.setQuick(i, j, transitionMatrix.getQuick(i, j) / sum); + } + // compute sum of probabilities for current row of emission matrix + sum = 0; + for (int j = 0; j < nrOfOutputStates; j++) { + sum += emissionMatrix.getQuick(i, j); + } + // normalize current row of emission matrix + for (int j = 0; j < nrOfOutputStates; j++) { + emissionMatrix.setQuick(i, j, emissionMatrix.getQuick(i, j) / sum); + } + } + + // return a new model using the parameter estimations + return new HmmModel(transitionMatrix, emissionMatrix, initialProbabilities); + } + + /** + * Function that counts the number of state->state and state->output + * transitions for the given observed/hidden sequence. + * + * @param transitionMatrix transition matrix to use. + * @param emissionMatrix emission matrix to use for counting. + * @param observedSequence observation sequence to use. + * @param hiddenSequence sequence of hidden states to use. + */ + private static void countTransitions(Matrix transitionMatrix, + Matrix emissionMatrix, int[] observedSequence, int[] hiddenSequence) { + emissionMatrix.setQuick(hiddenSequence[0], observedSequence[0], + emissionMatrix.getQuick(hiddenSequence[0], observedSequence[0]) + 1); + for (int i = 1; i < observedSequence.length; ++i) { + transitionMatrix + .setQuick(hiddenSequence[i - 1], hiddenSequence[i], transitionMatrix + .getQuick(hiddenSequence[i - 1], hiddenSequence[i]) + 1); + emissionMatrix.setQuick(hiddenSequence[i], observedSequence[i], + emissionMatrix.getQuick(hiddenSequence[i], observedSequence[i]) + 1); + } + } + + /** + * Create an supervised initial estimate of an HMM Model based on a number of + * sequences of observed and hidden states. + * + * @param nrOfHiddenStates The total number of hidden states + * @param nrOfOutputStates The total number of output states + * @param hiddenSequences Collection of hidden sequences to use for training + * @param observedSequences Collection of observed sequences to use for training associated with hidden sequences. + * @param pseudoCount Value that is assigned to non-occurring transitions to avoid zero + * probabilities. + * @return An initial model using the estimated parameters + */ + public static HmmModel trainSupervisedSequence(int nrOfHiddenStates, + int nrOfOutputStates, Collection<int[]> hiddenSequences, + Collection<int[]> observedSequences, double pseudoCount) { + + // make sure the pseudo count is not zero + pseudoCount = pseudoCount == 0 ? Double.MIN_VALUE : pseudoCount; + + // initialize parameters + DenseMatrix transitionMatrix = new DenseMatrix(nrOfHiddenStates, + nrOfHiddenStates); + DenseMatrix emissionMatrix = new DenseMatrix(nrOfHiddenStates, + nrOfOutputStates); + DenseVector initialProbabilities = new DenseVector(nrOfHiddenStates); + + // assign pseudo count to avoid zero probabilities + transitionMatrix.assign(pseudoCount); + emissionMatrix.assign(pseudoCount); + initialProbabilities.assign(pseudoCount); + + // now loop over the sequences to count the number of transitions + Iterator<int[]> hiddenSequenceIt = hiddenSequences.iterator(); + Iterator<int[]> observedSequenceIt = observedSequences.iterator(); + while (hiddenSequenceIt.hasNext() && observedSequenceIt.hasNext()) { + // fetch the current set of sequences + int[] hiddenSequence = hiddenSequenceIt.next(); + int[] observedSequence = observedSequenceIt.next(); + // increase the count for initial probabilities + initialProbabilities.setQuick(hiddenSequence[0], initialProbabilities + .getQuick(hiddenSequence[0]) + 1); + countTransitions(transitionMatrix, emissionMatrix, observedSequence, + hiddenSequence); + } + + // make sure that probabilities are normalized + double isum = 0; // sum of initial probabilities + for (int i = 0; i < nrOfHiddenStates; i++) { + isum += initialProbabilities.getQuick(i); + // compute sum of probabilities for current row of transition matrix + double sum = 0; + for (int j = 0; j < nrOfHiddenStates; j++) { + sum += transitionMatrix.getQuick(i, j); + } + // normalize current row of transition matrix + for (int j = 0; j < nrOfHiddenStates; j++) { + transitionMatrix.setQuick(i, j, transitionMatrix.getQuick(i, j) / sum); + } + // compute sum of probabilities for current row of emission matrix + sum = 0; + for (int j = 0; j < nrOfOutputStates; j++) { + sum += emissionMatrix.getQuick(i, j); + } + // normalize current row of emission matrix + for (int j = 0; j < nrOfOutputStates; j++) { + emissionMatrix.setQuick(i, j, emissionMatrix.getQuick(i, j) / sum); + } + } + // normalize the initial probabilities + for (int i = 0; i < nrOfHiddenStates; ++i) { + initialProbabilities.setQuick(i, initialProbabilities.getQuick(i) / isum); + } + + // return a new model using the parameter estimates + return new HmmModel(transitionMatrix, emissionMatrix, initialProbabilities); + } + + /** + * Iteratively train the parameters of the given initial model wrt to the + * observed sequence using Viterbi training. + * + * @param initialModel The initial model that gets iterated + * @param observedSequence The sequence of observed states + * @param pseudoCount Value that is assigned to non-occurring transitions to avoid zero + * probabilities. + * @param epsilon Convergence criteria + * @param maxIterations The maximum number of training iterations + * @param scaled Use Log-scaled implementation, this is computationally more + * expensive but offers better numerical stability for large observed + * sequences + * @return The iterated model + */ + public static HmmModel trainViterbi(HmmModel initialModel, + int[] observedSequence, double pseudoCount, double epsilon, + int maxIterations, boolean scaled) { + + // make sure the pseudo count is not zero + pseudoCount = pseudoCount == 0 ? Double.MIN_VALUE : pseudoCount; + + // allocate space for iteration models + HmmModel lastIteration = initialModel.clone(); + HmmModel iteration = initialModel.clone(); + + // allocate space for Viterbi path calculation + int[] viterbiPath = new int[observedSequence.length]; + int[][] phi = new int[observedSequence.length - 1][initialModel + .getNrOfHiddenStates()]; + double[][] delta = new double[observedSequence.length][initialModel + .getNrOfHiddenStates()]; + + // now run the Viterbi training iteration + for (int i = 0; i < maxIterations; ++i) { + // compute the Viterbi path + HmmAlgorithms.viterbiAlgorithm(viterbiPath, delta, phi, lastIteration, + observedSequence, scaled); + // Viterbi iteration uses the viterbi path to update + // the probabilities + Matrix emissionMatrix = iteration.getEmissionMatrix(); + Matrix transitionMatrix = iteration.getTransitionMatrix(); + + // first, assign the pseudo count + emissionMatrix.assign(pseudoCount); + transitionMatrix.assign(pseudoCount); + + // now count the transitions + countTransitions(transitionMatrix, emissionMatrix, observedSequence, + viterbiPath); + + // and normalize the probabilities + for (int j = 0; j < iteration.getNrOfHiddenStates(); ++j) { + double sum = 0; + // normalize the rows of the transition matrix + for (int k = 0; k < iteration.getNrOfHiddenStates(); ++k) { + sum += transitionMatrix.getQuick(j, k); + } + for (int k = 0; k < iteration.getNrOfHiddenStates(); ++k) { + transitionMatrix + .setQuick(j, k, transitionMatrix.getQuick(j, k) / sum); + } + // normalize the rows of the emission matrix + sum = 0; + for (int k = 0; k < iteration.getNrOfOutputStates(); ++k) { + sum += emissionMatrix.getQuick(j, k); + } + for (int k = 0; k < iteration.getNrOfOutputStates(); ++k) { + emissionMatrix.setQuick(j, k, emissionMatrix.getQuick(j, k) / sum); + } + } + // check for convergence + if (checkConvergence(lastIteration, iteration, epsilon)) { + break; + } + // overwrite the last iterated model by the new iteration + lastIteration.assign(iteration); + } + // we are done :) + return iteration; + } + + /** + * Iteratively train the parameters of the given initial model wrt the + * observed sequence using Baum-Welch training. + * + * @param initialModel The initial model that gets iterated + * @param observedSequence The sequence of observed states + * @param epsilon Convergence criteria + * @param maxIterations The maximum number of training iterations + * @param scaled Use log-scaled implementations of forward/backward algorithm. This + * is computationally more expensive, but offers better numerical + * stability for long output sequences. + * @return The iterated model + */ + public static HmmModel trainBaumWelch(HmmModel initialModel, + int[] observedSequence, double epsilon, int maxIterations, boolean scaled) { + // allocate space for the iterations + HmmModel lastIteration = initialModel.clone(); + HmmModel iteration = initialModel.clone(); + + // allocate space for baum-welch factors + int hiddenCount = initialModel.getNrOfHiddenStates(); + int visibleCount = observedSequence.length; + Matrix alpha = new DenseMatrix(visibleCount, hiddenCount); + Matrix beta = new DenseMatrix(visibleCount, hiddenCount); + + // now run the baum Welch training iteration + for (int it = 0; it < maxIterations; ++it) { + // fetch emission and transition matrix of current iteration + Vector initialProbabilities = iteration.getInitialProbabilities(); + Matrix emissionMatrix = iteration.getEmissionMatrix(); + Matrix transitionMatrix = iteration.getTransitionMatrix(); + + // compute forward and backward factors + HmmAlgorithms.forwardAlgorithm(alpha, iteration, observedSequence, scaled); + HmmAlgorithms.backwardAlgorithm(beta, iteration, observedSequence, scaled); + + if (scaled) { + logScaledBaumWelch(observedSequence, iteration, alpha, beta); + } else { + unscaledBaumWelch(observedSequence, iteration, alpha, beta); + } + // normalize transition/emission probabilities + // and normalize the probabilities + double isum = 0; + for (int j = 0; j < iteration.getNrOfHiddenStates(); ++j) { + double sum = 0; + // normalize the rows of the transition matrix + for (int k = 0; k < iteration.getNrOfHiddenStates(); ++k) { + sum += transitionMatrix.getQuick(j, k); + } + for (int k = 0; k < iteration.getNrOfHiddenStates(); ++k) { + transitionMatrix + .setQuick(j, k, transitionMatrix.getQuick(j, k) / sum); + } + // normalize the rows of the emission matrix + sum = 0; + for (int k = 0; k < iteration.getNrOfOutputStates(); ++k) { + sum += emissionMatrix.getQuick(j, k); + } + for (int k = 0; k < iteration.getNrOfOutputStates(); ++k) { + emissionMatrix.setQuick(j, k, emissionMatrix.getQuick(j, k) / sum); + } + // normalization parameter for initial probabilities + isum += initialProbabilities.getQuick(j); + } + // normalize initial probabilities + for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) { + initialProbabilities.setQuick(i, initialProbabilities.getQuick(i) + / isum); + } + // check for convergence + if (checkConvergence(lastIteration, iteration, epsilon)) { + break; + } + // overwrite the last iterated model by the new iteration + lastIteration.assign(iteration); + } + // we are done :) + return iteration; + } + + private static void unscaledBaumWelch(int[] observedSequence, HmmModel iteration, Matrix alpha, Matrix beta) { + Vector initialProbabilities = iteration.getInitialProbabilities(); + Matrix emissionMatrix = iteration.getEmissionMatrix(); + Matrix transitionMatrix = iteration.getTransitionMatrix(); + double modelLikelihood = HmmEvaluator.modelLikelihood(alpha, false); + + for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) { + initialProbabilities.setQuick(i, alpha.getQuick(0, i) + * beta.getQuick(0, i)); + } + + // recompute transition probabilities + for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) { + for (int j = 0; j < iteration.getNrOfHiddenStates(); ++j) { + double temp = 0; + for (int t = 0; t < observedSequence.length - 1; ++t) { + temp += alpha.getQuick(t, i) + * emissionMatrix.getQuick(j, observedSequence[t + 1]) + * beta.getQuick(t + 1, j); + } + transitionMatrix.setQuick(i, j, transitionMatrix.getQuick(i, j) + * temp / modelLikelihood); + } + } + // recompute emission probabilities + for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) { + for (int j = 0; j < iteration.getNrOfOutputStates(); ++j) { + double temp = 0; + for (int t = 0; t < observedSequence.length; ++t) { + // delta tensor + if (observedSequence[t] == j) { + temp += alpha.getQuick(t, i) * beta.getQuick(t, i); + } + } + emissionMatrix.setQuick(i, j, temp / modelLikelihood); + } + } + } + + private static void logScaledBaumWelch(int[] observedSequence, HmmModel iteration, Matrix alpha, Matrix beta) { + Vector initialProbabilities = iteration.getInitialProbabilities(); + Matrix emissionMatrix = iteration.getEmissionMatrix(); + Matrix transitionMatrix = iteration.getTransitionMatrix(); + double modelLikelihood = HmmEvaluator.modelLikelihood(alpha, true); + + for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) { + initialProbabilities.setQuick(i, Math.exp(alpha.getQuick(0, i) + beta.getQuick(0, i))); + } + + // recompute transition probabilities + for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) { + for (int j = 0; j < iteration.getNrOfHiddenStates(); ++j) { + double sum = Double.NEGATIVE_INFINITY; // log(0) + for (int t = 0; t < observedSequence.length - 1; ++t) { + double temp = alpha.getQuick(t, i) + + Math.log(emissionMatrix.getQuick(j, observedSequence[t + 1])) + + beta.getQuick(t + 1, j); + if (temp > Double.NEGATIVE_INFINITY) { + // handle 0-probabilities + sum = temp + Math.log1p(Math.exp(sum - temp)); + } + } + transitionMatrix.setQuick(i, j, transitionMatrix.getQuick(i, j) + * Math.exp(sum - modelLikelihood)); + } + } + // recompute emission probabilities + for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) { + for (int j = 0; j < iteration.getNrOfOutputStates(); ++j) { + double sum = Double.NEGATIVE_INFINITY; // log(0) + for (int t = 0; t < observedSequence.length; ++t) { + // delta tensor + if (observedSequence[t] == j) { + double temp = alpha.getQuick(t, i) + beta.getQuick(t, i); + if (temp > Double.NEGATIVE_INFINITY) { + // handle 0-probabilities + sum = temp + Math.log1p(Math.exp(sum - temp)); + } + } + } + emissionMatrix.setQuick(i, j, Math.exp(sum - modelLikelihood)); + } + } + } + + /** + * Check convergence of two HMM models by computing a simple distance between + * emission / transition matrices + * + * @param oldModel Old HMM Model + * @param newModel New HMM Model + * @param epsilon Convergence Factor + * @return true if training converged to a stable state. + */ + private static boolean checkConvergence(HmmModel oldModel, HmmModel newModel, + double epsilon) { + // check convergence of transitionProbabilities + Matrix oldTransitionMatrix = oldModel.getTransitionMatrix(); + Matrix newTransitionMatrix = newModel.getTransitionMatrix(); + double diff = 0; + for (int i = 0; i < oldModel.getNrOfHiddenStates(); ++i) { + for (int j = 0; j < oldModel.getNrOfHiddenStates(); ++j) { + double tmp = oldTransitionMatrix.getQuick(i, j) + - newTransitionMatrix.getQuick(i, j); + diff += tmp * tmp; + } + } + double norm = Math.sqrt(diff); + diff = 0; + // check convergence of emissionProbabilities + Matrix oldEmissionMatrix = oldModel.getEmissionMatrix(); + Matrix newEmissionMatrix = newModel.getEmissionMatrix(); + for (int i = 0; i < oldModel.getNrOfHiddenStates(); i++) { + for (int j = 0; j < oldModel.getNrOfOutputStates(); j++) { + + double tmp = oldEmissionMatrix.getQuick(i, j) + - newEmissionMatrix.getQuick(i, j); + diff += tmp * tmp; + } + } + norm += Math.sqrt(diff); + // iteration has converged :) + return norm < epsilon; + } + +}
