Author: isabel Date: Sun Jan 10 11:15:04 2010 New Revision: 897617 URL: http://svn.apache.org/viewvc?rev=897617&view=rev Log: MAHOUT-85 Added implementation for perceptron and winnow trainer.
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/LinearModel.java lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/LinearTrainer.java lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/PerceptronTrainer.java lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/PerceptronTrainingMapper.java lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/TrainingException.java lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/WinnowTrainer.java lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/LinearModelTest.java lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/PerceptronTrainerTest.java lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/WinnowTrainerTest.java Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/LinearModel.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/LinearModel.java?rev=897617&view=auto ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/LinearModel.java (added) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/LinearModel.java Sun Jan 10 11:15:04 2010 @@ -0,0 +1,103 @@ +/* Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.classifier.discriminative; + +import org.apache.mahout.math.CardinalityException; +import org.apache.mahout.math.IndexException; +import org.apache.mahout.math.Vector; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Classifies a data point using a hyperplane. + */ +public class LinearModel { + /** Logger for this class. */ + private static final Logger LOG = LoggerFactory.getLogger(LinearModel.class); + /** Represents the direction of the hyperplane found during training.*/ + private Vector hyperplane; + /** Displacement of hyperplane from origin.*/ + private double bias; + /** Classification threshold. */ + private double threshold; + + /** + * Init a linear model with a hyperplane, distance and displacement. + * */ + public LinearModel(final Vector hyperplane, final double displacement, final double threshold) { + this.hyperplane = hyperplane; + this.bias = displacement; + this.threshold = threshold; + } + + /** + * Init a linear model with zero displacement and a threshold of 0.5. + * */ + public LinearModel(final Vector hyperplane) { + this(hyperplane, 0, 0.5); + } + + /** + * Classify a point to either belong to the class modeled by this linear model or not. + * @param dataPoint the data point to classify. + * @return returns true if data point should be classified as belonging to this model. + * */ + public boolean classify(final Vector dataPoint) throws CardinalityException, IndexException { + double product = this.hyperplane.dot(dataPoint); + if (LOG.isDebugEnabled()) { + LOG.debug("model: " + this + " product: " + product + " Bias: " + this.bias + " threshold: " + this.threshold); + } + return ((product + this.bias) > this.threshold); + } + + /** + * Update the hyperplane by adding delta. + * @param delta the delta to add to the hyperplane vector. + * */ + public void addDelta(final Vector delta) { + this.hyperplane = this.hyperplane.plus(delta); + } + + @Override + public String toString() { + StringBuilder builder = new StringBuilder("Model: "); + for (int i = 0; i < this.hyperplane.size(); i++) { + builder.append(" ").append(this.hyperplane.get(i)); + } + builder.append(" C: ").append(this.bias); + return builder.toString(); + } + + /** + * Shift the bias of the model. + * @param factor factor to multiply the bias by. + * */ + public synchronized void shiftBias(double factor) { + this.bias = this.bias + factor; + } + + /** + * Multiply the weight at index by delta. + * @param index the index of the element to update. + * @param delta the delta to multiply the element with. + * */ + public void timesDelta(int index, double delta) { + double element = this.hyperplane.get(index); + element = element * delta; + this.hyperplane.setQuick(index, element); + } +} Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/LinearTrainer.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/LinearTrainer.java?rev=897617&view=auto ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/LinearTrainer.java (added) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/LinearTrainer.java Sun Jan 10 11:15:04 2010 @@ -0,0 +1,129 @@ +/* Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.classifier.discriminative; + +import org.apache.mahout.math.CardinalityException; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.IndexException; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Implementors of this class need to provide a way to train linear + * discriminative classifiers. + * + * As this is just the reference implementation we assume that the dataset fits + * into main memory - this should be the first thing to change when switching to + * Hadoop. + */ +public abstract class LinearTrainer { + + /** Logger for this class. */ + private static final Logger LOG = LoggerFactory + .getLogger(LinearTrainer.class); + /** The model to train. */ + private LinearModel model; + + /** + * Initialize the trainer. Distance is initialized to cosine distance, all + * weights are represented through a dense vector. + * + * + * @param dimension + * number of expected features. + * @param threshold + * threshold to use for classification. + * @param init + * initial value of weight vector. + * @param initBias + * initial classification bias. + * */ + public LinearTrainer(final int dimension, final double threshold, + final double init, final double initBias) throws CardinalityException { + DenseVector initialWeights = new DenseVector(dimension); + initialWeights.assign(init); + this.model = new LinearModel(initialWeights, initBias, threshold); + } + + /** + * Initializes training. Runs through all data points in the training set and + * updates the weight vector whenever a classification error occurs. + * + * Can be called multiple times. + * + * @param dataset + * the dataset to train on. Each column is treated as point. + * @param labelset + * the set of labels, one for each data point. If the cardinalities + * of data- and labelset do not match, a CardinalityException is + * thrown + * */ + public void train(final Vector labelset, final Matrix dataset) + throws IndexException, CardinalityException, TrainingException { + if (labelset.size() != dataset.size()[1]) { + throw new CardinalityException(); + } + + boolean converged = false; + int iteration = 0; + while (!converged) { + if (iteration > 1000) + throw new TrainingException( + "Too many iterations needed to find hyperplane."); + + converged = true; + int columnCount = dataset.size()[1]; + for (int i = 0; i < columnCount; i++) { + Vector dataPoint = dataset.getColumn(i); + LOG.debug("Training point: " + dataPoint); + + synchronized (this.model) { + boolean prediction = model.classify(dataPoint); + double label = labelset.get(i); + if ((label <= 0 && prediction) || (label > 0 && !prediction)) { + LOG.debug("updating"); + converged = false; + update(label, dataPoint, this.model); + } + } + } + } + } + + /** + * Retrieves the trained model if called after train, otherwise the raw model. + * */ + public LinearModel getModel() { + return this.model; + } + + /** + * Implement this method to match your training strategy. + * + * @param model + * the model to update. + * @param label + * the target label of the wrongly classified data point. + * @param dataPoint + * the data point that was classified incorrectly. + * */ + protected abstract void update(final double label, final Vector dataPoint, + LinearModel model); + +} Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/PerceptronTrainer.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/PerceptronTrainer.java?rev=897617&view=auto ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/PerceptronTrainer.java (added) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/PerceptronTrainer.java Sun Jan 10 11:15:04 2010 @@ -0,0 +1,70 @@ +/* Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.classifier.discriminative; + +import org.apache.mahout.math.CardinalityException; +import org.apache.mahout.math.Vector; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Implements training accoring to the perceptron update rule. + * */ +public class PerceptronTrainer extends LinearTrainer { + + /** Logger for this class. */ + private static final Logger LOG = LoggerFactory + .getLogger(PerceptronTrainer.class); + /** Rate the model is to be updated with at each step. */ + private final double learningRate; + + /** + * {...@inheritdoc} + * + * @param learningRate + * rate to update the model with at each step. + * */ + public PerceptronTrainer(int dimension, double threshold, + double learningRate, double init, double initBias) throws CardinalityException { + super(dimension, threshold, init, initBias); + this.learningRate = learningRate; + } + + /** + * {...@inheritdoc} Perceptron update works such that in case the predicted label + * does not match the real label, the weight vector is updated as follows: In + * case the prediction was positive but should have been negative, the weight vector + * is set to the sum of weight vector and example (multiplied by the learning rate). + * + * In case the prediction was negative but should have been positive, the example + * vector (multiplied by the learning rate) is subtracted from the weight vector. + * */ + @Override + protected void update(final double label, final Vector dataPoint, + final LinearModel model) { + double factor = 1.0; + if (label == 0.0) + factor = -1.0; + + Vector updateVector = dataPoint.times(factor).times(this.learningRate); + LOG.debug("Updatevec: " + updateVector); + + model.addDelta(updateVector); + model.shiftBias(factor * this.learningRate); + LOG.debug(model.toString()); + } +} Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/PerceptronTrainingMapper.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/PerceptronTrainingMapper.java?rev=897617&view=auto ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/PerceptronTrainingMapper.java (added) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/PerceptronTrainingMapper.java Sun Jan 10 11:15:04 2010 @@ -0,0 +1,23 @@ +package org.apache.mahout.classifier.discriminative; + +import java.io.IOException; + +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.math.Vector; + +/** + * Mapper for perceptron training. Strategy for parallelization: + * 1) Train separate models on training data samples. Each training + * data sample must fit into main memory. + * 2) Average all trained models into one. + * */ +public class PerceptronTrainingMapper extends + Mapper<Boolean, Vector, Text, LinearModel> { + + @Override + protected void map(final Boolean key, final Vector value, Context context) throws IOException, InterruptedException { + + } + +} Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/TrainingException.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/TrainingException.java?rev=897617&view=auto ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/TrainingException.java (added) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/TrainingException.java Sun Jan 10 11:15:04 2010 @@ -0,0 +1,18 @@ +package org.apache.mahout.classifier.discriminative; + +/** + * This exception is thrown in case training fails. E.g. training with an algorithm + * that can find linear separating hyperplanes only on a training set that is not + * linearly separable. + * */ +public class TrainingException extends Exception { + /** Serialization id. */ + private static final long serialVersionUID = 388611231310145397L; + + /** + * Init with message string describing the cause of the exception. + * */ + public TrainingException(final String message) { + super(message); + } +} Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/WinnowTrainer.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/WinnowTrainer.java?rev=897617&view=auto ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/WinnowTrainer.java (added) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/WinnowTrainer.java Sun Jan 10 11:15:04 2010 @@ -0,0 +1,90 @@ +/* Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.classifier.discriminative; + +import java.util.Iterator; + +import org.apache.mahout.math.CardinalityException; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.Vector.Element; + +/** + * This class implements training according to the winnow update algorithm. + */ +public class WinnowTrainer extends LinearTrainer { + + /** Promotion step to multiply weights with on update. */ + private final double promotionStep; + + public WinnowTrainer(final int dimension, final double promotionStep, + final double threshold, final double init, final double initBias) { + super(dimension, threshold, init, initBias); + this.promotionStep = promotionStep; + } + + /** {...@inheritdoc} */ + public WinnowTrainer(final int dimension, final double promotionStep) + throws CardinalityException { + this(dimension, promotionStep, 0.5, 1, 0); + } + + /** + * Initializes with dimension and promotionStep of 2. + * + * @param dimension + * number of features. + * */ + public WinnowTrainer(final int dimension) { + this(dimension, 2); + } + + /** + * {...@inheritdoc} Winnow update works such that in case the predicted label + * does not match the real label, the weight vector is updated as follows: In + * case the prediction was positiv but should have been negative, all entries + * in the weight vector that correspond to non null features in the example + * are doubled. + * + * In case the prediction was negative but should have been positive, all + * entries in the weight vector that correspond to non null features in the + * example are halfed. + * */ + @Override + protected void update(final double label, final Vector dataPoint, + LinearModel model) { + if (label > 0) { + // case one + Vector updateVector = dataPoint.times(1 / this.promotionStep); + System.out.println("Winnow update positive: " + updateVector); + Iterator<Element> iter = updateVector.iterateNonZero(); + while (iter.hasNext()) { + Element element = iter.next(); + model.timesDelta(element.index(), element.get()); + } + } else { + // case two + Vector updateVector = dataPoint.times(1 / this.promotionStep); + System.out.println("Winnow update negative: " + updateVector); + Iterator<Element> iter = updateVector.iterateNonZero(); + while (iter.hasNext()) { + Element element = iter.next(); + model.timesDelta(element.index(), element.get()); + } + } + System.out.println(model); + } +} Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/LinearModelTest.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/LinearModelTest.java?rev=897617&view=auto ============================================================================== --- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/LinearModelTest.java (added) +++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/LinearModelTest.java Sun Jan 10 11:15:04 2010 @@ -0,0 +1,61 @@ +package org.apache.mahout.classifier.discriminative; + +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; + +import junit.framework.TestCase; + +public class LinearModelTest extends TestCase { + + private LinearModel model; + private Vector hyperplane; + + protected void setUp() throws Exception { + super.setUp(); + double[] values = {0.0, 1.0, 0.0, 1.0, 0.0}; + this.hyperplane = new DenseVector(values); + this.model = new LinearModel(this.hyperplane, 0.1, 0.5); + } + + public void testClassify() { + double[] valuesFalse = {1.0, 0.0, 1.0, 0.0, 1.0}; + Vector dataPointFalse = new DenseVector(valuesFalse); + assertFalse(this.model.classify(dataPointFalse)); + + double[] valuesTrue = {0.0, 1.0, 0.0, 1.0, 0.0}; + Vector dataPointTrue = new DenseVector(valuesTrue); + assertTrue(this.model.classify(dataPointTrue)); + } + + public void testAddDelta() { + double[] values = {1.0, -1.0, 1.0, -1.0, 1.0}; + this.model.addDelta(new DenseVector(values)); + + double[] valuesFalse = {1.0, 0.0, 1.0, 0.0, 1.0}; + Vector dataPointFalse = new DenseVector(valuesFalse); + assertTrue(this.model.classify(dataPointFalse)); + + double[] valuesTrue = {0.0, 1.0, 0.0, 1.0, 0.0}; + Vector dataPointTrue = new DenseVector(valuesTrue); + assertFalse(this.model.classify(dataPointTrue)); + } + + public void testTimesDelta() { + double[] values = {-1.0, -1.0, -1.0, -1.0, -1.0}; + this.model.addDelta(new DenseVector(values)); + double[] dotval = {-1.0, -1.0, -1.0, -1.0, -1.0}; + + for (int i = 0; i < dotval.length; i++) { + this.model.timesDelta(i, dotval[i]); + } + + double[] valuesFalse = {1.0, 0.0, 1.0, 0.0, 1.0}; + Vector dataPointFalse = new DenseVector(valuesFalse); + assertTrue(this.model.classify(dataPointFalse)); + + double[] valuesTrue = {0.0, 1.0, 0.0, 1.0, 0.0}; + Vector dataPointTrue = new DenseVector(valuesTrue); + assertFalse(this.model.classify(dataPointTrue)); + } + +} Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/PerceptronTrainerTest.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/PerceptronTrainerTest.java?rev=897617&view=auto ============================================================================== --- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/PerceptronTrainerTest.java (added) +++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/PerceptronTrainerTest.java Sun Jan 10 11:15:04 2010 @@ -0,0 +1,40 @@ +package org.apache.mahout.classifier.discriminative; + +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; + +import junit.framework.TestCase; + +public class PerceptronTrainerTest extends TestCase { + + private PerceptronTrainer trainer; + + protected void setUp() throws Exception { + super.setUp(); + trainer = new PerceptronTrainer(3, 0.5, 0.1, 1.0, 1.0); + } + + public void testUpdate() throws TrainingException { + double[] labels = { 1.0, 1.0, 1.0, 0.0 }; + Vector labelset = new DenseVector(labels); + double[][] values = new double[3][4]; + for (int i = 0; i < 3; i++) { + values[i][0] = 1.0; + values[i][1] = 1.0; + values[i][2] = 1.0; + values[i][3] = 1.0; + } + values[1][0] = 0.0; + values[2][0] = 0.0; + values[1][1] = 0.0; + values[2][2] = 0.0; + + Matrix dataset = new DenseMatrix(values); + this.trainer.train(labelset, dataset); + assertFalse(this.trainer.getModel().classify(dataset.getColumn(3))); + assertTrue(this.trainer.getModel().classify(dataset.getColumn(0))); + } + +} Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/WinnowTrainerTest.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/WinnowTrainerTest.java?rev=897617&view=auto ============================================================================== --- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/WinnowTrainerTest.java (added) +++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/WinnowTrainerTest.java Sun Jan 10 11:15:04 2010 @@ -0,0 +1,41 @@ +package org.apache.mahout.classifier.discriminative; + +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; + +import junit.framework.TestCase; + + +public class WinnowTrainerTest extends TestCase { + + private WinnowTrainer trainer; + + protected void setUp() throws Exception { + super.setUp(); + trainer = new WinnowTrainer(3); + } + + public void testUpdate() throws Exception { + double[] labels = { 0.0, 0.0, 0.0, 1.0 }; + Vector labelset = new DenseVector(labels); + double[][] values = new double[3][4]; + for (int i = 0; i < 3; i++) { + values[i][0] = 1.0; + values[i][1] = 1.0; + values[i][2] = 1.0; + values[i][3] = 1.0; + } + values[1][0] = 0.0; + values[2][0] = 0.0; + values[1][1] = 0.0; + values[2][2] = 0.0; + + Matrix dataset = new DenseMatrix(values); + trainer.train(labelset, dataset); + assertTrue(trainer.getModel().classify(dataset.getColumn(3))); + assertFalse(trainer.getModel().classify(dataset.getColumn(0))); + } + +}