Author: tdunning
Date: Fri Aug 20 03:23:22 2010
New Revision: 987370
URL: http://svn.apache.org/viewvc?rev=987370&view=rev
Log:
MAHOUT-228 Cleans up initialization of ALR's
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java?rev=987370&r1=987369&r2=987370&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
Fri Aug 20 03:23:22 2010
@@ -39,13 +39,18 @@ public class AdaptiveLogisticRegression
List<TrainingExample> buffer = Lists.newArrayList();
private EvolutionaryProcess<Wrapper> ep;
private State<Wrapper> best;
-
- public AdaptiveLogisticRegression(int poolSize, int numCategories, int
numFeatures, PriorFunction prior) {
- State<Wrapper> s0 = new State<Wrapper>(new double[2], 10);
+ private int threadCount = 20;
+ private int poolSize = 20;
+ private State<Wrapper> seed;
+ private int numFeatures;
+
+ public AdaptiveLogisticRegression(int numCategories, int numFeatures,
PriorFunction prior) {
+ this.numFeatures = numFeatures;
+ seed = new State<Wrapper>(new double[2], 10);
Wrapper w = new Wrapper(numCategories, numFeatures, prior);
- s0.setPayload(w);
- w.setMappings(s0);
- ep = new EvolutionaryProcess<Wrapper>(20, poolSize, s0);
+ w.setMappings(seed);
+ seed.setPayload(w);
+ setPoolSize(poolSize);
}
@Override
@@ -113,6 +118,30 @@ public class AdaptiveLogisticRegression
this.evaluationInterval = interval;
}
+ public void setPoolSize(int poolSize) {
+ this.poolSize = poolSize;
+ setupOptimizer(poolSize);
+ }
+
+ public void setThreadCount(int threadCount) {
+ this.threadCount = threadCount;
+ setupOptimizer(poolSize);
+ }
+
+ private void setupOptimizer(int poolSize) {
+ ep = new EvolutionaryProcess<Wrapper>(threadCount, poolSize, seed);
+ }
+
+ /**
+ * Returns the size of the internal feature vector. Note that this is not
the
+ * same as the number of distinct features, especially if feature hashing is
+ * being used.
+ * @return The internal feature vector size.
+ */
+ public int numFeatures() {
+ return numFeatures;
+ }
+
/**
* What is the AUC for the current best member of the population. If no
member is best,
* usually because we haven't done any training yet, then the result is set
to NaN.
Modified:
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java?rev=987370&r1=987369&r2=987370&view=diff
==============================================================================
---
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java
(original)
+++
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java
Fri Aug 20 03:23:22 2010
@@ -37,7 +37,7 @@ public class AdaptiveLogisticRegressionT
}
}
- AdaptiveLogisticRegression x = new AdaptiveLogisticRegression(20, 2, 200,
new L1());
+ AdaptiveLogisticRegression x = new AdaptiveLogisticRegression(2, 200, new
L1());
x.setInterval(1000);
final Normal norm = new Normal(0, 1, gen);