Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java?rev=1000807&view=auto ============================================================================== --- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java (added) +++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java Fri Sep 24 11:17:13 2010 @@ -0,0 +1,44 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import junit.framework.Assert; + +import org.junit.Test; + +public class HMMModelTest extends HMMTestBase { + + @Test + public void testRandomModelGeneration() { + // make sure we generate a valid random model + HmmModel model = new HmmModel(10, 20); + // check whether the model is valid + HmmUtils.validate(model); + } + + @Test + public void testSerialization() { + String serialized = model.toJson(); + HmmModel model2 = HmmModel.fromJson(serialized); + String serialized2 = model2.toJson(); + // since there are no equals methods for the underlying objects, we + // check identity via the serialization string + Assert.assertEquals(serialized, serialized2); + } + +}
Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java?rev=1000807&view=auto ============================================================================== --- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java (added) +++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java Fri Sep 24 11:17:13 2010 @@ -0,0 +1,49 @@ +package org.apache.mahout.classifier.sequencelearning.hmm; + +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; + +public class HMMTestBase extends MahoutTestCase { + + protected HmmModel model; + protected int[] sequence = {1, 0, 2, 2, 0, 0, 1}; + + /** + * We initialize a new HMM model using the following parameters # hidden + * states: 4 ("H0","H1","H2","H3") # output states: 3 ("O0","O1","O2") # + * transition matrix to: H0 H1 H2 H3 from: H0 0.5 0.1 0.1 0.3 H1 0.4 0.4 0.1 + * 0.1 H2 0.1 0.0 0.8 0.1 H3 0.1 0.1 0.1 0.7 # output matrix to: O0 O1 O2 + * from: H0 0.8 0.1 0.1 H1 0.6 0.1 0.3 H2 0.1 0.8 0.1 H3 0.0 0.1 0.9 # initial + * probabilities H0 0.2 + * <p/> + * H1 0.1 H2 0.4 H3 0.3 + * <p/> + * We also intialize an observation sequence: "O1" "O0" "O2" "O2" "O0" "O0" + * "O1" + */ + + @Override + public void setUp() throws Exception { + super.setUp(); + // intialize the hidden/output state names + String hiddenNames[] = {"H0", "H1", "H2", "H3"}; + String outputNames[] = {"O0", "O1", "O2"}; + // initialize the transition matrix + double transitionP[][] = {{0.5, 0.1, 0.1, 0.3}, {0.4, 0.4, 0.1, 0.1}, + {0.1, 0.0, 0.8, 0.1}, {0.1, 0.1, 0.1, 0.7}}; + // initialize the emission matrix + double emissionP[][] = {{0.8, 0.1, 0.1}, {0.6, 0.1, 0.3}, + {0.1, 0.8, 0.1}, {0.0, 0.1, 0.9}}; + // initialize the initial probability vector + double initialP[] = {0.2, 0.1, 0.4, 0.3}; + // now generate the model + model = new HmmModel(new DenseMatrix(transitionP), new DenseMatrix( + emissionP), new DenseVector(initialP)); + model.registerHiddenStateNames(hiddenNames); + model.registerOutputStateNames(outputNames); + // make sure the model is valid :) + HmmUtils.validate(model); + } + +} Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java?rev=1000807&view=auto ============================================================================== --- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java (added) +++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java Fri Sep 24 11:17:13 2010 @@ -0,0 +1,160 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import junit.framework.Assert; + +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.junit.Test; + +public class HMMTrainerTest extends HMMTestBase { + + @Test + public void testViterbiTraining() { + // initialize the expected model parameters (from R) + // expected transition matrix + double transitionE[][] = {{0.3125, 0.0625, 0.3125, 0.3125}, + {0.25, 0.25, 0.25, 0.25}, {0.5, 0.071429, 0.357143, 0.071429}, + {0.5, 0.1, 0.1, 0.3}}; + // initialize the emission matrix + double emissionE[][] = {{0.882353, 0.058824, 0.058824}, + {0.333333, 0.333333, 0.3333333}, {0.076923, 0.846154, 0.076923}, + {0.111111, 0.111111, 0.777778}}; + + // train the given network to the following output sequence + int[] observed = {1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0}; + + HmmModel trained = HmmTrainer.trainViterbi(model, observed, 0.5, 0.1, 10, + false); + + // now check whether the model matches our expectations + Matrix emissionMatrix = trained.getEmissionMatrix(); + Matrix transitionMatrix = trained.getTransitionMatrix(); + + for (int i = 0; i < trained.getNrOfHiddenStates(); ++i) { + for (int j = 0; j < trained.getNrOfHiddenStates(); ++j) + Assert.assertEquals(transitionMatrix.getQuick(i, j), transitionE[i][j], + 0.00001); + + for (int j = 0; j < trained.getNrOfOutputStates(); ++j) + Assert.assertEquals(emissionMatrix.getQuick(i, j), emissionE[i][j], + 0.00001); + } + + } + + @Test + public void testScaledViterbiTraining() { + // initialize the expected model parameters (from R) + // expected transition matrix + double transitionE[][] = {{0.3125, 0.0625, 0.3125, 0.3125}, + {0.25, 0.25, 0.25, 0.25}, {0.5, 0.071429, 0.357143, 0.071429}, + {0.5, 0.1, 0.1, 0.3}}; + // initialize the emission matrix + double emissionE[][] = {{0.882353, 0.058824, 0.058824}, + {0.333333, 0.333333, 0.3333333}, {0.076923, 0.846154, 0.076923}, + {0.111111, 0.111111, 0.777778}}; + + // train the given network to the following output sequence + int[] observed = {1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0}; + + HmmModel trained = HmmTrainer.trainViterbi(model, observed, 0.5, 0.1, 10, + true); + + // now check whether the model matches our expectations + Matrix emissionMatrix = trained.getEmissionMatrix(); + Matrix transitionMatrix = trained.getTransitionMatrix(); + + for (int i = 0; i < trained.getNrOfHiddenStates(); ++i) { + for (int j = 0; j < trained.getNrOfHiddenStates(); ++j) + Assert.assertEquals(transitionMatrix.getQuick(i, j), transitionE[i][j], + 0.00001); + + for (int j = 0; j < trained.getNrOfOutputStates(); ++j) + Assert.assertEquals(emissionMatrix.getQuick(i, j), emissionE[i][j], + 0.00001); + } + + } + + @Test + public void testBaumWelchTraining() { + // train the given network to the following output sequence + int[] observed = {1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0}; + + // expected values from Matlab HMM package / R HMM package + double[] initialExpected = {0, 0, 1.0, 0}; + double[][] transitionExpected = {{0.2319, 0.0993, 0.0005, 0.6683}, + {0.0001, 0.3345, 0.6654, 0}, {0.5975, 0, 0.4025, 0}, + {0.0024, 0.6657, 0, 0.3319}}; + double[][] emissionExpected = {{0.9995, 0.0004, 0.0001}, + {0.9943, 0.0036, 0.0021}, {0.0059, 0.9941, 0}, {0, 0, 1}}; + + HmmModel trained = HmmTrainer.trainBaumWelch(model, observed, 0.1, 10, + false); + + Vector initialProbabilities = trained.getInitialProbabilities(); + Matrix emissionMatrix = trained.getEmissionMatrix(); + Matrix transitionMatrix = trained.getTransitionMatrix(); + + for (int i = 0; i < trained.getNrOfHiddenStates(); ++i) { + Assert.assertEquals(initialProbabilities.get(i), initialExpected[i], + 0.0001); + for (int j = 0; j < trained.getNrOfHiddenStates(); ++j) + Assert.assertEquals(transitionMatrix.getQuick(i, j), + transitionExpected[i][j], 0.0001); + for (int j = 0; j < trained.getNrOfOutputStates(); ++j) + Assert.assertEquals(emissionMatrix.getQuick(i, j), + emissionExpected[i][j], 0.0001); + } + } + + @Test + public void testScaledBaumWelchTraining() { + // train the given network to the following output sequence + int[] observed = {1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0}; + + // expected values from Matlab HMM package / R HMM package + double[] initialExpected = {0, 0, 1.0, 0}; + double[][] transitionExpected = {{0.2319, 0.0993, 0.0005, 0.6683}, + {0.0001, 0.3345, 0.6654, 0}, {0.5975, 0, 0.4025, 0}, + {0.0024, 0.6657, 0, 0.3319}}; + double[][] emissionExpected = {{0.9995, 0.0004, 0.0001}, + {0.9943, 0.0036, 0.0021}, {0.0059, 0.9941, 0}, {0, 0, 1}}; + + HmmModel trained = HmmTrainer + .trainBaumWelch(model, observed, 0.1, 10, true); + + Vector initialProbabilities = trained.getInitialProbabilities(); + Matrix emissionMatrix = trained.getEmissionMatrix(); + Matrix transitionMatrix = trained.getTransitionMatrix(); + + for (int i = 0; i < trained.getNrOfHiddenStates(); ++i) { + Assert.assertEquals(initialProbabilities.get(i), initialExpected[i], + 0.0001); + for (int j = 0; j < trained.getNrOfHiddenStates(); ++j) + Assert.assertEquals(transitionMatrix.getQuick(i, j), + transitionExpected[i][j], 0.0001); + for (int j = 0; j < trained.getNrOfOutputStates(); ++j) + Assert.assertEquals(emissionMatrix.getQuick(i, j), + emissionExpected[i][j], 0.0001); + } + } + +} Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java?rev=1000807&view=auto ============================================================================== --- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java (added) +++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java Fri Sep 24 11:17:13 2010 @@ -0,0 +1,160 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import java.util.Arrays; + +import junit.framework.Assert; + +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.junit.Test; + +public class HMMUtilsTest extends HMMTestBase { + + Matrix legal2_2; + Matrix legal2_3; + Matrix legal3_3; + Vector legal2; + Matrix illegal2_2; + + public void setUp() throws Exception { + super.setUp(); + legal2_2 = new DenseMatrix(new double[][]{{0.5, 0.5}, {0.3, 0.7}}); + legal2_3 = new DenseMatrix(new double[][]{{0.2, 0.2, 0.6}, + {0.3, 0.3, 0.4}}); + legal3_3 = new DenseMatrix(new double[][]{{0.1, 0.1, 0.8}, + {0.1, 0.2, 0.7}, {0.2, 0.3, 0.5}}); + legal2 = new DenseVector(new double[]{0.4, 0.6}); + illegal2_2 = new DenseMatrix(new double[][]{{1, 2}, {3, 4}}); + } + + @Test + public void testValidatorLegal() { + HmmUtils.validate(new HmmModel(legal2_2, legal2_3, legal2)); + } + + @Test + public void testValidatorDimensionError() { + try { + HmmUtils.validate(new HmmModel(legal3_3, legal2_3, legal2)); + } catch (IllegalArgumentException e) { + // success + return; + } + Assert.fail(); + } + + @Test + public void testValidatorIllegelMatrixError() { + try { + HmmUtils.validate(new HmmModel(illegal2_2, legal2_3, legal2)); + } catch (IllegalArgumentException e) { + // success + return; + } + Assert.fail(); + } + + @Test + public void testEncodeStateSequence() { + String[] hiddenSequence = {"H1", "H2", "H0", "H3", "H4"}; + String[] outputSequence = {"O1", "O2", "O4", "O0"}; + // test encoding the hidden Sequence + int[] hiddenSequenceEnc = HmmUtils.encodeStateSequence(model, Arrays + .asList(hiddenSequence), false, -1); + int[] outputSequenceEnc = HmmUtils.encodeStateSequence(model, Arrays + .asList(outputSequence), true, -1); + // expected state sequences + int[] hiddenSequenceExp = {1, 2, 0, 3, -1}; + int[] outputSequenceExp = {1, 2, -1, 0}; + // compare + for (int i = 0; i < hiddenSequenceEnc.length; ++i) + Assert.assertEquals(hiddenSequenceExp[i], hiddenSequenceEnc[i]); + for (int i = 0; i < outputSequenceEnc.length; ++i) + Assert.assertEquals(outputSequenceExp[i], outputSequenceEnc[i]); + } + + @Test + public void testDecodeStateSequence() { + int[] hiddenSequence = {1, 2, 0, 3, 10}; + int[] outputSequence = {1, 2, 10, 0}; + // test encoding the hidden Sequence + java.util.Vector<String> hiddenSequenceDec = HmmUtils.decodeStateSequence( + model, hiddenSequence, false, "unknown"); + java.util.Vector<String> outputSequenceDec = HmmUtils.decodeStateSequence( + model, outputSequence, true, "unknown"); + // expected state sequences + String[] hiddenSequenceExp = {"H1", "H2", "H0", "H3", "unknown"}; + String[] outputSequenceExp = {"O1", "O2", "unknown", "O0"}; + // compare + for (int i = 0; i < hiddenSequenceExp.length; ++i) + Assert.assertEquals(hiddenSequenceExp[i], hiddenSequenceDec.get(i)); + for (int i = 0; i < outputSequenceExp.length; ++i) + Assert.assertEquals(outputSequenceExp[i], outputSequenceDec.get(i)); + } + + @Test + public void testNormalizeModel() { + DenseVector ip = new DenseVector(new double[]{10, 20}); + DenseMatrix tr = new DenseMatrix(new double[][]{{10, 10}, {20, 25}}); + DenseMatrix em = new DenseMatrix(new double[][]{{5, 7}, {10, 15}}); + HmmModel model = new HmmModel(tr, em, ip); + HmmUtils.normalizeModel(model); + // the model should be valid now + HmmUtils.validate(model); + } + + @Test + public void testTruncateModel() { + DenseVector ip = new DenseVector(new double[]{0.0001, 0.0001, 0.9998}); + DenseMatrix tr = new DenseMatrix(new double[][]{ + {0.9998, 0.0001, 0.0001}, {0.0001, 0.9998, 0.0001}, + {0.0001, 0.0001, 0.9998}}); + DenseMatrix em = new DenseMatrix(new double[][]{ + {0.9998, 0.0001, 0.0001}, {0.0001, 0.9998, 0.0001}, + {0.0001, 0.0001, 0.9998}}); + HmmModel model = new HmmModel(tr, em, ip); + // now truncate the model + HmmModel sparseModel = HmmUtils.truncateModel(model, 0.01); + // first make sure this is a valid model + HmmUtils.validate(sparseModel); + // now check whether the values are as expected + Vector sparse_ip = sparseModel.getInitialProbabilities(); + Matrix sparse_tr = sparseModel.getTransitionMatrix(); + Matrix sparse_em = sparseModel.getEmissionMatrix(); + for (int i = 0; i < sparseModel.getNrOfHiddenStates(); ++i) { + if (i == 2) + Assert.assertEquals(1.0, sparse_ip.getQuick(i)); + else + Assert.assertEquals(0.0, sparse_ip.getQuick(i)); + for (int j = 0; j < sparseModel.getNrOfHiddenStates(); ++j) { + if (i == j) { + Assert.assertEquals(1.0, sparse_tr.getQuick(i, j)); + Assert.assertEquals(1.0, sparse_em.getQuick(i, j)); + } else { + Assert.assertEquals(0.0, sparse_tr.getQuick(i, j)); + Assert.assertEquals(0.0, sparse_em.getQuick(i, j)); + } + } + } + } + +} Added: mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java?rev=1000807&view=auto ============================================================================== --- mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java (added) +++ mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java Fri Sep 24 11:17:13 2010 @@ -0,0 +1,278 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.net.URL; +import java.net.URLConnection; +import java.util.Arrays; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.mahout.math.Matrix; + +/** + * This class implements a sample program that uses a pre-tagged training data + * set to train an HMM model as a POS tagger. The training data is automatically + * downloaded from the following URL: + * http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/train.txt It then + * trains an HMM Model using supervised learning and tests the model on the + * following test data set: + * http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/test.txt Further + * details regarding the data files can be found at + * http://flexcrfs.sourceforge.net/#Case_Study + * + * @author mheimel + */ +public final class PosTagger { + + /** + * No public constructors for utility classes. + */ + private PosTagger() { + // nothing to do here really. + } + + /** + * Logger for this class. + */ + private static final Log LOG = LogFactory.getLog(PosTagger.class); + /** + * Model trained in the example. + */ + private static HmmModel taggingModel; + + /** + * Map for storing the IDs for the POS tags (hidden states) + */ + private static Map<String, Integer> tagIDs; + + /** + * Counter for the next assigned POS tag ID The value of 0 is reserved for + * "unknown POS tag" + */ + private static int nextTagId; + + /** + * Map for storing the IDs for observed words (observed states) + */ + private static Map<String, Integer> wordIDs; + + /** + * Counter for the next assigned word ID The value of 0 is reserved for + * "unknown word" + */ + private static int nextWordId = 1; // 0 is reserved for "unknown word" + + /** + * Used for storing a list of POS tags of read sentences. + */ + private static List<int[]> hiddenSequences; + + /** + * Used for storing a list of word tags of read sentences. + */ + private static List<int[]> observedSequences; + + /** + * number of read lines + */ + private static int readLines; + + /** + * Given an URL, this function fetches the data file, parses it, assigns POS + * Tag/word IDs and fills the hiddenSequences/observedSequences lists with + * data from those files. The data is expected to be in the following format + * (one word per line): word pos-tag np-tag sentences are closed with the . + * pos tag + * + * @param url Where the data file is stored + * @param assignIDs Should IDs for unknown words/tags be assigned? (Needed for + * training data, not needed for test data) + * @throws IOException in case data file cannot be read. + */ + private static void readFromURL(String url, boolean assignIDs) throws IOException { + URLConnection connection = (new URL(url)).openConnection(); + BufferedReader input = new BufferedReader(new InputStreamReader(connection.getInputStream())); + // initialize the data structure + hiddenSequences = new LinkedList<int[]>(); + observedSequences = new LinkedList<int[]>(); + readLines = 0; + + // now read line by line of the input file + String line; + List<Integer> observedSequence = new LinkedList<Integer>(); + List<Integer> hiddenSequence = new LinkedList<Integer>(); + while ((line = input.readLine()) != null) { + if (line.isEmpty()) { + // new sentence starts + int[] observedSequenceArray = new int[observedSequence.size()]; + int[] hiddenSequenceArray = new int[hiddenSequence.size()]; + for (int i = 0; i < observedSequence.size(); ++i) { + observedSequenceArray[i] = observedSequence.get(i); + hiddenSequenceArray[i] = hiddenSequence.get(i); + } + // now register those arrays + hiddenSequences.add(hiddenSequenceArray); + observedSequences.add(observedSequenceArray); + // and reset the linked lists + observedSequence.clear(); + hiddenSequence.clear(); + continue; + } + readLines++; + // we expect the format [word] [POS tag] [NP tag] + String[] tags = line.split(" "); + // when analyzing the training set, assign IDs + if (assignIDs) { + if (!wordIDs.containsKey(tags[0])) + wordIDs.put(tags[0], nextWordId++); + if (!tagIDs.containsKey(tags[1])) + tagIDs.put(tags[1], nextTagId++); + } + // determine the IDs + Integer wordID = wordIDs.get(tags[0]); + Integer tagID = tagIDs.get(tags[1]); + // handle unknown values + wordID = (wordID == null) ? 0 : wordID; + tagID = (tagID == null) ? 0 : tagID; + // now construct the current sequence + observedSequence.add(wordID); + hiddenSequence.add(tagID); + } + // if there is still something in the pipe, register it + if (!observedSequence.isEmpty()) { + int[] observedSequenceArray = new int[observedSequence.size()]; + int[] hiddenSequenceArray = new int[hiddenSequence.size()]; + for (int i = 0; i < observedSequence.size(); ++i) { + observedSequenceArray[i] = observedSequence.get(i); + hiddenSequenceArray[i] = hiddenSequence.get(i); + } + // now register those arrays + hiddenSequences.add(hiddenSequenceArray); + observedSequences.add(observedSequenceArray); + } + } + + private static void trainModel(String trainingURL) throws IOException { + tagIDs = new HashMap<String, Integer>(44); // we expect 44 distinct tags + wordIDs = new HashMap<String, Integer>(19122); // we expect 19122 + // distinct words + LOG.info("Reading and parsing training data file from URL: " + trainingURL); + long start = System.currentTimeMillis(); + readFromURL(trainingURL, true); + long end = System.currentTimeMillis(); + double duration = (end - start) / (double) 1000; + LOG.info("Parsing done in " + duration + " seconds!"); + LOG.info("Read " + readLines + " lines containing " + + hiddenSequences.size() + " sentences with a total of " + + (nextWordId - 1) + " distinct words and " + (nextTagId - 1) + + " distinct POS tags."); + start = System.currentTimeMillis(); + taggingModel = HmmTrainer.trainSupervisedSequence(nextTagId, nextWordId, + hiddenSequences, observedSequences, 0.05); + // we have to adjust the model a bit, + // since we assume a higher probability that a given unknown word is NNP + // than anything else + Matrix emissions = taggingModel.getEmissionMatrix(); + for (int i = 0; i < taggingModel.getNrOfHiddenStates(); ++i) + emissions.setQuick(i, 0, 0.1 / (double) taggingModel + .getNrOfHiddenStates()); + int nnptag = tagIDs.get("NNP"); + emissions.setQuick(nnptag, 0, 1 / (double) taggingModel + .getNrOfHiddenStates()); + // re-normalize the emission probabilities + HmmUtils.normalizeModel(taggingModel); + // now register the names + taggingModel.registerHiddenStateNames(tagIDs); + taggingModel.registerOutputStateNames(wordIDs); + end = System.currentTimeMillis(); + duration = (end - start) / (double) 1000; + LOG.info("Trained HMM model sin " + duration + " seconds!"); + } + + private static void testModel(String testingURL) throws IOException { + LOG.info("Reading and parsing test data file from URL:" + testingURL); + long start = System.currentTimeMillis(); + readFromURL(testingURL, false); + long end = System.currentTimeMillis(); + double duration = (end - start) / (double) 1000; + LOG.info("Parsing done in " + duration + " seconds!"); + LOG.info("Read " + readLines + " lines containing " + + hiddenSequences.size() + " sentences."); + + start = System.currentTimeMillis(); + int errorCount = 0; + int totalCount = 0; + for (int i = 0; i < observedSequences.size(); ++i) { + // fetch the viterbi path as the POS tag for this observed sequence + int[] posEstimate = HmmEvaluator.decode(taggingModel, observedSequences + .get(i), false); + // compare with the expected + int[] posExpected = hiddenSequences.get(i); + for (int j = 0; j < posExpected.length; ++j) { + totalCount++; + if (posEstimate[j] != posExpected[j]) + errorCount++; + } + } + end = System.currentTimeMillis(); + duration = (end - start) / (double) 1000; + LOG.info("POS tagged test file in " + duration + " seconds!"); + double errorRate = (double) errorCount / (double) totalCount; + LOG.info("Tagged the test file with an error rate of: " + errorRate); + } + + private static java.util.Vector<String> tagSentence(String sentence) { + // first, we need to isolate all punctuation characters, so that they + // can be recognized + sentence = sentence.replaceAll("[,.!?:;\"]", " $0 "); + sentence = sentence.replaceAll("''", " '' "); + // now we tokenize the sentence + String[] tokens = sentence.split("[ ]+"); + // now generate the observed sequence + int[] observedSequence = HmmUtils.encodeStateSequence(taggingModel, Arrays + .asList(tokens), true, 0); + // POS tag this observedSequence + int[] hiddenSequence = HmmEvaluator.decode(taggingModel, observedSequence, + false); + // and now decode the tag names + return HmmUtils.decodeStateSequence(taggingModel, hiddenSequence, false, + null); + } + + public static void main(String[] args) throws IOException { + // generate the model from URL + trainModel("http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/train.txt"); + testModel("http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/test.txt"); + // tag an exemplary sentence + String test = "McDonalds is a huge company with many employees ."; + String[] testWords = test.split(" "); + java.util.Vector<String> posTags; + posTags = tagSentence(test); + for (int i = 0; i < posTags.size(); ++i) + LOG.info(testWords[i] + "[" + posTags.get(i) + "]"); + } + +}
