http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java new file mode 100644 index 0000000..ebb0614 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Ordering; +import org.apache.mahout.classifier.AbstractVectorClassifier; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.Vector; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; +import java.util.Queue; +import java.util.Set; + +/** + * Uses sample data to reverse engineer a feature-hashed model. + * + * The result gives approximate weights for features and interactions + * in the original space. + * + * The idea is that the hashed encoders have the option of having a trace dictionary. This + * tells us where each feature is hashed to, or each feature/value combination in the case + * of word-like values. Using this dictionary, we can put values into a synthetic feature + * vector in just the locations specified by a single feature or interaction. Then we can + * push this through a linear part of a model to see the contribution of that input. For + * any generalized linear model like logistic regression, there is a linear part of the + * model that allows this. + * + * What the ModelDissector does is to accept a trace dictionary and a model in an update + * method. It figures out the weights for the elements in the trace dictionary and stashes + * them. Then in a summary method, the biggest weights are returned. This update/flush + * style is used so that the trace dictionary doesn't have to grow to enormous levels, + * but instead can be cleared between updates. + */ +public class ModelDissector { + private final Map<String,Vector> weightMap; + + public ModelDissector() { + weightMap = Maps.newHashMap(); + } + + /** + * Probes a model to determine the effect of a particular variable. This is done + * with the ade of a trace dictionary which has recorded the locations in the feature + * vector that are modified by various variable values. We can set these locations to + * 1 and then look at the resulting score. This tells us the weight the model places + * on that variable. + * @param features A feature vector to use (destructively) + * @param traceDictionary A trace dictionary containing variables and what locations + * in the feature vector are affected by them + * @param learner The model that we are probing to find weights on features + */ + + public void update(Vector features, Map<String, Set<Integer>> traceDictionary, AbstractVectorClassifier learner) { + // zero out feature vector + features.assign(0); + for (Map.Entry<String, Set<Integer>> entry : traceDictionary.entrySet()) { + // get a feature and locations where it is stored in the feature vector + String key = entry.getKey(); + Set<Integer> value = entry.getValue(); + + // if we haven't looked at this feature yet + if (!weightMap.containsKey(key)) { + // put probe values in the feature vector + for (Integer where : value) { + features.set(where, 1); + } + + // see what the model says + Vector v = learner.classifyNoLink(features); + weightMap.put(key, v); + + // and zero out those locations again + for (Integer where : value) { + features.set(where, 0); + } + } + } + } + + /** + * Returns the n most important features with their + * weights, most important category and the top few + * categories that they affect. + * @param n How many results to return. + * @return A list of the top variables. + */ + public List<Weight> summary(int n) { + Queue<Weight> pq = new PriorityQueue<Weight>(); + for (Map.Entry<String, Vector> entry : weightMap.entrySet()) { + pq.add(new Weight(entry.getKey(), entry.getValue())); + while (pq.size() > n) { + pq.poll(); + } + } + List<Weight> r = Lists.newArrayList(pq); + Collections.sort(r, Ordering.natural().reverse()); + return r; + } + + private static final class Category implements Comparable<Category> { + private final int index; + private final double weight; + + private Category(int index, double weight) { + this.index = index; + this.weight = weight; + } + + @Override + public int compareTo(Category o) { + int r = Double.compare(Math.abs(weight), Math.abs(o.weight)); + if (r == 0) { + if (o.index < index) { + return -1; + } + if (o.index > index) { + return 1; + } + return 0; + } + return r; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof Category)) { + return false; + } + Category other = (Category) o; + return index == other.index && weight == other.weight; + } + + @Override + public int hashCode() { + return RandomUtils.hashDouble(weight) ^ index; + } + + } + + public static class Weight implements Comparable<Weight> { + private final String feature; + private final double value; + private final int maxIndex; + private final List<Category> categories; + + public Weight(String feature, Vector weights) { + this(feature, weights, 3); + } + + public Weight(String feature, Vector weights, int n) { + this.feature = feature; + // pick out the weight with the largest abs value, but don't forget the sign + Queue<Category> biggest = new PriorityQueue<Category>(n + 1, Ordering.natural()); + for (Vector.Element element : weights.all()) { + biggest.add(new Category(element.index(), element.get())); + while (biggest.size() > n) { + biggest.poll(); + } + } + categories = Lists.newArrayList(biggest); + Collections.sort(categories, Ordering.natural().reverse()); + value = categories.get(0).weight; + maxIndex = categories.get(0).index; + } + + @Override + public int compareTo(Weight other) { + int r = Double.compare(Math.abs(this.value), Math.abs(other.value)); + if (r == 0) { + return feature.compareTo(other.feature); + } + return r; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof Weight)) { + return false; + } + Weight other = (Weight) o; + return feature.equals(other.feature) + && value == other.value + && maxIndex == other.maxIndex + && categories.equals(other.categories); + } + + @Override + public int hashCode() { + return feature.hashCode() ^ RandomUtils.hashDouble(value) ^ maxIndex ^ categories.hashCode(); + } + + public String getFeature() { + return feature; + } + + public double getWeight() { + return value; + } + + public double getWeight(int n) { + return categories.get(n).weight; + } + + public double getCategory(int n) { + return categories.get(n).index; + } + + public int getMaxImpact() { + return maxIndex; + } + } +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java new file mode 100644 index 0000000..f0150e9 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java @@ -0,0 +1,76 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import java.io.DataInput; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; + +import com.google.common.io.Closeables; +import org.apache.hadoop.io.Writable; + +/** + * Provides the ability to store SGD model-related objects as binary files. + */ +public final class ModelSerializer { + + // static class ... don't instantiate + private ModelSerializer() { + } + + public static void writeBinary(String path, CrossFoldLearner model) throws IOException { + DataOutputStream out = new DataOutputStream(new FileOutputStream(path)); + try { + PolymorphicWritable.write(out, model); + } finally { + Closeables.close(out, false); + } + } + + public static void writeBinary(String path, OnlineLogisticRegression model) throws IOException { + DataOutputStream out = new DataOutputStream(new FileOutputStream(path)); + try { + PolymorphicWritable.write(out, model); + } finally { + Closeables.close(out, false); + } + } + + public static void writeBinary(String path, AdaptiveLogisticRegression model) throws IOException { + DataOutputStream out = new DataOutputStream(new FileOutputStream(path)); + try { + PolymorphicWritable.write(out, model); + } finally { + Closeables.close(out, false); + } + } + + public static <T extends Writable> T readBinary(InputStream in, Class<T> clazz) throws IOException { + DataInput dataIn = new DataInputStream(in); + try { + return PolymorphicWritable.read(dataIn, clazz); + } finally { + Closeables.close(in, false); + } + } + + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java new file mode 100644 index 0000000..7a9ca83 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.hadoop.io.Writable; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.MatrixWritable; +import org.apache.mahout.math.VectorWritable; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +/** + * Extends the basic on-line logistic regression learner with a specific set of learning + * rate annealing schedules. + */ +public class OnlineLogisticRegression extends AbstractOnlineLogisticRegression implements Writable { + public static final int WRITABLE_VERSION = 1; + + // these next two control decayFactor^steps exponential type of annealing + // learning rate and decay factor + private double mu0 = 1; + private double decayFactor = 1 - 1.0e-3; + + // these next two control 1/steps^forget type annealing + private int stepOffset = 10; + // -1 equals even weighting of all examples, 0 means only use exponential annealing + private double forgettingExponent = -0.5; + + // controls how per term annealing works + private int perTermAnnealingOffset = 20; + + public OnlineLogisticRegression() { + // private constructor available for serialization, but not normal use + } + + public OnlineLogisticRegression(int numCategories, int numFeatures, PriorFunction prior) { + this.numCategories = numCategories; + this.prior = prior; + + updateSteps = new DenseVector(numFeatures); + updateCounts = new DenseVector(numFeatures).assign(perTermAnnealingOffset); + beta = new DenseMatrix(numCategories - 1, numFeatures); + } + + /** + * Chainable configuration option. + * + * @param alpha New value of decayFactor, the exponential decay rate for the learning rate. + * @return This, so other configurations can be chained. + */ + public OnlineLogisticRegression alpha(double alpha) { + this.decayFactor = alpha; + return this; + } + + @Override + public OnlineLogisticRegression lambda(double lambda) { + // we only over-ride this to provide a more restrictive return type + super.lambda(lambda); + return this; + } + + /** + * Chainable configuration option. + * + * @param learningRate New value of initial learning rate. + * @return This, so other configurations can be chained. + */ + public OnlineLogisticRegression learningRate(double learningRate) { + this.mu0 = learningRate; + return this; + } + + public OnlineLogisticRegression stepOffset(int stepOffset) { + this.stepOffset = stepOffset; + return this; + } + + public OnlineLogisticRegression decayExponent(double decayExponent) { + if (decayExponent > 0) { + decayExponent = -decayExponent; + } + this.forgettingExponent = decayExponent; + return this; + } + + + @Override + public double perTermLearningRate(int j) { + return Math.sqrt(perTermAnnealingOffset / updateCounts.get(j)); + } + + @Override + public double currentLearningRate() { + return mu0 * Math.pow(decayFactor, getStep()) * Math.pow(getStep() + stepOffset, forgettingExponent); + } + + public void copyFrom(OnlineLogisticRegression other) { + super.copyFrom(other); + mu0 = other.mu0; + decayFactor = other.decayFactor; + + stepOffset = other.stepOffset; + forgettingExponent = other.forgettingExponent; + + perTermAnnealingOffset = other.perTermAnnealingOffset; + } + + public OnlineLogisticRegression copy() { + close(); + OnlineLogisticRegression r = new OnlineLogisticRegression(numCategories(), numFeatures(), prior); + r.copyFrom(this); + return r; + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeInt(WRITABLE_VERSION); + out.writeDouble(mu0); + out.writeDouble(getLambda()); + out.writeDouble(decayFactor); + out.writeInt(stepOffset); + out.writeInt(step); + out.writeDouble(forgettingExponent); + out.writeInt(perTermAnnealingOffset); + out.writeInt(numCategories); + MatrixWritable.writeMatrix(out, beta); + PolymorphicWritable.write(out, prior); + VectorWritable.writeVector(out, updateCounts); + VectorWritable.writeVector(out, updateSteps); + } + + @Override + public void readFields(DataInput in) throws IOException { + int version = in.readInt(); + if (version == WRITABLE_VERSION) { + mu0 = in.readDouble(); + lambda(in.readDouble()); + decayFactor = in.readDouble(); + stepOffset = in.readInt(); + step = in.readInt(); + forgettingExponent = in.readDouble(); + perTermAnnealingOffset = in.readInt(); + numCategories = in.readInt(); + beta = MatrixWritable.readMatrix(in); + prior = PolymorphicWritable.read(in, PriorFunction.class); + + updateCounts = VectorWritable.readVector(in); + updateSteps = VectorWritable.readVector(in); + } else { + throw new IOException("Incorrect object version, wanted " + WRITABLE_VERSION + " got " + version); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/PassiveAggressive.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/PassiveAggressive.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/PassiveAggressive.java new file mode 100644 index 0000000..c51361c --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/PassiveAggressive.java @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.hadoop.io.Writable; +import org.apache.mahout.classifier.AbstractVectorClassifier; +import org.apache.mahout.classifier.OnlineLearner; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.MatrixWritable; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.function.Functions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +/** + * Online passive aggressive learner that tries to minimize the label ranking hinge loss. + * Implements a multi-class linear classifier minimizing rank loss. + * based on "Online passive aggressive algorithms" by Cramer et al, 2006. + * Note: Its better to use classifyNoLink because the loss function is based + * on ensuring that the score of the good label is larger than the next + * highest label by some margin. The conversion to probability is just done + * by exponentiating and dividing by the sum and is empirical at best. + * Your features should be pre-normalized in some sensible range, for example, + * by subtracting the mean and standard deviation, if they are very + * different in magnitude from each other. + */ +public class PassiveAggressive extends AbstractVectorClassifier implements OnlineLearner, Writable { + + private static final Logger log = LoggerFactory.getLogger(PassiveAggressive.class); + + public static final int WRITABLE_VERSION = 1; + + // the learning rate of the algorithm + private double learningRate = 0.1; + + // loss statistics. + private int lossCount = 0; + private double lossSum = 0; + + // coefficients for the classification. This is a dense matrix + // that is (numCategories ) x numFeatures + private Matrix weights; + + // number of categories we are classifying. + private int numCategories; + + public PassiveAggressive(int numCategories, int numFeatures) { + this.numCategories = numCategories; + weights = new DenseMatrix(numCategories, numFeatures); + weights.assign(0.0); + } + + /** + * Chainable configuration option. + * + * @param learningRate New value of initial learning rate. + * @return This, so other configurations can be chained. + */ + public PassiveAggressive learningRate(double learningRate) { + this.learningRate = learningRate; + return this; + } + + public void copyFrom(PassiveAggressive other) { + learningRate = other.learningRate; + numCategories = other.numCategories; + weights = other.weights; + } + + @Override + public int numCategories() { + return numCategories; + } + + @Override + public Vector classify(Vector instance) { + Vector result = classifyNoLink(instance); + // Convert to probabilities by exponentiation. + double max = result.maxValue(); + result.assign(Functions.minus(max)).assign(Functions.EXP); + result = result.divide(result.norm(1)); + + return result.viewPart(1, result.size() - 1); + } + + @Override + public Vector classifyNoLink(Vector instance) { + Vector result = new DenseVector(weights.numRows()); + result.assign(0); + for (int i = 0; i < weights.numRows(); i++) { + result.setQuick(i, weights.viewRow(i).dot(instance)); + } + return result; + } + + @Override + public double classifyScalar(Vector instance) { + double v1 = weights.viewRow(0).dot(instance); + double v2 = weights.viewRow(1).dot(instance); + v1 = Math.exp(v1); + v2 = Math.exp(v2); + return v2 / (v1 + v2); + } + + public int numFeatures() { + return weights.numCols(); + } + + public PassiveAggressive copy() { + close(); + PassiveAggressive r = new PassiveAggressive(numCategories(), numFeatures()); + r.copyFrom(this); + return r; + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeInt(WRITABLE_VERSION); + out.writeDouble(learningRate); + out.writeInt(numCategories); + MatrixWritable.writeMatrix(out, weights); + } + + @Override + public void readFields(DataInput in) throws IOException { + int version = in.readInt(); + if (version == WRITABLE_VERSION) { + learningRate = in.readDouble(); + numCategories = in.readInt(); + weights = MatrixWritable.readMatrix(in); + } else { + throw new IOException("Incorrect object version, wanted " + WRITABLE_VERSION + " got " + version); + } + } + + @Override + public void close() { + // This is an online classifier, nothing to do. + } + + @Override + public void train(long trackingKey, String groupKey, int actual, Vector instance) { + if (lossCount > 1000) { + log.info("Avg. Loss = {}", lossSum / lossCount); + lossCount = 0; + lossSum = 0; + } + Vector result = classifyNoLink(instance); + double myScore = result.get(actual); + // Find the highest score that is not actual. + int otherIndex = result.maxValueIndex(); + double otherValue = result.get(otherIndex); + if (otherIndex == actual) { + result.setQuick(otherIndex, Double.NEGATIVE_INFINITY); + otherIndex = result.maxValueIndex(); + otherValue = result.get(otherIndex); + } + double loss = 1.0 - myScore + otherValue; + lossCount += 1; + if (loss >= 0) { + lossSum += loss; + double tau = loss / (instance.dot(instance) + 0.5 / learningRate); + Vector delta = instance.clone(); + delta.assign(Functions.mult(tau)); + weights.viewRow(actual).assign(delta, Functions.PLUS); +// delta.addTo(weights.viewRow(actual)); + delta.assign(Functions.mult(-1)); + weights.viewRow(otherIndex).assign(delta, Functions.PLUS); +// delta.addTo(weights.viewRow(otherIndex)); + } + } + + @Override + public void train(long trackingKey, int actual, Vector instance) { + train(trackingKey, null, actual, instance); + } + + @Override + public void train(int actual, Vector instance) { + train(0, null, actual, instance); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/PolymorphicWritable.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/PolymorphicWritable.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/PolymorphicWritable.java new file mode 100644 index 0000000..90062a6 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/PolymorphicWritable.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.hadoop.io.Writable; +import org.apache.mahout.common.ClassUtils; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +/** + * Utilities that write a class name and then serialize using writables. + */ +public final class PolymorphicWritable { + + private PolymorphicWritable() { + } + + public static <T extends Writable> void write(DataOutput dataOutput, T value) throws IOException { + dataOutput.writeUTF(value.getClass().getName()); + value.write(dataOutput); + } + + public static <T extends Writable> T read(DataInput dataInput, Class<? extends T> clazz) throws IOException { + String className = dataInput.readUTF(); + T r = ClassUtils.instantiateAs(className, clazz); + r.readFields(dataInput); + return r; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/PriorFunction.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/PriorFunction.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/PriorFunction.java new file mode 100644 index 0000000..857f061 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/PriorFunction.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.hadoop.io.Writable; + +/** + * A prior is used to regularize the learning algorithm. This allows a trade-off to + * be made between complexity of the model being learned and the accuracy with which + * the model fits the training data. There are different definitions of complexity + * which can be approximated using different priors. For large sparse systems, such + * as text classification, the L1 prior is often used which favors sparse models. + */ +public interface PriorFunction extends Writable { + /** + * Applies the regularization to a coefficient. + * @param oldValue The previous value. + * @param generations The number of generations. + * @param learningRate The learning rate with lambda baked in. + * @return The new coefficient value after regularization. + */ + double age(double oldValue, double generations, double learningRate); + + /** + * Returns the log of the probability of a particular coefficient value according to the prior. + * @param betaIJ The coefficient. + * @return The log probability. + */ + double logP(double betaIJ); +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java new file mode 100644 index 0000000..b52cb8c --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import com.google.common.collect.Lists; +import org.apache.mahout.classifier.AbstractVectorClassifier; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.function.Functions; + +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.List; + +/** + * Uses the difference between this instance and recent history to get a + * gradient that optimizes ranking performance. Essentially this is the + * same as directly optimizing AUC. It isn't expected that this would + * be used alone, but rather that a MixedGradient would use it and a + * DefaultGradient together to combine both ranking and log-likelihood + * goals. + */ +public class RankingGradient implements Gradient { + + private static final Gradient BASIC = new DefaultGradient(); + + private int window = 10; + + private final List<Deque<Vector>> history = Lists.newArrayList(); + + public RankingGradient(int window) { + this.window = window; + } + + @Override + public final Vector apply(String groupKey, int actual, Vector instance, AbstractVectorClassifier classifier) { + addToHistory(actual, instance); + + // now compute average gradient versus saved vectors from the other side + Deque<Vector> otherSide = history.get(1 - actual); + int n = otherSide.size(); + + Vector r = null; + for (Vector other : otherSide) { + Vector g = BASIC.apply(groupKey, actual, instance.minus(other), classifier); + + if (r == null) { + r = g; + } else { + r.assign(g, Functions.plusMult(1.0 / n)); + } + } + return r; + } + + public void addToHistory(int actual, Vector instance) { + while (history.size() <= actual) { + history.add(new ArrayDeque<Vector>(window)); + } + // save this instance + Deque<Vector> ourSide = history.get(actual); + ourSide.add(instance); + while (ourSide.size() >= window) { + ourSide.pollFirst(); + } + } + + public Gradient getBaseGradient() { + return BASIC; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/RecordFactory.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/RecordFactory.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/RecordFactory.java new file mode 100644 index 0000000..fbc825d --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/RecordFactory.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.mahout.math.Vector; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * A record factor understands how to convert a line of data into fields and then into a vector. + */ +public interface RecordFactory { + void defineTargetCategories(List<String> values); + + RecordFactory maxTargetValue(int max); + + boolean usesFirstLineAsSchema(); + + int processLine(String line, Vector featureVector); + + Iterable<String> getPredictors(); + + Map<String, Set<Integer>> getTraceDictionary(); + + RecordFactory includeBiasTerm(boolean useBias); + + List<String> getTargetCategories(); + + void firstLine(String line); +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/TPrior.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/TPrior.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/TPrior.java new file mode 100644 index 0000000..0a7b6a7 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/TPrior.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.commons.math3.special.Gamma; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +/** + * Provides a t-distribution as a prior. + */ +public class TPrior implements PriorFunction { + private double df; + + public TPrior(double df) { + this.df = df; + } + + @Override + public double age(double oldValue, double generations, double learningRate) { + for (int i = 0; i < generations; i++) { + oldValue -= learningRate * oldValue * (df + 1.0) / (df + oldValue * oldValue); + } + return oldValue; + } + + @Override + public double logP(double betaIJ) { + return Gamma.logGamma((df + 1.0) / 2.0) + - Math.log(df * Math.PI) + - Gamma.logGamma(df / 2.0) + - (df + 1.0) / 2.0 * Math.log1p(betaIJ * betaIJ); + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeDouble(df); + } + + @Override + public void readFields(DataInput in) throws IOException { + df = in.readDouble(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/UniformPrior.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/UniformPrior.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/UniformPrior.java new file mode 100644 index 0000000..23c812f --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/UniformPrior.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +/** + * A uniform prior. This is an improper prior that corresponds to no regularization at all. + */ +public class UniformPrior implements PriorFunction { + @Override + public double age(double oldValue, double generations, double learningRate) { + return oldValue; + } + + @Override + public double logP(double betaIJ) { + return 0; + } + + @Override + public void write(DataOutput dataOutput) throws IOException { + // nothing to write + } + + @Override + public void readFields(DataInput dataInput) throws IOException { + // stateless class is trivial to read + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/package-info.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/package-info.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/package-info.java new file mode 100644 index 0000000..c2ad966 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/package-info.java @@ -0,0 +1,23 @@ +/** + * <p>Implements a variety of on-line logistric regression classifiers using SGD-based algorithms. + * SGD stands for Stochastic Gradient Descent and refers to a class of learning algorithms + * that make it relatively easy to build high speed on-line learning algorithms for a variety + * of problems, notably including supervised learning for classification.</p> + * + * <p>The primary class of interest in the this package is + * {@link org.apache.mahout.classifier.sgd.CrossFoldLearner} which contains a + * number (typically 5) of sub-learners, each of which is given a different portion of the + * training data. Each of these sub-learners can then be evaluated on the data it was not + * trained on. This allows fully incremental learning while still getting cross-validated + * performance estimates.</p> + * + * <p>The CrossFoldLearner implements {@link org.apache.mahout.classifier.OnlineLearner} + * and thus expects to be fed input in the form + * of a target variable and a feature vector. The target variable is simply an integer in the + * half-open interval [0..numFeatures) where numFeatures is defined when the CrossFoldLearner + * is constructed. The creation of feature vectors is facilitated by the classes that inherit + * from {@link org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder}. + * These classes currently implement a form of feature hashing with + * multiple probes to limit feature ambiguity.</p> + */ +package org.apache.mahout.classifier.sgd; http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/AbstractCluster.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/clustering/AbstractCluster.java b/mr/src/main/java/org/apache/mahout/clustering/AbstractCluster.java new file mode 100644 index 0000000..cc05beb --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/clustering/AbstractCluster.java @@ -0,0 +1,391 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.HashMap; + +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import org.apache.hadoop.conf.Configuration; +import org.apache.mahout.common.parameters.Parameter; +import org.apache.mahout.math.RandomAccessSparseVector; +import org.apache.mahout.math.SequentialAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.Vector.Element; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.function.Functions; +import org.apache.mahout.math.function.SquareRootFunction; +import org.codehaus.jackson.map.ObjectMapper; + +public abstract class AbstractCluster implements Cluster { + + // cluster persistent state + private int id; + + private long numObservations; + + private long totalObservations; + + private Vector center; + + private Vector radius; + + // the observation statistics + private double s0; + + private Vector s1; + + private Vector s2; + + private static final ObjectMapper jxn = new ObjectMapper(); + + protected AbstractCluster() {} + + protected AbstractCluster(Vector point, int id2) { + this.numObservations = (long) 0; + this.totalObservations = (long) 0; + this.center = point.clone(); + this.radius = center.like(); + this.s0 = (double) 0; + this.s1 = center.like(); + this.s2 = center.like(); + this.id = id2; + } + + protected AbstractCluster(Vector center2, Vector radius2, int id2) { + this.numObservations = (long) 0; + this.totalObservations = (long) 0; + this.center = new RandomAccessSparseVector(center2); + this.radius = new RandomAccessSparseVector(radius2); + this.s0 = (double) 0; + this.s1 = center.like(); + this.s2 = center.like(); + this.id = id2; + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeInt(id); + out.writeLong(getNumObservations()); + out.writeLong(getTotalObservations()); + VectorWritable.writeVector(out, getCenter()); + VectorWritable.writeVector(out, getRadius()); + out.writeDouble(s0); + VectorWritable.writeVector(out, s1); + VectorWritable.writeVector(out, s2); + } + + @Override + public void readFields(DataInput in) throws IOException { + this.id = in.readInt(); + this.setNumObservations(in.readLong()); + this.setTotalObservations(in.readLong()); + this.setCenter(VectorWritable.readVector(in)); + this.setRadius(VectorWritable.readVector(in)); + this.setS0(in.readDouble()); + this.setS1(VectorWritable.readVector(in)); + this.setS2(VectorWritable.readVector(in)); + } + + @Override + public void configure(Configuration job) { + // nothing to do + } + + @Override + public Collection<Parameter<?>> getParameters() { + return Collections.emptyList(); + } + + @Override + public void createParameters(String prefix, Configuration jobConf) { + // nothing to do + } + + @Override + public int getId() { + return id; + } + + /** + * @param id + * the id to set + */ + protected void setId(int id) { + this.id = id; + } + + @Override + public long getNumObservations() { + return numObservations; + } + + /** + * @param l + * the numPoints to set + */ + protected void setNumObservations(long l) { + this.numObservations = l; + } + + @Override + public long getTotalObservations() { + return totalObservations; + } + + protected void setTotalObservations(long totalPoints) { + this.totalObservations = totalPoints; + } + + @Override + public Vector getCenter() { + return center; + } + + /** + * @param center + * the center to set + */ + protected void setCenter(Vector center) { + this.center = center; + } + + @Override + public Vector getRadius() { + return radius; + } + + /** + * @param radius + * the radius to set + */ + protected void setRadius(Vector radius) { + this.radius = radius; + } + + /** + * @return the s0 + */ + protected double getS0() { + return s0; + } + + protected void setS0(double s0) { + this.s0 = s0; + } + + /** + * @return the s1 + */ + protected Vector getS1() { + return s1; + } + + protected void setS1(Vector s1) { + this.s1 = s1; + } + + /** + * @return the s2 + */ + protected Vector getS2() { + return s2; + } + + protected void setS2(Vector s2) { + this.s2 = s2; + } + + @Override + public void observe(Model<VectorWritable> x) { + AbstractCluster cl = (AbstractCluster) x; + setS0(getS0() + cl.getS0()); + setS1(getS1().plus(cl.getS1())); + setS2(getS2().plus(cl.getS2())); + } + + @Override + public void observe(VectorWritable x) { + observe(x.get()); + } + + @Override + public void observe(VectorWritable x, double weight) { + observe(x.get(), weight); + } + + public void observe(Vector x, double weight) { + if (weight == 1.0) { + observe(x); + } else { + setS0(getS0() + weight); + Vector weightedX = x.times(weight); + if (getS1() == null) { + setS1(weightedX); + } else { + getS1().assign(weightedX, Functions.PLUS); + } + Vector x2 = x.times(x).times(weight); + if (getS2() == null) { + setS2(x2); + } else { + getS2().assign(x2, Functions.PLUS); + } + } + } + + public void observe(Vector x) { + setS0(getS0() + 1); + if (getS1() == null) { + setS1(x.clone()); + } else { + getS1().assign(x, Functions.PLUS); + } + Vector x2 = x.times(x); + if (getS2() == null) { + setS2(x2); + } else { + getS2().assign(x2, Functions.PLUS); + } + } + + + @Override + public void computeParameters() { + if (getS0() == 0) { + return; + } + setNumObservations((long) getS0()); + setTotalObservations(getTotalObservations() + getNumObservations()); + setCenter(getS1().divide(getS0())); + // compute the component stds + if (getS0() > 1) { + setRadius(getS2().times(getS0()).minus(getS1().times(getS1())).assign(new SquareRootFunction()).divide(getS0())); + } + setS0(0); + setS1(center.like()); + setS2(center.like()); + } + + @Override + public String asFormatString(String[] bindings) { + String fmtString = ""; + try { + fmtString = jxn.writeValueAsString(asJson(bindings)); + } catch (IOException e) { + log.error("Error writing JSON as String.", e); + } + return fmtString; + } + + public Map<String,Object> asJson(String[] bindings) { + Map<String,Object> dict = new HashMap<>(); + dict.put("identifier", getIdentifier()); + dict.put("n", getNumObservations()); + if (getCenter() != null) { + try { + dict.put("c", formatVectorAsJson(getCenter(), bindings)); + } catch (IOException e) { + log.error("IOException: ", e); + } + } + if (getRadius() != null) { + try { + dict.put("r", formatVectorAsJson(getRadius(), bindings)); + } catch (IOException e) { + log.error("IOException: ", e); + } + } + return dict; + } + + public abstract String getIdentifier(); + + /** + * Compute the centroid by averaging the pointTotals + * + * @return the new centroid + */ + public Vector computeCentroid() { + return getS0() == 0 ? getCenter() : getS1().divide(getS0()); + } + + /** + * Return a human-readable formatted string representation of the vector, not + * intended to be complete nor usable as an input/output representation + */ + public static String formatVector(Vector v, String[] bindings) { + String fmtString = ""; + try { + fmtString = jxn.writeValueAsString(formatVectorAsJson(v, bindings)); + } catch (IOException e) { + log.error("Error writing JSON as String.", e); + } + return fmtString; + } + + /** + * Create a List of HashMaps containing vector terms and weights + * + * @return List<Object> + */ + public static List<Object> formatVectorAsJson(Vector v, String[] bindings) throws IOException { + + boolean hasBindings = bindings != null; + boolean isSparse = !v.isDense() && v.getNumNondefaultElements() != v.size(); + + // we assume sequential access in the output + Vector provider = v.isSequentialAccess() ? v : new SequentialAccessSparseVector(v); + + List<Object> terms = Lists.newLinkedList(); + String term = ""; + + for (Element elem : provider.nonZeroes()) { + + if (hasBindings && bindings.length >= elem.index() + 1 && bindings[elem.index()] != null) { + term = bindings[elem.index()]; + } else if (hasBindings || isSparse) { + term = String.valueOf(elem.index()); + } + + Map<String, Object> term_entry = Maps.newHashMap(); + double roundedWeight = (double) Math.round(elem.get() * 1000) / 1000; + if (hasBindings || isSparse) { + term_entry.put(term, roundedWeight); + terms.add(term_entry); + } else { + terms.add(roundedWeight); + } + } + + return terms; + } + + @Override + public boolean isConverged() { + // Convergence has no meaning yet, perhaps in subclasses + return false; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/Cluster.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/clustering/Cluster.java b/mr/src/main/java/org/apache/mahout/clustering/Cluster.java new file mode 100644 index 0000000..07d6927 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/clustering/Cluster.java @@ -0,0 +1,90 @@ +/* Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering; + +import org.apache.mahout.common.parameters.Parametered; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; + +import java.util.Map; + +/** + * Implementations of this interface have a printable representation and certain + * attributes that are common across all clustering implementations + * + */ +public interface Cluster extends Model<VectorWritable>, Parametered { + + // default directory for initial clusters to prime iterative clustering + // algorithms + String INITIAL_CLUSTERS_DIR = "clusters-0"; + + // default directory for output of clusters per iteration + String CLUSTERS_DIR = "clusters-"; + + // default suffix for output of clusters for final iteration + String FINAL_ITERATION_SUFFIX = "-final"; + + /** + * Get the id of the Cluster + * + * @return a unique integer + */ + int getId(); + + /** + * Get the "center" of the Cluster as a Vector + * + * @return a Vector + */ + Vector getCenter(); + + /** + * Get the "radius" of the Cluster as a Vector. Usually the radius is the + * standard deviation expressed as a Vector of size equal to the center. Some + * clusters may return zero values if not appropriate. + * + * @return aVector + */ + Vector getRadius(); + + /** + * Produce a custom, human-friendly, printable representation of the Cluster. + * + * @param bindings + * an optional String[] containing labels used to format the primary + * Vector/s of this implementation. + * @return a String + */ + String asFormatString(String[] bindings); + + /** + * Produce a JSON representation of the Cluster. + * + * @param bindings + * an optional String[] containing labels used to format the primary + * Vector/s of this implementation. + * @return a Map + */ + Map<String,Object> asJson(String[] bindings); + + /** + * @return if the receiver has converged, or false if that has no meaning for + * the implementation + */ + boolean isConverged(); + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/ClusteringUtils.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/clustering/ClusteringUtils.java b/mr/src/main/java/org/apache/mahout/clustering/ClusteringUtils.java new file mode 100644 index 0000000..421ffcf --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/clustering/ClusteringUtils.java @@ -0,0 +1,305 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering; + +import java.util.List; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.common.distance.EuclideanDistanceMeasure; +import org.apache.mahout.math.Centroid; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.WeightedVector; +import org.apache.mahout.math.neighborhood.BruteSearch; +import org.apache.mahout.math.neighborhood.ProjectionSearch; +import org.apache.mahout.math.neighborhood.Searcher; +import org.apache.mahout.math.neighborhood.UpdatableSearcher; +import org.apache.mahout.math.random.WeightedThing; +import org.apache.mahout.math.stats.OnlineSummarizer; + +public final class ClusteringUtils { + private ClusteringUtils() { + } + + /** + * Computes the summaries for the distances in each cluster. + * @param datapoints iterable of datapoints. + * @param centroids iterable of Centroids. + * @return a list of OnlineSummarizers where the i-th element is the summarizer corresponding to the cluster whose + * index is i. + */ + public static List<OnlineSummarizer> summarizeClusterDistances(Iterable<? extends Vector> datapoints, + Iterable<? extends Vector> centroids, + DistanceMeasure distanceMeasure) { + UpdatableSearcher searcher = new ProjectionSearch(distanceMeasure, 3, 1); + searcher.addAll(centroids); + List<OnlineSummarizer> summarizers = Lists.newArrayList(); + if (searcher.size() == 0) { + return summarizers; + } + for (int i = 0; i < searcher.size(); ++i) { + summarizers.add(new OnlineSummarizer()); + } + for (Vector v : datapoints) { + Centroid closest = (Centroid)searcher.search(v, 1).get(0).getValue(); + OnlineSummarizer summarizer = summarizers.get(closest.getIndex()); + summarizer.add(distanceMeasure.distance(v, closest)); + } + return summarizers; + } + + /** + * Adds up the distances from each point to its closest cluster and returns the sum. + * @param datapoints iterable of datapoints. + * @param centroids iterable of Centroids. + * @return the total cost described above. + */ + public static double totalClusterCost(Iterable<? extends Vector> datapoints, Iterable<? extends Vector> centroids) { + DistanceMeasure distanceMeasure = new EuclideanDistanceMeasure(); + UpdatableSearcher searcher = new ProjectionSearch(distanceMeasure, 3, 1); + searcher.addAll(centroids); + return totalClusterCost(datapoints, searcher); + } + + /** + * Adds up the distances from each point to its closest cluster and returns the sum. + * @param datapoints iterable of datapoints. + * @param centroids searcher of Centroids. + * @return the total cost described above. + */ + public static double totalClusterCost(Iterable<? extends Vector> datapoints, Searcher centroids) { + double totalCost = 0; + for (Vector vector : datapoints) { + totalCost += centroids.searchFirst(vector, false).getWeight(); + } + return totalCost; + } + + /** + * Estimates the distance cutoff. In StreamingKMeans, the distance between two vectors divided + * by this value is used as a probability threshold when deciding whether to form a new cluster + * or not. + * Small values (comparable to the minimum distance between two points) are preferred as they + * guarantee with high likelihood that all but very close points are put in separate clusters + * initially. The clusters themselves are actually collapsed periodically when their number goes + * over the maximum number of clusters and the distanceCutoff is increased. + * So, the returned value is only an initial estimate. + * @param data the datapoints whose distance is to be estimated. + * @param distanceMeasure the distance measure used to compute the distance between two points. + * @return the minimum distance between the first sampleLimit points + * @see org.apache.mahout.clustering.streaming.cluster.StreamingKMeans#clusterInternal(Iterable, boolean) + */ + public static double estimateDistanceCutoff(List<? extends Vector> data, DistanceMeasure distanceMeasure) { + BruteSearch searcher = new BruteSearch(distanceMeasure); + searcher.addAll(data); + double minDistance = Double.POSITIVE_INFINITY; + for (Vector vector : data) { + double closest = searcher.searchFirst(vector, true).getWeight(); + if (minDistance > 0 && closest < minDistance) { + minDistance = closest; + } + searcher.add(vector); + } + return minDistance; + } + + public static <T extends Vector> double estimateDistanceCutoff( + Iterable<T> data, DistanceMeasure distanceMeasure, int sampleLimit) { + return estimateDistanceCutoff(Lists.newArrayList(Iterables.limit(data, sampleLimit)), distanceMeasure); + } + + /** + * Computes the Davies-Bouldin Index for a given clustering. + * See http://en.wikipedia.org/wiki/Clustering_algorithm#Internal_evaluation + * @param centroids list of centroids + * @param distanceMeasure distance measure for inter-cluster distances + * @param clusterDistanceSummaries summaries of the clusters; See summarizeClusterDistances + * @return the Davies-Bouldin Index + */ + public static double daviesBouldinIndex(List<? extends Vector> centroids, DistanceMeasure distanceMeasure, + List<OnlineSummarizer> clusterDistanceSummaries) { + Preconditions.checkArgument(centroids.size() == clusterDistanceSummaries.size(), + "Number of centroids and cluster summaries differ."); + int n = centroids.size(); + double totalDBIndex = 0; + // The inner loop shouldn't be reduced for j = i + 1 to n because the computation of the Davies-Bouldin + // index is not really symmetric. + // For a given cluster i, we look for a cluster j that maximizes the ratio of the sum of average distances + // from points in cluster i to its center and and points in cluster j to its center to the distance between + // cluster i and cluster j. + // The maximization is the key issue, as the cluster that maximizes this ratio might be j for i but is NOT + // NECESSARILY i for j. + for (int i = 0; i < n; ++i) { + double averageDistanceI = clusterDistanceSummaries.get(i).getMean(); + double maxDBIndex = 0; + for (int j = 0; j < n; ++j) { + if (i != j) { + double dbIndex = (averageDistanceI + clusterDistanceSummaries.get(j).getMean()) + / distanceMeasure.distance(centroids.get(i), centroids.get(j)); + if (dbIndex > maxDBIndex) { + maxDBIndex = dbIndex; + } + } + } + totalDBIndex += maxDBIndex; + } + return totalDBIndex / n; + } + + /** + * Computes the Dunn Index of a given clustering. See http://en.wikipedia.org/wiki/Dunn_index + * @param centroids list of centroids + * @param distanceMeasure distance measure to compute inter-centroid distance with + * @param clusterDistanceSummaries summaries of the clusters; See summarizeClusterDistances + * @return the Dunn Index + */ + public static double dunnIndex(List<? extends Vector> centroids, DistanceMeasure distanceMeasure, + List<OnlineSummarizer> clusterDistanceSummaries) { + Preconditions.checkArgument(centroids.size() == clusterDistanceSummaries.size(), + "Number of centroids and cluster summaries differ."); + int n = centroids.size(); + // Intra-cluster distances will come from the OnlineSummarizer, and will be the median distance (noting that + // the median for just one value is that value). + // A variety of metrics can be used for the intra-cluster distance including max distance between two points, + // mean distance, etc. Median distance was chosen as this is more robust to outliers and characterizes the + // distribution of distances (from a point to the center) better. + double maxIntraClusterDistance = 0; + for (OnlineSummarizer summarizer : clusterDistanceSummaries) { + if (summarizer.getCount() > 0) { + double intraClusterDistance; + if (summarizer.getCount() == 1) { + intraClusterDistance = summarizer.getMean(); + } else { + intraClusterDistance = summarizer.getMedian(); + } + if (maxIntraClusterDistance < intraClusterDistance) { + maxIntraClusterDistance = intraClusterDistance; + } + } + } + double minDunnIndex = Double.POSITIVE_INFINITY; + for (int i = 0; i < n; ++i) { + // Distances are symmetric, so d(i, j) = d(j, i). + for (int j = i + 1; j < n; ++j) { + double dunnIndex = distanceMeasure.distance(centroids.get(i), centroids.get(j)); + if (minDunnIndex > dunnIndex) { + minDunnIndex = dunnIndex; + } + } + } + return minDunnIndex / maxIntraClusterDistance; + } + + public static double choose2(double n) { + return n * (n - 1) / 2; + } + + /** + * Creates a confusion matrix by searching for the closest cluster of both the row clustering and column clustering + * of a point and adding its weight to that cell of the matrix. + * It doesn't matter which clustering is the row clustering and which is the column clustering. If they're + * interchanged, the resulting matrix is the transpose of the original one. + * @param rowCentroids clustering one + * @param columnCentroids clustering two + * @param datapoints datapoints whose closest cluster we need to find + * @param distanceMeasure distance measure to use + * @return the confusion matrix + */ + public static Matrix getConfusionMatrix(List<? extends Vector> rowCentroids, List<? extends Vector> columnCentroids, + Iterable<? extends Vector> datapoints, DistanceMeasure distanceMeasure) { + Searcher rowSearcher = new BruteSearch(distanceMeasure); + rowSearcher.addAll(rowCentroids); + Searcher columnSearcher = new BruteSearch(distanceMeasure); + columnSearcher.addAll(columnCentroids); + + int numRows = rowCentroids.size(); + int numCols = columnCentroids.size(); + Matrix confusionMatrix = new DenseMatrix(numRows, numCols); + + for (Vector vector : datapoints) { + WeightedThing<Vector> closestRowCentroid = rowSearcher.search(vector, 1).get(0); + WeightedThing<Vector> closestColumnCentroid = columnSearcher.search(vector, 1).get(0); + int row = ((Centroid) closestRowCentroid.getValue()).getIndex(); + int column = ((Centroid) closestColumnCentroid.getValue()).getIndex(); + double vectorWeight; + if (vector instanceof WeightedVector) { + vectorWeight = ((WeightedVector) vector).getWeight(); + } else { + vectorWeight = 1; + } + confusionMatrix.set(row, column, confusionMatrix.get(row, column) + vectorWeight); + } + + return confusionMatrix; + } + + /** + * Computes the Adjusted Rand Index for a given confusion matrix. + * @param confusionMatrix confusion matrix; not to be confused with the more restrictive ConfusionMatrix class + * @return the Adjusted Rand Index + */ + public static double getAdjustedRandIndex(Matrix confusionMatrix) { + int numRows = confusionMatrix.numRows(); + int numCols = confusionMatrix.numCols(); + double rowChoiceSum = 0; + double columnChoiceSum = 0; + double totalChoiceSum = 0; + double total = 0; + for (int i = 0; i < numRows; ++i) { + double rowSum = 0; + for (int j = 0; j < numCols; ++j) { + rowSum += confusionMatrix.get(i, j); + totalChoiceSum += choose2(confusionMatrix.get(i, j)); + } + total += rowSum; + rowChoiceSum += choose2(rowSum); + } + for (int j = 0; j < numCols; ++j) { + double columnSum = 0; + for (int i = 0; i < numRows; ++i) { + columnSum += confusionMatrix.get(i, j); + } + columnChoiceSum += choose2(columnSum); + } + double rowColumnChoiceSumDivTotal = rowChoiceSum * columnChoiceSum / choose2(total); + return (totalChoiceSum - rowColumnChoiceSumDivTotal) + / ((rowChoiceSum + columnChoiceSum) / 2 - rowColumnChoiceSumDivTotal); + } + + /** + * Computes the total weight of the points in the given Vector iterable. + * @param data iterable of points + * @return total weight + */ + public static double totalWeight(Iterable<? extends Vector> data) { + double sum = 0; + for (Vector row : data) { + Preconditions.checkNotNull(row); + if (row instanceof WeightedVector) { + sum += ((WeightedVector)row).getWeight(); + } else { + sum++; + } + } + return sum; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/GaussianAccumulator.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/clustering/GaussianAccumulator.java b/mr/src/main/java/org/apache/mahout/clustering/GaussianAccumulator.java new file mode 100644 index 0000000..c25e039 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/clustering/GaussianAccumulator.java @@ -0,0 +1,62 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering; + +import org.apache.mahout.math.Vector; + +public interface GaussianAccumulator { + + /** + * @return the number of observations + */ + double getN(); + + /** + * @return the mean of the observations + */ + Vector getMean(); + + /** + * @return the std of the observations + */ + Vector getStd(); + + /** + * @return the average of the vector std elements + */ + double getAverageStd(); + + /** + * @return the variance of the observations + */ + Vector getVariance(); + + /** + * Observe the vector + * + * @param x a Vector + * @param weight the double observation weight (usually 1.0) + */ + void observe(Vector x, double weight); + + /** + * Compute the mean, variance and standard deviation + */ + void compute(); + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/Model.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/clustering/Model.java b/mr/src/main/java/org/apache/mahout/clustering/Model.java new file mode 100644 index 0000000..79dab30 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/clustering/Model.java @@ -0,0 +1,93 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering; + +import org.apache.hadoop.io.Writable; +import org.apache.mahout.math.VectorWritable; + +/** + * A model is a probability distribution over observed data points and allows + * the probability of any data point to be computed. All Models have a + * persistent representation and extend + * WritablesampleFromPosterior(Model<VectorWritable>[]) + */ +public interface Model<O> extends Writable { + + /** + * Return the probability that the observation is described by this model + * + * @param x + * an Observation from the posterior + * @return the probability that x is in the receiver + */ + double pdf(O x); + + /** + * Observe the given observation, retaining information about it + * + * @param x + * an Observation from the posterior + */ + void observe(O x); + + /** + * Observe the given observation, retaining information about it + * + * @param x + * an Observation from the posterior + * @param weight + * a double weighting factor + */ + void observe(O x, double weight); + + /** + * Observe the given model, retaining information about its observations + * + * @param x + * a Model<0> + */ + void observe(Model<O> x); + + /** + * Compute a new set of posterior parameters based upon the Observations that + * have been observed since my creation + */ + void computeParameters(); + + /** + * Return the number of observations that this model has seen since its + * parameters were last computed + * + * @return a long + */ + long getNumObservations(); + + /** + * Return the number of observations that this model has seen over its + * lifetime + * + * @return a long + */ + long getTotalObservations(); + + /** + * @return a sample of my posterior model + */ + Model<VectorWritable> sampleFromPosterior(); + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/ModelDistribution.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/clustering/ModelDistribution.java b/mr/src/main/java/org/apache/mahout/clustering/ModelDistribution.java new file mode 100644 index 0000000..d77bf40 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/clustering/ModelDistribution.java @@ -0,0 +1,41 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering; + +/** A model distribution allows us to sample a model from its prior distribution. */ +public interface ModelDistribution<O> { + + /** + * Return a list of models sampled from the prior + * + * @param howMany + * the int number of models to return + * @return a Model<Observation>[] representing what is known apriori + */ + Model<O>[] sampleFromPrior(int howMany); + + /** + * Return a list of models sampled from the posterior + * + * @param posterior + * the Model<Observation>[] after observations + * @return a Model<Observation>[] representing what is known apriori + */ + Model<O>[] sampleFromPosterior(Model<O>[] posterior); + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/OnlineGaussianAccumulator.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/clustering/OnlineGaussianAccumulator.java b/mr/src/main/java/org/apache/mahout/clustering/OnlineGaussianAccumulator.java new file mode 100644 index 0000000..b76e00f --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/clustering/OnlineGaussianAccumulator.java @@ -0,0 +1,107 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering; + +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.function.SquareRootFunction; + +/** + * An online Gaussian statistics accumulator based upon Knuth (who cites Welford) which is declared to be + * numerically-stable. See http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance + */ +public class OnlineGaussianAccumulator implements GaussianAccumulator { + + private double sumWeight; + private Vector mean; + private Vector s; + private Vector variance; + + @Override + public double getN() { + return sumWeight; + } + + @Override + public Vector getMean() { + return mean; + } + + @Override + public Vector getStd() { + return variance.clone().assign(new SquareRootFunction()); + } + + /* from Wikipedia: http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance + * + * Weighted incremental algorithm + * + * def weighted_incremental_variance(dataWeightPairs): + * mean = 0 + * S = 0 + * sumweight = 0 + * for x, weight in dataWeightPairs: # Alternately "for x in zip(data, weight):" + * temp = weight + sumweight + * Q = x - mean + * R = Q * weight / temp + * S = S + sumweight * Q * R + * mean = mean + R + * sumweight = temp + * Variance = S / (sumweight-1) # if sample is the population, omit -1 + * return Variance + */ + @Override + public void observe(Vector x, double weight) { + double temp = weight + sumWeight; + Vector q; + if (mean == null) { + mean = x.like(); + q = x.clone(); + } else { + q = x.minus(mean); + } + Vector r = q.times(weight).divide(temp); + if (s == null) { + s = q.times(sumWeight).times(r); + } else { + s = s.plus(q.times(sumWeight).times(r)); + } + mean = mean.plus(r); + sumWeight = temp; + variance = s.divide(sumWeight - 1); // # if sample is the population, omit -1 + } + + @Override + public void compute() { + // nothing to do here! + } + + @Override + public double getAverageStd() { + if (sumWeight == 0.0) { + return 0.0; + } else { + Vector std = getStd(); + return std.zSum() / std.size(); + } + } + + @Override + public Vector getVariance() { + return variance; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/RunningSumsGaussianAccumulator.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/clustering/RunningSumsGaussianAccumulator.java b/mr/src/main/java/org/apache/mahout/clustering/RunningSumsGaussianAccumulator.java new file mode 100644 index 0000000..138e830 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/clustering/RunningSumsGaussianAccumulator.java @@ -0,0 +1,90 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering; + +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.function.Functions; +import org.apache.mahout.math.function.SquareRootFunction; + +/** + * An online Gaussian accumulator that uses a running power sums approach as reported + * on http://en.wikipedia.org/wiki/Standard_deviation + * Suffers from overflow, underflow and roundoff error but has minimal observe-time overhead + */ +public class RunningSumsGaussianAccumulator implements GaussianAccumulator { + + private double s0; + private Vector s1; + private Vector s2; + private Vector mean; + private Vector std; + + @Override + public double getN() { + return s0; + } + + @Override + public Vector getMean() { + return mean; + } + + @Override + public Vector getStd() { + return std; + } + + @Override + public double getAverageStd() { + if (s0 == 0.0) { + return 0.0; + } else { + return std.zSum() / std.size(); + } + } + + @Override + public Vector getVariance() { + return std.times(std); + } + + @Override + public void observe(Vector x, double weight) { + s0 += weight; + Vector weightedX = x.times(weight); + if (s1 == null) { + s1 = weightedX; + } else { + s1.assign(weightedX, Functions.PLUS); + } + Vector x2 = x.times(x).times(weight); + if (s2 == null) { + s2 = x2; + } else { + s2.assign(x2, Functions.PLUS); + } + } + + @Override + public void compute() { + if (s0 != 0.0) { + mean = s1.divide(s0); + std = s2.times(s0).minus(s1.times(s1)).assign(new SquareRootFunction()).divide(s0); + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/UncommonDistributions.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/clustering/UncommonDistributions.java b/mr/src/main/java/org/apache/mahout/clustering/UncommonDistributions.java new file mode 100644 index 0000000..ef43e1b --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/clustering/UncommonDistributions.java @@ -0,0 +1,136 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering; + +import org.apache.commons.math3.distribution.NormalDistribution; +import org.apache.commons.math3.distribution.RealDistribution; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.common.RandomWrapper; + +public final class UncommonDistributions { + + private static final RandomWrapper RANDOM = RandomUtils.getRandom(); + + private UncommonDistributions() {} + + // =============== start of BSD licensed code. See LICENSE.txt + /** + * Returns a double sampled according to this distribution. Uniformly fast for all k > 0. (Reference: + * Non-Uniform Random Variate Generation, Devroye http://cgm.cs.mcgill.ca/~luc/rnbookindex.html) Uses + * Cheng's rejection algorithm (GB) for k>=1, rejection from Weibull distribution for 0 < k < 1. + */ + public static double rGamma(double k, double lambda) { + boolean accept = false; + if (k >= 1.0) { + // Cheng's algorithm + double b = k - Math.log(4.0); + double c = k + Math.sqrt(2.0 * k - 1.0); + double lam = Math.sqrt(2.0 * k - 1.0); + double cheng = 1.0 + Math.log(4.5); + double x; + do { + double u = RANDOM.nextDouble(); + double v = RANDOM.nextDouble(); + double y = 1.0 / lam * Math.log(v / (1.0 - v)); + x = k * Math.exp(y); + double z = u * v * v; + double r = b + c * y - x; + if (r >= 4.5 * z - cheng || r >= Math.log(z)) { + accept = true; + } + } while (!accept); + return x / lambda; + } else { + // Weibull algorithm + double c = 1.0 / k; + double d = (1.0 - k) * Math.pow(k, k / (1.0 - k)); + double x; + do { + double u = RANDOM.nextDouble(); + double v = RANDOM.nextDouble(); + double z = -Math.log(u); + double e = -Math.log(v); + x = Math.pow(z, c); + if (z + e >= d + x) { + accept = true; + } + } while (!accept); + return x / lambda; + } + } + + // ============= end of BSD licensed code + + /** + * Returns a random sample from a beta distribution with the given shapes + * + * @param shape1 + * a double representing shape1 + * @param shape2 + * a double representing shape2 + * @return a Vector of samples + */ + public static double rBeta(double shape1, double shape2) { + double gam1 = rGamma(shape1, 1.0); + double gam2 = rGamma(shape2, 1.0); + return gam1 / (gam1 + gam2); + + } + + /** + * Return a random value from a normal distribution with the given mean and standard deviation + * + * @param mean + * a double mean value + * @param sd + * a double standard deviation + * @return a double sample + */ + public static double rNorm(double mean, double sd) { + RealDistribution dist = new NormalDistribution(RANDOM.getRandomGenerator(), + mean, + sd, + NormalDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY); + return dist.sample(); + } + + /** + * Returns an integer sampled according to this distribution. Takes time proportional to np + 1. (Reference: + * Non-Uniform Random Variate Generation, Devroye http://cgm.cs.mcgill.ca/~luc/rnbookindex.html) Second + * time-waiting algorithm. + */ + public static int rBinomial(int n, double p) { + if (p >= 1.0) { + return n; // needed to avoid infinite loops and negative results + } + double q = -Math.log1p(-p); + double sum = 0.0; + int x = 0; + while (sum <= q) { + double u = RANDOM.nextDouble(); + double e = -Math.log(u); + sum += e / (n - x); + x++; + } + if (x == 0) { + return 0; + } + return x - 1; + } + +}
