Author: isabel
Date: Sun Jan 10 11:15:04 2010
New Revision: 897617

URL: http://svn.apache.org/viewvc?rev=897617&view=rev
Log:
MAHOUT-85 Added implementation for perceptron and winnow trainer.

Added:
    
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/
    
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/LinearModel.java
    
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/LinearTrainer.java
    
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/PerceptronTrainer.java
    
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/PerceptronTrainingMapper.java
    
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/TrainingException.java
    
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/WinnowTrainer.java
    
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/
    
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/LinearModelTest.java
    
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/PerceptronTrainerTest.java
    
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/WinnowTrainerTest.java

Added: 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/LinearModel.java
URL: 
http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/LinearModel.java?rev=897617&view=auto
==============================================================================
--- 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/LinearModel.java
 (added)
+++ 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/LinearModel.java
 Sun Jan 10 11:15:04 2010
@@ -0,0 +1,103 @@
+/* Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.discriminative;
+
+import org.apache.mahout.math.CardinalityException;
+import org.apache.mahout.math.IndexException;
+import org.apache.mahout.math.Vector;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Classifies a data point using a hyperplane.
+ */
+public class LinearModel {
+  /** Logger for this class. */
+  private static final Logger LOG = LoggerFactory.getLogger(LinearModel.class);
+  /** Represents the direction of the hyperplane found during training.*/
+  private Vector hyperplane;
+  /** Displacement of hyperplane from origin.*/
+  private double bias;
+  /** Classification threshold. */
+  private double threshold;
+
+  /**
+   * Init a linear model with a hyperplane, distance and displacement.
+   * */
+  public LinearModel(final Vector hyperplane, final double displacement, final 
double threshold) {
+    this.hyperplane = hyperplane;
+    this.bias = displacement;
+    this.threshold = threshold;
+  }
+
+  /**
+   * Init a linear model with zero displacement and a threshold of 0.5.
+   * */
+  public LinearModel(final Vector hyperplane) {
+    this(hyperplane, 0, 0.5);
+  }
+
+  /**
+   * Classify a point to either belong to the class modeled by this linear 
model or not.
+   * @param dataPoint the data point to classify.
+   * @return returns true if data point should be classified as belonging to 
this model.
+   * */
+  public boolean classify(final Vector dataPoint) throws CardinalityException, 
IndexException {
+    double product = this.hyperplane.dot(dataPoint);
+    if (LOG.isDebugEnabled()) {
+      LOG.debug("model: " + this + " product: " + product + " Bias: " + 
this.bias + " threshold: " + this.threshold);
+    }
+    return ((product + this.bias) > this.threshold);
+  }
+
+  /**
+   * Update the hyperplane by adding delta.
+   * @param delta the delta to add to the hyperplane vector.
+   * */
+  public void addDelta(final Vector delta) {
+         this.hyperplane = this.hyperplane.plus(delta);
+  }
+
+  @Override
+  public String toString() {
+    StringBuilder builder = new StringBuilder("Model: ");
+    for (int i = 0; i < this.hyperplane.size(); i++) {
+      builder.append("  ").append(this.hyperplane.get(i));
+    }
+    builder.append(" C: ").append(this.bias);
+    return builder.toString();
+  }
+
+  /**
+   * Shift the bias of the model.
+   * @param factor factor to multiply the bias by.
+   * */
+  public synchronized void shiftBias(double factor) {
+    this.bias = this.bias + factor;
+  }
+
+  /**
+   * Multiply the weight at index by delta.
+   * @param index the index of the element to update.
+   * @param delta the delta to multiply the element with.
+   * */
+  public void timesDelta(int index, double delta) {
+    double element = this.hyperplane.get(index);
+    element = element * delta;
+    this.hyperplane.setQuick(index, element);
+  }
+}

Added: 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/LinearTrainer.java
URL: 
http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/LinearTrainer.java?rev=897617&view=auto
==============================================================================
--- 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/LinearTrainer.java
 (added)
