[
https://issues.apache.org/jira/browse/MAHOUT-734?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel
]
Sergey Bartunov updated MAHOUT-734:
-----------------------------------
Status: Patch Available (was: Open)
>From 7be42824a0767d4208b9dcd7da49beee06ff15ee Mon Sep 17 00:00:00 2001
From: Sergey Bartunov <[email protected]>
Date: Wed, 15 Jun 2011 01:04:39 +0400
Subject: [PATCH 3/5] command-line util for baum-welch algorithm on HMM
---
.../sequencelearning/hmm/BaumWelchTrainer.java | 127 ++++++++++++++++++++
.../sequencelearning/hmm/LossyHmmSerializer.java | 57 +++++++++
src/conf/driver.classes.props | 3 +-
3 files changed, 186 insertions(+), 1 deletions(-)
create mode 100644
core/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/BaumWelchTrainer.java
create mode 100644
core/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/LossyHmmSerializer.java
diff --git
a/core/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/BaumWelchTrainer.java
b/core/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/BaumWelchTrainer.java
new file mode 100644
index 0000000..410fcad
--- /dev/null
+++
b/core/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/BaumWelchTrainer.java
@@ -0,0 +1,127 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.mahout.common.CommandLineUtil;
+
+import java.io.DataOutputStream;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Date;
+import java.util.List;
+import java.util.Scanner;
+
+/**
+ * A class for EM training of HMM from console
+ */
+public class BaumWelchTrainer {
+ public static void main(String[] args) throws IOException {
+ final DefaultOptionBuilder optionBuilder = new DefaultOptionBuilder();
+ final ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+
+ final Option inputOption = optionBuilder.withLongName("input").
+ withDescription("Text file with space-separated integers to train on").
+
withShortName("i").withArgument(argumentBuilder.withMaximum(1).withMinimum(1).
+ withName("path").create()).withRequired(true).create();
+
+ final Option outputOption = optionBuilder.withLongName("output").
+ withDescription("Path trained HMM model should be serialized to").
+
withShortName("o").withArgument(argumentBuilder.withMaximum(1).withMinimum(1).
+ withName("path").create()).withRequired(true).create();
+
+ final Option stateNumberOption =
optionBuilder.withLongName("nrOfHiddenStates").
+ withDescription("Number of hidden states").
+
withShortName("nh").withArgument(argumentBuilder.withMaximum(1).withMinimum(1).
+ withName("number").create()).withRequired(true).create();
+
+ final Option observedStateNumberOption =
optionBuilder.withLongName("nrOfObservedStates").
+ withDescription("Number of observed states").
+
withShortName("no").withArgument(argumentBuilder.withMaximum(1).withMinimum(1).
+ withName("number").create()).withRequired(true).create();
+
+ final Option epsilonOption = optionBuilder.withLongName("epsilon").
+ withDescription("Convergence threshold").
+
withShortName("e").withArgument(argumentBuilder.withMaximum(1).withMinimum(1).
+ withName("number").create()).withRequired(true).create();
+
+ final Option iterationsOption =
optionBuilder.withLongName("max-iterations").
+ withDescription("Maximum iterations number").
+
withShortName("m").withArgument(argumentBuilder.withMaximum(1).withMinimum(1).
+ withName("number").create()).withRequired(true).create();
+
+ final Group optionGroup = new GroupBuilder().withOption(inputOption).
+
withOption(outputOption).withOption(stateNumberOption).withOption(observedStateNumberOption).
+ withOption(epsilonOption).withOption(iterationsOption).
+ withName("Options").create();
+
+ try {
+ final Parser parser = new Parser();
+ parser.setGroup(optionGroup);
+ final CommandLine commandLine = parser.parse(args);
+
+ final String input = (String) commandLine.getValue(inputOption);
+ final String output = (String) commandLine.getValue(outputOption);
+
+ final int nrOfHiddenStates = Integer.parseInt((String)
commandLine.getValue(stateNumberOption));
+ final int nrOfObservedStates = Integer.parseInt((String)
commandLine.getValue(observedStateNumberOption));
+
+ final double epsilon = Double.parseDouble((String)
commandLine.getValue(epsilonOption));
+ final int maxIterations = Integer.parseInt((String)
commandLine.getValue(iterationsOption));
+
+ //constructing random-generated HMM
+ final HmmModel model = new HmmModel(nrOfHiddenStates,
nrOfObservedStates, new Date().getTime());
+ final List<Integer> observations = new ArrayList<Integer>();
+
+ //reading observations
+ final FileInputStream inputStream = new FileInputStream(input);
+ final Scanner scanner = new Scanner(inputStream);
+
+ while (scanner.hasNextInt()) {
+ observations.add(scanner.nextInt());
+ }
+
+ scanner.close();
+ inputStream.close();
+
+ final int[] observationsArray = new int[observations.size()];
+ for (int i = 0; i < observations.size(); ++i)
+ observationsArray[i] = observations.get(i);
+
+ //training
+ final HmmModel trainedModel = HmmTrainer.trainBaumWelch(model,
+ observationsArray, epsilon, maxIterations, true);
+
+ //serializing trained model
+ final DataOutputStream stream = new DataOutputStream(new
FileOutputStream(output));
+ LossyHmmSerializer.serialize(trainedModel, stream);
+ stream.close();
+ } catch (OptionException e) {
+ CommandLineUtil.printHelp(optionGroup);
+ }
+ }
+}
diff --git
a/core/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/LossyHmmSerializer.java
b/core/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/LossyHmmSerializer.java
new file mode 100644
index 0000000..8bbb814
--- /dev/null
+++
b/core/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/LossyHmmSerializer.java
@@ -0,0 +1,57 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixWritable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * Utils for serializing Writable parts of HmmModel (that means without hidden
state names and so on)
+ */
+public class LossyHmmSerializer {
+ public static void serialize(HmmModel model, DataOutput output) throws
IOException {
+ final MatrixWritable matrix = new
MatrixWritable(model.getEmissionMatrix());
+ matrix.write(output);
+ matrix.set(model.getTransitionMatrix());
+ matrix.write(output);
+
+ final VectorWritable vector = new
VectorWritable(model.getInitialProbabilities());
+ vector.write(output);
+ }
+
+ public static HmmModel deserialize(DataInput input) throws IOException {
+ final MatrixWritable matrix = new MatrixWritable();
+ matrix.readFields(input);
+ final Matrix emissionMatrix = matrix.get();
+
+ matrix.readFields(input);
+ final Matrix transitionMatrix = matrix.get();
+
+ final VectorWritable vector = new VectorWritable();
+ vector.readFields(input);
+ final Vector initialProbabilities = vector.get();
+
+ return new HmmModel(transitionMatrix, emissionMatrix,
initialProbabilities);
+ }
+}
\ No newline at end of file
diff --git a/src/conf/driver.classes.props b/src/conf/driver.classes.props
index ed72253..cc29fd3 100644
--- a/src/conf/driver.classes.props
+++ b/src/conf/driver.classes.props
@@ -37,4 +37,5 @@ org.apache.mahout.math.hadoop.stochasticsvd.SSVDCli = ssvd :
Stochastic SVD
org.apache.mahout.clustering.spectral.eigencuts.EigencutsDriver = eigencuts :
Eigencuts spectral clustering
org.apache.mahout.clustering.spectral.kmeans.SpectralKMeansDriver =
spectralkmeans : Spectral k-means clustering
org.apache.mahout.cf.taste.hadoop.als.ParallelALSFactorizationJob =
parallelALS : ALS-WR factorization of a rating matrix
-org.apache.mahout.cf.taste.hadoop.als.PredictionJob = predictFromFactorization
: predict preferences from a factorization of a rating matrix
\ No newline at end of file
+org.apache.mahout.cf.taste.hadoop.als.PredictionJob = predictFromFactorization
: predict preferences from a factorization of a rating matrix
+org.apache.mahout.classifier.sequencelearning.hmm.BaumWelchTrainer = baumwelch
: Baum-Welch algorithm for unsupervised HMM training
--
1.7.1
> Command-line utils for HMM
> --------------------------
>
> Key: MAHOUT-734
> URL: https://issues.apache.org/jira/browse/MAHOUT-734
> Project: Mahout
> Issue Type: New Feature
> Components: Classification
> Affects Versions: 0.5
> Reporter: Sergey Bartunov
> Priority: Minor
> Labels: hmm
> Fix For: 0.6
>
>
> Mahout already have HMM functionality, but it presents only in API.
> Command-line tools should be added and registered in driver.classes.props
> [this is my "traning" issue in Jira to learn how to commit patches to the
> Mahout, so please be merficul]
--
This message is automatically generated by JIRA.
For more information on JIRA, see: http://www.atlassian.com/software/jira