Author: tdunning
Date: Thu Aug 19 07:26:25 2010
New Revision: 987048

URL: http://svn.apache.org/viewvc?rev=987048&view=rev
Log:
Added copyable to CFL and OLR

Added:
    mahout/trunk/core/src/main/java/org/apache/mahout/ep/Copyable.java
Modified:
    
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
    
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java
    mahout/trunk/core/src/main/java/org/apache/mahout/ep/State.java

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java?rev=987048&r1=987047&r2=987048&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
 Thu Aug 19 07:26:25 2010
@@ -3,6 +3,7 @@ package org.apache.mahout.classifier.sgd
 import com.google.common.collect.Lists;
 import org.apache.mahout.classifier.AbstractVectorClassifier;
 import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.ep.Copyable;
 import org.apache.mahout.math.DenseVector;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.function.Functions;
@@ -18,7 +19,7 @@ import java.util.List;
  * time the training data is traversed or a tracking key such as the file 
offset of the training
  * record should be passed with each training example.
  */
-class CrossFoldLearner extends AbstractVectorClassifier implements 
OnlineLearner, Comparable<CrossFoldLearner> {
+public class CrossFoldLearner extends AbstractVectorClassifier implements 
OnlineLearner, Copyable<CrossFoldLearner>, Comparable<CrossFoldLearner> {
   private static volatile int nextId = 0;
 
   private final int id = nextId++;
@@ -29,8 +30,12 @@ class CrossFoldLearner extends AbstractV
 
   // lambda, learningRate, perTermOffset, perTermExponent
   private double[] parameters = new double[4];
+  private int numFeatures;
+  private PriorFunction prior;
 
   CrossFoldLearner(int folds, int numCategories, int numFeatures, 
PriorFunction prior) {
+    this.numFeatures = numFeatures;
+    this.prior = prior;
     for (int i = 0; i < folds; i++) {
       OnlineLogisticRegression model = new 
OnlineLogisticRegression(numCategories, numFeatures, prior);
       model.alpha(1).stepOffset(0).decayExponent(0);
@@ -175,4 +180,14 @@ class CrossFoldLearner extends AbstractV
   public boolean equals(Object other) {
     return other instanceof CrossFoldLearner && id == ((CrossFoldLearner) 
other).id;
   }
+
+  @Override
+  public CrossFoldLearner copy() {
+    CrossFoldLearner r = new CrossFoldLearner(models.size(), numCategories(), 
numFeatures, prior);
+    r.models.clear();
+    for (OnlineLogisticRegression model : models) {
+      r.models.add(model.copy());
+    }
+    return r;
+  }
 }

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java?rev=987048&r1=987047&r2=987048&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java
 Thu Aug 19 07:26:25 2010
@@ -17,6 +17,8 @@
 
 package org.apache.mahout.classifier.sgd;
 
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.ep.Copyable;
 import org.apache.mahout.math.DenseMatrix;
 import org.apache.mahout.math.DenseVector;
 import org.apache.mahout.math.Vector;
@@ -25,7 +27,7 @@ import org.apache.mahout.math.Vector;
  * Extends the basic on-line logistic regression learner with a specific set 
of learning
  * rate annealing schedules.
  */
-public class OnlineLogisticRegression extends AbstractOnlineLogisticRegression 
{
+public class OnlineLogisticRegression extends AbstractOnlineLogisticRegression 
implements Copyable<OnlineLogisticRegression> {
   // these next two control decayFactor^steps exponential type of annealing
   // learning rate and decay factor
   private double mu0 = 1;
@@ -119,4 +121,11 @@ public class OnlineLogisticRegression ex
 
     perTermAnnealingOffset = other.perTermAnnealingOffset;
   }
+
+  @Override
+  public OnlineLogisticRegression copy() {
+    OnlineLogisticRegression r = new OnlineLogisticRegression(numCategories(), 
numFeatures(), prior);
+    r.copyFrom(this);
+    return r;
+  }
 }

Added: mahout/trunk/core/src/main/java/org/apache/mahout/ep/Copyable.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/ep/Copyable.java?rev=987048&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/ep/Copyable.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/ep/Copyable.java Thu Aug 
19 07:26:25 2010
@@ -0,0 +1,5 @@
+package org.apache.mahout.ep;
+
+public interface Copyable<T> {
+  public T copy();
+}

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/ep/State.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/ep/State.java?rev=987048&r1=987047&r2=987048&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/ep/State.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/ep/State.java Thu Aug 19 
07:26:25 2010
@@ -7,7 +7,7 @@ import java.util.Random;
  * Recorded step evolutionary optimization.  You provide the value, this class 
provides the
  * mutation.
  */
-public class State implements Comparable<State> {
+public class State<T extends Copyable<T>> implements Copyable<State<T>>, 
Comparable<State<T>> {
   // object count is kept to break ties in comparison.
   static volatile int objectCount = 0;
   private Random gen = new Random();
@@ -29,6 +29,11 @@ public class State implements Comparable
   // current fitness value
   private double value;
 
+  private T payload;
+
+  private State() {
+  }
+
   /**
    * Invent a new state with no momentum (yet).
    */
@@ -40,17 +45,17 @@ public class State implements Comparable
   }
 
   /**
-   * Deep clones a state, useful in mutation.
-   *
-   * @param params Current state
-   * @param omni   Current omni-directional mutation
-   * @param step   The step taken to get to this point
+   * Deep copies a state, useful in mutation.
    */
-  private State(double[] params, double omni, double[] step, Mapping[] maps) {
-    this.params = Arrays.copyOf(params, params.length);
-    this.omni = omni;
-    this.step = Arrays.copyOf(step, step.length);
-    this.maps = Arrays.copyOf(maps, maps.length);
+  public State<T> copy() {
+    State<T> r = new State<T>();
+    r.params = Arrays.copyOf(this.params, this.params.length);
+    r.omni = this.omni;
+    r.step = Arrays.copyOf(this.step, this.step.length);
+    r.maps = Arrays.copyOf(this.maps, this.maps.length);
+    r.payload = this.payload.copy();
+    r.gen = this.gen;
+    return r;
   }
 
   /**
@@ -58,14 +63,14 @@ public class State implements Comparable
    *
    * @return A new state.
    */
-  public State mutate() {
+  public State<T> mutate() {
     double sum = 0;
     for (double v : step) {
       sum += v * v;
     }
     sum = Math.sqrt(sum);
     double lambda = 0.9 + gen.nextGaussian();
-    State r = new State(params, omni, step, maps);
+    State<T> r = this.copy();
     r.omni = -Math.log(1 - gen.nextDouble()) * (0.9 * omni + sum / 10);
     for (int i = 0; i < step.length; i++) {
       r.step[i] = lambda * step[i] + r.omni * gen.nextGaussian();
@@ -150,4 +155,16 @@ public class State implements Comparable
   public void setValue(double v) {
     value = v;
   }
+
+  public T getPayload() {
+    return payload;
+  }
+
+  public double getValue() {
+    return value;
+  }
+
+  public void setPayload(T payload) {
+    this.payload = payload;
+  }
 }


Reply via email to