http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmUtils.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmUtils.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmUtils.java new file mode 100644 index 0000000..e710816 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmUtils.java @@ -0,0 +1,360 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; + +import com.google.common.base.Preconditions; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.RandomAccessSparseVector; +import org.apache.mahout.math.SparseMatrix; +import org.apache.mahout.math.Vector; + +/** + * A collection of utilities for handling HMMModel objects. + */ +public final class HmmUtils { + + /** + * No public constructor for utility classes. + */ + private HmmUtils() { + // nothing to do here really. + } + + /** + * Compute the cumulative transition probability matrix for the given HMM + * model. Matrix where each row i is the cumulative distribution of the + * transition probability distribution for hidden state i. + * + * @param model The HMM model for which the cumulative transition matrix should be + * computed + * @return The computed cumulative transition matrix. + */ + public static Matrix getCumulativeTransitionMatrix(HmmModel model) { + // fetch the needed parameters from the model + int hiddenStates = model.getNrOfHiddenStates(); + Matrix transitionMatrix = model.getTransitionMatrix(); + // now compute the cumulative transition matrix + Matrix resultMatrix = new DenseMatrix(hiddenStates, hiddenStates); + for (int i = 0; i < hiddenStates; ++i) { + double sum = 0; + for (int j = 0; j < hiddenStates; ++j) { + sum += transitionMatrix.get(i, j); + resultMatrix.set(i, j, sum); + } + resultMatrix.set(i, hiddenStates - 1, 1.0); + // make sure the last + // state has always a + // cumulative + // probability of + // exactly 1.0 + } + return resultMatrix; + } + + /** + * Compute the cumulative output probability matrix for the given HMM model. + * Matrix where each row i is the cumulative distribution of the output + * probability distribution for hidden state i. + * + * @param model The HMM model for which the cumulative output matrix should be + * computed + * @return The computed cumulative output matrix. + */ + public static Matrix getCumulativeOutputMatrix(HmmModel model) { + // fetch the needed parameters from the model + int hiddenStates = model.getNrOfHiddenStates(); + int outputStates = model.getNrOfOutputStates(); + Matrix outputMatrix = model.getEmissionMatrix(); + // now compute the cumulative output matrix + Matrix resultMatrix = new DenseMatrix(hiddenStates, outputStates); + for (int i = 0; i < hiddenStates; ++i) { + double sum = 0; + for (int j = 0; j < outputStates; ++j) { + sum += outputMatrix.get(i, j); + resultMatrix.set(i, j, sum); + } + resultMatrix.set(i, outputStates - 1, 1.0); + // make sure the last + // output state has + // always a cumulative + // probability of 1.0 + } + return resultMatrix; + } + + /** + * Compute the cumulative distribution of the initial hidden state + * probabilities for the given HMM model. + * + * @param model The HMM model for which the cumulative initial state probabilities + * should be computed + * @return The computed cumulative initial state probability vector. + */ + public static Vector getCumulativeInitialProbabilities(HmmModel model) { + // fetch the needed parameters from the model + int hiddenStates = model.getNrOfHiddenStates(); + Vector initialProbabilities = model.getInitialProbabilities(); + // now compute the cumulative output matrix + Vector resultVector = new DenseVector(initialProbabilities.size()); + double sum = 0; + for (int i = 0; i < hiddenStates; ++i) { + sum += initialProbabilities.get(i); + resultVector.set(i, sum); + } + resultVector.set(hiddenStates - 1, 1.0); // make sure the last initial + // hidden state probability + // has always a cumulative + // probability of 1.0 + return resultVector; + } + + /** + * Validates an HMM model set + * + * @param model model to sanity check. + */ + public static void validate(HmmModel model) { + if (model == null) { + return; // empty models are valid + } + + /* + * The number of hidden states is positive. + */ + Preconditions.checkArgument(model.getNrOfHiddenStates() > 0, + "Error: The number of hidden states has to be greater than 0"); + + /* + * The number of output states is positive. + */ + Preconditions.checkArgument(model.getNrOfOutputStates() > 0, + "Error: The number of output states has to be greater than 0!"); + + /* + * The size of the vector of initial probabilities is equal to the number of + * the hidden states. Each initial probability is non-negative. The sum of + * initial probabilities is equal to 1. + */ + Preconditions.checkArgument(model.getInitialProbabilities() != null + && model.getInitialProbabilities().size() == model.getNrOfHiddenStates(), + "Error: The vector of initial probabilities is not initialized!"); + + double sum = 0; + for (int i = 0; i < model.getInitialProbabilities().size(); i++) { + Preconditions.checkArgument(model.getInitialProbabilities().get(i) >= 0, + "Error: Initial probability of state %d is negative", i); + sum += model.getInitialProbabilities().get(i); + } + Preconditions.checkArgument(Math.abs(sum - 1) <= 0.00001, + "Error: Initial probabilities do not add up to 1"); + /* + * The row size of the output matrix is equal to the number of the hidden + * states. The column size is equal to the number of output states. Each + * probability of the matrix is non-negative. The sum of each row is equal + * to 1. + */ + Preconditions.checkNotNull(model.getEmissionMatrix(), "Error: The output state matrix is not initialized!"); + Preconditions.checkArgument(model.getEmissionMatrix().numRows() == model.getNrOfHiddenStates() + && model.getEmissionMatrix().numCols() == model.getNrOfOutputStates(), + "Error: The output state matrix is not of the form nrOfHiddenStates x nrOfOutputStates"); + for (int i = 0; i < model.getEmissionMatrix().numRows(); i++) { + sum = 0; + for (int j = 0; j < model.getEmissionMatrix().numCols(); j++) { + Preconditions.checkArgument(model.getEmissionMatrix().get(i, j) >= 0, + "The output state probability from hidden state " + i + " to output state " + j + " is negative"); + sum += model.getEmissionMatrix().get(i, j); + } + Preconditions.checkArgument(Math.abs(sum - 1) <= 0.00001, + "Error: The output state probabilities for hidden state %d don't add up to 1", i); + } + + /* + * The size of both dimension of the transition matrix is equal to the + * number of the hidden states. Each probability of the matrix is + * non-negative. The sum of each row in transition matrix is equal to 1. + */ + Preconditions.checkArgument(model.getTransitionMatrix() != null, + "Error: The hidden state matrix is not initialized!"); + Preconditions.checkArgument(model.getTransitionMatrix().numRows() == model.getNrOfHiddenStates() + && model.getTransitionMatrix().numCols() == model.getNrOfHiddenStates(), + "Error: The output state matrix is not of the form nrOfHiddenStates x nrOfHiddenStates"); + for (int i = 0; i < model.getTransitionMatrix().numRows(); i++) { + sum = 0; + for (int j = 0; j < model.getTransitionMatrix().numCols(); j++) { + Preconditions.checkArgument(model.getTransitionMatrix().get(i, j) >= 0, + "Error: The transition probability from hidden state %d to hidden state %d is negative", i, j); + sum += model.getTransitionMatrix().get(i, j); + } + Preconditions.checkArgument(Math.abs(sum - 1) <= 0.00001, + "Error: The transition probabilities for hidden state " + i + " don't add up to 1."); + } + } + + /** + * Encodes a given collection of state names by the corresponding state IDs + * registered in a given model. + * + * @param model Model to provide the encoding for + * @param sequence Collection of state names + * @param observed If set, the sequence is encoded as a sequence of observed states, + * else it is encoded as sequence of hidden states + * @param defaultValue The default value in case a state is not known + * @return integer array containing the encoded state IDs + */ + public static int[] encodeStateSequence(HmmModel model, + Collection<String> sequence, boolean observed, int defaultValue) { + int[] encoded = new int[sequence.size()]; + Iterator<String> seqIter = sequence.iterator(); + for (int i = 0; i < sequence.size(); ++i) { + String nextState = seqIter.next(); + int nextID; + if (observed) { + nextID = model.getOutputStateID(nextState); + } else { + nextID = model.getHiddenStateID(nextState); + } + // if the ID is -1, use the default value + encoded[i] = nextID < 0 ? defaultValue : nextID; + } + return encoded; + } + + /** + * Decodes a given collection of state IDs into the corresponding state names + * registered in a given model. + * + * @param model model to use for retrieving state names + * @param sequence int array of state IDs + * @param observed If set, the sequence is encoded as a sequence of observed states, + * else it is encoded as sequence of hidden states + * @param defaultValue The default value in case a state is not known + * @return list containing the decoded state names + */ + public static List<String> decodeStateSequence(HmmModel model, + int[] sequence, + boolean observed, + String defaultValue) { + List<String> decoded = new ArrayList<>(sequence.length); + for (int position : sequence) { + String nextState; + if (observed) { + nextState = model.getOutputStateName(position); + } else { + nextState = model.getHiddenStateName(position); + } + // if null was returned, use the default value + decoded.add(nextState == null ? defaultValue : nextState); + } + return decoded; + } + + /** + * Function used to normalize the probabilities of a given HMM model + * + * @param model model to normalize + */ + public static void normalizeModel(HmmModel model) { + Vector ip = model.getInitialProbabilities(); + Matrix emission = model.getEmissionMatrix(); + Matrix transition = model.getTransitionMatrix(); + // check normalization for all probabilities + double isum = 0; + for (int i = 0; i < model.getNrOfHiddenStates(); ++i) { + isum += ip.getQuick(i); + double sum = 0; + for (int j = 0; j < model.getNrOfHiddenStates(); ++j) { + sum += transition.getQuick(i, j); + } + if (sum != 1.0) { + for (int j = 0; j < model.getNrOfHiddenStates(); ++j) { + transition.setQuick(i, j, transition.getQuick(i, j) / sum); + } + } + sum = 0; + for (int j = 0; j < model.getNrOfOutputStates(); ++j) { + sum += emission.getQuick(i, j); + } + if (sum != 1.0) { + for (int j = 0; j < model.getNrOfOutputStates(); ++j) { + emission.setQuick(i, j, emission.getQuick(i, j) / sum); + } + } + } + if (isum != 1.0) { + for (int i = 0; i < model.getNrOfHiddenStates(); ++i) { + ip.setQuick(i, ip.getQuick(i) / isum); + } + } + } + + /** + * Method to reduce the size of an HMMmodel by converting the models + * DenseMatrix/DenseVectors to sparse implementations and setting every value + * < threshold to 0 + * + * @param model model to truncate + * @param threshold minimum value a model entry must have to be retained. + * @return Truncated model + */ + public static HmmModel truncateModel(HmmModel model, double threshold) { + Vector ip = model.getInitialProbabilities(); + Matrix em = model.getEmissionMatrix(); + Matrix tr = model.getTransitionMatrix(); + // allocate the sparse data structures + RandomAccessSparseVector sparseIp = new RandomAccessSparseVector(model + .getNrOfHiddenStates()); + SparseMatrix sparseEm = new SparseMatrix(model.getNrOfHiddenStates(), model.getNrOfOutputStates()); + SparseMatrix sparseTr = new SparseMatrix(model.getNrOfHiddenStates(), model.getNrOfHiddenStates()); + // now transfer the values + for (int i = 0; i < model.getNrOfHiddenStates(); ++i) { + double value = ip.getQuick(i); + if (value > threshold) { + sparseIp.setQuick(i, value); + } + for (int j = 0; j < model.getNrOfHiddenStates(); ++j) { + value = tr.getQuick(i, j); + if (value > threshold) { + sparseTr.setQuick(i, j, value); + } + } + + for (int j = 0; j < model.getNrOfOutputStates(); ++j) { + value = em.getQuick(i, j); + if (value > threshold) { + sparseEm.setQuick(i, j, value); + } + } + } + // create a new model + HmmModel sparseModel = new HmmModel(sparseTr, sparseEm, sparseIp); + // normalize the model + normalizeModel(sparseModel); + // register the names + sparseModel.registerHiddenStateNames(model.getHiddenStateNames()); + sparseModel.registerOutputStateNames(model.getOutputStateNames()); + // and return + return sparseModel; + } +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/LossyHmmSerializer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/LossyHmmSerializer.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/LossyHmmSerializer.java new file mode 100644 index 0000000..d0ae9c2 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/LossyHmmSerializer.java @@ -0,0 +1,62 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.MatrixWritable; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +/** + * Utils for serializing Writable parts of HmmModel (that means without hidden state names and so on) + */ +final class LossyHmmSerializer { + + private LossyHmmSerializer() { + } + + static void serialize(HmmModel model, DataOutput output) throws IOException { + MatrixWritable matrix = new MatrixWritable(model.getEmissionMatrix()); + matrix.write(output); + matrix.set(model.getTransitionMatrix()); + matrix.write(output); + + VectorWritable vector = new VectorWritable(model.getInitialProbabilities()); + vector.write(output); + } + + static HmmModel deserialize(DataInput input) throws IOException { + MatrixWritable matrix = new MatrixWritable(); + matrix.readFields(input); + Matrix emissionMatrix = matrix.get(); + + matrix.readFields(input); + Matrix transitionMatrix = matrix.get(); + + VectorWritable vector = new VectorWritable(); + vector.readFields(input); + Vector initialProbabilities = vector.get(); + + return new HmmModel(transitionMatrix, emissionMatrix, initialProbabilities); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/RandomSequenceGenerator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/RandomSequenceGenerator.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/RandomSequenceGenerator.java new file mode 100644 index 0000000..02baef1 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/RandomSequenceGenerator.java @@ -0,0 +1,102 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import java.io.DataInputStream; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.io.PrintWriter; + +import org.apache.commons.cli2.CommandLine; +import org.apache.commons.cli2.Group; +import org.apache.commons.cli2.Option; +import org.apache.commons.cli2.OptionException; +import org.apache.commons.cli2.builder.ArgumentBuilder; +import org.apache.commons.cli2.builder.DefaultOptionBuilder; +import org.apache.commons.cli2.builder.GroupBuilder; +import org.apache.commons.cli2.commandline.Parser; +import org.apache.commons.io.Charsets; +import org.apache.mahout.common.CommandLineUtil; + +/** + * Command-line tool for generating random sequences by given HMM + */ +public final class RandomSequenceGenerator { + + private RandomSequenceGenerator() { + } + + public static void main(String[] args) throws IOException { + DefaultOptionBuilder optionBuilder = new DefaultOptionBuilder(); + ArgumentBuilder argumentBuilder = new ArgumentBuilder(); + + Option outputOption = optionBuilder.withLongName("output"). + withDescription("Output file with sequence of observed states"). + withShortName("o").withArgument(argumentBuilder.withMaximum(1).withMinimum(1). + withName("path").create()).withRequired(false).create(); + + Option modelOption = optionBuilder.withLongName("model"). + withDescription("Path to serialized HMM model"). + withShortName("m").withArgument(argumentBuilder.withMaximum(1).withMinimum(1). + withName("path").create()).withRequired(true).create(); + + Option lengthOption = optionBuilder.withLongName("length"). + withDescription("Length of generated sequence"). + withShortName("l").withArgument(argumentBuilder.withMaximum(1).withMinimum(1). + withName("number").create()).withRequired(true).create(); + + Group optionGroup = new GroupBuilder(). + withOption(outputOption).withOption(modelOption).withOption(lengthOption). + withName("Options").create(); + + try { + Parser parser = new Parser(); + parser.setGroup(optionGroup); + CommandLine commandLine = parser.parse(args); + + String output = (String) commandLine.getValue(outputOption); + + String modelPath = (String) commandLine.getValue(modelOption); + + int length = Integer.parseInt((String) commandLine.getValue(lengthOption)); + + //reading serialized HMM + HmmModel model; + try (DataInputStream modelStream = new DataInputStream(new FileInputStream(modelPath))){ + model = LossyHmmSerializer.deserialize(modelStream); + } + + //generating observations + int[] observations = HmmEvaluator.predict(model, length, System.currentTimeMillis()); + + //writing output + try (PrintWriter writer = + new PrintWriter(new OutputStreamWriter(new FileOutputStream(output), Charsets.UTF_8), true)){ + for (int observation : observations) { + writer.print(observation); + writer.print(' '); + } + } + } catch (OptionException e) { + CommandLineUtil.printHelp(optionGroup); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/ViterbiEvaluator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/ViterbiEvaluator.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/ViterbiEvaluator.java new file mode 100644 index 0000000..317237d --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/ViterbiEvaluator.java @@ -0,0 +1,122 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import java.io.DataInputStream; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.io.PrintWriter; +import java.util.ArrayList; +import java.util.List; +import java.util.Scanner; + +import org.apache.commons.cli2.CommandLine; +import org.apache.commons.cli2.Group; +import org.apache.commons.cli2.Option; +import org.apache.commons.cli2.OptionException; +import org.apache.commons.cli2.builder.ArgumentBuilder; +import org.apache.commons.cli2.builder.DefaultOptionBuilder; +import org.apache.commons.cli2.builder.GroupBuilder; +import org.apache.commons.cli2.commandline.Parser; +import org.apache.commons.io.Charsets; +import org.apache.mahout.common.CommandLineUtil; +import org.apache.mahout.common.commandline.DefaultOptionCreator; + +/** + * Command-line tool for Viterbi evaluating + */ +public final class ViterbiEvaluator { + + private ViterbiEvaluator() { + } + + public static void main(String[] args) throws IOException { + DefaultOptionBuilder optionBuilder = new DefaultOptionBuilder(); + ArgumentBuilder argumentBuilder = new ArgumentBuilder(); + + Option inputOption = DefaultOptionCreator.inputOption().create(); + + Option outputOption = DefaultOptionCreator.outputOption().create(); + + Option modelOption = optionBuilder.withLongName("model"). + withDescription("Path to serialized HMM model"). + withShortName("m").withArgument(argumentBuilder.withMaximum(1).withMinimum(1). + withName("path").create()).withRequired(true).create(); + + Option likelihoodOption = optionBuilder.withLongName("likelihood"). + withDescription("Compute likelihood of observed sequence"). + withShortName("l").withRequired(false).create(); + + Group optionGroup = new GroupBuilder().withOption(inputOption). + withOption(outputOption).withOption(modelOption).withOption(likelihoodOption). + withName("Options").create(); + + try { + Parser parser = new Parser(); + parser.setGroup(optionGroup); + CommandLine commandLine = parser.parse(args); + + String input = (String) commandLine.getValue(inputOption); + String output = (String) commandLine.getValue(outputOption); + + String modelPath = (String) commandLine.getValue(modelOption); + + boolean computeLikelihood = commandLine.hasOption(likelihoodOption); + + //reading serialized HMM + ; + HmmModel model; + try (DataInputStream modelStream = new DataInputStream(new FileInputStream(modelPath))) { + model = LossyHmmSerializer.deserialize(modelStream); + } + + //reading observations + List<Integer> observations = new ArrayList<>(); + try (Scanner scanner = new Scanner(new FileInputStream(input), "UTF-8")) { + while (scanner.hasNextInt()) { + observations.add(scanner.nextInt()); + } + } + + int[] observationsArray = new int[observations.size()]; + for (int i = 0; i < observations.size(); ++i) { + observationsArray[i] = observations.get(i); + } + + //decoding + int[] hiddenStates = HmmEvaluator.decode(model, observationsArray, true); + + //writing output + try (PrintWriter writer = + new PrintWriter(new OutputStreamWriter(new FileOutputStream(output), Charsets.UTF_8), true)) { + for (int hiddenState : hiddenStates) { + writer.print(hiddenState); + writer.print(' '); + } + } + + if (computeLikelihood) { + System.out.println("Likelihood: " + HmmEvaluator.modelLikelihood(model, observationsArray, true)); + } + } catch (OptionException e) { + CommandLineUtil.printHelp(optionGroup); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java new file mode 100644 index 0000000..0b2c41b --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java @@ -0,0 +1,317 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.mahout.classifier.AbstractVectorClassifier; +import org.apache.mahout.classifier.OnlineLearner; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.Vector.Element; +import org.apache.mahout.math.function.DoubleFunction; +import org.apache.mahout.math.function.Functions; + +import com.google.common.base.Preconditions; + +/** + * Generic definition of a 1 of n logistic regression classifier that returns probabilities in + * response to a feature vector. This classifier uses 1 of n-1 coding where the 0-th category + * is not stored explicitly. + * <p/> + * Provides the SGD based algorithm for learning a logistic regression, but omits all + * annealing of learning rates. Any extension of this abstract class must define the overall + * and per-term annealing for themselves. + */ +public abstract class AbstractOnlineLogisticRegression extends AbstractVectorClassifier implements OnlineLearner { + // coefficients for the classification. This is a dense matrix + // that is (numCategories-1) x numFeatures + protected Matrix beta; + + // number of categories we are classifying. This should the number of rows of beta plus one. + protected int numCategories; + + protected int step; + + // information about how long since coefficient rows were updated. This allows lazy regularization. + protected Vector updateSteps; + + // information about how many updates we have had on a location. This allows per-term + // annealing a la confidence weighted learning. + protected Vector updateCounts; + + // weight of the prior on beta + private double lambda = 1.0e-5; + protected PriorFunction prior; + + // can we ignore any further regularization when doing classification? + private boolean sealed; + + // by default we don't do any fancy training + private Gradient gradient = new DefaultGradient(); + + /** + * Chainable configuration option. + * + * @param lambda New value of lambda, the weighting factor for the prior distribution. + * @return This, so other configurations can be chained. + */ + public AbstractOnlineLogisticRegression lambda(double lambda) { + this.lambda = lambda; + return this; + } + + /** + * Computes the inverse link function, by default the logistic link function. + * + * @param v The output of the linear combination in a GLM. Note that the value + * of v is disturbed. + * @return A version of v with the link function applied. + */ + public static Vector link(Vector v) { + double max = v.maxValue(); + if (max >= 40) { + // if max > 40, we subtract the large offset first + // the size of the max means that 1+sum(exp(v)) = sum(exp(v)) to within round-off + v.assign(Functions.minus(max)).assign(Functions.EXP); + return v.divide(v.norm(1)); + } else { + v.assign(Functions.EXP); + return v.divide(1 + v.norm(1)); + } + } + + /** + * Computes the binomial logistic inverse link function. + * + * @param r The value to transform. + * @return The logit of r. + */ + public static double link(double r) { + if (r < 0.0) { + double s = Math.exp(r); + return s / (1.0 + s); + } else { + double s = Math.exp(-r); + return 1.0 / (1.0 + s); + } + } + + @Override + public Vector classifyNoLink(Vector instance) { + // apply pending regularization to whichever coefficients matter + regularize(instance); + return beta.times(instance); + } + + public double classifyScalarNoLink(Vector instance) { + return beta.viewRow(0).dot(instance); + } + + /** + * Returns n-1 probabilities, one for each category but the 0-th. The probability of the 0-th + * category is 1 - sum(this result). + * + * @param instance A vector of features to be classified. + * @return A vector of probabilities, one for each of the first n-1 categories. + */ + @Override + public Vector classify(Vector instance) { + return link(classifyNoLink(instance)); + } + + /** + * Returns a single scalar probability in the case where we have two categories. Using this + * method avoids an extra vector allocation as opposed to calling classify() or an extra two + * vector allocations relative to classifyFull(). + * + * @param instance The vector of features to be classified. + * @return The probability of the first of two categories. + * @throws IllegalArgumentException If the classifier doesn't have two categories. + */ + @Override + public double classifyScalar(Vector instance) { + Preconditions.checkArgument(numCategories() == 2, "Can only call classifyScalar with two categories"); + + // apply pending regularization to whichever coefficients matter + regularize(instance); + + // result is a vector with one element so we can just use dot product + return link(classifyScalarNoLink(instance)); + } + + @Override + public void train(long trackingKey, String groupKey, int actual, Vector instance) { + unseal(); + + double learningRate = currentLearningRate(); + + // push coefficients back to zero based on the prior + regularize(instance); + + // update each row of coefficients according to result + Vector gradient = this.gradient.apply(groupKey, actual, instance, this); + for (int i = 0; i < numCategories - 1; i++) { + double gradientBase = gradient.get(i); + + // then we apply the gradientBase to the resulting element. + for (Element updateLocation : instance.nonZeroes()) { + int j = updateLocation.index(); + + double newValue = beta.getQuick(i, j) + gradientBase * learningRate * perTermLearningRate(j) * instance.get(j); + beta.setQuick(i, j, newValue); + } + } + + // remember that these elements got updated + for (Element element : instance.nonZeroes()) { + int j = element.index(); + updateSteps.setQuick(j, getStep()); + updateCounts.incrementQuick(j, 1); + } + nextStep(); + + } + + @Override + public void train(long trackingKey, int actual, Vector instance) { + train(trackingKey, null, actual, instance); + } + + @Override + public void train(int actual, Vector instance) { + train(0, null, actual, instance); + } + + public void regularize(Vector instance) { + if (updateSteps == null || isSealed()) { + return; + } + + // anneal learning rate + double learningRate = currentLearningRate(); + + // here we lazily apply the prior to make up for our neglect + for (int i = 0; i < numCategories - 1; i++) { + for (Element updateLocation : instance.nonZeroes()) { + int j = updateLocation.index(); + double missingUpdates = getStep() - updateSteps.get(j); + if (missingUpdates > 0) { + double rate = getLambda() * learningRate * perTermLearningRate(j); + double newValue = prior.age(beta.get(i, j), missingUpdates, rate); + beta.set(i, j, newValue); + updateSteps.set(j, getStep()); + } + } + } + } + + // these two abstract methods are how extensions can modify the basic learning behavior of this object. + + public abstract double perTermLearningRate(int j); + + public abstract double currentLearningRate(); + + public void setPrior(PriorFunction prior) { + this.prior = prior; + } + + public void setGradient(Gradient gradient) { + this.gradient = gradient; + } + + public PriorFunction getPrior() { + return prior; + } + + public Matrix getBeta() { + close(); + return beta; + } + + public void setBeta(int i, int j, double betaIJ) { + beta.set(i, j, betaIJ); + } + + @Override + public int numCategories() { + return numCategories; + } + + public int numFeatures() { + return beta.numCols(); + } + + public double getLambda() { + return lambda; + } + + public int getStep() { + return step; + } + + protected void nextStep() { + step++; + } + + public boolean isSealed() { + return sealed; + } + + protected void unseal() { + sealed = false; + } + + private void regularizeAll() { + Vector all = new DenseVector(beta.numCols()); + all.assign(1); + regularize(all); + } + + @Override + public void close() { + if (!sealed) { + step++; + regularizeAll(); + sealed = true; + } + } + + public void copyFrom(AbstractOnlineLogisticRegression other) { + // number of categories we are classifying. This should the number of rows of beta plus one. + Preconditions.checkArgument(numCategories == other.numCategories, + "Can't copy unless number of target categories is the same"); + + beta.assign(other.beta); + + step = other.step; + + updateSteps.assign(other.updateSteps); + updateCounts.assign(other.updateCounts); + } + + public boolean validModel() { + double k = beta.aggregate(Functions.PLUS, new DoubleFunction() { + @Override + public double apply(double v) { + return Double.isNaN(v) || Double.isInfinite(v) ? 1 : 0; + } + }); + return k < 1; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java new file mode 100644 index 0000000..24e5798 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java @@ -0,0 +1,586 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.hadoop.io.Writable; +import org.apache.mahout.classifier.OnlineLearner; +import org.apache.mahout.ep.EvolutionaryProcess; +import org.apache.mahout.ep.Mapping; +import org.apache.mahout.ep.Payload; +import org.apache.mahout.ep.State; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.stats.OnlineAuc; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.concurrent.ExecutionException; + +/** + * This is a meta-learner that maintains a pool of ordinary + * {@link org.apache.mahout.classifier.sgd.OnlineLogisticRegression} learners. Each + * member of the pool has different learning rates. Whichever of the learners in the pool falls + * behind in terms of average log-likelihood will be tossed out and replaced with variants of the + * survivors. This will let us automatically derive an annealing schedule that optimizes learning + * speed. Since on-line learners tend to be IO bound anyway, it doesn't cost as much as it might + * seem that it would to maintain multiple learners in memory. Doing this adaptation on-line as we + * learn also decreases the number of learning rate parameters required and replaces the normal + * hyper-parameter search. + * <p/> + * One wrinkle is that the pool of learners that we maintain is actually a pool of + * {@link org.apache.mahout.classifier.sgd.CrossFoldLearner} which themselves contain several OnlineLogisticRegression + * objects. These pools allow estimation + * of performance on the fly even if we make many passes through the data. This does, however, + * increase the cost of training since if we are using 5-fold cross-validation, each vector is used + * 4 times for training and once for classification. If this becomes a problem, then we should + * probably use a 2-way unbalanced train/test split rather than full cross validation. With the + * current default settings, we have 100 learners running. This is better than the alternative of + * running hundreds of training passes to find good hyper-parameters because we only have to parse + * and feature-ize our inputs once. If you already have good hyper-parameters, then you might + * prefer to just run one CrossFoldLearner with those settings. + * <p/> + * The fitness used here is AUC. Another alternative would be to try log-likelihood, but it is much + * easier to get bogus values of log-likelihood than with AUC and the results seem to accord pretty + * well. It would be nice to allow the fitness function to be pluggable. This use of AUC means that + * AdaptiveLogisticRegression is mostly suited for binary target variables. This will be fixed + * before long by extending OnlineAuc to handle non-binary cases or by using a different fitness + * value in non-binary cases. + */ +public class AdaptiveLogisticRegression implements OnlineLearner, Writable { + public static final int DEFAULT_THREAD_COUNT = 20; + public static final int DEFAULT_POOL_SIZE = 20; + private static final int SURVIVORS = 2; + + private int record; + private int cutoff = 1000; + private int minInterval = 1000; + private int maxInterval = 1000; + private int currentStep = 1000; + private int bufferSize = 1000; + + private List<TrainingExample> buffer = new ArrayList<>(); + private EvolutionaryProcess<Wrapper, CrossFoldLearner> ep; + private State<Wrapper, CrossFoldLearner> best; + private int threadCount = DEFAULT_THREAD_COUNT; + private int poolSize = DEFAULT_POOL_SIZE; + private State<Wrapper, CrossFoldLearner> seed; + private int numFeatures; + + private boolean freezeSurvivors = true; + + private static final Logger log = LoggerFactory.getLogger(AdaptiveLogisticRegression.class); + + public AdaptiveLogisticRegression() {} + + /** + * Uses {@link #DEFAULT_THREAD_COUNT} and {@link #DEFAULT_POOL_SIZE} + * @param numCategories The number of categories (labels) to train on + * @param numFeatures The number of features used in creating the vectors (i.e. the cardinality of the vector) + * @param prior The {@link org.apache.mahout.classifier.sgd.PriorFunction} to use + * + * @see #AdaptiveLogisticRegression(int, int, org.apache.mahout.classifier.sgd.PriorFunction, int, int) + */ + public AdaptiveLogisticRegression(int numCategories, int numFeatures, PriorFunction prior) { + this(numCategories, numFeatures, prior, DEFAULT_THREAD_COUNT, DEFAULT_POOL_SIZE); + } + + /** + * + * @param numCategories The number of categories (labels) to train on + * @param numFeatures The number of features used in creating the vectors (i.e. the cardinality of the vector) + * @param prior The {@link org.apache.mahout.classifier.sgd.PriorFunction} to use + * @param threadCount The number of threads to use for training + * @param poolSize The number of {@link org.apache.mahout.classifier.sgd.CrossFoldLearner} to use. + */ + public AdaptiveLogisticRegression(int numCategories, int numFeatures, PriorFunction prior, int threadCount, + int poolSize) { + this.numFeatures = numFeatures; + this.threadCount = threadCount; + this.poolSize = poolSize; + seed = new State<>(new double[2], 10); + Wrapper w = new Wrapper(numCategories, numFeatures, prior); + seed.setPayload(w); + + Wrapper.setMappings(seed); + seed.setPayload(w); + setPoolSize(this.poolSize); + } + + @Override + public void train(int actual, Vector instance) { + train(record, null, actual, instance); + } + + @Override + public void train(long trackingKey, int actual, Vector instance) { + train(trackingKey, null, actual, instance); + } + + @Override + public void train(long trackingKey, String groupKey, int actual, Vector instance) { + record++; + + buffer.add(new TrainingExample(trackingKey, groupKey, actual, instance)); + //don't train until we have enough examples + if (buffer.size() > bufferSize) { + trainWithBufferedExamples(); + } + } + + private void trainWithBufferedExamples() { + try { + this.best = ep.parallelDo(new EvolutionaryProcess.Function<Payload<CrossFoldLearner>>() { + @Override + public double apply(Payload<CrossFoldLearner> z, double[] params) { + Wrapper x = (Wrapper) z; + for (TrainingExample example : buffer) { + x.train(example); + } + if (x.getLearner().validModel()) { + if (x.getLearner().numCategories() == 2) { + return x.wrapped.auc(); + } else { + return x.wrapped.logLikelihood(); + } + } else { + return Double.NaN; + } + } + }); + } catch (InterruptedException e) { + // ignore ... shouldn't happen + log.warn("Ignoring exception", e); + } catch (ExecutionException e) { + throw new IllegalStateException(e.getCause()); + } + buffer.clear(); + + if (record > cutoff) { + cutoff = nextStep(record); + + // evolve based on new fitness + ep.mutatePopulation(SURVIVORS); + + if (freezeSurvivors) { + // now grossly hack the top survivors so they stick around. Set their + // mutation rates small and also hack their learning rate to be small + // as well. + for (State<Wrapper, CrossFoldLearner> state : ep.getPopulation().subList(0, SURVIVORS)) { + Wrapper.freeze(state); + } + } + } + + } + + public int nextStep(int recordNumber) { + int stepSize = stepSize(recordNumber, 2.6); + if (stepSize < minInterval) { + stepSize = minInterval; + } + + if (stepSize > maxInterval) { + stepSize = maxInterval; + } + + int newCutoff = stepSize * (recordNumber / stepSize + 1); + if (newCutoff < cutoff + currentStep) { + newCutoff = cutoff + currentStep; + } else { + this.currentStep = stepSize; + } + return newCutoff; + } + + public static int stepSize(int recordNumber, double multiplier) { + int[] bumps = {1, 2, 5}; + double log = Math.floor(multiplier * Math.log10(recordNumber)); + int bump = bumps[(int) log % bumps.length]; + int scale = (int) Math.pow(10, Math.floor(log / bumps.length)); + + return bump * scale; + } + + @Override + public void close() { + trainWithBufferedExamples(); + try { + ep.parallelDo(new EvolutionaryProcess.Function<Payload<CrossFoldLearner>>() { + @Override + public double apply(Payload<CrossFoldLearner> payload, double[] params) { + CrossFoldLearner learner = ((Wrapper) payload).getLearner(); + learner.close(); + return learner.logLikelihood(); + } + }); + } catch (InterruptedException e) { + log.warn("Ignoring exception", e); + } catch (ExecutionException e) { + throw new IllegalStateException(e); + } finally { + ep.close(); + } + } + + /** + * How often should the evolutionary optimization of learning parameters occur? + * + * @param interval Number of training examples to use in each epoch of optimization. + */ + public void setInterval(int interval) { + setInterval(interval, interval); + } + + /** + * Starts optimization using the shorter interval and progresses to the longer using the specified + * number of steps per decade. Note that values < 200 are not accepted. Values even that small + * are unlikely to be useful. + * + * @param minInterval The minimum epoch length for the evolutionary optimization + * @param maxInterval The maximum epoch length + */ + public void setInterval(int minInterval, int maxInterval) { + this.minInterval = Math.max(200, minInterval); + this.maxInterval = Math.max(200, maxInterval); + this.cutoff = minInterval * (record / minInterval + 1); + this.currentStep = minInterval; + bufferSize = Math.min(minInterval, bufferSize); + } + + public final void setPoolSize(int poolSize) { + this.poolSize = poolSize; + setupOptimizer(poolSize); + } + + public void setThreadCount(int threadCount) { + this.threadCount = threadCount; + setupOptimizer(poolSize); + } + + public void setAucEvaluator(OnlineAuc auc) { + seed.getPayload().setAucEvaluator(auc); + setupOptimizer(poolSize); + } + + private void setupOptimizer(int poolSize) { + ep = new EvolutionaryProcess<>(threadCount, poolSize, seed); + } + + /** + * Returns the size of the internal feature vector. Note that this is not the same as the number + * of distinct features, especially if feature hashing is being used. + * + * @return The internal feature vector size. + */ + public int numFeatures() { + return numFeatures; + } + + /** + * What is the AUC for the current best member of the population. If no member is best, usually + * because we haven't done any training yet, then the result is set to NaN. + * + * @return The AUC of the best member of the population or NaN if we can't figure that out. + */ + public double auc() { + if (best == null) { + return Double.NaN; + } else { + Wrapper payload = best.getPayload(); + return payload.getLearner().auc(); + } + } + + public State<Wrapper, CrossFoldLearner> getBest() { + return best; + } + + public void setBest(State<Wrapper, CrossFoldLearner> best) { + this.best = best; + } + + public int getRecord() { + return record; + } + + public void setRecord(int record) { + this.record = record; + } + + public int getMinInterval() { + return minInterval; + } + + public int getMaxInterval() { + return maxInterval; + } + + public int getNumCategories() { + return seed.getPayload().getLearner().numCategories(); + } + + public PriorFunction getPrior() { + return seed.getPayload().getLearner().getPrior(); + } + + public void setBuffer(List<TrainingExample> buffer) { + this.buffer = buffer; + } + + public List<TrainingExample> getBuffer() { + return buffer; + } + + public EvolutionaryProcess<Wrapper, CrossFoldLearner> getEp() { + return ep; + } + + public void setEp(EvolutionaryProcess<Wrapper, CrossFoldLearner> ep) { + this.ep = ep; + } + + public State<Wrapper, CrossFoldLearner> getSeed() { + return seed; + } + + public void setSeed(State<Wrapper, CrossFoldLearner> seed) { + this.seed = seed; + } + + public int getNumFeatures() { + return numFeatures; + } + + public void setAveragingWindow(int averagingWindow) { + seed.getPayload().getLearner().setWindowSize(averagingWindow); + setupOptimizer(poolSize); + } + + public void setFreezeSurvivors(boolean freezeSurvivors) { + this.freezeSurvivors = freezeSurvivors; + } + + /** + * Provides a shim between the EP optimization stuff and the CrossFoldLearner. The most important + * interface has to do with the parameters of the optimization. These are taken from the double[] + * params in the following order <ul> <li> regularization constant lambda <li> learningRate </ul>. + * All other parameters are set in such a way so as to defeat annealing to the extent possible. + * This lets the evolutionary algorithm handle the annealing. + * <p/> + * Note that per coefficient annealing is still done and no optimization of the per coefficient + * offset is done. + */ + public static class Wrapper implements Payload<CrossFoldLearner> { + private CrossFoldLearner wrapped; + + public Wrapper() { + } + + public Wrapper(int numCategories, int numFeatures, PriorFunction prior) { + wrapped = new CrossFoldLearner(5, numCategories, numFeatures, prior); + } + + @Override + public Wrapper copy() { + Wrapper r = new Wrapper(); + r.wrapped = wrapped.copy(); + return r; + } + + @Override + public void update(double[] params) { + int i = 0; + wrapped.lambda(params[i++]); + wrapped.learningRate(params[i]); + + wrapped.stepOffset(1); + wrapped.alpha(1); + wrapped.decayExponent(0); + } + + public static void freeze(State<Wrapper, CrossFoldLearner> s) { + // radically decrease learning rate + double[] params = s.getParams(); + params[1] -= 10; + + // and cause evolution to hold (almost) + s.setOmni(s.getOmni() / 20); + double[] step = s.getStep(); + for (int i = 0; i < step.length; i++) { + step[i] /= 20; + } + } + + public static void setMappings(State<Wrapper, CrossFoldLearner> x) { + int i = 0; + // set the range for regularization (lambda) + x.setMap(i++, Mapping.logLimit(1.0e-8, 0.1)); + // set the range for learning rate (mu) + x.setMap(i, Mapping.logLimit(1.0e-8, 1)); + } + + public void train(TrainingExample example) { + wrapped.train(example.getKey(), example.getGroupKey(), example.getActual(), example.getInstance()); + } + + public CrossFoldLearner getLearner() { + return wrapped; + } + + @Override + public String toString() { + return String.format(Locale.ENGLISH, "auc=%.2f", wrapped.auc()); + } + + public void setAucEvaluator(OnlineAuc auc) { + wrapped.setAucEvaluator(auc); + } + + @Override + public void write(DataOutput out) throws IOException { + wrapped.write(out); + } + + @Override + public void readFields(DataInput input) throws IOException { + wrapped = new CrossFoldLearner(); + wrapped.readFields(input); + } + } + + public static class TrainingExample implements Writable { + private long key; + private String groupKey; + private int actual; + private Vector instance; + + private TrainingExample() { + } + + public TrainingExample(long key, String groupKey, int actual, Vector instance) { + this.key = key; + this.groupKey = groupKey; + this.actual = actual; + this.instance = instance; + } + + public long getKey() { + return key; + } + + public int getActual() { + return actual; + } + + public Vector getInstance() { + return instance; + } + + public String getGroupKey() { + return groupKey; + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeLong(key); + if (groupKey != null) { + out.writeBoolean(true); + out.writeUTF(groupKey); + } else { + out.writeBoolean(false); + } + out.writeInt(actual); + VectorWritable.writeVector(out, instance, true); + } + + @Override + public void readFields(DataInput in) throws IOException { + key = in.readLong(); + if (in.readBoolean()) { + groupKey = in.readUTF(); + } + actual = in.readInt(); + instance = VectorWritable.readVector(in); + } + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeInt(record); + out.writeInt(cutoff); + out.writeInt(minInterval); + out.writeInt(maxInterval); + out.writeInt(currentStep); + out.writeInt(bufferSize); + + out.writeInt(buffer.size()); + for (TrainingExample example : buffer) { + example.write(out); + } + + ep.write(out); + + best.write(out); + + out.writeInt(threadCount); + out.writeInt(poolSize); + seed.write(out); + out.writeInt(numFeatures); + + out.writeBoolean(freezeSurvivors); + } + + @Override + public void readFields(DataInput in) throws IOException { + record = in.readInt(); + cutoff = in.readInt(); + minInterval = in.readInt(); + maxInterval = in.readInt(); + currentStep = in.readInt(); + bufferSize = in.readInt(); + + int n = in.readInt(); + buffer = new ArrayList<>(); + for (int i = 0; i < n; i++) { + TrainingExample example = new TrainingExample(); + example.readFields(in); + buffer.add(example); + } + + ep = new EvolutionaryProcess<>(); + ep.readFields(in); + + best = new State<>(); + best.readFields(in); + + threadCount = in.readInt(); + poolSize = in.readInt(); + seed = new State<>(); + seed.readFields(in); + + numFeatures = in.readInt(); + freezeSurvivors = in.readBoolean(); + } +} + http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java new file mode 100644 index 0000000..f56814b --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java @@ -0,0 +1,334 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.hadoop.io.Writable; +import org.apache.mahout.classifier.AbstractVectorClassifier; +import org.apache.mahout.classifier.OnlineLearner; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.function.DoubleDoubleFunction; +import org.apache.mahout.math.function.Functions; +import org.apache.mahout.math.stats.GlobalOnlineAuc; +import org.apache.mahout.math.stats.OnlineAuc; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +/** + * Does cross-fold validation of log-likelihood and AUC on several online logistic regression + * models. Each record is passed to all but one of the models for training and to the remaining + * model for evaluation. In order to maintain proper segregation between the different folds across + * training data iterations, data should either be passed to this learner in the same order each + * time the training data is traversed or a tracking key such as the file offset of the training + * record should be passed with each training example. + */ +public class CrossFoldLearner extends AbstractVectorClassifier implements OnlineLearner, Writable { + private int record; + // minimum score to be used for computing log likelihood + private static final double MIN_SCORE = 1.0e-50; + private OnlineAuc auc = new GlobalOnlineAuc(); + private double logLikelihood; + private final List<OnlineLogisticRegression> models = new ArrayList<>(); + + // lambda, learningRate, perTermOffset, perTermExponent + private double[] parameters = new double[4]; + private int numFeatures; + private PriorFunction prior; + private double percentCorrect; + + private int windowSize = Integer.MAX_VALUE; + + public CrossFoldLearner() { + } + + public CrossFoldLearner(int folds, int numCategories, int numFeatures, PriorFunction prior) { + this.numFeatures = numFeatures; + this.prior = prior; + for (int i = 0; i < folds; i++) { + OnlineLogisticRegression model = new OnlineLogisticRegression(numCategories, numFeatures, prior); + model.alpha(1).stepOffset(0).decayExponent(0); + models.add(model); + } + } + + // -------- builder-like configuration methods + + public CrossFoldLearner lambda(double v) { + for (OnlineLogisticRegression model : models) { + model.lambda(v); + } + return this; + } + + public CrossFoldLearner learningRate(double x) { + for (OnlineLogisticRegression model : models) { + model.learningRate(x); + } + return this; + } + + public CrossFoldLearner stepOffset(int x) { + for (OnlineLogisticRegression model : models) { + model.stepOffset(x); + } + return this; + } + + public CrossFoldLearner decayExponent(double x) { + for (OnlineLogisticRegression model : models) { + model.decayExponent(x); + } + return this; + } + + public CrossFoldLearner alpha(double alpha) { + for (OnlineLogisticRegression model : models) { + model.alpha(alpha); + } + return this; + } + + // -------- training methods + @Override + public void train(int actual, Vector instance) { + train(record, null, actual, instance); + } + + @Override + public void train(long trackingKey, int actual, Vector instance) { + train(trackingKey, null, actual, instance); + } + + @Override + public void train(long trackingKey, String groupKey, int actual, Vector instance) { + record++; + int k = 0; + for (OnlineLogisticRegression model : models) { + if (k == mod(trackingKey, models.size())) { + Vector v = model.classifyFull(instance); + double score = Math.max(v.get(actual), MIN_SCORE); + logLikelihood += (Math.log(score) - logLikelihood) / Math.min(record, windowSize); + + int correct = v.maxValueIndex() == actual ? 1 : 0; + percentCorrect += (correct - percentCorrect) / Math.min(record, windowSize); + if (numCategories() == 2) { + auc.addSample(actual, groupKey, v.get(1)); + } + } else { + model.train(trackingKey, groupKey, actual, instance); + } + k++; + } + } + + private static long mod(long x, int y) { + long r = x % y; + return r < 0 ? r + y : r; + } + + @Override + public void close() { + for (OnlineLogisticRegression m : models) { + m.close(); + } + } + + public void resetLineCounter() { + record = 0; + } + + public boolean validModel() { + boolean r = true; + for (OnlineLogisticRegression model : models) { + r &= model.validModel(); + } + return r; + } + + // -------- classification methods + + @Override + public Vector classify(Vector instance) { + Vector r = new DenseVector(numCategories() - 1); + DoubleDoubleFunction scale = Functions.plusMult(1.0 / models.size()); + for (OnlineLogisticRegression model : models) { + r.assign(model.classify(instance), scale); + } + return r; + } + + @Override + public Vector classifyNoLink(Vector instance) { + Vector r = new DenseVector(numCategories() - 1); + DoubleDoubleFunction scale = Functions.plusMult(1.0 / models.size()); + for (OnlineLogisticRegression model : models) { + r.assign(model.classifyNoLink(instance), scale); + } + return r; + } + + @Override + public double classifyScalar(Vector instance) { + double r = 0; + int n = 0; + for (OnlineLogisticRegression model : models) { + n++; + r += model.classifyScalar(instance); + } + return r / n; + } + + // -------- status reporting methods + + @Override + public int numCategories() { + return models.get(0).numCategories(); + } + + public double auc() { + return auc.auc(); + } + + public double logLikelihood() { + return logLikelihood; + } + + public double percentCorrect() { + return percentCorrect; + } + + // -------- evolutionary optimization + + public CrossFoldLearner copy() { + CrossFoldLearner r = new CrossFoldLearner(models.size(), numCategories(), numFeatures, prior); + r.models.clear(); + for (OnlineLogisticRegression model : models) { + model.close(); + OnlineLogisticRegression newModel = + new OnlineLogisticRegression(model.numCategories(), model.numFeatures(), model.prior); + newModel.copyFrom(model); + r.models.add(newModel); + } + return r; + } + + public int getRecord() { + return record; + } + + public void setRecord(int record) { + this.record = record; + } + + public OnlineAuc getAucEvaluator() { + return auc; + } + + public void setAucEvaluator(OnlineAuc auc) { + this.auc = auc; + } + + public double getLogLikelihood() { + return logLikelihood; + } + + public void setLogLikelihood(double logLikelihood) { + this.logLikelihood = logLikelihood; + } + + public List<OnlineLogisticRegression> getModels() { + return models; + } + + public void addModel(OnlineLogisticRegression model) { + models.add(model); + } + + public double[] getParameters() { + return parameters; + } + + public void setParameters(double[] parameters) { + this.parameters = parameters; + } + + public int getNumFeatures() { + return numFeatures; + } + + public void setNumFeatures(int numFeatures) { + this.numFeatures = numFeatures; + } + + public void setWindowSize(int windowSize) { + this.windowSize = windowSize; + auc.setWindowSize(windowSize); + } + + public PriorFunction getPrior() { + return prior; + } + + public void setPrior(PriorFunction prior) { + this.prior = prior; + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeInt(record); + PolymorphicWritable.write(out, auc); + out.writeDouble(logLikelihood); + out.writeInt(models.size()); + for (OnlineLogisticRegression model : models) { + model.write(out); + } + + for (double x : parameters) { + out.writeDouble(x); + } + out.writeInt(numFeatures); + PolymorphicWritable.write(out, prior); + out.writeDouble(percentCorrect); + out.writeInt(windowSize); + } + + @Override + public void readFields(DataInput in) throws IOException { + record = in.readInt(); + auc = PolymorphicWritable.read(in, OnlineAuc.class); + logLikelihood = in.readDouble(); + int n = in.readInt(); + for (int i = 0; i < n; i++) { + OnlineLogisticRegression olr = new OnlineLogisticRegression(); + olr.readFields(in); + models.add(olr); + } + parameters = new double[4]; + for (int i = 0; i < 4; i++) { + parameters[i] = in.readDouble(); + } + numFeatures = in.readInt(); + prior = PolymorphicWritable.read(in, PriorFunction.class); + percentCorrect = in.readDouble(); + windowSize = in.readInt(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java new file mode 100644 index 0000000..dbf3198 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java @@ -0,0 +1,395 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import com.google.common.base.Function; +import com.google.common.base.Preconditions; +import com.google.common.collect.Collections2; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; + +import org.apache.commons.csv.CSVUtils; +import org.apache.mahout.math.Vector; +import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder; +import org.apache.mahout.vectorizer.encoders.ContinuousValueEncoder; +import org.apache.mahout.vectorizer.encoders.Dictionary; +import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder; +import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder; +import org.apache.mahout.vectorizer.encoders.TextValueEncoder; + +import java.io.IOException; +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; + +/** + * Converts CSV data lines to vectors. + * + * Use of this class proceeds in a few steps. + * <ul> + * <li> At construction time, you tell the class about the target variable and provide + * a dictionary of the types of the predictor values. At this point, + * the class yet cannot decode inputs because it doesn't know the fields that are in the + * data records, nor their order. + * <li> Optionally, you tell the parser object about the possible values of the target + * variable. If you don't do this then you probably should set the number of distinct + * values so that the target variable values will be taken from a restricted range. + * <li> Later, when you get a list of the fields, typically from the first line of a CSV + * file, you tell the factory about these fields and it builds internal data structures + * that allow it to decode inputs. The most important internal state is the field numbers + * for various fields. After this point, you can use the factory for decoding data. + * <li> To encode data as a vector, you present a line of input to the factory and it + * mutates a vector that you provide. The factory also retains trace information so + * that it can approximately reverse engineer vectors later. + * <li> After converting data, you can ask for an explanation of the data in terms of + * terms and weights. In order to explain a vector accurately, the factory needs to + * have seen the particular values of categorical fields (typically during encoding vectors) + * and needs to have a reasonably small number of collisions in the vector encoding. + * </ul> + */ +public class CsvRecordFactory implements RecordFactory { + private static final String INTERCEPT_TERM = "Intercept Term"; + + private static final Map<String, Class<? extends FeatureVectorEncoder>> TYPE_DICTIONARY = + ImmutableMap.<String, Class<? extends FeatureVectorEncoder>>builder() + .put("continuous", ContinuousValueEncoder.class) + .put("numeric", ContinuousValueEncoder.class) + .put("n", ContinuousValueEncoder.class) + .put("word", StaticWordValueEncoder.class) + .put("w", StaticWordValueEncoder.class) + .put("text", TextValueEncoder.class) + .put("t", TextValueEncoder.class) + .build(); + + private final Map<String, Set<Integer>> traceDictionary = new TreeMap<>(); + + private int target; + private final Dictionary targetDictionary; + + //Which column is used for identify a CSV file line + private String idName; + private int id = -1; + + private List<Integer> predictors; + private Map<Integer, FeatureVectorEncoder> predictorEncoders; + private int maxTargetValue = Integer.MAX_VALUE; + private final String targetName; + private final Map<String, String> typeMap; + private List<String> variableNames; + private boolean includeBiasTerm; + private static final String CANNOT_CONSTRUCT_CONVERTER = + "Unable to construct type converter... shouldn't be possible"; + + /** + * Parse a single line of CSV-formatted text. + * + * Separated to make changing this functionality for the entire class easier + * in the future. + * @param line - CSV formatted text + * @return List<String> + */ + private List<String> parseCsvLine(String line) { + try { + return Arrays.asList(CSVUtils.parseLine(line)); + } + catch (IOException e) { + List<String> list = new ArrayList<>(); + list.add(line); + return list; + } + } + + private List<String> parseCsvLine(CharSequence line) { + return parseCsvLine(line.toString()); + } + + /** + * Construct a parser for CSV lines that encodes the parsed data in vector form. + * @param targetName The name of the target variable. + * @param typeMap A map describing the types of the predictor variables. + */ + public CsvRecordFactory(String targetName, Map<String, String> typeMap) { + this.targetName = targetName; + this.typeMap = typeMap; + targetDictionary = new Dictionary(); + } + + public CsvRecordFactory(String targetName, String idName, Map<String, String> typeMap) { + this(targetName, typeMap); + this.idName = idName; + } + + /** + * Defines the values and thus the encoding of values of the target variables. Note + * that any values of the target variable not present in this list will be given the + * value of the last member of the list. + * @param values The values the target variable can have. + */ + @Override + public void defineTargetCategories(List<String> values) { + Preconditions.checkArgument( + values.size() <= maxTargetValue, + "Must have less than or equal to " + maxTargetValue + " categories for target variable, but found " + + values.size()); + if (maxTargetValue == Integer.MAX_VALUE) { + maxTargetValue = values.size(); + } + + for (String value : values) { + targetDictionary.intern(value); + } + } + + /** + * Defines the number of target variable categories, but allows this parser to + * pick encodings for them as they appear. + * @param max The number of categories that will be expected. Once this many have been + * seen, all others will get the encoding max-1. + */ + @Override + public CsvRecordFactory maxTargetValue(int max) { + maxTargetValue = max; + return this; + } + + @Override + public boolean usesFirstLineAsSchema() { + return true; + } + + /** + * Processes the first line of a file (which should contain the variable names). The target and + * predictor column numbers are set from the names on this line. + * + * @param line Header line for the file. + */ + @Override + public void firstLine(String line) { + // read variable names, build map of name -> column + final Map<String, Integer> vars = new HashMap<>(); + variableNames = parseCsvLine(line); + int column = 0; + for (String var : variableNames) { + vars.put(var, column++); + } + + // record target column and establish dictionary for decoding target + target = vars.get(targetName); + + // record id column + if (idName != null) { + id = vars.get(idName); + } + + // create list of predictor column numbers + predictors = new ArrayList<>(Collections2.transform(typeMap.keySet(), new Function<String, Integer>() { + @Override + public Integer apply(String from) { + Integer r = vars.get(from); + Preconditions.checkArgument(r != null, "Can't find variable %s, only know about %s", from, vars); + return r; + } + })); + + if (includeBiasTerm) { + predictors.add(-1); + } + Collections.sort(predictors); + + // and map from column number to type encoder for each column that is a predictor + predictorEncoders = new HashMap<>(); + for (Integer predictor : predictors) { + String name; + Class<? extends FeatureVectorEncoder> c; + if (predictor == -1) { + name = INTERCEPT_TERM; + c = ConstantValueEncoder.class; + } else { + name = variableNames.get(predictor); + c = TYPE_DICTIONARY.get(typeMap.get(name)); + } + try { + Preconditions.checkArgument(c != null, "Invalid type of variable %s, wanted one of %s", + typeMap.get(name), TYPE_DICTIONARY.keySet()); + Constructor<? extends FeatureVectorEncoder> constructor = c.getConstructor(String.class); + Preconditions.checkArgument(constructor != null, "Can't find correct constructor for %s", typeMap.get(name)); + FeatureVectorEncoder encoder = constructor.newInstance(name); + predictorEncoders.put(predictor, encoder); + encoder.setTraceDictionary(traceDictionary); + } catch (InstantiationException e) { + throw new IllegalStateException(CANNOT_CONSTRUCT_CONVERTER, e); + } catch (IllegalAccessException e) { + throw new IllegalStateException(CANNOT_CONSTRUCT_CONVERTER, e); + } catch (InvocationTargetException e) { + throw new IllegalStateException(CANNOT_CONSTRUCT_CONVERTER, e); + } catch (NoSuchMethodException e) { + throw new IllegalStateException(CANNOT_CONSTRUCT_CONVERTER, e); + } + } + } + + + /** + * Decodes a single line of CSV data and records the target and predictor variables in a record. + * As a side effect, features are added into the featureVector. Returns the value of the target + * variable. + * + * @param line The raw data. + * @param featureVector Where to fill in the features. Should be zeroed before calling + * processLine. + * @return The value of the target variable. + */ + @Override + public int processLine(String line, Vector featureVector) { + List<String> values = parseCsvLine(line); + + int targetValue = targetDictionary.intern(values.get(target)); + if (targetValue >= maxTargetValue) { + targetValue = maxTargetValue - 1; + } + + for (Integer predictor : predictors) { + String value; + if (predictor >= 0) { + value = values.get(predictor); + } else { + value = null; + } + predictorEncoders.get(predictor).addToVector(value, featureVector); + } + return targetValue; + } + + /*** + * Decodes a single line of CSV data and records the target(if retrunTarget is true) + * and predictor variables in a record. As a side effect, features are added into the featureVector. + * Returns the value of the target variable. When used during classify against production data without + * target value, the method will be called with returnTarget = false. + * @param line The raw data. + * @param featureVector Where to fill in the features. Should be zeroed before calling + * processLine. + * @param returnTarget whether process and return target value, -1 will be returned if false. + * @return The value of the target variable. + */ + public int processLine(CharSequence line, Vector featureVector, boolean returnTarget) { + List<String> values = parseCsvLine(line); + int targetValue = -1; + if (returnTarget) { + targetValue = targetDictionary.intern(values.get(target)); + if (targetValue >= maxTargetValue) { + targetValue = maxTargetValue - 1; + } + } + + for (Integer predictor : predictors) { + String value = predictor >= 0 ? values.get(predictor) : null; + predictorEncoders.get(predictor).addToVector(value, featureVector); + } + return targetValue; + } + + /*** + * Extract the raw target string from a line read from a CSV file. + * @param line the line of content read from CSV file + * @return the raw target value in the corresponding column of CSV line + */ + public String getTargetString(CharSequence line) { + List<String> values = parseCsvLine(line); + return values.get(target); + + } + + /*** + * Extract the corresponding raw target label according to a code + * @param code the integer code encoded during training process + * @return the raw target label + */ + public String getTargetLabel(int code) { + for (String key : targetDictionary.values()) { + if (targetDictionary.intern(key) == code) { + return key; + } + } + return null; + } + + /*** + * Extract the id column value from the CSV record + * @param line the line of content read from CSV file + * @return the id value of the CSV record + */ + public String getIdString(CharSequence line) { + List<String> values = parseCsvLine(line); + return values.get(id); + } + + /** + * Returns a list of the names of the predictor variables. + * + * @return A list of variable names. + */ + @Override + public Iterable<String> getPredictors() { + return Lists.transform(predictors, new Function<Integer, String>() { + @Override + public String apply(Integer v) { + if (v >= 0) { + return variableNames.get(v); + } else { + return INTERCEPT_TERM; + } + } + }); + } + + @Override + public Map<String, Set<Integer>> getTraceDictionary() { + return traceDictionary; + } + + @Override + public CsvRecordFactory includeBiasTerm(boolean useBias) { + includeBiasTerm = useBias; + return this; + } + + @Override + public List<String> getTargetCategories() { + List<String> r = targetDictionary.values(); + if (r.size() > maxTargetValue) { + r.subList(maxTargetValue, r.size()).clear(); + } + return r; + } + + public String getIdName() { + return idName; + } + + public void setIdName(String idName) { + this.idName = idName; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/DefaultGradient.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/DefaultGradient.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/DefaultGradient.java new file mode 100644 index 0000000..f81d8ce --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/DefaultGradient.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.mahout.classifier.AbstractVectorClassifier; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.function.Functions; + +/** + * Implements the basic logistic training law. + */ +public class DefaultGradient implements Gradient { + /** + * Provides a default gradient computation useful for logistic regression. + * + * @param groupKey A grouping key to allow per-something AUC loss to be used for training. + * @param actual The target variable value. + * @param instance The current feature vector to use for gradient computation + * @param classifier The classifier that can compute scores + * @return The gradient to be applied to beta + */ + @Override + public final Vector apply(String groupKey, int actual, Vector instance, AbstractVectorClassifier classifier) { + // what does the current model say? + Vector v = classifier.classify(instance); + + Vector r = v.like(); + if (actual != 0) { + r.setQuick(actual - 1, 1); + } + r.assign(v, Functions.MINUS); + return r; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/ElasticBandPrior.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/ElasticBandPrior.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/ElasticBandPrior.java new file mode 100644 index 0000000..8128370 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/ElasticBandPrior.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +/** + * Implements a linear combination of L1 and L2 priors. This can give an + * interesting mixture of sparsity and load-sharing between redundant predictors. + */ +public class ElasticBandPrior implements PriorFunction { + private double alphaByLambda; + private L1 l1; + private L2 l2; + + // Exists for Writable + public ElasticBandPrior() { + this(0.0); + } + + public ElasticBandPrior(double alphaByLambda) { + this.alphaByLambda = alphaByLambda; + l1 = new L1(); + l2 = new L2(1); + } + + @Override + public double age(double oldValue, double generations, double learningRate) { + oldValue *= Math.pow(1 - alphaByLambda * learningRate, generations); + double newValue = oldValue - Math.signum(oldValue) * learningRate * generations; + if (newValue * oldValue < 0.0) { + // don't allow the value to change sign + return 0.0; + } else { + return newValue; + } + } + + @Override + public double logP(double betaIJ) { + return l1.logP(betaIJ) + alphaByLambda * l2.logP(betaIJ); + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeDouble(alphaByLambda); + l1.write(out); + l2.write(out); + } + + @Override + public void readFields(DataInput in) throws IOException { + alphaByLambda = in.readDouble(); + l1 = new L1(); + l1.readFields(in); + l2 = new L2(); + l2.readFields(in); + } +}
