Author: tdunning
Date: Wed Sep 29 21:52:20 2010
New Revision: 1002863

URL: http://svn.apache.org/viewvc?rev=1002863&view=rev
Log:
Adjusted class structure and API for classifiers to allow a group key to passed 
down into the training.
Also broke out the gradient computation to allow variant training objectives 
such as AUC in addition to logistic loss.
Added the gradient to serialized content
Made a new PolymorphicTypeAdapter to limit amount of repeated code

Added:
    
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/Gradient.java
Modified:
    
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/OnlineLearner.java
    
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
    
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
    
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
    
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
    
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/OnlineLearner.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/OnlineLearner.java?rev=1002863&r1=1002862&r2=1002863&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/OnlineLearner.java 
(original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/OnlineLearner.java 
Wed Sep 29 21:52:20 2010
@@ -38,8 +38,30 @@ public interface OnlineLearner {
    * the original data record in a data file.
    *
    * @param trackingKey The tracking key for this training example.
+   * @param groupKey     An optional value that allows examples to be grouped 
in the computation of
+   * the update to the model.
    * @param actual   The value of the target variable.  This value should be 
in the half-open
- *                 interval [0..n) where n is the number of target categories.
+   *                 interval [0..n) where n is the number of target 
categories.
+   * @param instance The feature vector for this example.
+   */
+  void train(long trackingKey, String groupKey, int actual, Vector instance);
+
+  /**
+   * Updates the model using a particular target variable value and a feature 
vector.
+   * <p/>
+   * There may an assumption that if multiple passes through the training data 
are necessary that
+   * the tracking key for a record will be the same for each pass and that 
there will be a
+   * relatively large number of distinct tracking keys and that the low-order 
bits of the tracking
+   * keys will not correlate with any of the input variables.  This tracking 
key is used to assign
+   * training examples to different test/training splits.
+   * <p/>
+   * Examples of useful tracking keys include id-numbers for the training 
records derived from
+   * a database id for the base table from the which the record is derived, or 
the offset of
+   * the original data record in a data file.
+   *
+   * @param trackingKey The tracking key for this training example.
+   * @param actual   The value of the target variable.  This value should be 
in the half-open
+   *                 interval [0..n) where n is the number of target 
categories.
    * @param instance The feature vector for this example.
    */
   void train(long trackingKey, int actual, Vector instance);

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java?rev=1002863&r1=1002862&r2=1002863&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
 Wed Sep 29 21:52:20 2010
@@ -60,6 +60,9 @@ public abstract class AbstractOnlineLogi
   // can we ignore any further regularization when doing classification?
   private boolean sealed = false;
 
+  // by default we don't do any fancy training
+  private Gradient gradient = new DefaultGradient();
+
   /**
    * Chainable configuration option.
    *
@@ -149,11 +152,7 @@ public abstract class AbstractOnlineLogi
   }
 
   @Override
-  public void train(long trackingKey, int actual, Vector instance) {
-    train(actual, instance);
-  }
-
-  public void train(int actual, Vector instance) {
+  public void train(long trackingKey, String groupKey, int actual, Vector 
instance) {
     unseal();
 
     double learningRate = currentLearningRate();
@@ -165,12 +164,9 @@ public abstract class AbstractOnlineLogi
     Vector v = classify(instance);
 
     // update each row of coefficients according to result
+    Vector gradient = this.gradient.apply(groupKey, actual, v);
     for (int i = 0; i < numCategories - 1; i++) {
-      double gradientBase = -v.getQuick(i);
-      // the use of i+1 instead of i here is what makes the 0-th category be 
the one without coefficients
-      if ((i + 1) == actual) {
-        gradientBase += 1;
-      }
+      double gradientBase = gradient.get(i);
 
       // then we apply the gradientBase to the resulting element.
       Iterator<Vector.Element> nonZeros = instance.iterateNonZero();
@@ -195,6 +191,16 @@ public abstract class AbstractOnlineLogi
 
   }
 
+  @Override
+  public void train(long trackingKey, int actual, Vector instance) {
+    train(trackingKey, null, actual, instance);
+  }
+
+  @Override
+  public void train(int actual, Vector instance) {
+    train(0, null, actual, instance);
+  }
+
   public void regularize(Vector instance) {
     if (updateSteps == null || isSealed()) {
       return;
@@ -230,6 +236,10 @@ public abstract class AbstractOnlineLogi
     this.prior = prior;
   }
 
+  public void setGradient(Gradient gradient) {
+    this.gradient = gradient;
+  }
+
   public PriorFunction getPrior() {
     return prior;
   }
@@ -308,4 +318,25 @@ public abstract class AbstractOnlineLogi
     });
     return k < 1;
   }
+
+  public static class DefaultGradient implements Gradient {
+    /**
+     * Provides a default gradient computation useful for logistic regression. 
 This
+     * can be over-ridden to incorporate AUC driven learning.
+     * <p>
+     * See 
www.eecs.tufts.edu/~dsculley/papers/combined-ranking-and-regression.pdf
+     * @param groupKey     A grouping key to allow per-something AUC loss to 
be used for training.
+     *...@param actual       The target variable value.
+     * @param v            The current score vector.   @return
+     */
+    @Override
+    public final Vector apply(String groupKey, int actual, Vector v) {
+      Vector r = v.like();
+      if (actual != 0) {
+        r.setQuick(actual - 1, 1);
+      }
+      r.assign(v, Functions.MINUS);
+      return r;
+    }
+  }
 }

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java?rev=1002863&r1=1002862&r2=1002863&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
 Wed Sep 29 21:52:20 2010
