Author: tdunning
Date: Fri Aug 20 03:23:22 2010
New Revision: 987370

URL: http://svn.apache.org/viewvc?rev=987370&view=rev
Log:
MAHOUT-228 Cleans up initialization of ALR's

Modified:
    
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
    
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java?rev=987370&r1=987369&r2=987370&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
 Fri Aug 20 03:23:22 2010
@@ -39,13 +39,18 @@ public class AdaptiveLogisticRegression 
   List<TrainingExample> buffer = Lists.newArrayList();
   private EvolutionaryProcess<Wrapper> ep;
   private State<Wrapper> best;
-
-  public AdaptiveLogisticRegression(int poolSize, int numCategories, int 
numFeatures, PriorFunction prior) {
-    State<Wrapper> s0 = new State<Wrapper>(new double[2], 10);
+  private int threadCount = 20;
+  private int poolSize = 20;
+  private State<Wrapper> seed;
+  private int numFeatures;
+
+  public AdaptiveLogisticRegression(int numCategories, int numFeatures, 
PriorFunction prior) {
+    this.numFeatures = numFeatures;
+    seed = new State<Wrapper>(new double[2], 10);
     Wrapper w = new Wrapper(numCategories, numFeatures, prior);
-    s0.setPayload(w);
-    w.setMappings(s0);
-    ep = new EvolutionaryProcess<Wrapper>(20, poolSize, s0);
+    w.setMappings(seed);
+    seed.setPayload(w);
+    setPoolSize(poolSize);
   }
 
   @Override
@@ -113,6 +118,30 @@ public class AdaptiveLogisticRegression 
     this.evaluationInterval = interval;
   }
 
+  public void setPoolSize(int poolSize) {
+    this.poolSize = poolSize;
+    setupOptimizer(poolSize);
+  }
+
+  public void setThreadCount(int threadCount) {
+    this.threadCount = threadCount;
+    setupOptimizer(poolSize);
+  }
+
+  private void setupOptimizer(int poolSize) {
+    ep = new EvolutionaryProcess<Wrapper>(threadCount, poolSize, seed);
+  }
+
+  /**
+   * Returns the size of the internal feature vector.  Note that this is not 
the
+   * same as the number of distinct features, especially if feature hashing is
+   * being used.
+   * @return The internal feature vector size.
+   */
+  public int numFeatures() {
+    return numFeatures;
+  }
+
   /**
    * What is the AUC for the current best member of the population.  If no 
member is best,
    * usually because we haven't done any training yet, then the result is set 
to NaN.

Modified: 
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java?rev=987370&r1=987369&r2=987370&view=diff
==============================================================================
--- 
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java
 (original)
+++ 
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java
 Fri Aug 20 03:23:22 2010
@@ -37,7 +37,7 @@ public class AdaptiveLogisticRegressionT
       }
     }
 
-    AdaptiveLogisticRegression x = new AdaptiveLogisticRegression(20, 2, 200, 
new L1());
+    AdaptiveLogisticRegression x = new AdaptiveLogisticRegression(2, 200, new 
L1());
     x.setInterval(1000);
 
     final Normal norm = new Normal(0, 1, gen);


Reply via email to