Author: tdunning
Date: Tue Aug 17 16:22:09 2010
New Revision: 986373
URL: http://svn.apache.org/viewvc?rev=986373&view=rev
Log:
Evolutionary on-line parameter tuning
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveAnnealedLogisticRegression.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java?rev=986373&r1=986372&r2=986373&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
Tue Aug 17 16:22:09 2010
@@ -251,4 +251,18 @@ public abstract class AbstractOnlineLogi
sealed = true;
}
}
+
+ public void copyFrom(AbstractOnlineLogisticRegression other) {
+ beta.assign(other.beta);
+
+ // number of categories we are classifying. This should the number of
rows of beta plus one.
+ if (numCategories != other.numCategories) {
+ throw new IllegalArgumentException("Can't copy unless number of target
categories is the same");
+ }
+
+ step = other.step;
+
+ updateSteps.assign(other.updateSteps);
+ updateCounts.assign(other.updateCounts);
+ }
}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveAnnealedLogisticRegression.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveAnnealedLogisticRegression.java?rev=986373&r1=986372&r2=986373&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveAnnealedLogisticRegression.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveAnnealedLogisticRegression.java
Tue Aug 17 16:22:09 2010
@@ -3,8 +3,8 @@ package org.apache.mahout.classifier.sgd
import com.google.common.collect.Lists;
import org.apache.mahout.classifier.OnlineLearner;
import org.apache.mahout.math.Vector;
-import org.apache.mahout.math.jet.random.NegativeBinomial;
import org.apache.mahout.math.jet.random.engine.MersenneTwister;
+import org.apache.mahout.math.jet.random.engine.RandomEngine;
import java.util.Collections;
import java.util.List;
@@ -23,13 +23,16 @@ public class AdaptiveAnnealedLogisticReg
private int record = 0;
private List<CrossFoldLearner> pool = Lists.newArrayList();
private int evaluationInterval = 1000;
+ private RandomEngine rand;
+ private int depth = 10;
public AdaptiveAnnealedLogisticRegression(int poolSize, int numCategories,
int numFeatures, PriorFunction prior) {
for (int i = 0; i < poolSize; i++) {
CrossFoldLearner model = new CrossFoldLearner(5, numCategories,
numFeatures, prior);
pool.add(model);
}
- NegativeBinomial nb = new NegativeBinomial(10, 0.1, new MersenneTwister());
+ depth = poolSize / 5;
+ rand = new MersenneTwister();
}
@Override
@@ -40,8 +43,12 @@ public class AdaptiveAnnealedLogisticReg
record++;
if (record % evaluationInterval == 0) {
Collections.sort(pool);
- // pick a parent from the top half of the pool weighted toward the top
few
+ for (int i = pool.size() / 2; i < pool.size(); i++) {
+ // pick a parent from the top half of the pool weighted toward the top
few
+ int n = ((int) Math.floor(-depth * Math.log(rand.nextDouble()))) %
pool.size();
+ pool.get(i).copyFrom(pool.get(n));
+ }
}
}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java?rev=986373&r1=986372&r2=986373&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
Tue Aug 17 16:22:09 2010
@@ -139,6 +139,18 @@ class CrossFoldLearner extends AbstractV
public double logLikelihood() {
return logLikelihood;
}
+
+ // -------- evolutionary optimization
+
+ public void copyFrom(CrossFoldLearner other) {
+ int i = 0;
+ for (OnlineLogisticRegression model : models) {
+ model.copyFrom(other.models.get(i++));
+ }
+ System.arraycopy(other.parameters, 0, parameters, 0, parameters.length);
+ // TODO mutate parameters
+ }
+
// -------- general object and ordering stuff
/**
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java?rev=986373&r1=986372&r2=986373&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java
Tue Aug 17 16:22:09 2010
@@ -28,7 +28,7 @@ import org.apache.mahout.math.Vector;
public class OnlineLogisticRegression extends AbstractOnlineLogisticRegression
{
// these next two control decayFactor^steps exponential type of annealing
// learning rate and decay factor
- private double mu_0 = 1;
+ private double mu0 = 1;
private double decayFactor = 1 - 1e-3;
@@ -76,7 +76,7 @@ public class OnlineLogisticRegression ex
* @return This, so other configurations can be chained.
*/
public OnlineLogisticRegression learningRate(double learningRate) {
- this.mu_0 = learningRate;
+ this.mu0 = learningRate;
return this;
}
@@ -101,11 +101,22 @@ public class OnlineLogisticRegression ex
@Override
public double currentLearningRate() {
- return mu_0 * Math.pow(decayFactor, getStep()) * Math.pow(getStep() +
stepOffset, forgettingExponent);
+ return mu0 * Math.pow(decayFactor, getStep()) * Math.pow(getStep() +
stepOffset, forgettingExponent);
}
@Override
public void train(int trackingKey, int actual, Vector instance) {
train(actual, instance);
}
+
+ public void copyFrom(OnlineLogisticRegression other) {
+ super.copyFrom(other);
+ mu0 = other.mu0;
+ decayFactor = other.decayFactor;
+
+ stepOffset = other.stepOffset;
+ forgettingExponent = other.forgettingExponent;
+
+ perTermAnnealingOffset = other.perTermAnnealingOffset;
+ }
}