+++ 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/LinearTrainer.java
 Sun Jan 10 11:15:04 2010
@@ -0,0 +1,129 @@
+/* Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.discriminative;
+
+import org.apache.mahout.math.CardinalityException;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.IndexException;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Implementors of this class need to provide a way to train linear
+ * discriminative classifiers.
+ * 
+ * As this is just the reference implementation we assume that the dataset fits
+ * into main memory - this should be the first thing to change when switching 
to
+ * Hadoop.
+ */
+public abstract class LinearTrainer {
+
+  /** Logger for this class. */
+  private static final Logger LOG = LoggerFactory
+      .getLogger(LinearTrainer.class);
+  /** The model to train. */
+  private LinearModel model;
+
+  /**
+   * Initialize the trainer. Distance is initialized to cosine distance, all
+   * weights are represented through a dense vector.
+   * 
+   * 
+   * @param dimension
+   *          number of expected features.
+   * @param threshold
+   *          threshold to use for classification.
+   * @param init
+   *          initial value of weight vector.
+   * @param initBias
+   *          initial classification bias.
+   * */
+  public LinearTrainer(final int dimension, final double threshold,
+      final double init, final double initBias) throws CardinalityException {
+    DenseVector initialWeights = new DenseVector(dimension);
+    initialWeights.assign(init);
+    this.model = new LinearModel(initialWeights, initBias, threshold);
+  }
+
+  /**
+   * Initializes training. Runs through all data points in the training set and
+   * updates the weight vector whenever a classification error occurs.
+   * 
+   * Can be called multiple times.
+   * 
+   * @param dataset
+   *          the dataset to train on. Each column is treated as point.
+   * @param labelset
+   *          the set of labels, one for each data point. If the cardinalities
+   *          of data- and labelset do not match, a CardinalityException is
+   *          thrown
+   * */
+  public void train(final Vector labelset, final Matrix dataset)
+      throws IndexException, CardinalityException, TrainingException {
+    if (labelset.size() != dataset.size()[1]) {
+      throw new CardinalityException();
+    }
+
+    boolean converged = false;
+    int iteration = 0;
+    while (!converged) {
+      if (iteration > 1000)
+        throw new TrainingException(
+            "Too many iterations needed to find hyperplane.");
+
+      converged = true;
+      int columnCount = dataset.size()[1];
+      for (int i = 0; i < columnCount; i++) {
+        Vector dataPoint = dataset.getColumn(i);
+        LOG.debug("Training point: " + dataPoint);
+
+        synchronized (this.model) {
+          boolean prediction = model.classify(dataPoint);
+          double label = labelset.get(i);
+          if ((label <= 0 && prediction) || (label > 0 && !prediction)) {
+            LOG.debug("updating");
+            converged = false;
+            update(label, dataPoint, this.model);
+          }
+        }
+      }
+    }
+  }
+
+  /**
+   * Retrieves the trained model if called after train, otherwise the raw 
model.
+   * */
+  public LinearModel getModel() {
+    return this.model;
+  }
+
+  /**
+   * Implement this method to match your training strategy.
+   * 
+   * @param model
+   *          the model to update.
+   * @param label
+   *          the target label of the wrongly classified data point.
+   * @param dataPoint
+   *          the data point that was classified incorrectly.
+   * */
+  protected abstract void update(final double label, final Vector dataPoint,
+      LinearModel model);
+
+}

Added: 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/PerceptronTrainer.java
URL: 
http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/PerceptronTrainer.java?rev=897617&view=auto
==============================================================================
--- 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/PerceptronTrainer.java
 (added)
+++ 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/PerceptronTrainer.java
 Sun Jan 10 11:15:04 2010
