http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmAlgorithms.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmAlgorithms.java b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmAlgorithms.java new file mode 100644 index 0000000..c1d328e --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmAlgorithms.java @@ -0,0 +1,306 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; + +/** + * Class containing implementations of the three major HMM algorithms: forward, + * backward and Viterbi + */ +public final class HmmAlgorithms { + + + /** + * No public constructors for utility classes. + */ + private HmmAlgorithms() { + // nothing to do here really + } + + /** + * External function to compute a matrix of alpha factors + * + * @param model model to run forward algorithm for. + * @param observations observation sequence to train on. + * @param scaled Should log-scaled beta factors be computed? + * @return matrix of alpha factors. + */ + public static Matrix forwardAlgorithm(HmmModel model, int[] observations, boolean scaled) { + Matrix alpha = new DenseMatrix(observations.length, model.getNrOfHiddenStates()); + forwardAlgorithm(alpha, model, observations, scaled); + + return alpha; + } + + /** + * Internal function to compute the alpha factors + * + * @param alpha matrix to store alpha factors in. + * @param model model to use for alpha factor computation. + * @param observations observation sequence seen. + * @param scaled set to true if log-scaled beta factors should be computed. + */ + static void forwardAlgorithm(Matrix alpha, HmmModel model, int[] observations, boolean scaled) { + + // fetch references to the model parameters + Vector ip = model.getInitialProbabilities(); + Matrix b = model.getEmissionMatrix(); + Matrix a = model.getTransitionMatrix(); + + if (scaled) { // compute log scaled alpha values + // Initialization + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + alpha.setQuick(0, i, Math.log(ip.getQuick(i) * b.getQuick(i, observations[0]))); + } + + // Induction + for (int t = 1; t < observations.length; t++) { + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + double sum = Double.NEGATIVE_INFINITY; // log(0) + for (int j = 0; j < model.getNrOfHiddenStates(); j++) { + double tmp = alpha.getQuick(t - 1, j) + Math.log(a.getQuick(j, i)); + if (tmp > Double.NEGATIVE_INFINITY) { + // make sure we handle log(0) correctly + sum = tmp + Math.log1p(Math.exp(sum - tmp)); + } + } + alpha.setQuick(t, i, sum + Math.log(b.getQuick(i, observations[t]))); + } + } + } else { + + // Initialization + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + alpha.setQuick(0, i, ip.getQuick(i) * b.getQuick(i, observations[0])); + } + + // Induction + for (int t = 1; t < observations.length; t++) { + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + double sum = 0.0; + for (int j = 0; j < model.getNrOfHiddenStates(); j++) { + sum += alpha.getQuick(t - 1, j) * a.getQuick(j, i); + } + alpha.setQuick(t, i, sum * b.getQuick(i, observations[t])); + } + } + } + } + + /** + * External function to compute a matrix of beta factors + * + * @param model model to use for estimation. + * @param observations observation sequence seen. + * @param scaled Set to true if log-scaled beta factors should be computed. + * @return beta factors based on the model and observation sequence. + */ + public static Matrix backwardAlgorithm(HmmModel model, int[] observations, boolean scaled) { + // initialize the matrix + Matrix beta = new DenseMatrix(observations.length, model.getNrOfHiddenStates()); + // compute the beta factors + backwardAlgorithm(beta, model, observations, scaled); + + return beta; + } + + /** + * Internal function to compute the beta factors + * + * @param beta Matrix to store resulting factors in. + * @param model model to use for factor estimation. + * @param observations sequence of observations to estimate. + * @param scaled set to true to compute log-scaled parameters. + */ + static void backwardAlgorithm(Matrix beta, HmmModel model, int[] observations, boolean scaled) { + // fetch references to the model parameters + Matrix b = model.getEmissionMatrix(); + Matrix a = model.getTransitionMatrix(); + + if (scaled) { // compute log-scaled factors + // initialization + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + beta.setQuick(observations.length - 1, i, 0); + } + + // induction + for (int t = observations.length - 2; t >= 0; t--) { + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + double sum = Double.NEGATIVE_INFINITY; // log(0) + for (int j = 0; j < model.getNrOfHiddenStates(); j++) { + double tmp = beta.getQuick(t + 1, j) + Math.log(a.getQuick(i, j)) + + Math.log(b.getQuick(j, observations[t + 1])); + if (tmp > Double.NEGATIVE_INFINITY) { + // handle log(0) + sum = tmp + Math.log1p(Math.exp(sum - tmp)); + } + } + beta.setQuick(t, i, sum); + } + } + } else { + // initialization + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + beta.setQuick(observations.length - 1, i, 1); + } + // induction + for (int t = observations.length - 2; t >= 0; t--) { + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + double sum = 0; + for (int j = 0; j < model.getNrOfHiddenStates(); j++) { + sum += beta.getQuick(t + 1, j) * a.getQuick(i, j) * b.getQuick(j, observations[t + 1]); + } + beta.setQuick(t, i, sum); + } + } + } + } + + /** + * Viterbi algorithm to compute the most likely hidden sequence for a given + * model and observed sequence + * + * @param model HmmModel for which the Viterbi path should be computed + * @param observations Sequence of observations + * @param scaled Use log-scaled computations, this requires higher computational + * effort but is numerically more stable for large observation + * sequences + * @return nrOfObservations 1D int array containing the most likely hidden + * sequence + */ + public static int[] viterbiAlgorithm(HmmModel model, int[] observations, boolean scaled) { + + // probability that the most probable hidden states ends at state i at + // time t + double[][] delta = new double[observations.length][model + .getNrOfHiddenStates()]; + + // previous hidden state in the most probable state leading up to state + // i at time t + int[][] phi = new int[observations.length - 1][model.getNrOfHiddenStates()]; + + // initialize the return array + int[] sequence = new int[observations.length]; + + viterbiAlgorithm(sequence, delta, phi, model, observations, scaled); + + return sequence; + } + + /** + * Internal version of the viterbi algorithm, allowing to reuse existing + * arrays instead of allocating new ones + * + * @param sequence NrOfObservations 1D int array for storing the viterbi sequence + * @param delta NrOfObservations x NrHiddenStates 2D double array for storing the + * delta factors + * @param phi NrOfObservations-1 x NrHiddenStates 2D int array for storing the + * phi values + * @param model HmmModel for which the viterbi path should be computed + * @param observations Sequence of observations + * @param scaled Use log-scaled computations, this requires higher computational + * effort but is numerically more stable for large observation + * sequences + */ + static void viterbiAlgorithm(int[] sequence, double[][] delta, int[][] phi, HmmModel model, int[] observations, + boolean scaled) { + // fetch references to the model parameters + Vector ip = model.getInitialProbabilities(); + Matrix b = model.getEmissionMatrix(); + Matrix a = model.getTransitionMatrix(); + + // Initialization + if (scaled) { + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + delta[0][i] = Math.log(ip.getQuick(i) * b.getQuick(i, observations[0])); + } + } else { + + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + delta[0][i] = ip.getQuick(i) * b.getQuick(i, observations[0]); + } + } + + // Induction + // iterate over the time + if (scaled) { + for (int t = 1; t < observations.length; t++) { + // iterate over the hidden states + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + // find the maximum probability and most likely state + // leading up + // to this + int maxState = 0; + double maxProb = delta[t - 1][0] + Math.log(a.getQuick(0, i)); + for (int j = 1; j < model.getNrOfHiddenStates(); j++) { + double prob = delta[t - 1][j] + Math.log(a.getQuick(j, i)); + if (prob > maxProb) { + maxProb = prob; + maxState = j; + } + } + delta[t][i] = maxProb + Math.log(b.getQuick(i, observations[t])); + phi[t - 1][i] = maxState; + } + } + } else { + for (int t = 1; t < observations.length; t++) { + // iterate over the hidden states + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + // find the maximum probability and most likely state + // leading up + // to this + int maxState = 0; + double maxProb = delta[t - 1][0] * a.getQuick(0, i); + for (int j = 1; j < model.getNrOfHiddenStates(); j++) { + double prob = delta[t - 1][j] * a.getQuick(j, i); + if (prob > maxProb) { + maxProb = prob; + maxState = j; + } + } + delta[t][i] = maxProb * b.getQuick(i, observations[t]); + phi[t - 1][i] = maxState; + } + } + } + + // find the most likely end state for initialization + double maxProb; + if (scaled) { + maxProb = Double.NEGATIVE_INFINITY; + } else { + maxProb = 0.0; + } + for (int i = 0; i < model.getNrOfHiddenStates(); i++) { + if (delta[observations.length - 1][i] > maxProb) { + maxProb = delta[observations.length - 1][i]; + sequence[observations.length - 1] = i; + } + } + + // now backtrack to find the most likely hidden sequence + for (int t = observations.length - 2; t >= 0; t--) { + sequence[t] = phi[t][sequence[t + 1]]; + } + } + +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmEvaluator.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmEvaluator.java b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmEvaluator.java new file mode 100644 index 0000000..6e2def6 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmEvaluator.java @@ -0,0 +1,194 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import java.util.Random; + +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; + +/** + * The HMMEvaluator class offers several methods to evaluate an HMM Model. The + * following use-cases are covered: 1) Generate a sequence of output states from + * a given model (prediction). 2) Compute the likelihood that a given model + * generated a given sequence of output states (model likelihood). 3) Compute + * the most likely hidden sequence for a given model and a given observed + * sequence (decoding). + */ +public final class HmmEvaluator { + + /** + * No constructor for utility classes. + */ + private HmmEvaluator() {} + + /** + * Predict a sequence of steps output states for the given HMM model + * + * @param model The Hidden Markov model used to generate the output sequence + * @param steps Size of the generated output sequence + * @return integer array containing a sequence of steps output state IDs, + * generated by the specified model + */ + public static int[] predict(HmmModel model, int steps) { + return predict(model, steps, RandomUtils.getRandom()); + } + + /** + * Predict a sequence of steps output states for the given HMM model + * + * @param model The Hidden Markov model used to generate the output sequence + * @param steps Size of the generated output sequence + * @param seed seed to use for the RNG + * @return integer array containing a sequence of steps output state IDs, + * generated by the specified model + */ + public static int[] predict(HmmModel model, int steps, long seed) { + return predict(model, steps, RandomUtils.getRandom(seed)); + } + /** + * Predict a sequence of steps output states for the given HMM model using the + * given seed for probabilistic experiments + * + * @param model The Hidden Markov model used to generate the output sequence + * @param steps Size of the generated output sequence + * @param rand RNG to use + * @return integer array containing a sequence of steps output state IDs, + * generated by the specified model + */ + private static int[] predict(HmmModel model, int steps, Random rand) { + // fetch the cumulative distributions + Vector cip = HmmUtils.getCumulativeInitialProbabilities(model); + Matrix ctm = HmmUtils.getCumulativeTransitionMatrix(model); + Matrix com = HmmUtils.getCumulativeOutputMatrix(model); + // allocate the result IntArrayList + int[] result = new int[steps]; + // choose the initial state + int hiddenState = 0; + + double randnr = rand.nextDouble(); + while (cip.get(hiddenState) < randnr) { + hiddenState++; + } + + // now draw steps output states according to the cumulative + // distributions + for (int step = 0; step < steps; ++step) { + // choose output state to given hidden state + randnr = rand.nextDouble(); + int outputState = 0; + while (com.get(hiddenState, outputState) < randnr) { + outputState++; + } + result[step] = outputState; + // choose the next hidden state + randnr = rand.nextDouble(); + int nextHiddenState = 0; + while (ctm.get(hiddenState, nextHiddenState) < randnr) { + nextHiddenState++; + } + hiddenState = nextHiddenState; + } + return result; + } + + /** + * Returns the likelihood that a given output sequence was produced by the + * given model. Internally, this function calls the forward algorithm to + * compute the alpha values and then uses the overloaded function to compute + * the actual model likelihood. + * + * @param model Model to base the likelihood on. + * @param outputSequence Sequence to compute likelihood for. + * @param scaled Use log-scaled parameters for computation. This is computationally + * more expensive, but offers better numerically stability in case of + * long output sequences + * @return Likelihood that the given model produced the given sequence + */ + public static double modelLikelihood(HmmModel model, int[] outputSequence, boolean scaled) { + return modelLikelihood(HmmAlgorithms.forwardAlgorithm(model, outputSequence, scaled), scaled); + } + + /** + * Computes the likelihood that a given output sequence was computed by a + * given model using the alpha values computed by the forward algorithm. + * // TODO I am a bit confused here - where is the output sequence referenced in the comment above in the code? + * @param alpha Matrix of alpha values + * @param scaled Set to true if the alpha values are log-scaled. + * @return model likelihood. + */ + public static double modelLikelihood(Matrix alpha, boolean scaled) { + double likelihood = 0; + if (scaled) { + for (int i = 0; i < alpha.numCols(); ++i) { + likelihood += Math.exp(alpha.getQuick(alpha.numRows() - 1, i)); + } + } else { + for (int i = 0; i < alpha.numCols(); ++i) { + likelihood += alpha.getQuick(alpha.numRows() - 1, i); + } + } + return likelihood; + } + + /** + * Computes the likelihood that a given output sequence was computed by a + * given model. + * + * @param model model to compute sequence likelihood for. + * @param outputSequence sequence to base computation on. + * @param beta beta parameters. + * @param scaled set to true if betas are log-scaled. + * @return likelihood of the outputSequence given the model. + */ + public static double modelLikelihood(HmmModel model, int[] outputSequence, Matrix beta, boolean scaled) { + double likelihood = 0; + // fetch the emission probabilities + Matrix e = model.getEmissionMatrix(); + Vector pi = model.getInitialProbabilities(); + int firstOutput = outputSequence[0]; + if (scaled) { + for (int i = 0; i < model.getNrOfHiddenStates(); ++i) { + likelihood += pi.getQuick(i) * Math.exp(beta.getQuick(0, i)) * e.getQuick(i, firstOutput); + } + } else { + for (int i = 0; i < model.getNrOfHiddenStates(); ++i) { + likelihood += pi.getQuick(i) * beta.getQuick(0, i) * e.getQuick(i, firstOutput); + } + } + return likelihood; + } + + /** + * Returns the most likely sequence of hidden states for the given model and + * observation + * + * @param model model to use for decoding. + * @param observations integer Array containing a sequence of observed state IDs + * @param scaled Use log-scaled computations, this requires higher computational + * effort but is numerically more stable for large observation + * sequences + * @return integer array containing the most likely sequence of hidden state + * IDs + */ + public static int[] decode(HmmModel model, int[] observations, boolean scaled) { + return HmmAlgorithms.viterbiAlgorithm(model, observations, scaled); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmModel.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmModel.java b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmModel.java new file mode 100644 index 0000000..bc24884 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmModel.java @@ -0,0 +1,383 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import java.util.Map; +import java.util.Random; + +import com.google.common.collect.BiMap; +import com.google.common.collect.HashBiMap; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; + +/** + * Main class defining a Hidden Markov Model + */ +public class HmmModel implements Cloneable { + + /** Bi-directional Map for storing the observed state names */ + private BiMap<String,Integer> outputStateNames; + + /** Bi-Directional Map for storing the hidden state names */ + private BiMap<String,Integer> hiddenStateNames; + + /* Number of hidden states */ + private int nrOfHiddenStates; + + /** Number of output states */ + private int nrOfOutputStates; + + /** + * Transition matrix containing the transition probabilities between hidden + * states. TransitionMatrix(i,j) is the probability that we change from hidden + * state i to hidden state j In general: P(h(t+1)=h_j | h(t) = h_i) = + * transitionMatrix(i,j) Since we have to make sure that each hidden state can + * be "left", the following normalization condition has to hold: + * sum(transitionMatrix(i,j),j=1..hiddenStates) = 1 + */ + private Matrix transitionMatrix; + + /** + * Output matrix containing the probabilities that we observe a given output + * state given a hidden state. outputMatrix(i,j) is the probability that we + * observe output state j if we are in hidden state i Formally: P(o(t)=o_j | + * h(t)=h_i) = outputMatrix(i,j) Since we always have an observation for each + * hidden state, the following normalization condition has to hold: + * sum(outputMatrix(i,j),j=1..outputStates) = 1 + */ + private Matrix emissionMatrix; + + /** + * Vector containing the initial hidden state probabilities. That is + * P(h(0)=h_i) = initialProbabilities(i). Since we are dealing with + * probabilities the following normalization condition has to hold: + * sum(initialProbabilities(i),i=1..hiddenStates) = 1 + */ + private Vector initialProbabilities; + + + /** + * Get a copy of this model + */ + @Override + public HmmModel clone() { + HmmModel model = new HmmModel(transitionMatrix.clone(), emissionMatrix.clone(), initialProbabilities.clone()); + if (hiddenStateNames != null) { + model.hiddenStateNames = HashBiMap.create(hiddenStateNames); + } + if (outputStateNames != null) { + model.outputStateNames = HashBiMap.create(outputStateNames); + } + return model; + } + + /** + * Assign the content of another HMM model to this one + * + * @param model The HmmModel that will be assigned to this one + */ + public void assign(HmmModel model) { + this.nrOfHiddenStates = model.nrOfHiddenStates; + this.nrOfOutputStates = model.nrOfOutputStates; + this.hiddenStateNames = model.hiddenStateNames; + this.outputStateNames = model.outputStateNames; + // for now clone the matrix/vectors + this.initialProbabilities = model.initialProbabilities.clone(); + this.emissionMatrix = model.emissionMatrix.clone(); + this.transitionMatrix = model.transitionMatrix.clone(); + } + + /** + * Construct a valid random Hidden-Markov parameter set with the given number + * of hidden and output states using a given seed. + * + * @param nrOfHiddenStates Number of hidden states + * @param nrOfOutputStates Number of output states + * @param seed Seed for the random initialization, if set to 0 the current time + * is used + */ + public HmmModel(int nrOfHiddenStates, int nrOfOutputStates, long seed) { + this.nrOfHiddenStates = nrOfHiddenStates; + this.nrOfOutputStates = nrOfOutputStates; + this.transitionMatrix = new DenseMatrix(nrOfHiddenStates, nrOfHiddenStates); + this.emissionMatrix = new DenseMatrix(nrOfHiddenStates, nrOfOutputStates); + this.initialProbabilities = new DenseVector(nrOfHiddenStates); + // initialize a random, valid parameter set + initRandomParameters(seed); + } + + /** + * Construct a valid random Hidden-Markov parameter set with the given number + * of hidden and output states. + * + * @param nrOfHiddenStates Number of hidden states + * @param nrOfOutputStates Number of output states + */ + public HmmModel(int nrOfHiddenStates, int nrOfOutputStates) { + this(nrOfHiddenStates, nrOfOutputStates, 0); + } + + /** + * Generates a Hidden Markov model using the specified parameters + * + * @param transitionMatrix transition probabilities. + * @param emissionMatrix emission probabilities. + * @param initialProbabilities initial start probabilities. + * @throws IllegalArgumentException If the given parameter set is invalid + */ + public HmmModel(Matrix transitionMatrix, Matrix emissionMatrix, Vector initialProbabilities) { + this.nrOfHiddenStates = initialProbabilities.size(); + this.nrOfOutputStates = emissionMatrix.numCols(); + this.transitionMatrix = transitionMatrix; + this.emissionMatrix = emissionMatrix; + this.initialProbabilities = initialProbabilities; + } + + /** + * Initialize a valid random set of HMM parameters + * + * @param seed seed to use for Random initialization. Use 0 to use Java-built-in-version. + */ + private void initRandomParameters(long seed) { + Random rand; + // initialize the random number generator + if (seed == 0) { + rand = RandomUtils.getRandom(); + } else { + rand = RandomUtils.getRandom(seed); + } + // initialize the initial Probabilities + double sum = 0; // used for normalization + for (int i = 0; i < nrOfHiddenStates; i++) { + double nextRand = rand.nextDouble(); + initialProbabilities.set(i, nextRand); + sum += nextRand; + } + // "normalize" the vector to generate probabilities + initialProbabilities = initialProbabilities.divide(sum); + + // initialize the transition matrix + double[] values = new double[nrOfHiddenStates]; + for (int i = 0; i < nrOfHiddenStates; i++) { + sum = 0; + for (int j = 0; j < nrOfHiddenStates; j++) { + values[j] = rand.nextDouble(); + sum += values[j]; + } + // normalize the random values to obtain probabilities + for (int j = 0; j < nrOfHiddenStates; j++) { + values[j] /= sum; + } + // set this row of the transition matrix + transitionMatrix.set(i, values); + } + + // initialize the output matrix + values = new double[nrOfOutputStates]; + for (int i = 0; i < nrOfHiddenStates; i++) { + sum = 0; + for (int j = 0; j < nrOfOutputStates; j++) { + values[j] = rand.nextDouble(); + sum += values[j]; + } + // normalize the random values to obtain probabilities + for (int j = 0; j < nrOfOutputStates; j++) { + values[j] /= sum; + } + // set this row of the output matrix + emissionMatrix.set(i, values); + } + } + + /** + * Getter Method for the number of hidden states + * + * @return Number of hidden states + */ + public int getNrOfHiddenStates() { + return nrOfHiddenStates; + } + + /** + * Getter Method for the number of output states + * + * @return Number of output states + */ + public int getNrOfOutputStates() { + return nrOfOutputStates; + } + + /** + * Getter function to get the hidden state transition matrix + * + * @return returns the model's transition matrix. + */ + public Matrix getTransitionMatrix() { + return transitionMatrix; + } + + /** + * Getter function to get the output state probability matrix + * + * @return returns the models emission matrix. + */ + public Matrix getEmissionMatrix() { + return emissionMatrix; + } + + /** + * Getter function to return the vector of initial hidden state probabilities + * + * @return returns the model's init probabilities. + */ + public Vector getInitialProbabilities() { + return initialProbabilities; + } + + /** + * Getter method for the hidden state Names map + * + * @return hidden state names. + */ + public Map<String, Integer> getHiddenStateNames() { + return hiddenStateNames; + } + + /** + * Register an array of hidden state Names. We assume that the state name at + * position i has the ID i + * + * @param stateNames names of hidden states. + */ + public void registerHiddenStateNames(String[] stateNames) { + if (stateNames != null) { + hiddenStateNames = HashBiMap.create(); + for (int i = 0; i < stateNames.length; ++i) { + hiddenStateNames.put(stateNames[i], i); + } + } + } + + /** + * Register a map of hidden state Names/state IDs + * + * @param stateNames <String,Integer> Map that assigns each state name an integer ID + */ + public void registerHiddenStateNames(Map<String, Integer> stateNames) { + if (stateNames != null) { + hiddenStateNames = HashBiMap.create(stateNames); + } + } + + /** + * Lookup the name for the given hidden state ID + * + * @param id Integer id of the hidden state + * @return String containing the name for the given ID, null if this ID is not + * known or no hidden state names were specified + */ + public String getHiddenStateName(int id) { + if (hiddenStateNames == null) { + return null; + } + return hiddenStateNames.inverse().get(id); + } + + /** + * Lookup the ID for the given hidden state name + * + * @param name Name of the hidden state + * @return int containing the ID for the given name, -1 if this name is not + * known or no hidden state names were specified + */ + public int getHiddenStateID(String name) { + if (hiddenStateNames == null) { + return -1; + } + Integer tmp = hiddenStateNames.get(name); + return tmp == null ? -1 : tmp; + } + + /** + * Getter method for the output state Names map + * + * @return names of output states. + */ + public Map<String, Integer> getOutputStateNames() { + return outputStateNames; + } + + /** + * Register an array of hidden state Names. We assume that the state name at + * position i has the ID i + * + * @param stateNames state names to register. + */ + public void registerOutputStateNames(String[] stateNames) { + if (stateNames != null) { + outputStateNames = HashBiMap.create(); + for (int i = 0; i < stateNames.length; ++i) { + outputStateNames.put(stateNames[i], i); + } + } + } + + /** + * Register a map of hidden state Names/state IDs + * + * @param stateNames <String,Integer> Map that assigns each state name an integer ID + */ + public void registerOutputStateNames(Map<String, Integer> stateNames) { + if (stateNames != null) { + outputStateNames = HashBiMap.create(stateNames); + } + } + + /** + * Lookup the name for the given output state id + * + * @param id Integer id of the output state + * @return String containing the name for the given id, null if this id is not + * known or no output state names were specified + */ + public String getOutputStateName(int id) { + if (outputStateNames == null) { + return null; + } + return outputStateNames.inverse().get(id); + } + + /** + * Lookup the ID for the given output state name + * + * @param name Name of the output state + * @return int containing the ID for the given name, -1 if this name is not + * known or no output state names were specified + */ + public int getOutputStateID(String name) { + if (outputStateNames == null) { + return -1; + } + Integer tmp = outputStateNames.get(name); + return tmp == null ? -1 : tmp; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmTrainer.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmTrainer.java b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmTrainer.java new file mode 100644 index 0000000..a1cd3e0 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmTrainer.java @@ -0,0 +1,488 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import java.util.Collection; +import java.util.Iterator; + +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; + +/** + * Class containing several algorithms used to train a Hidden Markov Model. The + * three main algorithms are: supervised learning, unsupervised Viterbi and + * unsupervised Baum-Welch. + */ +public final class HmmTrainer { + + /** + * No public constructor for utility classes. + */ + private HmmTrainer() { + // nothing to do here really. + } + + /** + * Create an supervised initial estimate of an HMM Model based on a sequence + * of observed and hidden states. + * + * @param nrOfHiddenStates The total number of hidden states + * @param nrOfOutputStates The total number of output states + * @param observedSequence Integer array containing the observed sequence + * @param hiddenSequence Integer array containing the hidden sequence + * @param pseudoCount Value that is assigned to non-occurring transitions to avoid zero + * probabilities. + * @return An initial model using the estimated parameters + */ + public static HmmModel trainSupervised(int nrOfHiddenStates, int nrOfOutputStates, int[] observedSequence, + int[] hiddenSequence, double pseudoCount) { + // make sure the pseudo count is not zero + pseudoCount = pseudoCount == 0 ? Double.MIN_VALUE : pseudoCount; + + // initialize the parameters + DenseMatrix transitionMatrix = new DenseMatrix(nrOfHiddenStates, nrOfHiddenStates); + DenseMatrix emissionMatrix = new DenseMatrix(nrOfHiddenStates, nrOfOutputStates); + // assign a small initial probability that is larger than zero, so + // unseen states will not get a zero probability + transitionMatrix.assign(pseudoCount); + emissionMatrix.assign(pseudoCount); + // given no prior knowledge, we have to assume that all initial hidden + // states are equally likely + DenseVector initialProbabilities = new DenseVector(nrOfHiddenStates); + initialProbabilities.assign(1.0 / nrOfHiddenStates); + + // now loop over the sequences to count the number of transitions + countTransitions(transitionMatrix, emissionMatrix, observedSequence, + hiddenSequence); + + // make sure that probabilities are normalized + for (int i = 0; i < nrOfHiddenStates; i++) { + // compute sum of probabilities for current row of transition matrix + double sum = 0; + for (int j = 0; j < nrOfHiddenStates; j++) { + sum += transitionMatrix.getQuick(i, j); + } + // normalize current row of transition matrix + for (int j = 0; j < nrOfHiddenStates; j++) { + transitionMatrix.setQuick(i, j, transitionMatrix.getQuick(i, j) / sum); + } + // compute sum of probabilities for current row of emission matrix + sum = 0; + for (int j = 0; j < nrOfOutputStates; j++) { + sum += emissionMatrix.getQuick(i, j); + } + // normalize current row of emission matrix + for (int j = 0; j < nrOfOutputStates; j++) { + emissionMatrix.setQuick(i, j, emissionMatrix.getQuick(i, j) / sum); + } + } + + // return a new model using the parameter estimations + return new HmmModel(transitionMatrix, emissionMatrix, initialProbabilities); + } + + /** + * Function that counts the number of state->state and state->output + * transitions for the given observed/hidden sequence. + * + * @param transitionMatrix transition matrix to use. + * @param emissionMatrix emission matrix to use for counting. + * @param observedSequence observation sequence to use. + * @param hiddenSequence sequence of hidden states to use. + */ + private static void countTransitions(Matrix transitionMatrix, + Matrix emissionMatrix, int[] observedSequence, int[] hiddenSequence) { + emissionMatrix.setQuick(hiddenSequence[0], observedSequence[0], + emissionMatrix.getQuick(hiddenSequence[0], observedSequence[0]) + 1); + for (int i = 1; i < observedSequence.length; ++i) { + transitionMatrix + .setQuick(hiddenSequence[i - 1], hiddenSequence[i], transitionMatrix + .getQuick(hiddenSequence[i - 1], hiddenSequence[i]) + 1); + emissionMatrix.setQuick(hiddenSequence[i], observedSequence[i], + emissionMatrix.getQuick(hiddenSequence[i], observedSequence[i]) + 1); + } + } + + /** + * Create an supervised initial estimate of an HMM Model based on a number of + * sequences of observed and hidden states. + * + * @param nrOfHiddenStates The total number of hidden states + * @param nrOfOutputStates The total number of output states + * @param hiddenSequences Collection of hidden sequences to use for training + * @param observedSequences Collection of observed sequences to use for training associated with hidden sequences. + * @param pseudoCount Value that is assigned to non-occurring transitions to avoid zero + * probabilities. + * @return An initial model using the estimated parameters + */ + public static HmmModel trainSupervisedSequence(int nrOfHiddenStates, + int nrOfOutputStates, Collection<int[]> hiddenSequences, + Collection<int[]> observedSequences, double pseudoCount) { + + // make sure the pseudo count is not zero + pseudoCount = pseudoCount == 0 ? Double.MIN_VALUE : pseudoCount; + + // initialize parameters + DenseMatrix transitionMatrix = new DenseMatrix(nrOfHiddenStates, + nrOfHiddenStates); + DenseMatrix emissionMatrix = new DenseMatrix(nrOfHiddenStates, + nrOfOutputStates); + DenseVector initialProbabilities = new DenseVector(nrOfHiddenStates); + + // assign pseudo count to avoid zero probabilities + transitionMatrix.assign(pseudoCount); + emissionMatrix.assign(pseudoCount); + initialProbabilities.assign(pseudoCount); + + // now loop over the sequences to count the number of transitions + Iterator<int[]> hiddenSequenceIt = hiddenSequences.iterator(); + Iterator<int[]> observedSequenceIt = observedSequences.iterator(); + while (hiddenSequenceIt.hasNext() && observedSequenceIt.hasNext()) { + // fetch the current set of sequences + int[] hiddenSequence = hiddenSequenceIt.next(); + int[] observedSequence = observedSequenceIt.next(); + // increase the count for initial probabilities + initialProbabilities.setQuick(hiddenSequence[0], initialProbabilities + .getQuick(hiddenSequence[0]) + 1); + countTransitions(transitionMatrix, emissionMatrix, observedSequence, + hiddenSequence); + } + + // make sure that probabilities are normalized + double isum = 0; // sum of initial probabilities + for (int i = 0; i < nrOfHiddenStates; i++) { + isum += initialProbabilities.getQuick(i); + // compute sum of probabilities for current row of transition matrix + double sum = 0; + for (int j = 0; j < nrOfHiddenStates; j++) { + sum += transitionMatrix.getQuick(i, j); + } + // normalize current row of transition matrix + for (int j = 0; j < nrOfHiddenStates; j++) { + transitionMatrix.setQuick(i, j, transitionMatrix.getQuick(i, j) / sum); + } + // compute sum of probabilities for current row of emission matrix + sum = 0; + for (int j = 0; j < nrOfOutputStates; j++) { + sum += emissionMatrix.getQuick(i, j); + } + // normalize current row of emission matrix + for (int j = 0; j < nrOfOutputStates; j++) { + emissionMatrix.setQuick(i, j, emissionMatrix.getQuick(i, j) / sum); + } + } + // normalize the initial probabilities + for (int i = 0; i < nrOfHiddenStates; ++i) { + initialProbabilities.setQuick(i, initialProbabilities.getQuick(i) / isum); + } + + // return a new model using the parameter estimates + return new HmmModel(transitionMatrix, emissionMatrix, initialProbabilities); + } + + /** + * Iteratively train the parameters of the given initial model wrt to the + * observed sequence using Viterbi training. + * + * @param initialModel The initial model that gets iterated + * @param observedSequence The sequence of observed states + * @param pseudoCount Value that is assigned to non-occurring transitions to avoid zero + * probabilities. + * @param epsilon Convergence criteria + * @param maxIterations The maximum number of training iterations + * @param scaled Use Log-scaled implementation, this is computationally more + * expensive but offers better numerical stability for large observed + * sequences + * @return The iterated model + */ + public static HmmModel trainViterbi(HmmModel initialModel, + int[] observedSequence, double pseudoCount, double epsilon, + int maxIterations, boolean scaled) { + + // make sure the pseudo count is not zero + pseudoCount = pseudoCount == 0 ? Double.MIN_VALUE : pseudoCount; + + // allocate space for iteration models + HmmModel lastIteration = initialModel.clone(); + HmmModel iteration = initialModel.clone(); + + // allocate space for Viterbi path calculation + int[] viterbiPath = new int[observedSequence.length]; + int[][] phi = new int[observedSequence.length - 1][initialModel + .getNrOfHiddenStates()]; + double[][] delta = new double[observedSequence.length][initialModel + .getNrOfHiddenStates()]; + + // now run the Viterbi training iteration + for (int i = 0; i < maxIterations; ++i) { + // compute the Viterbi path + HmmAlgorithms.viterbiAlgorithm(viterbiPath, delta, phi, lastIteration, + observedSequence, scaled); + // Viterbi iteration uses the viterbi path to update + // the probabilities + Matrix emissionMatrix = iteration.getEmissionMatrix(); + Matrix transitionMatrix = iteration.getTransitionMatrix(); + + // first, assign the pseudo count + emissionMatrix.assign(pseudoCount); + transitionMatrix.assign(pseudoCount); + + // now count the transitions + countTransitions(transitionMatrix, emissionMatrix, observedSequence, + viterbiPath); + + // and normalize the probabilities + for (int j = 0; j < iteration.getNrOfHiddenStates(); ++j) { + double sum = 0; + // normalize the rows of the transition matrix + for (int k = 0; k < iteration.getNrOfHiddenStates(); ++k) { + sum += transitionMatrix.getQuick(j, k); + } + for (int k = 0; k < iteration.getNrOfHiddenStates(); ++k) { + transitionMatrix + .setQuick(j, k, transitionMatrix.getQuick(j, k) / sum); + } + // normalize the rows of the emission matrix + sum = 0; + for (int k = 0; k < iteration.getNrOfOutputStates(); ++k) { + sum += emissionMatrix.getQuick(j, k); + } + for (int k = 0; k < iteration.getNrOfOutputStates(); ++k) { + emissionMatrix.setQuick(j, k, emissionMatrix.getQuick(j, k) / sum); + } + } + // check for convergence + if (checkConvergence(lastIteration, iteration, epsilon)) { + break; + } + // overwrite the last iterated model by the new iteration + lastIteration.assign(iteration); + } + // we are done :) + return iteration; + } + + /** + * Iteratively train the parameters of the given initial model wrt the + * observed sequence using Baum-Welch training. + * + * @param initialModel The initial model that gets iterated + * @param observedSequence The sequence of observed states + * @param epsilon Convergence criteria + * @param maxIterations The maximum number of training iterations + * @param scaled Use log-scaled implementations of forward/backward algorithm. This + * is computationally more expensive, but offers better numerical + * stability for long output sequences. + * @return The iterated model + */ + public static HmmModel trainBaumWelch(HmmModel initialModel, + int[] observedSequence, double epsilon, int maxIterations, boolean scaled) { + // allocate space for the iterations + HmmModel lastIteration = initialModel.clone(); + HmmModel iteration = initialModel.clone(); + + // allocate space for baum-welch factors + int hiddenCount = initialModel.getNrOfHiddenStates(); + int visibleCount = observedSequence.length; + Matrix alpha = new DenseMatrix(visibleCount, hiddenCount); + Matrix beta = new DenseMatrix(visibleCount, hiddenCount); + + // now run the baum Welch training iteration + for (int it = 0; it < maxIterations; ++it) { + // fetch emission and transition matrix of current iteration + Vector initialProbabilities = iteration.getInitialProbabilities(); + Matrix emissionMatrix = iteration.getEmissionMatrix(); + Matrix transitionMatrix = iteration.getTransitionMatrix(); + + // compute forward and backward factors + HmmAlgorithms.forwardAlgorithm(alpha, iteration, observedSequence, scaled); + HmmAlgorithms.backwardAlgorithm(beta, iteration, observedSequence, scaled); + + if (scaled) { + logScaledBaumWelch(observedSequence, iteration, alpha, beta); + } else { + unscaledBaumWelch(observedSequence, iteration, alpha, beta); + } + // normalize transition/emission probabilities + // and normalize the probabilities + double isum = 0; + for (int j = 0; j < iteration.getNrOfHiddenStates(); ++j) { + double sum = 0; + // normalize the rows of the transition matrix + for (int k = 0; k < iteration.getNrOfHiddenStates(); ++k) { + sum += transitionMatrix.getQuick(j, k); + } + for (int k = 0; k < iteration.getNrOfHiddenStates(); ++k) { + transitionMatrix + .setQuick(j, k, transitionMatrix.getQuick(j, k) / sum); + } + // normalize the rows of the emission matrix + sum = 0; + for (int k = 0; k < iteration.getNrOfOutputStates(); ++k) { + sum += emissionMatrix.getQuick(j, k); + } + for (int k = 0; k < iteration.getNrOfOutputStates(); ++k) { + emissionMatrix.setQuick(j, k, emissionMatrix.getQuick(j, k) / sum); + } + // normalization parameter for initial probabilities + isum += initialProbabilities.getQuick(j); + } + // normalize initial probabilities + for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) { + initialProbabilities.setQuick(i, initialProbabilities.getQuick(i) + / isum); + } + // check for convergence + if (checkConvergence(lastIteration, iteration, epsilon)) { + break; + } + // overwrite the last iterated model by the new iteration + lastIteration.assign(iteration); + } + // we are done :) + return iteration; + } + + private static void unscaledBaumWelch(int[] observedSequence, HmmModel iteration, Matrix alpha, Matrix beta) { + Vector initialProbabilities = iteration.getInitialProbabilities(); + Matrix emissionMatrix = iteration.getEmissionMatrix(); + Matrix transitionMatrix = iteration.getTransitionMatrix(); + double modelLikelihood = HmmEvaluator.modelLikelihood(alpha, false); + + for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) { + initialProbabilities.setQuick(i, alpha.getQuick(0, i) + * beta.getQuick(0, i)); + } + + // recompute transition probabilities + for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) { + for (int j = 0; j < iteration.getNrOfHiddenStates(); ++j) { + double temp = 0; + for (int t = 0; t < observedSequence.length - 1; ++t) { + temp += alpha.getQuick(t, i) + * emissionMatrix.getQuick(j, observedSequence[t + 1]) + * beta.getQuick(t + 1, j); + } + transitionMatrix.setQuick(i, j, transitionMatrix.getQuick(i, j) + * temp / modelLikelihood); + } + } + // recompute emission probabilities + for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) { + for (int j = 0; j < iteration.getNrOfOutputStates(); ++j) { + double temp = 0; + for (int t = 0; t < observedSequence.length; ++t) { + // delta tensor + if (observedSequence[t] == j) { + temp += alpha.getQuick(t, i) * beta.getQuick(t, i); + } + } + emissionMatrix.setQuick(i, j, temp / modelLikelihood); + } + } + } + + private static void logScaledBaumWelch(int[] observedSequence, HmmModel iteration, Matrix alpha, Matrix beta) { + Vector initialProbabilities = iteration.getInitialProbabilities(); + Matrix emissionMatrix = iteration.getEmissionMatrix(); + Matrix transitionMatrix = iteration.getTransitionMatrix(); + double modelLikelihood = HmmEvaluator.modelLikelihood(alpha, true); + + for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) { + initialProbabilities.setQuick(i, Math.exp(alpha.getQuick(0, i) + beta.getQuick(0, i))); + } + + // recompute transition probabilities + for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) { + for (int j = 0; j < iteration.getNrOfHiddenStates(); ++j) { + double sum = Double.NEGATIVE_INFINITY; // log(0) + for (int t = 0; t < observedSequence.length - 1; ++t) { + double temp = alpha.getQuick(t, i) + + Math.log(emissionMatrix.getQuick(j, observedSequence[t + 1])) + + beta.getQuick(t + 1, j); + if (temp > Double.NEGATIVE_INFINITY) { + // handle 0-probabilities + sum = temp + Math.log1p(Math.exp(sum - temp)); + } + } + transitionMatrix.setQuick(i, j, transitionMatrix.getQuick(i, j) + * Math.exp(sum - modelLikelihood)); + } + } + // recompute emission probabilities + for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) { + for (int j = 0; j < iteration.getNrOfOutputStates(); ++j) { + double sum = Double.NEGATIVE_INFINITY; // log(0) + for (int t = 0; t < observedSequence.length; ++t) { + // delta tensor + if (observedSequence[t] == j) { + double temp = alpha.getQuick(t, i) + beta.getQuick(t, i); + if (temp > Double.NEGATIVE_INFINITY) { + // handle 0-probabilities + sum = temp + Math.log1p(Math.exp(sum - temp)); + } + } + } + emissionMatrix.setQuick(i, j, Math.exp(sum - modelLikelihood)); + } + } + } + + /** + * Check convergence of two HMM models by computing a simple distance between + * emission / transition matrices + * + * @param oldModel Old HMM Model + * @param newModel New HMM Model + * @param epsilon Convergence Factor + * @return true if training converged to a stable state. + */ + private static boolean checkConvergence(HmmModel oldModel, HmmModel newModel, + double epsilon) { + // check convergence of transitionProbabilities + Matrix oldTransitionMatrix = oldModel.getTransitionMatrix(); + Matrix newTransitionMatrix = newModel.getTransitionMatrix(); + double diff = 0; + for (int i = 0; i < oldModel.getNrOfHiddenStates(); ++i) { + for (int j = 0; j < oldModel.getNrOfHiddenStates(); ++j) { + double tmp = oldTransitionMatrix.getQuick(i, j) + - newTransitionMatrix.getQuick(i, j); + diff += tmp * tmp; + } + } + double norm = Math.sqrt(diff); + diff = 0; + // check convergence of emissionProbabilities + Matrix oldEmissionMatrix = oldModel.getEmissionMatrix(); + Matrix newEmissionMatrix = newModel.getEmissionMatrix(); + for (int i = 0; i < oldModel.getNrOfHiddenStates(); i++) { + for (int j = 0; j < oldModel.getNrOfOutputStates(); j++) { + + double tmp = oldEmissionMatrix.getQuick(i, j) + - newEmissionMatrix.getQuick(i, j); + diff += tmp * tmp; + } + } + norm += Math.sqrt(diff); + // iteration has converged :) + return norm < epsilon; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmUtils.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmUtils.java b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmUtils.java new file mode 100644 index 0000000..521be09 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmUtils.java @@ -0,0 +1,361 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import java.util.Collection; +import java.util.Iterator; +import java.util.List; + +import com.google.common.collect.Lists; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.RandomAccessSparseVector; +import org.apache.mahout.math.SparseMatrix; +import org.apache.mahout.math.Vector; + +import com.google.common.base.Preconditions; + +/** + * A collection of utilities for handling HMMModel objects. + */ +public final class HmmUtils { + + /** + * No public constructor for utility classes. + */ + private HmmUtils() { + // nothing to do here really. + } + + /** + * Compute the cumulative transition probability matrix for the given HMM + * model. Matrix where each row i is the cumulative distribution of the + * transition probability distribution for hidden state i. + * + * @param model The HMM model for which the cumulative transition matrix should be + * computed + * @return The computed cumulative transition matrix. + */ + public static Matrix getCumulativeTransitionMatrix(HmmModel model) { + // fetch the needed parameters from the model + int hiddenStates = model.getNrOfHiddenStates(); + Matrix transitionMatrix = model.getTransitionMatrix(); + // now compute the cumulative transition matrix + Matrix resultMatrix = new DenseMatrix(hiddenStates, hiddenStates); + for (int i = 0; i < hiddenStates; ++i) { + double sum = 0; + for (int j = 0; j < hiddenStates; ++j) { + sum += transitionMatrix.get(i, j); + resultMatrix.set(i, j, sum); + } + resultMatrix.set(i, hiddenStates - 1, 1.0); + // make sure the last + // state has always a + // cumulative + // probability of + // exactly 1.0 + } + return resultMatrix; + } + + /** + * Compute the cumulative output probability matrix for the given HMM model. + * Matrix where each row i is the cumulative distribution of the output + * probability distribution for hidden state i. + * + * @param model The HMM model for which the cumulative output matrix should be + * computed + * @return The computed cumulative output matrix. + */ + public static Matrix getCumulativeOutputMatrix(HmmModel model) { + // fetch the needed parameters from the model + int hiddenStates = model.getNrOfHiddenStates(); + int outputStates = model.getNrOfOutputStates(); + Matrix outputMatrix = model.getEmissionMatrix(); + // now compute the cumulative output matrix + Matrix resultMatrix = new DenseMatrix(hiddenStates, outputStates); + for (int i = 0; i < hiddenStates; ++i) { + double sum = 0; + for (int j = 0; j < outputStates; ++j) { + sum += outputMatrix.get(i, j); + resultMatrix.set(i, j, sum); + } + resultMatrix.set(i, outputStates - 1, 1.0); + // make sure the last + // output state has + // always a cumulative + // probability of 1.0 + } + return resultMatrix; + } + + /** + * Compute the cumulative distribution of the initial hidden state + * probabilities for the given HMM model. + * + * @param model The HMM model for which the cumulative initial state probabilities + * should be computed + * @return The computed cumulative initial state probability vector. + */ + public static Vector getCumulativeInitialProbabilities(HmmModel model) { + // fetch the needed parameters from the model + int hiddenStates = model.getNrOfHiddenStates(); + Vector initialProbabilities = model.getInitialProbabilities(); + // now compute the cumulative output matrix + Vector resultVector = new DenseVector(initialProbabilities.size()); + double sum = 0; + for (int i = 0; i < hiddenStates; ++i) { + sum += initialProbabilities.get(i); + resultVector.set(i, sum); + } + resultVector.set(hiddenStates - 1, 1.0); // make sure the last initial + // hidden state probability + // has always a cumulative + // probability of 1.0 + return resultVector; + } + + /** + * Validates an HMM model set + * + * @param model model to sanity check. + */ + public static void validate(HmmModel model) { + if (model == null) { + return; // empty models are valid + } + + /* + * The number of hidden states is positive. + */ + Preconditions.checkArgument(model.getNrOfHiddenStates() > 0, + "Error: The number of hidden states has to be greater than 0"); + + /* + * The number of output states is positive. + */ + Preconditions.checkArgument(model.getNrOfOutputStates() > 0, + "Error: The number of output states has to be greater than 0!"); + + /* + * The size of the vector of initial probabilities is equal to the number of + * the hidden states. Each initial probability is non-negative. The sum of + * initial probabilities is equal to 1. + */ + Preconditions.checkArgument(model.getInitialProbabilities() != null + && model.getInitialProbabilities().size() == model.getNrOfHiddenStates(), + "Error: The vector of initial probabilities is not initialized!"); + + double sum = 0; + for (int i = 0; i < model.getInitialProbabilities().size(); i++) { + Preconditions.checkArgument(model.getInitialProbabilities().get(i) >= 0, + "Error: Initial probability of state %d is negative", i); + sum += model.getInitialProbabilities().get(i); + } + Preconditions.checkArgument(Math.abs(sum - 1) <= 0.00001, + "Error: Initial probabilities do not add up to 1"); + /* + * The row size of the output matrix is equal to the number of the hidden + * states. The column size is equal to the number of output states. Each + * probability of the matrix is non-negative. The sum of each row is equal + * to 1. + */ + Preconditions.checkNotNull(model.getEmissionMatrix(), "Error: The output state matrix is not initialized!"); + Preconditions.checkArgument(model.getEmissionMatrix().numRows() == model.getNrOfHiddenStates() + && model.getEmissionMatrix().numCols() == model.getNrOfOutputStates(), + "Error: The output state matrix is not of the form nrOfHiddenStates x nrOfOutputStates"); + for (int i = 0; i < model.getEmissionMatrix().numRows(); i++) { + sum = 0; + for (int j = 0; j < model.getEmissionMatrix().numCols(); j++) { + Preconditions.checkArgument(model.getEmissionMatrix().get(i, j) >= 0, + "The output state probability from hidden state " + i + " to output state " + j + " is negative"); + sum += model.getEmissionMatrix().get(i, j); + } + Preconditions.checkArgument(Math.abs(sum - 1) <= 0.00001, + "Error: The output state probabilities for hidden state %d don't add up to 1", i); + } + + /* + * The size of both dimension of the transition matrix is equal to the + * number of the hidden states. Each probability of the matrix is + * non-negative. The sum of each row in transition matrix is equal to 1. + */ + Preconditions.checkArgument(model.getTransitionMatrix() != null, + "Error: The hidden state matrix is not initialized!"); + Preconditions.checkArgument(model.getTransitionMatrix().numRows() == model.getNrOfHiddenStates() + && model.getTransitionMatrix().numCols() == model.getNrOfHiddenStates(), + "Error: The output state matrix is not of the form nrOfHiddenStates x nrOfHiddenStates"); + for (int i = 0; i < model.getTransitionMatrix().numRows(); i++) { + sum = 0; + for (int j = 0; j < model.getTransitionMatrix().numCols(); j++) { + Preconditions.checkArgument(model.getTransitionMatrix().get(i, j) >= 0, + "Error: The transition probability from hidden state %d to hidden state %d is negative", i, j); + sum += model.getTransitionMatrix().get(i, j); + } + Preconditions.checkArgument(Math.abs(sum - 1) <= 0.00001, + "Error: The transition probabilities for hidden state " + i + " don't add up to 1."); + } + } + + /** + * Encodes a given collection of state names by the corresponding state IDs + * registered in a given model. + * + * @param model Model to provide the encoding for + * @param sequence Collection of state names + * @param observed If set, the sequence is encoded as a sequence of observed states, + * else it is encoded as sequence of hidden states + * @param defaultValue The default value in case a state is not known + * @return integer array containing the encoded state IDs + */ + public static int[] encodeStateSequence(HmmModel model, + Collection<String> sequence, boolean observed, int defaultValue) { + int[] encoded = new int[sequence.size()]; + Iterator<String> seqIter = sequence.iterator(); + for (int i = 0; i < sequence.size(); ++i) { + String nextState = seqIter.next(); + int nextID; + if (observed) { + nextID = model.getOutputStateID(nextState); + } else { + nextID = model.getHiddenStateID(nextState); + } + // if the ID is -1, use the default value + encoded[i] = nextID < 0 ? defaultValue : nextID; + } + return encoded; + } + + /** + * Decodes a given collection of state IDs into the corresponding state names + * registered in a given model. + * + * @param model model to use for retrieving state names + * @param sequence int array of state IDs + * @param observed If set, the sequence is encoded as a sequence of observed states, + * else it is encoded as sequence of hidden states + * @param defaultValue The default value in case a state is not known + * @return list containing the decoded state names + */ + public static List<String> decodeStateSequence(HmmModel model, + int[] sequence, + boolean observed, + String defaultValue) { + List<String> decoded = Lists.newArrayListWithCapacity(sequence.length); + for (int position : sequence) { + String nextState; + if (observed) { + nextState = model.getOutputStateName(position); + } else { + nextState = model.getHiddenStateName(position); + } + // if null was returned, use the default value + decoded.add(nextState == null ? defaultValue : nextState); + } + return decoded; + } + + /** + * Function used to normalize the probabilities of a given HMM model + * + * @param model model to normalize + */ + public static void normalizeModel(HmmModel model) { + Vector ip = model.getInitialProbabilities(); + Matrix emission = model.getEmissionMatrix(); + Matrix transition = model.getTransitionMatrix(); + // check normalization for all probabilities + double isum = 0; + for (int i = 0; i < model.getNrOfHiddenStates(); ++i) { + isum += ip.getQuick(i); + double sum = 0; + for (int j = 0; j < model.getNrOfHiddenStates(); ++j) { + sum += transition.getQuick(i, j); + } + if (sum != 1.0) { + for (int j = 0; j < model.getNrOfHiddenStates(); ++j) { + transition.setQuick(i, j, transition.getQuick(i, j) / sum); + } + } + sum = 0; + for (int j = 0; j < model.getNrOfOutputStates(); ++j) { + sum += emission.getQuick(i, j); + } + if (sum != 1.0) { + for (int j = 0; j < model.getNrOfOutputStates(); ++j) { + emission.setQuick(i, j, emission.getQuick(i, j) / sum); + } + } + } + if (isum != 1.0) { + for (int i = 0; i < model.getNrOfHiddenStates(); ++i) { + ip.setQuick(i, ip.getQuick(i) / isum); + } + } + } + + /** + * Method to reduce the size of an HMMmodel by converting the models + * DenseMatrix/DenseVectors to sparse implementations and setting every value + * < threshold to 0 + * + * @param model model to truncate + * @param threshold minimum value a model entry must have to be retained. + * @return Truncated model + */ + public static HmmModel truncateModel(HmmModel model, double threshold) { + Vector ip = model.getInitialProbabilities(); + Matrix em = model.getEmissionMatrix(); + Matrix tr = model.getTransitionMatrix(); + // allocate the sparse data structures + RandomAccessSparseVector sparseIp = new RandomAccessSparseVector(model + .getNrOfHiddenStates()); + SparseMatrix sparseEm = new SparseMatrix(model.getNrOfHiddenStates(), model.getNrOfOutputStates()); + SparseMatrix sparseTr = new SparseMatrix(model.getNrOfHiddenStates(), model.getNrOfHiddenStates()); + // now transfer the values + for (int i = 0; i < model.getNrOfHiddenStates(); ++i) { + double value = ip.getQuick(i); + if (value > threshold) { + sparseIp.setQuick(i, value); + } + for (int j = 0; j < model.getNrOfHiddenStates(); ++j) { + value = tr.getQuick(i, j); + if (value > threshold) { + sparseTr.setQuick(i, j, value); + } + } + + for (int j = 0; j < model.getNrOfOutputStates(); ++j) { + value = em.getQuick(i, j); + if (value > threshold) { + sparseEm.setQuick(i, j, value); + } + } + } + // create a new model + HmmModel sparseModel = new HmmModel(sparseTr, sparseEm, sparseIp); + // normalize the model + normalizeModel(sparseModel); + // register the names + sparseModel.registerHiddenStateNames(model.getHiddenStateNames()); + sparseModel.registerOutputStateNames(model.getOutputStateNames()); + // and return + return sparseModel; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/LossyHmmSerializer.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/LossyHmmSerializer.java b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/LossyHmmSerializer.java new file mode 100644 index 0000000..d0ae9c2 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/LossyHmmSerializer.java @@ -0,0 +1,62 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.MatrixWritable; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +/** + * Utils for serializing Writable parts of HmmModel (that means without hidden state names and so on) + */ +final class LossyHmmSerializer { + + private LossyHmmSerializer() { + } + + static void serialize(HmmModel model, DataOutput output) throws IOException { + MatrixWritable matrix = new MatrixWritable(model.getEmissionMatrix()); + matrix.write(output); + matrix.set(model.getTransitionMatrix()); + matrix.write(output); + + VectorWritable vector = new VectorWritable(model.getInitialProbabilities()); + vector.write(output); + } + + static HmmModel deserialize(DataInput input) throws IOException { + MatrixWritable matrix = new MatrixWritable(); + matrix.readFields(input); + Matrix emissionMatrix = matrix.get(); + + matrix.readFields(input); + Matrix transitionMatrix = matrix.get(); + + VectorWritable vector = new VectorWritable(); + vector.readFields(input); + Vector initialProbabilities = vector.get(); + + return new HmmModel(transitionMatrix, emissionMatrix, initialProbabilities); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/RandomSequenceGenerator.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/RandomSequenceGenerator.java b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/RandomSequenceGenerator.java new file mode 100644 index 0000000..cd2ced1 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/RandomSequenceGenerator.java @@ -0,0 +1,108 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import java.io.DataInputStream; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.io.PrintWriter; + +import com.google.common.base.Charsets; +import com.google.common.io.Closeables; +import org.apache.commons.cli2.CommandLine; +import org.apache.commons.cli2.Group; +import org.apache.commons.cli2.Option; +import org.apache.commons.cli2.OptionException; +import org.apache.commons.cli2.builder.ArgumentBuilder; +import org.apache.commons.cli2.builder.DefaultOptionBuilder; +import org.apache.commons.cli2.builder.GroupBuilder; +import org.apache.commons.cli2.commandline.Parser; +import org.apache.mahout.common.CommandLineUtil; + +/** + * Command-line tool for generating random sequences by given HMM + */ +public final class RandomSequenceGenerator { + + private RandomSequenceGenerator() { + } + + public static void main(String[] args) throws IOException { + DefaultOptionBuilder optionBuilder = new DefaultOptionBuilder(); + ArgumentBuilder argumentBuilder = new ArgumentBuilder(); + + Option outputOption = optionBuilder.withLongName("output"). + withDescription("Output file with sequence of observed states"). + withShortName("o").withArgument(argumentBuilder.withMaximum(1).withMinimum(1). + withName("path").create()).withRequired(false).create(); + + Option modelOption = optionBuilder.withLongName("model"). + withDescription("Path to serialized HMM model"). + withShortName("m").withArgument(argumentBuilder.withMaximum(1).withMinimum(1). + withName("path").create()).withRequired(true).create(); + + Option lengthOption = optionBuilder.withLongName("length"). + withDescription("Length of generated sequence"). + withShortName("l").withArgument(argumentBuilder.withMaximum(1).withMinimum(1). + withName("number").create()).withRequired(true).create(); + + Group optionGroup = new GroupBuilder(). + withOption(outputOption).withOption(modelOption).withOption(lengthOption). + withName("Options").create(); + + try { + Parser parser = new Parser(); + parser.setGroup(optionGroup); + CommandLine commandLine = parser.parse(args); + + String output = (String) commandLine.getValue(outputOption); + + String modelPath = (String) commandLine.getValue(modelOption); + + int length = Integer.parseInt((String) commandLine.getValue(lengthOption)); + + //reading serialized HMM + DataInputStream modelStream = new DataInputStream(new FileInputStream(modelPath)); + HmmModel model; + try { + model = LossyHmmSerializer.deserialize(modelStream); + } finally { + Closeables.close(modelStream, true); + } + + //generating observations + int[] observations = HmmEvaluator.predict(model, length, System.currentTimeMillis()); + + //writing output + PrintWriter writer = new PrintWriter(new OutputStreamWriter(new FileOutputStream(output), Charsets.UTF_8), true); + try { + for (int observation : observations) { + writer.print(observation); + writer.print(' '); + } + } finally { + Closeables.close(writer, false); + } + } catch (OptionException e) { + CommandLineUtil.printHelp(optionGroup); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/ViterbiEvaluator.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/ViterbiEvaluator.java b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/ViterbiEvaluator.java new file mode 100644 index 0000000..fb64385 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/ViterbiEvaluator.java @@ -0,0 +1,127 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import java.io.DataInputStream; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.io.PrintWriter; +import java.util.List; +import java.util.Scanner; + +import com.google.common.base.Charsets; +import com.google.common.collect.Lists; +import com.google.common.io.Closeables; +import org.apache.commons.cli2.CommandLine; +import org.apache.commons.cli2.Group; +import org.apache.commons.cli2.Option; +import org.apache.commons.cli2.OptionException; +import org.apache.commons.cli2.builder.ArgumentBuilder; +import org.apache.commons.cli2.builder.DefaultOptionBuilder; +import org.apache.commons.cli2.builder.GroupBuilder; +import org.apache.commons.cli2.commandline.Parser; +import org.apache.mahout.common.CommandLineUtil; +import org.apache.mahout.common.commandline.DefaultOptionCreator; + +/** + * Command-line tool for Viterbi evaluating + */ +public final class ViterbiEvaluator { + + private ViterbiEvaluator() { + } + + public static void main(String[] args) throws IOException { + DefaultOptionBuilder optionBuilder = new DefaultOptionBuilder(); + ArgumentBuilder argumentBuilder = new ArgumentBuilder(); + + Option inputOption = DefaultOptionCreator.inputOption().create(); + + Option outputOption = DefaultOptionCreator.outputOption().create(); + + Option modelOption = optionBuilder.withLongName("model"). + withDescription("Path to serialized HMM model"). + withShortName("m").withArgument(argumentBuilder.withMaximum(1).withMinimum(1). + withName("path").create()).withRequired(true).create(); + + Option likelihoodOption = optionBuilder.withLongName("likelihood"). + withDescription("Compute likelihood of observed sequence"). + withShortName("l").withRequired(false).create(); + + Group optionGroup = new GroupBuilder().withOption(inputOption). + withOption(outputOption).withOption(modelOption).withOption(likelihoodOption). + withName("Options").create(); + + try { + Parser parser = new Parser(); + parser.setGroup(optionGroup); + CommandLine commandLine = parser.parse(args); + + String input = (String) commandLine.getValue(inputOption); + String output = (String) commandLine.getValue(outputOption); + + String modelPath = (String) commandLine.getValue(modelOption); + + boolean computeLikelihood = commandLine.hasOption(likelihoodOption); + + //reading serialized HMM + DataInputStream modelStream = new DataInputStream(new FileInputStream(modelPath)); + HmmModel model; + try { + model = LossyHmmSerializer.deserialize(modelStream); + } finally { + Closeables.close(modelStream, true); + } + + //reading observations + List<Integer> observations = Lists.newArrayList(); + try (Scanner scanner = new Scanner(new FileInputStream(input), "UTF-8")) { + while (scanner.hasNextInt()) { + observations.add(scanner.nextInt()); + } + } + + int[] observationsArray = new int[observations.size()]; + for (int i = 0; i < observations.size(); ++i) { + observationsArray[i] = observations.get(i); + } + + //decoding + int[] hiddenStates = HmmEvaluator.decode(model, observationsArray, true); + + //writing output + PrintWriter writer = new PrintWriter(new OutputStreamWriter(new FileOutputStream(output), Charsets.UTF_8), true); + try { + for (int hiddenState : hiddenStates) { + writer.print(hiddenState); + writer.print(' '); + } + } finally { + Closeables.close(writer, false); + } + + if (computeLikelihood) { + System.out.println("Likelihood: " + HmmEvaluator.modelLikelihood(model, observationsArray, true)); + } + } catch (OptionException e) { + CommandLineUtil.printHelp(optionGroup); + } + } +}
