Author: tdunning
Date: Wed Sep 29 16:41:17 2010
New Revision: 1002728
URL: http://svn.apache.org/viewvc?rev=1002728&view=rev
Log:
Adjust bufferSize in ALR when setting custom step sizes
Got rid of buffer in saved model
Re-enabled adaptiveLogisticRegressionRoundTrip test.
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java?rev=1002728&r1=1002727&r2=1002728&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
Wed Sep 29 16:41:17 2010
@@ -210,9 +210,7 @@ public class AdaptiveLogisticRegression
* @param interval Number of training examples to use in each epoch of
optimization.
*/
public void setInterval(int interval) {
- this.minInterval = interval;
- this.maxInterval = interval;
- this.cutoff = interval * (record / interval + 1);
+ setInterval(interval, interval);
}
/**
@@ -227,6 +225,8 @@ public class AdaptiveLogisticRegression
this.minInterval = Math.max(200, minInterval);
this.maxInterval = Math.max(200, maxInterval);
this.cutoff = minInterval * (record / minInterval + 1);
+ this.currentStep = minInterval;
+ bufferSize = Math.min(minInterval, bufferSize);
}
public void setPoolSize(int poolSize) {
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java?rev=1002728&r1=1002727&r2=1002728&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
Wed Sep 29 16:41:17 2010
@@ -350,10 +350,15 @@ public final class ModelSerializer {
Type epType = new
TypeToken<EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper>>()
{}.getType();
r.setEp(jdc.<EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper>>deserialize(x.get("ep"),
epType));
r.setSeed(jdc.<State<AdaptiveLogisticRegression.Wrapper>>deserialize(x.get("seed"),
stateType));
-
r.setBest(jdc.<State<AdaptiveLogisticRegression.Wrapper>>deserialize(x.get("best"),
stateType));
+ if (x.get("best") != null) {
+
r.setBest(jdc.<State<AdaptiveLogisticRegression.Wrapper>>deserialize(x.get("best"),
stateType));
+ }
-
r.setBuffer(jdc.<List<AdaptiveLogisticRegression.TrainingExample>>deserialize(x.get("buffer"),
- new
TypeToken<List<AdaptiveLogisticRegression.TrainingExample>>() {}.getType()));
+ if (x.get("buffer") != null) {
+
r.setBuffer(jdc.<List<AdaptiveLogisticRegression.TrainingExample>>deserialize(x.get("buffer"),
+ new TypeToken<List<AdaptiveLogisticRegression.TrainingExample>>() {
+ }.getType()));
+ }
return r;
}
@@ -362,8 +367,6 @@ public final class ModelSerializer {
JsonObject r = new JsonObject();
r.add("ep", jsc.serialize(x.getEp(),
new
TypeToken<EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper>>()
{}.getType()));
- r.add("buffer", jsc.serialize(x.getBuffer(),
- new TypeToken<List<AdaptiveLogisticRegression.TrainingExample>>()
{}.getType()));
r.add("minInterval", jsc.serialize(x.getMinInterval()));
r.add("maxInterval", jsc.serialize(x.getMaxInterval()));
Type stateType = new
TypeToken<State<AdaptiveLogisticRegression.Wrapper>>() {}.getType();
Modified:
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java?rev=1002728&r1=1002727&r2=1002728&view=diff
==============================================================================
---
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java
(original)
+++
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java
Wed Sep 29 16:41:17 2010
@@ -140,7 +140,7 @@ public final class ModelSerializerTest e
assertTrue(auc2 > auc1);
}
-// @Test
+ @Test
public void adaptiveLogisticRegressionRoundTrip() {
AdaptiveLogisticRegression learner = new AdaptiveLogisticRegression(2, 5,
new L1());
learner.setInterval(200);