@@ -0,0 +1,70 @@
+/* Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.discriminative;
+
+import org.apache.mahout.math.CardinalityException;
+import org.apache.mahout.math.Vector;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Implements training accoring to the perceptron update rule.
+ * */
+public class PerceptronTrainer extends LinearTrainer {
+
+  /** Logger for this class. */
+  private static final Logger LOG = LoggerFactory
+      .getLogger(PerceptronTrainer.class);
+  /** Rate the model is to be updated with at each step. */
+  private final double learningRate;
+
+  /**
+   * {...@inheritdoc}
+   * 
+   * @param learningRate
+   *          rate to update the model with at each step.
+   * */
+  public PerceptronTrainer(int dimension, double threshold,
+      double learningRate, double init, double initBias) throws 
CardinalityException {
+    super(dimension, threshold, init, initBias);
+    this.learningRate = learningRate;
+  }
+
+  /**
+   * {...@inheritdoc} Perceptron update works such that in case the predicted 
label
+   * does not match the real label, the weight vector is updated as follows: In
+   * case the prediction was positive but should have been negative, the 
weight vector
+   * is set to the sum of weight vector and example (multiplied by the 
learning rate).
+   * 
+   * In case the prediction was negative but should have been positive, the 
example
+   * vector (multiplied by the learning rate) is subtracted from the weight 
vector.
+   * */
+  @Override
+  protected void update(final double label, final Vector dataPoint,
+      final LinearModel model) {
+    double factor = 1.0;
+    if (label == 0.0)
+      factor = -1.0;
+
+    Vector updateVector = dataPoint.times(factor).times(this.learningRate);
+    LOG.debug("Updatevec: " + updateVector);
+
+    model.addDelta(updateVector);
+    model.shiftBias(factor * this.learningRate);
+    LOG.debug(model.toString());
+  }
+}

Added: 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/PerceptronTrainingMapper.java
URL: 
http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/PerceptronTrainingMapper.java?rev=897617&view=auto
==============================================================================
--- 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/PerceptronTrainingMapper.java
 (added)
+++ 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/PerceptronTrainingMapper.java
 Sun Jan 10 11:15:04 2010
@@ -0,0 +1,23 @@
+package org.apache.mahout.classifier.discriminative;
+
+import java.io.IOException;
+
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.math.Vector;
+
+/**
+ * Mapper for perceptron training. Strategy for parallelization:
+ * 1) Train separate models on training data samples. Each training
+ *    data sample must fit into main memory.
+ * 2) Average all trained models into one.
+ * */
+public class PerceptronTrainingMapper extends
+    Mapper<Boolean, Vector, Text, LinearModel> {
+
+  @Override
+  protected void map(final Boolean key, final Vector value, Context context) 
throws IOException, InterruptedException {
+    
+  }
+  
+}

Added: 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/TrainingException.java
URL: 
http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/TrainingException.java?rev=897617&view=auto
==============================================================================
--- 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/TrainingException.java
 (added)
+++ 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/TrainingException.java
 Sun Jan 10 11:15:04 2010
@@ -0,0 +1,18 @@
+package org.apache.mahout.classifier.discriminative;
+
+/**
+ * This exception is thrown in case training fails. E.g. training with an 
algorithm
+ * that can find linear separating hyperplanes only on a training set that is 
not
+ * linearly separable.
+ * */
+public class TrainingException extends Exception {
+  /** Serialization id. */
+  private static final long serialVersionUID = 388611231310145397L;
+
+  /**
+   * Init with message string describing the cause of the exception.
+   * */
+  public TrainingException(final String message) {
+    super(message);
+  }
+}

Added: 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/WinnowTrainer.java
URL: 
http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/WinnowTrainer.java?rev=897617&view=auto
==============================================================================
--- 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/WinnowTrainer.java
 (added)
+++ 
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/WinnowTrainer.java
 Sun Jan 10 11:15:04 2010
