Author: tdunning
Date: Thu Aug 19 07:26:25 2010
New Revision: 987048
URL: http://svn.apache.org/viewvc?rev=987048&view=rev
Log:
Added copyable to CFL and OLR
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/ep/Copyable.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java
mahout/trunk/core/src/main/java/org/apache/mahout/ep/State.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java?rev=987048&r1=987047&r2=987048&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
Thu Aug 19 07:26:25 2010
@@ -3,6 +3,7 @@ package org.apache.mahout.classifier.sgd
import com.google.common.collect.Lists;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.ep.Copyable;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.Functions;
@@ -18,7 +19,7 @@ import java.util.List;
* time the training data is traversed or a tracking key such as the file
offset of the training
* record should be passed with each training example.
*/
-class CrossFoldLearner extends AbstractVectorClassifier implements
OnlineLearner, Comparable<CrossFoldLearner> {
+public class CrossFoldLearner extends AbstractVectorClassifier implements
OnlineLearner, Copyable<CrossFoldLearner>, Comparable<CrossFoldLearner> {
private static volatile int nextId = 0;
private final int id = nextId++;
@@ -29,8 +30,12 @@ class CrossFoldLearner extends AbstractV
// lambda, learningRate, perTermOffset, perTermExponent
private double[] parameters = new double[4];
+ private int numFeatures;
+ private PriorFunction prior;
CrossFoldLearner(int folds, int numCategories, int numFeatures,
PriorFunction prior) {
+ this.numFeatures = numFeatures;
+ this.prior = prior;
for (int i = 0; i < folds; i++) {
OnlineLogisticRegression model = new
OnlineLogisticRegression(numCategories, numFeatures, prior);
model.alpha(1).stepOffset(0).decayExponent(0);
@@ -175,4 +180,14 @@ class CrossFoldLearner extends AbstractV
public boolean equals(Object other) {
return other instanceof CrossFoldLearner && id == ((CrossFoldLearner)
other).id;
}
+
+ @Override
+ public CrossFoldLearner copy() {
+ CrossFoldLearner r = new CrossFoldLearner(models.size(), numCategories(),
numFeatures, prior);
+ r.models.clear();
+ for (OnlineLogisticRegression model : models) {
+ r.models.add(model.copy());
+ }
+ return r;
+ }
}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java?rev=987048&r1=987047&r2=987048&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java
Thu Aug 19 07:26:25 2010
@@ -17,6 +17,8 @@
package org.apache.mahout.classifier.sgd;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.ep.Copyable;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
@@ -25,7 +27,7 @@ import org.apache.mahout.math.Vector;
* Extends the basic on-line logistic regression learner with a specific set
of learning
* rate annealing schedules.
*/
-public class OnlineLogisticRegression extends AbstractOnlineLogisticRegression
{
+public class OnlineLogisticRegression extends AbstractOnlineLogisticRegression
implements Copyable<OnlineLogisticRegression> {
// these next two control decayFactor^steps exponential type of annealing
// learning rate and decay factor
private double mu0 = 1;
@@ -119,4 +121,11 @@ public class OnlineLogisticRegression ex
perTermAnnealingOffset = other.perTermAnnealingOffset;
}
+
+ @Override
+ public OnlineLogisticRegression copy() {
+ OnlineLogisticRegression r = new OnlineLogisticRegression(numCategories(),
numFeatures(), prior);
+ r.copyFrom(this);
+ return r;
+ }
}
Added: mahout/trunk/core/src/main/java/org/apache/mahout/ep/Copyable.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/ep/Copyable.java?rev=987048&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/ep/Copyable.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/ep/Copyable.java Thu Aug
19 07:26:25 2010
@@ -0,0 +1,5 @@
+package org.apache.mahout.ep;
+
+public interface Copyable<T> {
+ public T copy();
+}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/ep/State.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/ep/State.java?rev=987048&r1=987047&r2=987048&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/ep/State.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/ep/State.java Thu Aug 19
07:26:25 2010
@@ -7,7 +7,7 @@ import java.util.Random;
* Recorded step evolutionary optimization. You provide the value, this class
provides the
* mutation.
*/
-public class State implements Comparable<State> {
+public class State<T extends Copyable<T>> implements Copyable<State<T>>,
Comparable<State<T>> {
// object count is kept to break ties in comparison.
static volatile int objectCount = 0;
private Random gen = new Random();
@@ -29,6 +29,11 @@ public class State implements Comparable
// current fitness value
private double value;
+ private T payload;
+
+ private State() {
+ }
+
/**
* Invent a new state with no momentum (yet).
*/
@@ -40,17 +45,17 @@ public class State implements Comparable
}
/**
- * Deep clones a state, useful in mutation.
- *
- * @param params Current state
- * @param omni Current omni-directional mutation
- * @param step The step taken to get to this point
+ * Deep copies a state, useful in mutation.
*/
- private State(double[] params, double omni, double[] step, Mapping[] maps) {
- this.params = Arrays.copyOf(params, params.length);
- this.omni = omni;
- this.step = Arrays.copyOf(step, step.length);
- this.maps = Arrays.copyOf(maps, maps.length);
+ public State<T> copy() {
+ State<T> r = new State<T>();
+ r.params = Arrays.copyOf(this.params, this.params.length);
+ r.omni = this.omni;
+ r.step = Arrays.copyOf(this.step, this.step.length);
+ r.maps = Arrays.copyOf(this.maps, this.maps.length);
+ r.payload = this.payload.copy();
+ r.gen = this.gen;
+ return r;
}
/**
@@ -58,14 +63,14 @@ public class State implements Comparable
*
* @return A new state.
*/
- public State mutate() {
+ public State<T> mutate() {
double sum = 0;
for (double v : step) {
sum += v * v;
}
sum = Math.sqrt(sum);
double lambda = 0.9 + gen.nextGaussian();
- State r = new State(params, omni, step, maps);
+ State<T> r = this.copy();
r.omni = -Math.log(1 - gen.nextDouble()) * (0.9 * omni + sum / 10);
for (int i = 0; i < step.length; i++) {
r.step[i] = lambda * step[i] + r.omni * gen.nextGaussian();
@@ -150,4 +155,16 @@ public class State implements Comparable
public void setValue(double v) {
value = v;
}
+
+ public T getPayload() {
+ return payload;
+ }
+
+ public double getValue() {
+ return value;
+ }
+
+ public void setPayload(T payload) {
+ this.payload = payload;
+ }
}