@@ -69,19 +69,18 @@ public class AdaptiveLogisticRegression 
   private int currentStep = 1000;
   private int bufferSize = 1000;
 
-  // transient here is a signal to GSON not to serialize pending records
-  private transient List<TrainingExample> buffer = Lists.newArrayList();
+  private List<TrainingExample> buffer = Lists.newArrayList();
   private EvolutionaryProcess<Wrapper> ep;
   private State<Wrapper> best;
   private int threadCount = 20;
   private int poolSize = 20;
   private State<Wrapper> seed;
   private int numFeatures;
-  //private double averagingWindow;
 
   private boolean freezeSurvivors = true;
 
   // for GSON
+  @SuppressWarnings({"UnusedDeclaration"})
   private AdaptiveLogisticRegression() {
   }
 
@@ -96,7 +95,7 @@ public class AdaptiveLogisticRegression 
 
   @Override
   public void train(int actual, Vector instance) {
-    train(record, actual, instance);
+    train(record, null, actual, instance);
   }
 
   @Override
@@ -104,7 +103,7 @@ public class AdaptiveLogisticRegression 
     train(trackingKey, null, actual, instance);
   }
 
-
+  @Override
   public void train(long trackingKey, String groupKey, int actual, Vector 
instance) {
     record++;
 
@@ -424,6 +423,7 @@ public class AdaptiveLogisticRegression 
     private Vector instance;
 
     // for GSON
+    @SuppressWarnings({"UnusedDeclaration"})
     private TrainingExample() {
     }
 

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java?rev=1002863&r1=1002862&r2=1002863&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
 Wed Sep 29 21:52:20 2010
@@ -90,7 +90,7 @@ public class CrossFoldLearner extends Ab
   // -------- training methods
   @Override
   public void train(int actual, Vector instance) {
-    train(record, actual, instance);
+    train(record, null, actual, instance);
   }
 
   @Override
@@ -98,6 +98,7 @@ public class CrossFoldLearner extends Ab
     train(trackingKey, null, actual, instance);
   }
 
+  @Override
   public void train(long trackingKey, String groupKey, int actual, Vector 
instance) {
     record++;
     int k = 0;
@@ -113,7 +114,7 @@ public class CrossFoldLearner extends Ab
           auc.addSample(actual, groupKey, v.get(1));
         }
       } else {
-        model.train(trackingKey, actual, instance);
+        model.train(trackingKey, groupKey, actual, instance);
       }
       k++;
     }

Added: 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/Gradient.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/Gradient.java?rev=1002863&view=auto
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/Gradient.java 
(added)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/Gradient.java 
Wed Sep 29 21:52:20 2010
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.math.Vector;
+
+/**
+ * Provides the ability to inject a gradient into the SGD logistic regresion.
+ * Typical uses of this are to use a ranking score such as AUC instead of a
+ * normal loss function.
+ */
+public interface Gradient {
+  Vector apply(String groupKey, int actual, Vector v);
+}

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java?rev=1002863&r1=1002862&r2=1002863&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
 Wed Sep 29 21:52:20 2010