@@ -0,0 +1,90 @@
+/* Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.discriminative;
+
+import java.util.Iterator;
+
+import org.apache.mahout.math.CardinalityException;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+
+/**
+ * This class implements training according to the winnow update algorithm.
+ */
+public class WinnowTrainer extends LinearTrainer {
+
+  /** Promotion step to multiply weights with on update. */
+  private final double promotionStep;
+
+  public WinnowTrainer(final int dimension, final double promotionStep,
+      final double threshold, final double init, final double initBias) {
+    super(dimension, threshold, init, initBias);
+    this.promotionStep = promotionStep;
+  }
+
+  /** {...@inheritdoc} */
+  public WinnowTrainer(final int dimension, final double promotionStep)
+      throws CardinalityException {
+    this(dimension, promotionStep, 0.5, 1, 0);
+  }
+
+  /**
+   * Initializes with dimension and promotionStep of 2.
+   * 
+   * @param dimension
+   *          number of features.
+   * */
+  public WinnowTrainer(final int dimension) {
+    this(dimension, 2);
+  }
+
+  /**
+   * {...@inheritdoc} Winnow update works such that in case the predicted label
+   * does not match the real label, the weight vector is updated as follows: In
+   * case the prediction was positiv but should have been negative, all entries
+   * in the weight vector that correspond to non null features in the example
+   * are doubled.
+   * 
+   * In case the prediction was negative but should have been positive, all
+   * entries in the weight vector that correspond to non null features in the
+   * example are halfed.
+   * */
+  @Override
+  protected void update(final double label, final Vector dataPoint,
+      LinearModel model) {
+    if (label > 0) {
+      // case one
+      Vector updateVector = dataPoint.times(1 / this.promotionStep);
+      System.out.println("Winnow update positive: " + updateVector);
+      Iterator<Element> iter = updateVector.iterateNonZero();
+      while (iter.hasNext()) {
+        Element element = iter.next();
+        model.timesDelta(element.index(), element.get());
+      }
+    } else {
+      // case two
+      Vector updateVector = dataPoint.times(1 / this.promotionStep);
+      System.out.println("Winnow update negative: " + updateVector);
+      Iterator<Element> iter = updateVector.iterateNonZero();
+      while (iter.hasNext()) {
+        Element element = iter.next();
+        model.timesDelta(element.index(), element.get());
+      }
+    }
+    System.out.println(model);
+  }
+}

Added: 
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/LinearModelTest.java
URL: 
http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/LinearModelTest.java?rev=897617&view=auto
==============================================================================
--- 
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/LinearModelTest.java
 (added)
+++ 
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/LinearModelTest.java
 Sun Jan 10 11:15:04 2010
@@ -0,0 +1,61 @@
+package org.apache.mahout.classifier.discriminative;
+
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+
+import junit.framework.TestCase;
+
+public class LinearModelTest extends TestCase {
+
+  private LinearModel model;
+  private Vector hyperplane;
+
+  protected void setUp() throws Exception {
+    super.setUp();
+    double[] values = {0.0, 1.0, 0.0, 1.0, 0.0};
+    this.hyperplane = new DenseVector(values);
+    this.model = new LinearModel(this.hyperplane, 0.1, 0.5);
+  }
+
+  public void testClassify() {
+    double[] valuesFalse = {1.0, 0.0, 1.0, 0.0, 1.0};
+    Vector dataPointFalse = new DenseVector(valuesFalse);
+    assertFalse(this.model.classify(dataPointFalse));
+
+    double[] valuesTrue = {0.0, 1.0, 0.0, 1.0, 0.0};
+    Vector dataPointTrue = new DenseVector(valuesTrue);
+    assertTrue(this.model.classify(dataPointTrue));
+  }
+
+  public void testAddDelta() {
+    double[] values = {1.0, -1.0, 1.0, -1.0, 1.0};
+    this.model.addDelta(new DenseVector(values));
+
+    double[] valuesFalse = {1.0, 0.0, 1.0, 0.0, 1.0};
+    Vector dataPointFalse = new DenseVector(valuesFalse);
+    assertTrue(this.model.classify(dataPointFalse));
+
+    double[] valuesTrue = {0.0, 1.0, 0.0, 1.0, 0.0};
+    Vector dataPointTrue = new DenseVector(valuesTrue);
+    assertFalse(this.model.classify(dataPointTrue));
+  }
+
+  public void testTimesDelta() {
+    double[] values = {-1.0, -1.0, -1.0, -1.0, -1.0};
+    this.model.addDelta(new DenseVector(values));
+    double[] dotval = {-1.0, -1.0, -1.0, -1.0, -1.0};
+    
+    for (int i = 0; i < dotval.length; i++) {
+      this.model.timesDelta(i, dotval[i]);
+    }
+
+    double[] valuesFalse = {1.0, 0.0, 1.0, 0.0, 1.0};
+    Vector dataPointFalse = new DenseVector(valuesFalse);
+    assertTrue(this.model.classify(dataPointFalse));
+
+    double[] valuesTrue = {0.0, 1.0, 0.0, 1.0, 0.0};
+    Vector dataPointTrue = new DenseVector(valuesTrue);
+    assertFalse(this.model.classify(dataPointTrue));
+  }
+
+}

