Author: tdunning
Date: Tue Aug 17 16:22:09 2010
New Revision: 986373

URL: http://svn.apache.org/viewvc?rev=986373&view=rev
Log:
Evolutionary on-line parameter tuning

Modified:
    
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
    
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveAnnealedLogisticRegression.java
    
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
    
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java?rev=986373&r1=986372&r2=986373&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
 Tue Aug 17 16:22:09 2010
@@ -251,4 +251,18 @@ public abstract class AbstractOnlineLogi
       sealed = true;
     }
   }
+
+  public void copyFrom(AbstractOnlineLogisticRegression other) {
+    beta.assign(other.beta);
+
+    // number of categories we are classifying.  This should the number of 
rows of beta plus one.
+    if (numCategories != other.numCategories) {
+      throw new IllegalArgumentException("Can't copy unless number of target 
categories is the same");
+    }
+
+    step = other.step;
+
+    updateSteps.assign(other.updateSteps);
+    updateCounts.assign(other.updateCounts);
+  }
 }

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveAnnealedLogisticRegression.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveAnnealedLogisticRegression.java?rev=986373&r1=986372&r2=986373&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveAnnealedLogisticRegression.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveAnnealedLogisticRegression.java
 Tue Aug 17 16:22:09 2010
@@ -3,8 +3,8 @@ package org.apache.mahout.classifier.sgd
 import com.google.common.collect.Lists;
 import org.apache.mahout.classifier.OnlineLearner;
 import org.apache.mahout.math.Vector;
-import org.apache.mahout.math.jet.random.NegativeBinomial;
 import org.apache.mahout.math.jet.random.engine.MersenneTwister;
+import org.apache.mahout.math.jet.random.engine.RandomEngine;
 
 import java.util.Collections;
 import java.util.List;
@@ -23,13 +23,16 @@ public class AdaptiveAnnealedLogisticReg
   private int record = 0;
   private List<CrossFoldLearner> pool = Lists.newArrayList();
   private int evaluationInterval = 1000;
+  private RandomEngine rand;
+  private int depth = 10;
 
   public AdaptiveAnnealedLogisticRegression(int poolSize, int numCategories, 
int numFeatures, PriorFunction prior) {
     for (int i = 0; i < poolSize; i++) {
       CrossFoldLearner model = new CrossFoldLearner(5, numCategories, 
numFeatures, prior);
       pool.add(model);
     }
-    NegativeBinomial nb = new NegativeBinomial(10, 0.1, new MersenneTwister());
+    depth = poolSize / 5;
+    rand = new MersenneTwister();
   }
 
   @Override
@@ -40,8 +43,12 @@ public class AdaptiveAnnealedLogisticReg
     record++;
     if (record % evaluationInterval == 0) {
       Collections.sort(pool);
-      // pick a parent from the top half of the pool weighted toward the top 
few
+      for (int i = pool.size() / 2; i < pool.size(); i++) {
+        // pick a parent from the top half of the pool weighted toward the top 
few
+        int n = ((int) Math.floor(-depth * Math.log(rand.nextDouble()))) % 
pool.size();
 
+        pool.get(i).copyFrom(pool.get(n));
+      }
     }
   }
 

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java?rev=986373&r1=986372&r2=986373&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
 Tue Aug 17 16:22:09 2010
@@ -139,6 +139,18 @@ class CrossFoldLearner extends AbstractV
   public double logLikelihood() {
     return logLikelihood;
   }
+
+  // -------- evolutionary optimization
+
+  public void copyFrom(CrossFoldLearner other) {
+    int i = 0;
+    for (OnlineLogisticRegression model : models) {
+      model.copyFrom(other.models.get(i++));
+    }
+    System.arraycopy(other.parameters, 0, parameters, 0, parameters.length);
+    // TODO mutate parameters
+  }
+
   // -------- general object and ordering stuff
 
   /**

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java?rev=986373&r1=986372&r2=986373&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java
 Tue Aug 17 16:22:09 2010
@@ -28,7 +28,7 @@ import org.apache.mahout.math.Vector;
 public class OnlineLogisticRegression extends AbstractOnlineLogisticRegression 
{
   // these next two control decayFactor^steps exponential type of annealing
   // learning rate and decay factor
-  private double mu_0 = 1;
+  private double mu0 = 1;
   private double decayFactor = 1 - 1e-3;
 
 
@@ -76,7 +76,7 @@ public class OnlineLogisticRegression ex
    * @return This, so other configurations can be chained.
    */
   public OnlineLogisticRegression learningRate(double learningRate) {
-    this.mu_0 = learningRate;
+    this.mu0 = learningRate;
     return this;
   }
 
@@ -101,11 +101,22 @@ public class OnlineLogisticRegression ex
 
   @Override
   public double currentLearningRate() {
-    return mu_0 * Math.pow(decayFactor, getStep()) * Math.pow(getStep() + 
stepOffset, forgettingExponent);
+    return mu0 * Math.pow(decayFactor, getStep()) * Math.pow(getStep() + 
stepOffset, forgettingExponent);
   }
 
   @Override
   public void train(int trackingKey, int actual, Vector instance) {
     train(actual, instance);
   }
+
+  public void copyFrom(OnlineLogisticRegression other) {
+    super.copyFrom(other);
+    mu0 = other.mu0;
+    decayFactor = other.decayFactor;
+
+    stepOffset = other.stepOffset;
+    forgettingExponent = other.forgettingExponent;
+
+    perTermAnnealingOffset = other.perTermAnnealingOffset;
+  }
 }


Reply via email to