@@ -55,9 +55,10 @@ public final class ModelSerializer {
   static {
     final GsonBuilder gb = new GsonBuilder();
     gb.registerTypeAdapter(AdaptiveLogisticRegression.class, new 
AdaptiveLogisticRegressionTypeAdapter());
-    gb.registerTypeAdapter(Mapping.class, new MappingTypeAdapter());
-    gb.registerTypeAdapter(PriorFunction.class, new PriorTypeAdapter());
-    gb.registerTypeAdapter(OnlineAuc.class, new AucTypeAdapter());
+    gb.registerTypeAdapter(Mapping.class, new 
PolymorphicTypeAdapter<Mapping>());
+    gb.registerTypeAdapter(PriorFunction.class, new 
PolymorphicTypeAdapter<PriorFunction>());
+    gb.registerTypeAdapter(OnlineAuc.class, new 
PolymorphicTypeAdapter<OnlineAuc>());
+    gb.registerTypeAdapter(Gradient.class, new 
PolymorphicTypeAdapter<Gradient>());
     gb.registerTypeAdapter(CrossFoldLearner.class, new 
CrossFoldLearnerTypeAdapter());
     gb.registerTypeAdapter(Vector.class, new VectorTypeAdapter());
     gb.registerTypeAdapter(Matrix.class, new MatrixTypeAdapter());
@@ -92,69 +93,22 @@ public final class ModelSerializer {
    * Reads a model in JSON format.
    *
    * @param in Where to read the model from.
-   * @param clazz
+   * @param clazz The class of the object we expect to read.
    * @return The LogisticModelParameters object that we read.
    */
   public static AdaptiveLogisticRegression loadJsonFrom(Reader in, 
Class<AdaptiveLogisticRegression> clazz) {
     return gson().fromJson(in, clazz);
   }
 
-  private static class MappingTypeAdapter implements 
JsonDeserializer<Mapping>, JsonSerializer<Mapping> {
+  private static class PolymorphicTypeAdapter<T> implements 
JsonDeserializer<T>, JsonSerializer<T> {
     @Override
-    public Mapping deserialize(JsonElement jsonElement,
-                               Type type,
-                               JsonDeserializationContext 
jsonDeserializationContext) {
-      JsonObject x = jsonElement.getAsJsonObject();
-      try {
-        return jsonDeserializationContext.deserialize(x.get("value"), 
Class.forName(x.get("class").getAsString()));
-      } catch (ClassNotFoundException e) {
-        throw new IllegalStateException("Can't understand serialized data, 
found bad type: "
-            + x.get("class").getAsString());
-      }
-    }
-
-    @Override
-    public JsonElement serialize(Mapping mapping, Type type, 
JsonSerializationContext jsonSerializationContext) {
-      JsonObject r = new JsonObject();
-      r.add("class", new JsonPrimitive(mapping.getClass().getName()));
-      r.add("value", jsonSerializationContext.serialize(mapping));
-      return r;
-    }
-  }
-
-  private static class PriorTypeAdapter implements 
JsonDeserializer<PriorFunction>, JsonSerializer<PriorFunction> {
-    @Override
-    public PriorFunction deserialize(JsonElement jsonElement,
-                                     Type type,
-                                     JsonDeserializationContext 
jsonDeserializationContext) {
-      JsonObject x = jsonElement.getAsJsonObject();
-      try {
-        return jsonDeserializationContext.deserialize(x.get("value"), 
Class.forName(x.get("class").getAsString()));
-      } catch (ClassNotFoundException e) {
-        throw new IllegalStateException("Can't understand serialized data, 
found bad type: "
-            + x.get("class").getAsString());
-      }
-    }
-
-    @Override
-    public JsonElement serialize(PriorFunction priorFunction,
-                                 Type type,
-                                 JsonSerializationContext 
jsonSerializationContext) {
-      JsonObject r = new JsonObject();
-      r.add("class", new JsonPrimitive(priorFunction.getClass().getName()));
-      r.add("value", jsonSerializationContext.serialize(priorFunction));
-      return r;
-    }
-  }
-
-  private static class AucTypeAdapter implements JsonDeserializer<OnlineAuc>, 
JsonSerializer<OnlineAuc> {
-    @Override
-    public OnlineAuc deserialize(JsonElement jsonElement,
+    public T deserialize(JsonElement jsonElement,
                                      Type type,
                                      JsonDeserializationContext 
jsonDeserializationContext) {
       JsonObject x = jsonElement.getAsJsonObject();
       try {
-        return jsonDeserializationContext.deserialize(x.get("value"), 
Class.forName(x.get("class").getAsString()));
+        //noinspection RedundantTypeArguments
+        return jsonDeserializationContext.<T>deserialize(x.get("value"), 
Class.forName(x.get("class").getAsString()));
       } catch (ClassNotFoundException e) {
         throw new IllegalStateException("Can't understand serialized data, 
found bad type: "
             + x.get("class").getAsString());
@@ -162,12 +116,12 @@ public final class ModelSerializer {
     }
 
     @Override
-    public JsonElement serialize(OnlineAuc auc,
+    public JsonElement serialize(T x,
                                  Type type,
                                  JsonSerializationContext 
jsonSerializationContext) {
       JsonObject r = new JsonObject();
-      r.add("class", new JsonPrimitive(auc.getClass().getName()));
-      r.add("value", jsonSerializationContext.serialize(auc));
+      r.add("class", new JsonPrimitive(x.getClass().getName()));
+      r.add("value", jsonSerializationContext.serialize(x));
       return r;
     }
   }
@@ -196,6 +150,59 @@ public final class ModelSerializer {
     }
   }
 
+  private static class AdaptiveLogisticRegressionTypeAdapter implements 
JsonSerializer<AdaptiveLogisticRegression>,
+    JsonDeserializer<AdaptiveLogisticRegression> {
+
+    @Override
+    public AdaptiveLogisticRegression deserialize(JsonElement element, Type 
type, JsonDeserializationContext jdc) {
+      JsonObject x = element.getAsJsonObject();
+      AdaptiveLogisticRegression r =
+          new AdaptiveLogisticRegression(x.get("numCategories").getAsInt(),
+                                         x.get("numFeatures").getAsInt(),
+                                         
jdc.<PriorFunction>deserialize(x.get("prior"), PriorFunction.class));
+      Type stateType = new 
TypeToken<State<AdaptiveLogisticRegression.Wrapper>>() {}.getType();
+      if (x.get("evaluationInterval")!=null) {
+        r.setInterval(x.get("evaluationInterval").getAsInt());
+      } else {
+        r.setInterval(x.get("minInterval").getAsInt(), 
x.get("minInterval").getAsInt());
+      }
+      r.setRecord(x.get("record").getAsInt());
+
+      Type epType = new 
TypeToken<EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper>>() 
{}.getType();
+      
r.setEp(jdc.<EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper>>deserialize(x.get("ep"),
 epType));
+      
r.setSeed(jdc.<State<AdaptiveLogisticRegression.Wrapper>>deserialize(x.get("seed"),
 stateType));
+      if (x.get("best") != null) {
+        
r.setBest(jdc.<State<AdaptiveLogisticRegression.Wrapper>>deserialize(x.get("best"),
 stateType));
+      }
+
+      if (x.get("buffer") != null) {
+        
r.setBuffer(jdc.<List<AdaptiveLogisticRegression.TrainingExample>>deserialize(x.get("buffer"),
+          new TypeToken<List<AdaptiveLogisticRegression.TrainingExample>>() {
+          }.getType()));
+      }
+      return r;
+    }
+
+    @Override
+    public JsonElement serialize(AdaptiveLogisticRegression x, Type type, 
JsonSerializationContext jsc) {
+      JsonObject r = new JsonObject();
+      r.add("ep", jsc.serialize(x.getEp(),
+          new 
TypeToken<EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper>>() 
{}.getType()));
+      r.add("minInterval", jsc.serialize(x.getMinInterval()));
+      r.add("maxInterval", jsc.serialize(x.getMaxInterval()));
+      Type stateType = new 
TypeToken<State<AdaptiveLogisticRegression.Wrapper>>() {}.getType();
+      r.add("best", jsc.serialize(x.getBest(), stateType));
+      r.add("numFeatures", jsc.serialize(x.getNumFeatures()));
+      r.add("numCategories", jsc.serialize(x.getNumCategories()));
+      PriorFunction prior = x.getPrior();
+      JsonElement pf = jsc.serialize(prior, PriorFunction.class);
+      r.add("prior", pf);
+      r.add("record", jsc.serialize(x.getRecord()));
+      r.add("seed", jsc.serialize(x.getSeed(), stateType));
+      return r;
+    }
+  }
+
   /**
    * Tells GSON how to (de)serialize a Mahout matrix.  We assume on 
deserialization that the matrix
    * is dense.
@@ -329,59 +336,6 @@ public final class ModelSerializer {
     }
   }
 
-  private static class AdaptiveLogisticRegressionTypeAdapter implements 
JsonSerializer<AdaptiveLogisticRegression>,
-    JsonDeserializer<AdaptiveLogisticRegression> {
-
-    @Override
-    public AdaptiveLogisticRegression deserialize(JsonElement element, Type 
type, JsonDeserializationContext jdc) {
-      JsonObject x = element.getAsJsonObject();
-      AdaptiveLogisticRegression r =
-          new AdaptiveLogisticRegression(x.get("numCategories").getAsInt(),
-                                         x.get("numFeatures").getAsInt(),
-                                         
jdc.<PriorFunction>deserialize(x.get("prior"), PriorFunction.class));
-      Type stateType = new 
TypeToken<State<AdaptiveLogisticRegression.Wrapper>>() {}.getType();
-      if (x.get("evaluationInterval")!=null) {
-        r.setInterval(x.get("evaluationInterval").getAsInt());
-      } else {
-        r.setInterval(x.get("minInterval").getAsInt(), 
x.get("minInterval").getAsInt());
-      }
-      r.setRecord(x.get("record").getAsInt());
-
-      Type epType = new 
TypeToken<EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper>>() 
{}.getType();
-      
r.setEp(jdc.<EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper>>deserialize(x.get("ep"),
 epType));
-      
r.setSeed(jdc.<State<AdaptiveLogisticRegression.Wrapper>>deserialize(x.get("seed"),
 stateType));
-      if (x.get("best") != null) {
-        
r.setBest(jdc.<State<AdaptiveLogisticRegression.Wrapper>>deserialize(x.get("best"),
 stateType));
-      }
-
-      if (x.get("buffer") != null) {
-        
r.setBuffer(jdc.<List<AdaptiveLogisticRegression.TrainingExample>>deserialize(x.get("buffer"),
-          new TypeToken<List<AdaptiveLogisticRegression.TrainingExample>>() {
-          }.getType()));
-      }
-      return r;
-    }
-
-    @Override
-    public JsonElement serialize(AdaptiveLogisticRegression x, Type type, 
JsonSerializationContext jsc) {
-      JsonObject r = new JsonObject();
-      r.add("ep", jsc.serialize(x.getEp(),
-          new 
TypeToken<EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper>>() 
{}.getType()));
-      r.add("minInterval", jsc.serialize(x.getMinInterval()));
-      r.add("maxInterval", jsc.serialize(x.getMaxInterval()));
-      Type stateType = new 
TypeToken<State<AdaptiveLogisticRegression.Wrapper>>() {}.getType();
-      r.add("best", jsc.serialize(x.getBest(), stateType));
-      r.add("numFeatures", jsc.serialize(x.getNumFeatures()));
-      r.add("numCategories", jsc.serialize(x.getNumCategories()));
-      PriorFunction prior = x.getPrior();
-      JsonElement pf = jsc.serialize(prior, PriorFunction.class);
-      r.add("prior", pf);
-      r.add("record", jsc.serialize(x.getRecord()));
-      r.add("seed", jsc.serialize(x.getSeed(), stateType));
-      return r;
-    }
-  }
-
   private static class EvolutionaryProcessTypeAdapter implements
     InstanceCreator<EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper>>,
     JsonDeserializer<EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper>>,
@@ -434,5 +388,4 @@ public final class ModelSerializer {
     }
     return params;
   }
-
 }

Modified: 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java?rev=1002863&r1=1002862&r2=1002863&view=diff
==============================================================================
--- 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java
 (original)
+++ 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java
 Wed Sep 29 21:52:20 2010
@@ -66,7 +66,7 @@ public class LogisticModelParameters {
    * the model.  If the input isn't CSV, then calling setTargetCategories 
before calling saveTo will
    * suffice.
    *
-   * @return
+   * @return The CsvRecordFactory.
    */
   public CsvRecordFactory getCsvRecordFactory() {
     if (csv == null) {
@@ -83,7 +83,7 @@ public class LogisticModelParameters {
   /**
    * Creates a logistic regression trainer using the parameters collected here.
    *
-   * @return
+   * @return The newly allocated OnlineLogisticRegression object
    */
   public OnlineLogisticRegression createRegression() {
     if (lr == null) {
@@ -113,7 +113,7 @@ public class LogisticModelParameters {
    * trainer and the dictionary for the target categories.
    *
    * @param out Where to write the model.
-   * @throws IOException
+   * @throws IOException If we can't write the model.
    */
   public void saveTo(Writer out) throws IOException {
     if (lr != null) {
@@ -180,7 +180,7 @@ public class LogisticModelParameters {
   /**
    * Sets the target variable.  If you don't use the CSV record factory, then 
this is irrelevant.
    *
-   * @param targetVariable
+   * @param targetVariable The name of the target variable.
    */
   public void setTargetVariable(String targetVariable) {
     this.targetVariable = targetVariable;
@@ -189,7 +189,7 @@ public class LogisticModelParameters {
   /**
    * Sets the number of target categories to be considered.
    *
-   * @param maxTargetCategories
+   * @param maxTargetCategories The number of target categories.
    */
   public void setMaxTargetCategories(int maxTargetCategories) {
     this.maxTargetCategories = maxTargetCategories;


Reply via email to