http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/Gradient.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/Gradient.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/Gradient.java new file mode 100644 index 0000000..524fc06 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/Gradient.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.mahout.classifier.AbstractVectorClassifier; +import org.apache.mahout.math.Vector; + +/** + * Provides the ability to inject a gradient into the SGD logistic regresion. + * Typical uses of this are to use a ranking score such as AUC instead of a + * normal loss function. + */ +public interface Gradient { + Vector apply(String groupKey, int actual, Vector instance, AbstractVectorClassifier classifier); +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/GradientMachine.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/GradientMachine.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/GradientMachine.java new file mode 100644 index 0000000..90ef7a8 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/GradientMachine.java @@ -0,0 +1,405 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.hadoop.io.Writable; +import org.apache.mahout.classifier.AbstractVectorClassifier; +import org.apache.mahout.classifier.OnlineLearner; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.function.Functions; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.Collection; +import java.util.HashSet; +import java.util.Random; + +/** + * Online gradient machine learner that tries to minimize the label ranking hinge loss. + * Implements a gradient machine with one sigmpod hidden layer. + * It tries to minimize the ranking loss of some given set of labels, + * so this can be used for multi-class, multi-label + * or auto-encoding of sparse data (e.g. text). + */ +public class GradientMachine extends AbstractVectorClassifier implements OnlineLearner, Writable { + + public static final int WRITABLE_VERSION = 1; + + // the learning rate of the algorithm + private double learningRate = 0.1; + + // the regularization term, a positive number that controls the size of the weight vector + private double regularization = 0.1; + + // the sparsity term, a positive number that controls the sparsity of the hidden layer. (0 - 1) + private double sparsity = 0.1; + + // the sparsity learning rate. + private double sparsityLearningRate = 0.1; + + // the number of features + private int numFeatures = 10; + // the number of hidden nodes + private int numHidden = 100; + // the number of output nodes + private int numOutput = 2; + + // coefficients for the input to hidden layer. + // There are numHidden Vectors of dimension numFeatures. + private Vector[] hiddenWeights; + + // coefficients for the hidden to output layer. + // There are numOuput Vectors of dimension numHidden. + private Vector[] outputWeights; + + // hidden unit bias + private Vector hiddenBias; + + // output unit bias + private Vector outputBias; + + private final Random rnd; + + public GradientMachine(int numFeatures, int numHidden, int numOutput) { + this.numFeatures = numFeatures; + this.numHidden = numHidden; + this.numOutput = numOutput; + hiddenWeights = new DenseVector[numHidden]; + for (int i = 0; i < numHidden; i++) { + hiddenWeights[i] = new DenseVector(numFeatures); + hiddenWeights[i].assign(0); + } + hiddenBias = new DenseVector(numHidden); + hiddenBias.assign(0); + outputWeights = new DenseVector[numOutput]; + for (int i = 0; i < numOutput; i++) { + outputWeights[i] = new DenseVector(numHidden); + outputWeights[i].assign(0); + } + outputBias = new DenseVector(numOutput); + outputBias.assign(0); + rnd = RandomUtils.getRandom(); + } + + /** + * Initialize weights. + * + * @param gen random number generator. + */ + public void initWeights(Random gen) { + double hiddenFanIn = 1.0 / Math.sqrt(numFeatures); + for (int i = 0; i < numHidden; i++) { + for (int j = 0; j < numFeatures; j++) { + double val = (2.0 * gen.nextDouble() - 1.0) * hiddenFanIn; + hiddenWeights[i].setQuick(j, val); + } + } + double outputFanIn = 1.0 / Math.sqrt(numHidden); + for (int i = 0; i < numOutput; i++) { + for (int j = 0; j < numHidden; j++) { + double val = (2.0 * gen.nextDouble() - 1.0) * outputFanIn; + outputWeights[i].setQuick(j, val); + } + } + } + + /** + * Chainable configuration option. + * + * @param learningRate New value of initial learning rate. + * @return This, so other configurations can be chained. + */ + public GradientMachine learningRate(double learningRate) { + this.learningRate = learningRate; + return this; + } + + /** + * Chainable configuration option. + * + * @param regularization A positive value that controls the weight vector size. + * @return This, so other configurations can be chained. + */ + public GradientMachine regularization(double regularization) { + this.regularization = regularization; + return this; + } + + /** + * Chainable configuration option. + * + * @param sparsity A value between zero and one that controls the fraction of hidden units + * that are activated on average. + * @return This, so other configurations can be chained. + */ + public GradientMachine sparsity(double sparsity) { + this.sparsity = sparsity; + return this; + } + + /** + * Chainable configuration option. + * + * @param sparsityLearningRate New value of initial learning rate for sparsity. + * @return This, so other configurations can be chained. + */ + public GradientMachine sparsityLearningRate(double sparsityLearningRate) { + this.sparsityLearningRate = sparsityLearningRate; + return this; + } + + public void copyFrom(GradientMachine other) { + numFeatures = other.numFeatures; + numHidden = other.numHidden; + numOutput = other.numOutput; + learningRate = other.learningRate; + regularization = other.regularization; + sparsity = other.sparsity; + sparsityLearningRate = other.sparsityLearningRate; + hiddenWeights = new DenseVector[numHidden]; + for (int i = 0; i < numHidden; i++) { + hiddenWeights[i] = other.hiddenWeights[i].clone(); + } + hiddenBias = other.hiddenBias.clone(); + outputWeights = new DenseVector[numOutput]; + for (int i = 0; i < numOutput; i++) { + outputWeights[i] = other.outputWeights[i].clone(); + } + outputBias = other.outputBias.clone(); + } + + @Override + public int numCategories() { + return numOutput; + } + + public int numFeatures() { + return numFeatures; + } + + public int numHidden() { + return numHidden; + } + + /** + * Feeds forward from input to hidden unit.. + * + * @return Hidden unit activations. + */ + public DenseVector inputToHidden(Vector input) { + DenseVector activations = new DenseVector(numHidden); + for (int i = 0; i < numHidden; i++) { + activations.setQuick(i, hiddenWeights[i].dot(input)); + } + activations.assign(hiddenBias, Functions.PLUS); + activations.assign(Functions.min(40.0)).assign(Functions.max(-40)); + activations.assign(Functions.SIGMOID); + return activations; + } + + /** + * Feeds forward from hidden to output + * + * @return Output unit activations. + */ + public DenseVector hiddenToOutput(Vector hiddenActivation) { + DenseVector activations = new DenseVector(numOutput); + for (int i = 0; i < numOutput; i++) { + activations.setQuick(i, outputWeights[i].dot(hiddenActivation)); + } + activations.assign(outputBias, Functions.PLUS); + return activations; + } + + /** + * Updates using ranking loss. + * + * @param hiddenActivation the hidden unit's activation + * @param goodLabels the labels you want ranked above others. + * @param numTrials how many times you want to search for the highest scoring bad label. + * @param gen Random number generator. + */ + public void updateRanking(Vector hiddenActivation, + Collection<Integer> goodLabels, + int numTrials, + Random gen) { + // All the labels are good, do nothing. + if (goodLabels.size() >= numOutput) { + return; + } + for (Integer good : goodLabels) { + double goodScore = outputWeights[good].dot(hiddenActivation); + int highestBad = -1; + double highestBadScore = Double.NEGATIVE_INFINITY; + for (int i = 0; i < numTrials; i++) { + int bad = gen.nextInt(numOutput); + while (goodLabels.contains(bad)) { + bad = gen.nextInt(numOutput); + } + double badScore = outputWeights[bad].dot(hiddenActivation); + if (badScore > highestBadScore) { + highestBadScore = badScore; + highestBad = bad; + } + } + int bad = highestBad; + double loss = 1.0 - goodScore + highestBadScore; + if (loss < 0.0) { + continue; + } + // Note from the loss above the gradient dloss/dy , y being the label is -1 for good + // and +1 for bad. + // dy / dw is just w since y = x' * w + b. + // Hence by the chain rule, dloss / dw = dloss / dy * dy / dw = -w. + // For the regularization part, 0.5 * lambda * w' w, the gradient is lambda * w. + // dy / db = 1. + Vector gradGood = outputWeights[good].clone(); + gradGood.assign(Functions.NEGATE); + Vector propHidden = gradGood.clone(); + Vector gradBad = outputWeights[bad].clone(); + propHidden.assign(gradBad, Functions.PLUS); + gradGood.assign(Functions.mult(-learningRate * (1.0 - regularization))); + outputWeights[good].assign(gradGood, Functions.PLUS); + gradBad.assign(Functions.mult(-learningRate * (1.0 + regularization))); + outputWeights[bad].assign(gradBad, Functions.PLUS); + outputBias.setQuick(good, outputBias.get(good) + learningRate); + outputBias.setQuick(bad, outputBias.get(bad) - learningRate); + // Gradient of sigmoid is s * (1 -s). + Vector gradSig = hiddenActivation.clone(); + gradSig.assign(Functions.SIGMOIDGRADIENT); + // Multiply by the change caused by the ranking loss. + for (int i = 0; i < numHidden; i++) { + gradSig.setQuick(i, gradSig.get(i) * propHidden.get(i)); + } + for (int i = 0; i < numHidden; i++) { + for (int j = 0; j < numFeatures; j++) { + double v = hiddenWeights[i].get(j); + v -= learningRate * (gradSig.get(i) + regularization * v); + hiddenWeights[i].setQuick(j, v); + } + } + } + } + + @Override + public Vector classify(Vector instance) { + Vector result = classifyNoLink(instance); + // Find the max value's index. + int max = result.maxValueIndex(); + result.assign(0); + result.setQuick(max, 1.0); + return result.viewPart(1, result.size() - 1); + } + + @Override + public Vector classifyNoLink(Vector instance) { + DenseVector hidden = inputToHidden(instance); + return hiddenToOutput(hidden); + } + + @Override + public double classifyScalar(Vector instance) { + Vector output = classifyNoLink(instance); + if (output.get(0) > output.get(1)) { + return 0; + } + return 1; + } + + public GradientMachine copy() { + close(); + GradientMachine r = new GradientMachine(numFeatures(), numHidden(), numCategories()); + r.copyFrom(this); + return r; + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeInt(WRITABLE_VERSION); + out.writeDouble(learningRate); + out.writeDouble(regularization); + out.writeDouble(sparsity); + out.writeDouble(sparsityLearningRate); + out.writeInt(numFeatures); + out.writeInt(numHidden); + out.writeInt(numOutput); + VectorWritable.writeVector(out, hiddenBias); + for (int i = 0; i < numHidden; i++) { + VectorWritable.writeVector(out, hiddenWeights[i]); + } + VectorWritable.writeVector(out, outputBias); + for (int i = 0; i < numOutput; i++) { + VectorWritable.writeVector(out, outputWeights[i]); + } + } + + @Override + public void readFields(DataInput in) throws IOException { + int version = in.readInt(); + if (version == WRITABLE_VERSION) { + learningRate = in.readDouble(); + regularization = in.readDouble(); + sparsity = in.readDouble(); + sparsityLearningRate = in.readDouble(); + numFeatures = in.readInt(); + numHidden = in.readInt(); + numOutput = in.readInt(); + hiddenWeights = new DenseVector[numHidden]; + hiddenBias = VectorWritable.readVector(in); + for (int i = 0; i < numHidden; i++) { + hiddenWeights[i] = VectorWritable.readVector(in); + } + outputWeights = new DenseVector[numOutput]; + outputBias = VectorWritable.readVector(in); + for (int i = 0; i < numOutput; i++) { + outputWeights[i] = VectorWritable.readVector(in); + } + } else { + throw new IOException("Incorrect object version, wanted " + WRITABLE_VERSION + " got " + version); + } + } + + @Override + public void close() { + // This is an online classifier, nothing to do. + } + + @Override + public void train(long trackingKey, String groupKey, int actual, Vector instance) { + Vector hiddenActivation = inputToHidden(instance); + hiddenToOutput(hiddenActivation); + Collection<Integer> goodLabels = new HashSet<>(); + goodLabels.add(actual); + updateRanking(hiddenActivation, goodLabels, 2, rnd); + } + + @Override + public void train(long trackingKey, int actual, Vector instance) { + train(trackingKey, null, actual, instance); + } + + @Override + public void train(int actual, Vector instance) { + train(0, null, actual, instance); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/L1.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/L1.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/L1.java new file mode 100644 index 0000000..28a05f2 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/L1.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +/** + * Implements the Laplacian or bi-exponential prior. This prior has a strong tendency to set coefficients to zero + * and thus is useful as an alternative to variable selection. This version implements truncation which prevents + * a coefficient from changing sign. If a correction would change the sign, the coefficient is truncated to zero. + * + * Note that it doesn't matter to have a scale for this distribution because after taking the derivative of the logP, + * the lambda coefficient used to combine the prior with the observations has the same effect. If we had a scale here, + * then it would be the same effect as just changing lambda. + */ +public class L1 implements PriorFunction { + @Override + public double age(double oldValue, double generations, double learningRate) { + double newValue = oldValue - Math.signum(oldValue) * learningRate * generations; + if (newValue * oldValue < 0) { + // don't allow the value to change sign + return 0; + } else { + return newValue; + } + } + + @Override + public double logP(double betaIJ) { + return -Math.abs(betaIJ); + } + + @Override + public void write(DataOutput out) throws IOException { + // stateless class has nothing to serialize + } + + @Override + public void readFields(DataInput dataInput) throws IOException { + // stateless class has nothing to serialize + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/L2.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/L2.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/L2.java new file mode 100644 index 0000000..3dfb9fc --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/L2.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +/** + * Implements the Gaussian prior. This prior has a tendency to decrease large coefficients toward zero, but + * doesn't tend to set them to exactly zero. + */ +public class L2 implements PriorFunction { + + private static final double HALF_LOG_2PI = Math.log(2.0 * Math.PI) / 2.0; + + private double s2; + private double s; + + public L2(double scale) { + s = scale; + s2 = scale * scale; + } + + public L2() { + s = 1.0; + s2 = 1.0; + } + + @Override + public double age(double oldValue, double generations, double learningRate) { + return oldValue * Math.pow(1.0 - learningRate / s2, generations); + } + + @Override + public double logP(double betaIJ) { + return -betaIJ * betaIJ / s2 / 2.0 - Math.log(s) - HALF_LOG_2PI; + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeDouble(s2); + out.writeDouble(s); + } + + @Override + public void readFields(DataInput in) throws IOException { + s2 = in.readDouble(); + s = in.readDouble(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/MixedGradient.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/MixedGradient.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/MixedGradient.java new file mode 100644 index 0000000..a290b22 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/MixedGradient.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.mahout.classifier.AbstractVectorClassifier; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.Vector; + +import java.util.Random; + +/** + * <p>Provides a stochastic mixture of ranking updates and normal logistic updates. This uses a + * combination of AUC driven learning to improve ranking performance and traditional log-loss driven + * learning to improve log-likelihood.</p> + * + * <p>See www.eecs.tufts.edu/~dsculley/papers/combined-ranking-and-regression.pdf</p> + * + * <p>This implementation only makes sense for the binomial case.</p> + */ +public class MixedGradient implements Gradient { + + private final double alpha; + private final RankingGradient rank; + private final Gradient basic; + private final Random random = RandomUtils.getRandom(); + private boolean hasZero; + private boolean hasOne; + + public MixedGradient(double alpha, int window) { + this.alpha = alpha; + this.rank = new RankingGradient(window); + this.basic = this.rank.getBaseGradient(); + } + + @Override + public Vector apply(String groupKey, int actual, Vector instance, AbstractVectorClassifier classifier) { + if (random.nextDouble() < alpha) { + // one option is to apply a ranking update relative to our recent history + if (!hasZero || !hasOne) { + throw new IllegalStateException(); + } + return rank.apply(groupKey, actual, instance, classifier); + } else { + hasZero |= actual == 0; + hasOne |= actual == 1; + // the other option is a normal update, but we have to update our history on the way + rank.addToHistory(actual, instance); + return basic.apply(groupKey, actual, instance, classifier); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java new file mode 100644 index 0000000..bcd2ebc --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import com.google.common.collect.Ordering; +import org.apache.mahout.classifier.AbstractVectorClassifier; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.Vector; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; +import java.util.Queue; +import java.util.Set; + +/** + * Uses sample data to reverse engineer a feature-hashed model. + * + * The result gives approximate weights for features and interactions + * in the original space. + * + * The idea is that the hashed encoders have the option of having a trace dictionary. This + * tells us where each feature is hashed to, or each feature/value combination in the case + * of word-like values. Using this dictionary, we can put values into a synthetic feature + * vector in just the locations specified by a single feature or interaction. Then we can + * push this through a linear part of a model to see the contribution of that input. For + * any generalized linear model like logistic regression, there is a linear part of the + * model that allows this. + * + * What the ModelDissector does is to accept a trace dictionary and a model in an update + * method. It figures out the weights for the elements in the trace dictionary and stashes + * them. Then in a summary method, the biggest weights are returned. This update/flush + * style is used so that the trace dictionary doesn't have to grow to enormous levels, + * but instead can be cleared between updates. + */ +public class ModelDissector { + private final Map<String,Vector> weightMap; + + public ModelDissector() { + weightMap = new HashMap<>(); + } + + /** + * Probes a model to determine the effect of a particular variable. This is done + * with the ade of a trace dictionary which has recorded the locations in the feature + * vector that are modified by various variable values. We can set these locations to + * 1 and then look at the resulting score. This tells us the weight the model places + * on that variable. + * @param features A feature vector to use (destructively) + * @param traceDictionary A trace dictionary containing variables and what locations + * in the feature vector are affected by them + * @param learner The model that we are probing to find weights on features + */ + + public void update(Vector features, Map<String, Set<Integer>> traceDictionary, AbstractVectorClassifier learner) { + // zero out feature vector + features.assign(0); + for (Map.Entry<String, Set<Integer>> entry : traceDictionary.entrySet()) { + // get a feature and locations where it is stored in the feature vector + String key = entry.getKey(); + Set<Integer> value = entry.getValue(); + + // if we haven't looked at this feature yet + if (!weightMap.containsKey(key)) { + // put probe values in the feature vector + for (Integer where : value) { + features.set(where, 1); + } + + // see what the model says + Vector v = learner.classifyNoLink(features); + weightMap.put(key, v); + + // and zero out those locations again + for (Integer where : value) { + features.set(where, 0); + } + } + } + } + + /** + * Returns the n most important features with their + * weights, most important category and the top few + * categories that they affect. + * @param n How many results to return. + * @return A list of the top variables. + */ + public List<Weight> summary(int n) { + Queue<Weight> pq = new PriorityQueue<>(); + for (Map.Entry<String, Vector> entry : weightMap.entrySet()) { + pq.add(new Weight(entry.getKey(), entry.getValue())); + while (pq.size() > n) { + pq.poll(); + } + } + List<Weight> r = new ArrayList<>(pq); + Collections.sort(r, Ordering.natural().reverse()); + return r; + } + + private static final class Category implements Comparable<Category> { + private final int index; + private final double weight; + + private Category(int index, double weight) { + this.index = index; + this.weight = weight; + } + + @Override + public int compareTo(Category o) { + int r = Double.compare(Math.abs(weight), Math.abs(o.weight)); + if (r == 0) { + if (o.index < index) { + return -1; + } + if (o.index > index) { + return 1; + } + return 0; + } + return r; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof Category)) { + return false; + } + Category other = (Category) o; + return index == other.index && weight == other.weight; + } + + @Override + public int hashCode() { + return RandomUtils.hashDouble(weight) ^ index; + } + + } + + public static class Weight implements Comparable<Weight> { + private final String feature; + private final double value; + private final int maxIndex; + private final List<Category> categories; + + public Weight(String feature, Vector weights) { + this(feature, weights, 3); + } + + public Weight(String feature, Vector weights, int n) { + this.feature = feature; + // pick out the weight with the largest abs value, but don't forget the sign + Queue<Category> biggest = new PriorityQueue<>(n + 1, Ordering.natural()); + for (Vector.Element element : weights.all()) { + biggest.add(new Category(element.index(), element.get())); + while (biggest.size() > n) { + biggest.poll(); + } + } + categories = new ArrayList<>(biggest); + Collections.sort(categories, Ordering.natural().reverse()); + value = categories.get(0).weight; + maxIndex = categories.get(0).index; + } + + @Override + public int compareTo(Weight other) { + int r = Double.compare(Math.abs(this.value), Math.abs(other.value)); + if (r == 0) { + return feature.compareTo(other.feature); + } + return r; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof Weight)) { + return false; + } + Weight other = (Weight) o; + return feature.equals(other.feature) + && value == other.value + && maxIndex == other.maxIndex + && categories.equals(other.categories); + } + + @Override + public int hashCode() { + return feature.hashCode() ^ RandomUtils.hashDouble(value) ^ maxIndex ^ categories.hashCode(); + } + + public String getFeature() { + return feature; + } + + public double getWeight() { + return value; + } + + public double getWeight(int n) { + return categories.get(n).weight; + } + + public double getCategory(int n) { + return categories.get(n).index; + } + + public int getMaxImpact() { + return maxIndex; + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java new file mode 100644 index 0000000..f89b245 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java @@ -0,0 +1,67 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import java.io.DataInput; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; + +import com.google.common.io.Closeables; +import org.apache.hadoop.io.Writable; + +/** + * Provides the ability to store SGD model-related objects as binary files. + */ +public final class ModelSerializer { + + // static class ... don't instantiate + private ModelSerializer() { + } + + public static void writeBinary(String path, CrossFoldLearner model) throws IOException { + try (DataOutputStream out = new DataOutputStream(new FileOutputStream(path))) { + PolymorphicWritable.write(out, model); + } + } + + public static void writeBinary(String path, OnlineLogisticRegression model) throws IOException { + try (DataOutputStream out = new DataOutputStream(new FileOutputStream(path))) { + PolymorphicWritable.write(out, model); + } + } + + public static void writeBinary(String path, AdaptiveLogisticRegression model) throws IOException { + try (DataOutputStream out = new DataOutputStream(new FileOutputStream(path))){ + PolymorphicWritable.write(out, model); + } + } + + public static <T extends Writable> T readBinary(InputStream in, Class<T> clazz) throws IOException { + DataInput dataIn = new DataInputStream(in); + try { + return PolymorphicWritable.read(dataIn, clazz); + } finally { + Closeables.close(in, false); + } + } + + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java new file mode 100644 index 0000000..7a9ca83 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.hadoop.io.Writable; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.MatrixWritable; +import org.apache.mahout.math.VectorWritable; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +/** + * Extends the basic on-line logistic regression learner with a specific set of learning + * rate annealing schedules. + */ +public class OnlineLogisticRegression extends AbstractOnlineLogisticRegression implements Writable { + public static final int WRITABLE_VERSION = 1; + + // these next two control decayFactor^steps exponential type of annealing + // learning rate and decay factor + private double mu0 = 1; + private double decayFactor = 1 - 1.0e-3; + + // these next two control 1/steps^forget type annealing + private int stepOffset = 10; + // -1 equals even weighting of all examples, 0 means only use exponential annealing + private double forgettingExponent = -0.5; + + // controls how per term annealing works + private int perTermAnnealingOffset = 20; + + public OnlineLogisticRegression() { + // private constructor available for serialization, but not normal use + } + + public OnlineLogisticRegression(int numCategories, int numFeatures, PriorFunction prior) { + this.numCategories = numCategories; + this.prior = prior; + + updateSteps = new DenseVector(numFeatures); + updateCounts = new DenseVector(numFeatures).assign(perTermAnnealingOffset); + beta = new DenseMatrix(numCategories - 1, numFeatures); + } + + /** + * Chainable configuration option. + * + * @param alpha New value of decayFactor, the exponential decay rate for the learning rate. + * @return This, so other configurations can be chained. + */ + public OnlineLogisticRegression alpha(double alpha) { + this.decayFactor = alpha; + return this; + } + + @Override + public OnlineLogisticRegression lambda(double lambda) { + // we only over-ride this to provide a more restrictive return type + super.lambda(lambda); + return this; + } + + /** + * Chainable configuration option. + * + * @param learningRate New value of initial learning rate. + * @return This, so other configurations can be chained. + */ + public OnlineLogisticRegression learningRate(double learningRate) { + this.mu0 = learningRate; + return this; + } + + public OnlineLogisticRegression stepOffset(int stepOffset) { + this.stepOffset = stepOffset; + return this; + } + + public OnlineLogisticRegression decayExponent(double decayExponent) { + if (decayExponent > 0) { + decayExponent = -decayExponent; + } + this.forgettingExponent = decayExponent; + return this; + } + + + @Override + public double perTermLearningRate(int j) { + return Math.sqrt(perTermAnnealingOffset / updateCounts.get(j)); + } + + @Override + public double currentLearningRate() { + return mu0 * Math.pow(decayFactor, getStep()) * Math.pow(getStep() + stepOffset, forgettingExponent); + } + + public void copyFrom(OnlineLogisticRegression other) { + super.copyFrom(other); + mu0 = other.mu0; + decayFactor = other.decayFactor; + + stepOffset = other.stepOffset; + forgettingExponent = other.forgettingExponent; + + perTermAnnealingOffset = other.perTermAnnealingOffset; + } + + public OnlineLogisticRegression copy() { + close(); + OnlineLogisticRegression r = new OnlineLogisticRegression(numCategories(), numFeatures(), prior); + r.copyFrom(this); + return r; + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeInt(WRITABLE_VERSION); + out.writeDouble(mu0); + out.writeDouble(getLambda()); + out.writeDouble(decayFactor); + out.writeInt(stepOffset); + out.writeInt(step); + out.writeDouble(forgettingExponent); + out.writeInt(perTermAnnealingOffset); + out.writeInt(numCategories); + MatrixWritable.writeMatrix(out, beta); + PolymorphicWritable.write(out, prior); + VectorWritable.writeVector(out, updateCounts); + VectorWritable.writeVector(out, updateSteps); + } + + @Override + public void readFields(DataInput in) throws IOException { + int version = in.readInt(); + if (version == WRITABLE_VERSION) { + mu0 = in.readDouble(); + lambda(in.readDouble()); + decayFactor = in.readDouble(); + stepOffset = in.readInt(); + step = in.readInt(); + forgettingExponent = in.readDouble(); + perTermAnnealingOffset = in.readInt(); + numCategories = in.readInt(); + beta = MatrixWritable.readMatrix(in); + prior = PolymorphicWritable.read(in, PriorFunction.class); + + updateCounts = VectorWritable.readVector(in); + updateSteps = VectorWritable.readVector(in); + } else { + throw new IOException("Incorrect object version, wanted " + WRITABLE_VERSION + " got " + version); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/PassiveAggressive.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/PassiveAggressive.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/PassiveAggressive.java new file mode 100644 index 0000000..c51361c --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/PassiveAggressive.java @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.hadoop.io.Writable; +import org.apache.mahout.classifier.AbstractVectorClassifier; +import org.apache.mahout.classifier.OnlineLearner; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.MatrixWritable; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.function.Functions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +/** + * Online passive aggressive learner that tries to minimize the label ranking hinge loss. + * Implements a multi-class linear classifier minimizing rank loss. + * based on "Online passive aggressive algorithms" by Cramer et al, 2006. + * Note: Its better to use classifyNoLink because the loss function is based + * on ensuring that the score of the good label is larger than the next + * highest label by some margin. The conversion to probability is just done + * by exponentiating and dividing by the sum and is empirical at best. + * Your features should be pre-normalized in some sensible range, for example, + * by subtracting the mean and standard deviation, if they are very + * different in magnitude from each other. + */ +public class PassiveAggressive extends AbstractVectorClassifier implements OnlineLearner, Writable { + + private static final Logger log = LoggerFactory.getLogger(PassiveAggressive.class); + + public static final int WRITABLE_VERSION = 1; + + // the learning rate of the algorithm + private double learningRate = 0.1; + + // loss statistics. + private int lossCount = 0; + private double lossSum = 0; + + // coefficients for the classification. This is a dense matrix + // that is (numCategories ) x numFeatures + private Matrix weights; + + // number of categories we are classifying. + private int numCategories; + + public PassiveAggressive(int numCategories, int numFeatures) { + this.numCategories = numCategories; + weights = new DenseMatrix(numCategories, numFeatures); + weights.assign(0.0); + } + + /** + * Chainable configuration option. + * + * @param learningRate New value of initial learning rate. + * @return This, so other configurations can be chained. + */ + public PassiveAggressive learningRate(double learningRate) { + this.learningRate = learningRate; + return this; + } + + public void copyFrom(PassiveAggressive other) { + learningRate = other.learningRate; + numCategories = other.numCategories; + weights = other.weights; + } + + @Override + public int numCategories() { + return numCategories; + } + + @Override + public Vector classify(Vector instance) { + Vector result = classifyNoLink(instance); + // Convert to probabilities by exponentiation. + double max = result.maxValue(); + result.assign(Functions.minus(max)).assign(Functions.EXP); + result = result.divide(result.norm(1)); + + return result.viewPart(1, result.size() - 1); + } + + @Override + public Vector classifyNoLink(Vector instance) { + Vector result = new DenseVector(weights.numRows()); + result.assign(0); + for (int i = 0; i < weights.numRows(); i++) { + result.setQuick(i, weights.viewRow(i).dot(instance)); + } + return result; + } + + @Override + public double classifyScalar(Vector instance) { + double v1 = weights.viewRow(0).dot(instance); + double v2 = weights.viewRow(1).dot(instance); + v1 = Math.exp(v1); + v2 = Math.exp(v2); + return v2 / (v1 + v2); + } + + public int numFeatures() { + return weights.numCols(); + } + + public PassiveAggressive copy() { + close(); + PassiveAggressive r = new PassiveAggressive(numCategories(), numFeatures()); + r.copyFrom(this); + return r; + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeInt(WRITABLE_VERSION); + out.writeDouble(learningRate); + out.writeInt(numCategories); + MatrixWritable.writeMatrix(out, weights); + } + + @Override + public void readFields(DataInput in) throws IOException { + int version = in.readInt(); + if (version == WRITABLE_VERSION) { + learningRate = in.readDouble(); + numCategories = in.readInt(); + weights = MatrixWritable.readMatrix(in); + } else { + throw new IOException("Incorrect object version, wanted " + WRITABLE_VERSION + " got " + version); + } + } + + @Override + public void close() { + // This is an online classifier, nothing to do. + } + + @Override + public void train(long trackingKey, String groupKey, int actual, Vector instance) { + if (lossCount > 1000) { + log.info("Avg. Loss = {}", lossSum / lossCount); + lossCount = 0; + lossSum = 0; + } + Vector result = classifyNoLink(instance); + double myScore = result.get(actual); + // Find the highest score that is not actual. + int otherIndex = result.maxValueIndex(); + double otherValue = result.get(otherIndex); + if (otherIndex == actual) { + result.setQuick(otherIndex, Double.NEGATIVE_INFINITY); + otherIndex = result.maxValueIndex(); + otherValue = result.get(otherIndex); + } + double loss = 1.0 - myScore + otherValue; + lossCount += 1; + if (loss >= 0) { + lossSum += loss; + double tau = loss / (instance.dot(instance) + 0.5 / learningRate); + Vector delta = instance.clone(); + delta.assign(Functions.mult(tau)); + weights.viewRow(actual).assign(delta, Functions.PLUS); +// delta.addTo(weights.viewRow(actual)); + delta.assign(Functions.mult(-1)); + weights.viewRow(otherIndex).assign(delta, Functions.PLUS); +// delta.addTo(weights.viewRow(otherIndex)); + } + } + + @Override + public void train(long trackingKey, int actual, Vector instance) { + train(trackingKey, null, actual, instance); + } + + @Override + public void train(int actual, Vector instance) { + train(0, null, actual, instance); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/PolymorphicWritable.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/PolymorphicWritable.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/PolymorphicWritable.java new file mode 100644 index 0000000..90062a6 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/PolymorphicWritable.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.hadoop.io.Writable; +import org.apache.mahout.common.ClassUtils; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +/** + * Utilities that write a class name and then serialize using writables. + */ +public final class PolymorphicWritable { + + private PolymorphicWritable() { + } + + public static <T extends Writable> void write(DataOutput dataOutput, T value) throws IOException { + dataOutput.writeUTF(value.getClass().getName()); + value.write(dataOutput); + } + + public static <T extends Writable> T read(DataInput dataInput, Class<? extends T> clazz) throws IOException { + String className = dataInput.readUTF(); + T r = ClassUtils.instantiateAs(className, clazz); + r.readFields(dataInput); + return r; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/PriorFunction.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/PriorFunction.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/PriorFunction.java new file mode 100644 index 0000000..857f061 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/PriorFunction.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.hadoop.io.Writable; + +/** + * A prior is used to regularize the learning algorithm. This allows a trade-off to + * be made between complexity of the model being learned and the accuracy with which + * the model fits the training data. There are different definitions of complexity + * which can be approximated using different priors. For large sparse systems, such + * as text classification, the L1 prior is often used which favors sparse models. + */ +public interface PriorFunction extends Writable { + /** + * Applies the regularization to a coefficient. + * @param oldValue The previous value. + * @param generations The number of generations. + * @param learningRate The learning rate with lambda baked in. + * @return The new coefficient value after regularization. + */ + double age(double oldValue, double generations, double learningRate); + + /** + * Returns the log of the probability of a particular coefficient value according to the prior. + * @param betaIJ The coefficient. + * @return The log probability. + */ + double logP(double betaIJ); +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java new file mode 100644 index 0000000..a04fc8b --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.mahout.classifier.AbstractVectorClassifier; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.function.Functions; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Deque; +import java.util.List; + +/** + * Uses the difference between this instance and recent history to get a + * gradient that optimizes ranking performance. Essentially this is the + * same as directly optimizing AUC. It isn't expected that this would + * be used alone, but rather that a MixedGradient would use it and a + * DefaultGradient together to combine both ranking and log-likelihood + * goals. + */ +public class RankingGradient implements Gradient { + + private static final Gradient BASIC = new DefaultGradient(); + + private int window = 10; + + private final List<Deque<Vector>> history = new ArrayList<>(); + + public RankingGradient(int window) { + this.window = window; + } + + @Override + public final Vector apply(String groupKey, int actual, Vector instance, AbstractVectorClassifier classifier) { + addToHistory(actual, instance); + + // now compute average gradient versus saved vectors from the other side + Deque<Vector> otherSide = history.get(1 - actual); + int n = otherSide.size(); + + Vector r = null; + for (Vector other : otherSide) { + Vector g = BASIC.apply(groupKey, actual, instance.minus(other), classifier); + + if (r == null) { + r = g; + } else { + r.assign(g, Functions.plusMult(1.0 / n)); + } + } + return r; + } + + public void addToHistory(int actual, Vector instance) { + while (history.size() <= actual) { + history.add(new ArrayDeque<Vector>(window)); + } + // save this instance + Deque<Vector> ourSide = history.get(actual); + ourSide.add(instance); + while (ourSide.size() >= window) { + ourSide.pollFirst(); + } + } + + public Gradient getBaseGradient() { + return BASIC; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/RecordFactory.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/RecordFactory.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/RecordFactory.java new file mode 100644 index 0000000..fbc825d --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/RecordFactory.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.mahout.math.Vector; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * A record factor understands how to convert a line of data into fields and then into a vector. + */ +public interface RecordFactory { + void defineTargetCategories(List<String> values); + + RecordFactory maxTargetValue(int max); + + boolean usesFirstLineAsSchema(); + + int processLine(String line, Vector featureVector); + + Iterable<String> getPredictors(); + + Map<String, Set<Integer>> getTraceDictionary(); + + RecordFactory includeBiasTerm(boolean useBias); + + List<String> getTargetCategories(); + + void firstLine(String line); +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/TPrior.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/TPrior.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/TPrior.java new file mode 100644 index 0000000..0a7b6a7 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/TPrior.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.commons.math3.special.Gamma; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +/** + * Provides a t-distribution as a prior. + */ +public class TPrior implements PriorFunction { + private double df; + + public TPrior(double df) { + this.df = df; + } + + @Override + public double age(double oldValue, double generations, double learningRate) { + for (int i = 0; i < generations; i++) { + oldValue -= learningRate * oldValue * (df + 1.0) / (df + oldValue * oldValue); + } + return oldValue; + } + + @Override + public double logP(double betaIJ) { + return Gamma.logGamma((df + 1.0) / 2.0) + - Math.log(df * Math.PI) + - Gamma.logGamma(df / 2.0) + - (df + 1.0) / 2.0 * Math.log1p(betaIJ * betaIJ); + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeDouble(df); + } + + @Override + public void readFields(DataInput in) throws IOException { + df = in.readDouble(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/UniformPrior.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/UniformPrior.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/UniformPrior.java new file mode 100644 index 0000000..23c812f --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/UniformPrior.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +/** + * A uniform prior. This is an improper prior that corresponds to no regularization at all. + */ +public class UniformPrior implements PriorFunction { + @Override + public double age(double oldValue, double generations, double learningRate) { + return oldValue; + } + + @Override + public double logP(double betaIJ) { + return 0; + } + + @Override + public void write(DataOutput dataOutput) throws IOException { + // nothing to write + } + + @Override + public void readFields(DataInput dataInput) throws IOException { + // stateless class is trivial to read + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/package-info.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/package-info.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/package-info.java new file mode 100644 index 0000000..c2ad966 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/sgd/package-info.java @@ -0,0 +1,23 @@ +/** + * <p>Implements a variety of on-line logistric regression classifiers using SGD-based algorithms. + * SGD stands for Stochastic Gradient Descent and refers to a class of learning algorithms + * that make it relatively easy to build high speed on-line learning algorithms for a variety + * of problems, notably including supervised learning for classification.</p> + * + * <p>The primary class of interest in the this package is + * {@link org.apache.mahout.classifier.sgd.CrossFoldLearner} which contains a + * number (typically 5) of sub-learners, each of which is given a different portion of the + * training data. Each of these sub-learners can then be evaluated on the data it was not + * trained on. This allows fully incremental learning while still getting cross-validated + * performance estimates.</p> + * + * <p>The CrossFoldLearner implements {@link org.apache.mahout.classifier.OnlineLearner} + * and thus expects to be fed input in the form + * of a target variable and a feature vector. The target variable is simply an integer in the + * half-open interval [0..numFeatures) where numFeatures is defined when the CrossFoldLearner + * is constructed. The creation of feature vectors is facilitated by the classes that inherit + * from {@link org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder}. + * These classes currently implement a form of feature hashing with + * multiple probes to limit feature ambiguity.</p> + */ +package org.apache.mahout.classifier.sgd; http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/clustering/AbstractCluster.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/clustering/AbstractCluster.java b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/AbstractCluster.java new file mode 100644 index 0000000..be7ed2a --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/AbstractCluster.java @@ -0,0 +1,390 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; + +import org.apache.hadoop.conf.Configuration; +import org.apache.mahout.common.parameters.Parameter; +import org.apache.mahout.math.RandomAccessSparseVector; +import org.apache.mahout.math.SequentialAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.Vector.Element; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.function.Functions; +import org.apache.mahout.math.function.SquareRootFunction; +import org.codehaus.jackson.map.ObjectMapper; + +public abstract class AbstractCluster implements Cluster { + + // cluster persistent state + private int id; + + private long numObservations; + + private long totalObservations; + + private Vector center; + + private Vector radius; + + // the observation statistics + private double s0; + + private Vector s1; + + private Vector s2; + + private static final ObjectMapper jxn = new ObjectMapper(); + + protected AbstractCluster() {} + + protected AbstractCluster(Vector point, int id2) { + this.numObservations = (long) 0; + this.totalObservations = (long) 0; + this.center = point.clone(); + this.radius = center.like(); + this.s0 = (double) 0; + this.s1 = center.like(); + this.s2 = center.like(); + this.id = id2; + } + + protected AbstractCluster(Vector center2, Vector radius2, int id2) { + this.numObservations = (long) 0; + this.totalObservations = (long) 0; + this.center = new RandomAccessSparseVector(center2); + this.radius = new RandomAccessSparseVector(radius2); + this.s0 = (double) 0; + this.s1 = center.like(); + this.s2 = center.like(); + this.id = id2; + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeInt(id); + out.writeLong(getNumObservations()); + out.writeLong(getTotalObservations()); + VectorWritable.writeVector(out, getCenter()); + VectorWritable.writeVector(out, getRadius()); + out.writeDouble(s0); + VectorWritable.writeVector(out, s1); + VectorWritable.writeVector(out, s2); + } + + @Override + public void readFields(DataInput in) throws IOException { + this.id = in.readInt(); + this.setNumObservations(in.readLong()); + this.setTotalObservations(in.readLong()); + this.setCenter(VectorWritable.readVector(in)); + this.setRadius(VectorWritable.readVector(in)); + this.setS0(in.readDouble()); + this.setS1(VectorWritable.readVector(in)); + this.setS2(VectorWritable.readVector(in)); + } + + @Override + public void configure(Configuration job) { + // nothing to do + } + + @Override + public Collection<Parameter<?>> getParameters() { + return Collections.emptyList(); + } + + @Override + public void createParameters(String prefix, Configuration jobConf) { + // nothing to do + } + + @Override + public int getId() { + return id; + } + + /** + * @param id + * the id to set + */ + protected void setId(int id) { + this.id = id; + } + + @Override + public long getNumObservations() { + return numObservations; + } + + /** + * @param l + * the numPoints to set + */ + protected void setNumObservations(long l) { + this.numObservations = l; + } + + @Override + public long getTotalObservations() { + return totalObservations; + } + + protected void setTotalObservations(long totalPoints) { + this.totalObservations = totalPoints; + } + + @Override + public Vector getCenter() { + return center; + } + + /** + * @param center + * the center to set + */ + protected void setCenter(Vector center) { + this.center = center; + } + + @Override + public Vector getRadius() { + return radius; + } + + /** + * @param radius + * the radius to set + */ + protected void setRadius(Vector radius) { + this.radius = radius; + } + + /** + * @return the s0 + */ + protected double getS0() { + return s0; + } + + protected void setS0(double s0) { + this.s0 = s0; + } + + /** + * @return the s1 + */ + protected Vector getS1() { + return s1; + } + + protected void setS1(Vector s1) { + this.s1 = s1; + } + + /** + * @return the s2 + */ + protected Vector getS2() { + return s2; + } + + protected void setS2(Vector s2) { + this.s2 = s2; + } + + @Override + public void observe(Model<VectorWritable> x) { + AbstractCluster cl = (AbstractCluster) x; + setS0(getS0() + cl.getS0()); + setS1(getS1().plus(cl.getS1())); + setS2(getS2().plus(cl.getS2())); + } + + @Override + public void observe(VectorWritable x) { + observe(x.get()); + } + + @Override + public void observe(VectorWritable x, double weight) { + observe(x.get(), weight); + } + + public void observe(Vector x, double weight) { + if (weight == 1.0) { + observe(x); + } else { + setS0(getS0() + weight); + Vector weightedX = x.times(weight); + if (getS1() == null) { + setS1(weightedX); + } else { + getS1().assign(weightedX, Functions.PLUS); + } + Vector x2 = x.times(x).times(weight); + if (getS2() == null) { + setS2(x2); + } else { + getS2().assign(x2, Functions.PLUS); + } + } + } + + public void observe(Vector x) { + setS0(getS0() + 1); + if (getS1() == null) { + setS1(x.clone()); + } else { + getS1().assign(x, Functions.PLUS); + } + Vector x2 = x.times(x); + if (getS2() == null) { + setS2(x2); + } else { + getS2().assign(x2, Functions.PLUS); + } + } + + + @Override + public void computeParameters() { + if (getS0() == 0) { + return; + } + setNumObservations((long) getS0()); + setTotalObservations(getTotalObservations() + getNumObservations()); + setCenter(getS1().divide(getS0())); + // compute the component stds + if (getS0() > 1) { + setRadius(getS2().times(getS0()).minus(getS1().times(getS1())).assign(new SquareRootFunction()).divide(getS0())); + } + setS0(0); + setS1(center.like()); + setS2(center.like()); + } + + @Override + public String asFormatString(String[] bindings) { + String fmtString = ""; + try { + fmtString = jxn.writeValueAsString(asJson(bindings)); + } catch (IOException e) { + log.error("Error writing JSON as String.", e); + } + return fmtString; + } + + public Map<String,Object> asJson(String[] bindings) { + Map<String,Object> dict = new HashMap<>(); + dict.put("identifier", getIdentifier()); + dict.put("n", getNumObservations()); + if (getCenter() != null) { + try { + dict.put("c", formatVectorAsJson(getCenter(), bindings)); + } catch (IOException e) { + log.error("IOException: ", e); + } + } + if (getRadius() != null) { + try { + dict.put("r", formatVectorAsJson(getRadius(), bindings)); + } catch (IOException e) { + log.error("IOException: ", e); + } + } + return dict; + } + + public abstract String getIdentifier(); + + /** + * Compute the centroid by averaging the pointTotals + * + * @return the new centroid + */ + public Vector computeCentroid() { + return getS0() == 0 ? getCenter() : getS1().divide(getS0()); + } + + /** + * Return a human-readable formatted string representation of the vector, not + * intended to be complete nor usable as an input/output representation + */ + public static String formatVector(Vector v, String[] bindings) { + String fmtString = ""; + try { + fmtString = jxn.writeValueAsString(formatVectorAsJson(v, bindings)); + } catch (IOException e) { + log.error("Error writing JSON as String.", e); + } + return fmtString; + } + + /** + * Create a List of HashMaps containing vector terms and weights + * + * @return List<Object> + */ + public static List<Object> formatVectorAsJson(Vector v, String[] bindings) throws IOException { + + boolean hasBindings = bindings != null; + boolean isSparse = v.getNumNonZeroElements() != v.size(); + + // we assume sequential access in the output + Vector provider = v.isSequentialAccess() ? v : new SequentialAccessSparseVector(v); + + List<Object> terms = new LinkedList<>(); + String term = ""; + + for (Element elem : provider.nonZeroes()) { + + if (hasBindings && bindings.length >= elem.index() + 1 && bindings[elem.index()] != null) { + term = bindings[elem.index()]; + } else if (hasBindings || isSparse) { + term = String.valueOf(elem.index()); + } + + Map<String, Object> term_entry = new HashMap<>(); + double roundedWeight = (double) Math.round(elem.get() * 1000) / 1000; + if (hasBindings || isSparse) { + term_entry.put(term, roundedWeight); + terms.add(term_entry); + } else { + terms.add(roundedWeight); + } + } + + return terms; + } + + @Override + public boolean isConverged() { + // Convergence has no meaning yet, perhaps in subclasses + return false; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/clustering/Cluster.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/clustering/Cluster.java b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/Cluster.java new file mode 100644 index 0000000..07d6927 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/Cluster.java @@ -0,0 +1,90 @@ +/* Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering; + +import org.apache.mahout.common.parameters.Parametered; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; + +import java.util.Map; + +/** + * Implementations of this interface have a printable representation and certain + * attributes that are common across all clustering implementations + * + */ +public interface Cluster extends Model<VectorWritable>, Parametered { + + // default directory for initial clusters to prime iterative clustering + // algorithms + String INITIAL_CLUSTERS_DIR = "clusters-0"; + + // default directory for output of clusters per iteration + String CLUSTERS_DIR = "clusters-"; + + // default suffix for output of clusters for final iteration + String FINAL_ITERATION_SUFFIX = "-final"; + + /** + * Get the id of the Cluster + * + * @return a unique integer + */ + int getId(); + + /** + * Get the "center" of the Cluster as a Vector + * + * @return a Vector + */ + Vector getCenter(); + + /** + * Get the "radius" of the Cluster as a Vector. Usually the radius is the + * standard deviation expressed as a Vector of size equal to the center. Some + * clusters may return zero values if not appropriate. + * + * @return aVector + */ + Vector getRadius(); + + /** + * Produce a custom, human-friendly, printable representation of the Cluster. + * + * @param bindings + * an optional String[] containing labels used to format the primary + * Vector/s of this implementation. + * @return a String + */ + String asFormatString(String[] bindings); + + /** + * Produce a JSON representation of the Cluster. + * + * @param bindings + * an optional String[] containing labels used to format the primary + * Vector/s of this implementation. + * @return a Map + */ + Map<String,Object> asJson(String[] bindings); + + /** + * @return if the receiver has converged, or false if that has no meaning for + * the implementation + */ + boolean isConverged(); + +}