Added: 
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/PerceptronTrainerTest.java
URL: 
http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/PerceptronTrainerTest.java?rev=897617&view=auto
==============================================================================
--- 
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/PerceptronTrainerTest.java
 (added)
+++ 
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/PerceptronTrainerTest.java
 Sun Jan 10 11:15:04 2010
@@ -0,0 +1,40 @@
+package org.apache.mahout.classifier.discriminative;
+
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+
+import junit.framework.TestCase;
+
+public class PerceptronTrainerTest extends TestCase {
+
+  private PerceptronTrainer trainer;
+
+  protected void setUp() throws Exception {
+    super.setUp();
+    trainer = new PerceptronTrainer(3, 0.5, 0.1, 1.0, 1.0);
+  }
+
+  public void testUpdate() throws TrainingException {
+    double[] labels = { 1.0, 1.0, 1.0, 0.0 };
+    Vector labelset = new DenseVector(labels);
+    double[][] values = new double[3][4];
+    for (int i = 0; i < 3; i++) {
+      values[i][0] = 1.0;
+      values[i][1] = 1.0;
+      values[i][2] = 1.0;
+      values[i][3] = 1.0;
+    }
+    values[1][0] = 0.0;
+    values[2][0] = 0.0;
+    values[1][1] = 0.0;
+    values[2][2] = 0.0;
+
+    Matrix dataset = new DenseMatrix(values);
+    this.trainer.train(labelset, dataset);
+    assertFalse(this.trainer.getModel().classify(dataset.getColumn(3)));
+    assertTrue(this.trainer.getModel().classify(dataset.getColumn(0)));
+  }
+
+}

Added: 
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/WinnowTrainerTest.java
URL: 
http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/WinnowTrainerTest.java?rev=897617&view=auto
==============================================================================
--- 
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/WinnowTrainerTest.java
 (added)
+++ 
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/WinnowTrainerTest.java
 Sun Jan 10 11:15:04 2010
@@ -0,0 +1,41 @@
+package org.apache.mahout.classifier.discriminative;
+
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+
+import junit.framework.TestCase;
+
+
+public class WinnowTrainerTest extends TestCase {
+
+  private WinnowTrainer trainer;
+
+  protected void setUp() throws Exception {
+    super.setUp();
+    trainer = new WinnowTrainer(3);
+  }
+
+  public void testUpdate() throws Exception {
+    double[] labels = { 0.0, 0.0, 0.0, 1.0 };
+    Vector labelset = new DenseVector(labels);
+    double[][] values = new double[3][4];
+    for (int i = 0; i < 3; i++) {
+      values[i][0] = 1.0;
+      values[i][1] = 1.0;
+      values[i][2] = 1.0;
+      values[i][3] = 1.0;
+    }
+    values[1][0] = 0.0;
+    values[2][0] = 0.0;
+    values[1][1] = 0.0;
+    values[2][2] = 0.0;
+
+    Matrix dataset = new DenseMatrix(values);
+    trainer.train(labelset, dataset);
+    assertTrue(trainer.getModel().classify(dataset.getColumn(3)));
+    assertFalse(trainer.getModel().classify(dataset.getColumn(0)));
+  }
+
+}


Reply via email to