Author: smarthi Date: Thu Dec 19 19:29:02 2013 New Revision: 1552403 URL: http://svn.apache.org/r1552403 Log: MAHOUT-1265: Multilayer Perceptron
Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/MultilayerPerceptron.java mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetwork.java mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetworkFunctions.java mahout/trunk/core/src/test/java/org/apache/mahout/classifier/mlp/ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/mlp/TestMultilayerPerceptron.java mahout/trunk/core/src/test/java/org/apache/mahout/classifier/mlp/TestNeuralNetwork.java Modified: mahout/trunk/CHANGELOG Modified: mahout/trunk/CHANGELOG URL: http://svn.apache.org/viewvc/mahout/trunk/CHANGELOG?rev=1552403&r1=1552402&r2=1552403&view=diff ============================================================================== --- mahout/trunk/CHANGELOG (original) +++ mahout/trunk/CHANGELOG Thu Dec 19 19:29:02 2013 @@ -76,6 +76,8 @@ Release 0.9 - unreleased MAHOUT-1275: Dropped bz2 distribution format for source and binaries (sslavic) + MAHOUT-1265: Multilayer Perceptron (Yexi Jiang via smarthi) + MAHOUT-1261: TasteHadoopUtils.idToIndex can return an int that has size Integer.MAX_VALUE (Carl Clark, smarthi) MAHOUT-1249: Clusterdumper/loadTermDictionary crashes when highest index in (sparse) dictionary vector is larger than dictionary vector size (Andrew Musselman via smarthi) Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/MultilayerPerceptron.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/MultilayerPerceptron.java?rev=1552403&view=auto ============================================================================== --- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/MultilayerPerceptron.java (added) +++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/MultilayerPerceptron.java Thu Dec 19 19:29:02 2013 @@ -0,0 +1,88 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.classifier.mlp; + +import org.apache.mahout.classifier.OnlineLearner; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; + +/** + * A Multilayer Perceptron (MLP) is a kind of feed-forward artificial neural + * network, which is a mathematical model inspired by the biological neural + * network. The multilayer perceptron can be used for various machine learning + * tasks such as classification and regression. + * + * A detailed introduction about MLP can be found at + * http://ufldl.stanford.edu/wiki/index.php/Neural_Networks. + * + * For this particular implementation, the users can freely control the topology + * of the MLP, including: 1. The size of the input layer; 2. The number of + * hidden layers; 3. The size of each hidden layer; 4. The size of the output + * later. 5. The cost function. 6. The squashing function. + * + * The model is trained in an online learning approach, where the weights of + * neurons in the MLP is updated incremented using backPropagation algorithm + * proposed by (Rumelhart, D. E., Hinton, G. E., and Williams, R. J. (1986) + * Learning representations by back-propagating errors. Nature, 323, 533--536.) + */ +public class MultilayerPerceptron extends NeuralNetwork implements OnlineLearner { + + /** + * The default constructor. + */ + public MultilayerPerceptron() { + super(); + } + + /** + * Initialize the MLP by specifying the location of the model. + * + * @param modelPath The path of the model. + */ + public MultilayerPerceptron(String modelPath) { + super(modelPath); + } + + @Override + public void train(int actual, Vector instance) { + // construct the training instance, where append the actual to instance + Vector trainingInstance = new DenseVector(instance.size() + 1); + for (int i = 0; i < instance.size(); ++i) { + trainingInstance.setQuick(i, instance.getQuick(i)); + } + trainingInstance.setQuick(instance.size(), actual); + this.trainOnline(trainingInstance); + } + + @Override + public void train(long trackingKey, String groupKey, int actual, + Vector instance) { + throw new UnsupportedOperationException(); + } + + @Override + public void train(long trackingKey, int actual, Vector instance) { + throw new UnsupportedOperationException(); + } + + @Override + public void close() { + // DO NOTHING + } + +} Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetwork.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetwork.java?rev=1552403&view=auto ============================================================================== --- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetwork.java (added) +++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetwork.java Thu Dec 19 19:29:02 2013 @@ -0,0 +1,740 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.classifier.mlp; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FSDataInputStream; +import org.apache.hadoop.fs.FSDataOutputStream; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.WritableUtils; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.common.RandomWrapper; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.MatrixWritable; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.function.DoubleDoubleFunction; +import org.apache.mahout.math.function.DoubleFunction; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; +import com.google.common.io.Closeables; + +/** + * AbstractNeuralNetwork defines the general operations for a neural network + * based model. Typically, all derivative models such as Multilayer Perceptron + * and Autoencoder consist of neurons and the weights between neurons. + */ +public abstract class NeuralNetwork { + + /* The default learning rate */ + private static final double DEFAULT_LEARNING_RATE = 0.5; + /* The default regularization weight */ + private static final double DEFAULT_REGULARIZATION_WEIGHT = 0; + /* The default momentum weight */ + private static final double DEFAULT_MOMENTUM_WEIGHT = 0.1; + + public static enum TrainingMethod { + GRADIENT_DESCENT + } + + /* the name of the model */ + protected String modelType; + + /* the path to store the model */ + protected String modelPath; + + protected double learningRate; + + /* The weight of regularization */ + protected double regularizationWeight; + + /* The momentum weight */ + protected double momentumWeight; + + /* The cost function of the model */ + protected String costFunctionName; + + /* Record the size of each layer */ + protected List<Integer> layerSizeList; + + /* Training method used for training the model */ + protected TrainingMethod trainingMethod; + + /* Weights between neurons at adjacent layers */ + protected List<Matrix> weightMatrixList; + + /* Previous weight updates between neurons at adjacent layers */ + protected List<Matrix> prevWeightUpdatesList; + + /* Different layers can have different squashing function */ + protected List<String> squashingFunctionList; + + /* The index of the final layer */ + protected int finalLayerIdx; + + /** + * The default constructor that initializes the learning rate, regularization + * weight, and momentum weight by default. + */ + public NeuralNetwork() { + this.learningRate = DEFAULT_LEARNING_RATE; + this.regularizationWeight = DEFAULT_REGULARIZATION_WEIGHT; + this.momentumWeight = DEFAULT_MOMENTUM_WEIGHT; + this.trainingMethod = TrainingMethod.GRADIENT_DESCENT; + this.costFunctionName = "Minus_Squared"; + this.modelType = this.getClass().getSimpleName(); + + this.layerSizeList = Lists.newArrayList(); + this.layerSizeList = Lists.newArrayList(); + this.weightMatrixList = Lists.newArrayList(); + this.prevWeightUpdatesList = Lists.newArrayList(); + this.squashingFunctionList = Lists.newArrayList(); + } + + /** + * Initialize the NeuralNetwork by specifying learning rate, momentum weight + * and regularization weight. + * + * @param learningRate The learning rate. + * @param momentumWeight The momentum weight. + * @param regularizationWeight The regularization weight. + */ + public NeuralNetwork(double learningRate, double momentumWeight, double regularizationWeight) { + this(); + this.setLearningRate(learningRate); + this.setMomentumWeight(momentumWeight); + this.setRegularizationWeight(regularizationWeight); + } + + /** + * Initialize the NeuralNetwork by specifying the location of the model. + * + * @param modelPath The location that the model is stored. + */ + public NeuralNetwork(String modelPath) { + try { + this.modelPath = modelPath; + this.readFromModel(); + } catch (IOException e) { + e.printStackTrace(); + } + } + + /** + * Get the type of the model. + * + * @return The name of the model. + */ + public String getModelType() { + return this.modelType; + } + + /** + * Set the degree of aggression during model training, a large learning rate + * can increase the training speed, but it also decreases the chance of model + * converge. + * + * @param learningRate Learning rate must be a non-negative value. Recommend in range (0, 0.5). + * @return The model instance. + */ + public NeuralNetwork setLearningRate(double learningRate) { + Preconditions.checkArgument(learningRate > 0, "Learning rate must be larger than 0."); + this.learningRate = learningRate; + return this; + } + + /** + * Get the value of learning rate. + * + * @return The value of learning rate. + */ + public double getLearningRate() { + return this.learningRate; + } + + /** + * Set the regularization weight. More complex the model is, less weight the + * regularization is. + * + * @param regularizationWeight regularization must be in the range [0, 0.1). + * @return The model instance. + */ + public NeuralNetwork setRegularizationWeight(double regularizationWeight) { + Preconditions.checkArgument(regularizationWeight >= 0 + && regularizationWeight < 0.1, "Regularization weight must be in range [0, 0.1)"); + this.regularizationWeight = regularizationWeight; + return this; + } + + /** + * Get the weight of the regularization. + * + * @return The weight of regularization. + */ + public double getRegularizationWeight() { + return this.regularizationWeight; + } + + /** + * Set the momentum weight for the model. + * + * @param momentumWeight momentumWeight must be in range [0, 0.5]. + * @return The model instance. + */ + public NeuralNetwork setMomentumWeight(double momentumWeight) { + Preconditions.checkArgument(momentumWeight >= 0 && momentumWeight <= 1.0, + "Momentum weight must be in range [0, 1.0]"); + this.momentumWeight = momentumWeight; + return this; + } + + /** + * Get the momentum weight. + * + * @return The value of momentum. + */ + public double getMomentumWeight() { + return this.momentumWeight; + } + + /** + * Set the training method. + * + * @param method The training method, currently supports GRADIENT_DESCENT. + * @return The instance of the model. + */ + public NeuralNetwork setTrainingMethod(TrainingMethod method) { + this.trainingMethod = method; + return this; + } + + /** + * Get the training method. + * + * @return The training method enumeration. + */ + public TrainingMethod getTrainingMethod() { + return this.trainingMethod; + } + + /** + * Set the cost function for the model. + * + * @param costFunction the name of the cost function. Currently supports + * "Minus_Squared", "Cross_Entropy". + */ + public NeuralNetwork setCostFunction(String costFunction) { + this.costFunctionName = costFunction; + return this; + } + + /** + * Add a layer of neurons with specified size. If the added layer is not the + * first layer, it will automatically connect the neurons between with the + * previous layer. + * + * @param size The size of the layer. (bias neuron excluded) + * @param isFinalLayer If false, add a bias neuron. + * @param squashingFunctionName The squashing function for this layer, input + * layer is f(x) = x by default. + * @return The layer index, starts with 0. + */ + public int addLayer(int size, boolean isFinalLayer, String squashingFunctionName) { + Preconditions.checkArgument(size > 0, "Size of layer must be larger than 0."); + int actualSize = size; + if (!isFinalLayer) { + actualSize += 1; + } + + this.layerSizeList.add(actualSize); + int layerIdx = this.layerSizeList.size() - 1; + if (isFinalLayer) { + this.finalLayerIdx = layerIdx; + } + + // add weights between current layer and previous layer, and input layer has + // no squashing function + if (layerIdx > 0) { + int sizePrevLayer = this.layerSizeList.get(layerIdx - 1); + // row count equals to size of current size and column count equal to + // size of previous layer + int row = isFinalLayer ? actualSize : actualSize - 1; + Matrix weightMatrix = new DenseMatrix(row, sizePrevLayer); + // initialize weights + final RandomWrapper rnd = RandomUtils.getRandom(); + weightMatrix.assign(new DoubleFunction() { + @Override + public double apply(double value) { + return rnd.nextDouble() - 0.5; + } + }); + this.weightMatrixList.add(weightMatrix); + this.prevWeightUpdatesList.add(new DenseMatrix(row, sizePrevLayer)); + this.squashingFunctionList.add(squashingFunctionName); + } + return layerIdx; + } + + /** + * Get the size of a particular layer. + * + * @param layer The index of the layer, starting from 0. + * @return The size of the corresponding layer. + */ + public int getLayerSize(int layer) { + Preconditions.checkArgument(layer >= 0 && layer < this.layerSizeList.size(), + String.format("Input must be in range [0, %d]\n", this.layerSizeList.size() - 1)); + return this.layerSizeList.get(layer); + } + + /** + * Get the layer size list. + * + * @return The sizes of the layers. + */ + protected List<Integer> getLayerSizeList() { + return this.layerSizeList; + } + + /** + * Get the weights between layer layerIdx and layerIdx + 1 + * + * @param layerIdx The index of the layer. + * @return The weights in form of {@link Matrix}. + */ + public Matrix getWeightsByLayer(int layerIdx) { + return this.weightMatrixList.get(layerIdx); + } + + /** + * Update the weight matrices with given matrices. + * + * @param matrices The weight matrices, must be the same dimension as the + * existing matrices. + */ + public void updateWeightMatrices(Matrix[] matrices) { + for (int i = 0; i < matrices.length; ++i) { + Matrix matrix = this.weightMatrixList.get(i); + this.weightMatrixList.set(i, matrix.plus(matrices[i])); + } + } + + /** + * Set the weight matrices. + * + * @param matrices The weight matrices, must be the same dimension of the + * existing matrices. + */ + public void setWeightMatrices(Matrix[] matrices) { + this.weightMatrixList = Lists.newArrayList(); + Collections.addAll(this.weightMatrixList, matrices); + } + + /** + * Set the weight matrix for a specified layer. + * + * @param index The index of the matrix, starting from 0 (between layer 0 and 1). + * @param matrix The instance of {@link Matrix}. + */ + public void setWeightMatrix(int index, Matrix matrix) { + Preconditions.checkArgument(0 <= index && index < this.weightMatrixList.size(), + String.format("index [%s] should be in range [%s, %s).", index, 0, this.weightMatrixList.size())); + this.weightMatrixList.set(index, matrix); + } + + /** + * Get all the weight matrices. + * + * @return The weight matrices. + */ + public Matrix[] getWeightMatrices() { + Matrix[] matrices = new Matrix[this.weightMatrixList.size()]; + this.weightMatrixList.toArray(matrices); + return matrices; + } + + /** + * Get the output calculated by the model. + * + * @param instance The feature instance in form of {@link Vector}, each dimension contains the value of the corresponding feature. + * @return The output vector. + */ + public Vector getOutput(Vector instance) { + Preconditions.checkArgument(this.layerSizeList.get(0) == instance.size() + 1, + String.format("The dimension of input instance should be %d, but the input has dimension %d.", + this.layerSizeList.get(0) - 1, instance.size())); + + // add bias feature + Vector instanceWithBias = new DenseVector(instance.size() + 1); + // set bias to be a little bit less than 1.0 + instanceWithBias.set(0, 0.99999); + for (int i = 1; i < instanceWithBias.size(); ++i) { + instanceWithBias.set(i, instance.get(i - 1)); + } + + List<Vector> outputCache = getOutputInternal(instanceWithBias); + // return the output of the last layer + Vector result = outputCache.get(outputCache.size() - 1); + // remove bias + return result.viewPart(1, result.size() - 1); + } + + /** + * Calculate output internally, the intermediate output of each layer will be + * stored. + * + * @param instance The feature instance in form of {@link Vector}, each dimension contains the value of the corresponding feature. + * @return Cached output of each layer. + */ + protected List<Vector> getOutputInternal(Vector instance) { + List<Vector> outputCache = Lists.newArrayList(); + // fill with instance + Vector intermediateOutput = instance; + outputCache.add(intermediateOutput); + + for (int i = 0; i < this.layerSizeList.size() - 1; ++i) { + intermediateOutput = forward(i, intermediateOutput); + outputCache.add(intermediateOutput); + } + return outputCache; + } + + /** + * Forward the calculation for one layer. + * + * @param fromLayer The index of the previous layer. + * @param intermediateOutput The intermediate output of previous layer. + * @return The intermediate results of the current layer. + */ + protected Vector forward(int fromLayer, Vector intermediateOutput) { + Matrix weightMatrix = this.weightMatrixList.get(fromLayer); + + Vector vec = weightMatrix.times(intermediateOutput); + vec = vec.assign(NeuralNetworkFunctions.getDoubleFunction(this.squashingFunctionList.get(fromLayer))); + + // add bias + Vector vecWithBias = new DenseVector(vec.size() + 1); + vecWithBias.set(0, 1); + for (int i = 0; i < vec.size(); ++i) { + vecWithBias.set(i + 1, vec.get(i)); + } + return vecWithBias; + } + + /** + * Train the neural network incrementally with given training instance. + * + * @param trainingInstance An training instance, including the features and the label(s). Its dimension must equals + * to the size of the input layer (bias neuron excluded) + the size + * of the output layer (a.k.a. the dimension of the labels). + */ + public void trainOnline(Vector trainingInstance) { + Matrix[] matrices = this.trainByInstance(trainingInstance); + this.updateWeightMatrices(matrices); + } + + /** + * Get the updated weights using one training instance. + * + * @param trainingInstance An training instance, including the features and the label(s). Its dimension must equals + * to the size of the input layer (bias neuron excluded) + the size + * of the output layer (a.k.a. the dimension of the labels). + * @return The update of each weight, in form of {@link Matrix} list. + */ + public Matrix[] trainByInstance(Vector trainingInstance) { + // validate training instance + int inputDimension = this.layerSizeList.get(0) - 1; + int outputDimension = this.layerSizeList.get(this.layerSizeList.size() - 1); + Preconditions.checkArgument(inputDimension + outputDimension == trainingInstance.size(), + String.format("The dimension of training instance is %d, but requires %d.", trainingInstance.size(), + inputDimension + outputDimension)); + + if (this.trainingMethod.equals(TrainingMethod.GRADIENT_DESCENT)) { + return this.trainByInstanceGradientDescent(trainingInstance); + } + throw new IllegalArgumentException(String.format("Training method is not supported.")); + } + + /** + * Train by gradient descent. Get the updated weights using one training + * instance. + * + * @param trainingInstance An training instance, including the features and the label(s). Its dimension must equals + * to the size of the input layer (bias neuron excluded) + the size + * of the output layer (a.k.a. the dimension of the labels). + * @return The weight update matrices. + */ + private Matrix[] trainByInstanceGradientDescent(Vector trainingInstance) { + int inputDimension = this.layerSizeList.get(0) - 1; + + Vector inputInstance = new DenseVector(this.layerSizeList.get(0)); + inputInstance.set(0, 1); // add bias + for (int i = 0; i < inputDimension; ++i) { + inputInstance.set(i + 1, trainingInstance.get(i)); + } + + Vector labels = trainingInstance.viewPart(inputInstance.size() - 1, trainingInstance.size() - inputInstance.size() + 1); + + // initialize weight update matrices + Matrix[] weightUpdateMatrices = new Matrix[this.weightMatrixList.size()]; + for (int m = 0; m < weightUpdateMatrices.length; ++m) { + weightUpdateMatrices[m] = new DenseMatrix(this.weightMatrixList.get(m).rowSize(), this.weightMatrixList.get(m).columnSize()); + } + + List<Vector> internalResults = this.getOutputInternal(inputInstance); + + Vector deltaVec = new DenseVector(this.layerSizeList.get(this.layerSizeList.size() - 1)); + Vector output = internalResults.get(internalResults.size() - 1); + + final DoubleFunction derivativeSquashingFunction = + NeuralNetworkFunctions.getDerivativeDoubleFunction(this.squashingFunctionList.get(this.squashingFunctionList.size() - 1)); + + final DoubleDoubleFunction costFunction = NeuralNetworkFunctions.getDerivativeDoubleDoubleFunction(this.costFunctionName); + + Matrix lastWeightMatrix = this.weightMatrixList.get(this.weightMatrixList.size() - 1); + + for (int i = 0; i < deltaVec.size(); ++i) { + double costFuncDerivative = costFunction.apply(labels.get(i), output.get(i + 1)); + // add regularization + costFuncDerivative += this.regularizationWeight * lastWeightMatrix.viewRow(i).zSum(); + deltaVec.set(i, costFuncDerivative); + deltaVec.set(i, deltaVec.get(i) * derivativeSquashingFunction.apply(output.get(i + 1))); + } + + // start from previous layer of output layer + for (int layer = this.layerSizeList.size() - 2; layer >= 0; --layer) { + deltaVec = backPropagate(layer, deltaVec, internalResults, weightUpdateMatrices[layer]); + } + + this.prevWeightUpdatesList = Arrays.asList(weightUpdateMatrices); + + return weightUpdateMatrices; + } + + /** + * Back-propagate the errors to from next layer to current layer. The weight + * updated information will be stored in the weightUpdateMatrices, and the + * delta of the prevLayer will be returned. + * + * @param curLayerIdx Index of current layer. + * @param nextLayerDelta Delta of next layer. + * @param outputCache The output cache to store intermediate results. + * @param weightUpdateMatrix The weight update, in form of {@link Matrix}. + * @return The weight updates. + */ + private Vector backPropagate(int curLayerIdx, Vector nextLayerDelta, + List<Vector> outputCache, Matrix weightUpdateMatrix) { + + // get layer related information + final DoubleFunction derivativeSquashingFunction = + NeuralNetworkFunctions.getDerivativeDoubleFunction(this.squashingFunctionList.get(curLayerIdx)); + Vector curLayerOutput = outputCache.get(curLayerIdx); + Matrix weightMatrix = this.weightMatrixList.get(curLayerIdx); + Matrix prevWeightMatrix = this.prevWeightUpdatesList.get(curLayerIdx); + + // next layer is not output layer, remove the delta of bias neuron + if (curLayerIdx != this.layerSizeList.size() - 2) { + nextLayerDelta = nextLayerDelta.viewPart(1, nextLayerDelta.size() - 1); + } + + Vector delta = weightMatrix.transpose().times(nextLayerDelta); + + delta = delta.assign(curLayerOutput, new DoubleDoubleFunction() { + @Override + public double apply(double deltaElem, double curLayerOutputElem) { + return deltaElem * derivativeSquashingFunction.apply(curLayerOutputElem); + } + }); + + // update weights + for (int i = 0; i < weightUpdateMatrix.rowSize(); ++i) { + for (int j = 0; j < weightUpdateMatrix.columnSize(); ++j) { + weightUpdateMatrix.set(i, j, -learningRate * nextLayerDelta.get(i) * + curLayerOutput.get(j) + this.momentumWeight * prevWeightMatrix.get(i, j)); + } + } + + return delta; + } + + /** + * Read the model meta-data from the specified location. + * + * @throws IOException + */ + protected void readFromModel() throws IOException { + Preconditions.checkArgument(this.modelPath != null, "Model path has not been set."); + FSDataInputStream is = null; + try { + Path path = new Path(this.modelPath); + FileSystem fs = path.getFileSystem(new Configuration()); + is = new FSDataInputStream(fs.open(path)); + this.readFields(is); + } finally { + Closeables.close(is, true); + } + } + + /** + * Write the model data to specified location. + * + * @throws IOException + */ + public void writeModelToFile() throws IOException { + Preconditions.checkArgument(this.modelPath != null, "Model path has not been set."); + FSDataOutputStream stream = null; + try { + Path path = new Path(this.modelPath); + FileSystem fs = path.getFileSystem(new Configuration()); + stream = fs.create(path, true); + this.write(stream); + } finally { + Closeables.close(stream, false); + } + } + + /** + * Set the model path. + * + * @param modelPath The path of the model. + */ + public void setModelPath(String modelPath) { + this.modelPath = modelPath; + } + + /** + * Get the model path. + * + * @return The path of the model. + */ + public String getModelPath() { + return this.modelPath; + } + + /** + * Write the fields of the model to output. + * + * @param output The output instance. + * @throws IOException + */ + public void write(DataOutput output) throws IOException { + // write model type + WritableUtils.writeString(output, modelType); + // write learning rate + output.writeDouble(learningRate); + // write model path + if (this.modelPath != null) { + WritableUtils.writeString(output, modelPath); + } else { + WritableUtils.writeString(output, "null"); + } + + // write regularization weight + output.writeDouble(this.regularizationWeight); + // write momentum weight + output.writeDouble(this.momentumWeight); + + // write cost function + WritableUtils.writeString(output, this.costFunctionName); + + // write layer size list + output.writeInt(this.layerSizeList.size()); + for (Integer aLayerSizeList : this.layerSizeList) { + output.writeInt(aLayerSizeList); + } + + WritableUtils.writeEnum(output, this.trainingMethod); + + // write squashing functions + output.writeInt(this.squashingFunctionList.size()); + for (String aSquashingFunctionList : this.squashingFunctionList) { + WritableUtils.writeString(output, aSquashingFunctionList); + } + + // write weight matrices + output.writeInt(this.weightMatrixList.size()); + for (Matrix aWeightMatrixList : this.weightMatrixList) { + MatrixWritable.writeMatrix(output, aWeightMatrixList); + } + } + + /** + * Read the fields of the model from input. + * + * @param input The input instance. + * @throws IOException + */ + public void readFields(DataInput input) throws IOException { + // read model type + this.modelType = WritableUtils.readString(input); + if (!this.modelType.equals(this.getClass().getSimpleName())) { + throw new IllegalArgumentException("The specified location does not contains the valid NeuralNetwork model."); + } + // read learning rate + this.learningRate = input.readDouble(); + // read model path + this.modelPath = WritableUtils.readString(input); + if (this.modelPath.equals("null")) { + this.modelPath = null; + } + + // read regularization weight + this.regularizationWeight = input.readDouble(); + // read momentum weight + this.momentumWeight = input.readDouble(); + + // read cost function + this.costFunctionName = WritableUtils.readString(input); + + // read layer size list + int numLayers = input.readInt(); + this.layerSizeList = Lists.newArrayList(); + for (int i = 0; i < numLayers; i++) { + this.layerSizeList.add(input.readInt()); + } + + this.trainingMethod = WritableUtils.readEnum(input, TrainingMethod.class); + + // read squash functions + int squashingFunctionSize = input.readInt(); + this.squashingFunctionList = Lists.newArrayList(); + for (int i = 0; i < squashingFunctionSize; i++) { + this.squashingFunctionList.add(WritableUtils.readString(input)); + } + + // read weights and construct matrices of previous updates + int numOfMatrices = input.readInt(); + this.weightMatrixList = Lists.newArrayList(); + this.prevWeightUpdatesList = Lists.newArrayList(); + for (int i = 0; i < numOfMatrices; i++) { + Matrix matrix = MatrixWritable.readMatrix(input); + this.weightMatrixList.add(matrix); + this.prevWeightUpdatesList.add(new DenseMatrix(matrix.rowSize(), matrix.columnSize())); + } + } + +} \ No newline at end of file Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetworkFunctions.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetworkFunctions.java?rev=1552403&view=auto ============================================================================== --- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetworkFunctions.java (added) +++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetworkFunctions.java Thu Dec 19 19:29:02 2013 @@ -0,0 +1,150 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.classifier.mlp; + +import org.apache.mahout.math.function.DoubleDoubleFunction; +import org.apache.mahout.math.function.DoubleFunction; +import org.apache.mahout.math.function.Functions; + +/** + * The functions that will be used by NeuralNetwork. + */ +public class NeuralNetworkFunctions { + + /** + * The derivation of identity function (f(x) = x). + */ + public static DoubleFunction derivativeIdentityFunction = new DoubleFunction() { + @Override + public double apply(double x) { + return 1; + } + }; + + /** + * The derivation of minus squared function (f(t, o) = (o - t)^2). + */ + public static DoubleDoubleFunction derivativeMinusSquared = new DoubleDoubleFunction() { + @Override + public double apply(double target, double output) { + return 2 * (output - target); + } + }; + + /** + * The cross entropy function (f(t, o) = -t * log(o) - (1 - t) * log(1 - o)). + */ + public static DoubleDoubleFunction crossEntropy = new DoubleDoubleFunction() { + @Override + public double apply(double target, double output) { + return -target * Math.log(output) - (1 - target) * Math.log(1 - output); + } + }; + + /** + * The derivation of cross entropy function (f(t, o) = -t * log(o) - (1 - t) * + * log(1 - o)). + */ + public static DoubleDoubleFunction derivativeCrossEntropy = new DoubleDoubleFunction() { + @Override + public double apply(double target, double output) { + double adjustedTarget = target; + double adjustedActual = output; + if (adjustedActual == 1) { + adjustedActual = 0.999; + } else if (output == 0) { + adjustedActual = 0.001; + } + if (adjustedTarget == 1) { + adjustedTarget = 0.999; + } else if (adjustedTarget == 0) { + adjustedTarget = 0.001; + } + return -adjustedTarget / adjustedActual + (1 - adjustedTarget) / (1 - adjustedActual); + } + }; + + /** + * Get the corresponding function by its name. + * Currently supports: "Identity", "Sigmoid". + * + * @param function The name of the function. + * @return The corresponding double function. + */ + public static DoubleFunction getDoubleFunction(String function) { + if (function.equalsIgnoreCase("Identity")) { + return Functions.IDENTITY; + } else if (function.equalsIgnoreCase("Sigmoid")) { + return Functions.SIGMOID; + } else { + throw new IllegalArgumentException("Function not supported."); + } + } + + /** + * Get the derivation double function by the name. + * Currently supports: "Identity", "Sigmoid". + * + * @param function The name of the function. + * @return The double function. + */ + public static DoubleFunction getDerivativeDoubleFunction(String function) { + if (function.equalsIgnoreCase("Identity")) { + return derivativeIdentityFunction; + } else if (function.equalsIgnoreCase("Sigmoid")) { + return Functions.SIGMOIDGRADIENT; + } else { + throw new IllegalArgumentException("Function not supported."); + } + } + + /** + * Get the corresponding double-double function by the name. + * Currently supports: "Minus_Squared", "Cross_Entropy". + * + * @param function The name of the function. + * @return The double-double function. + */ + public static DoubleDoubleFunction getDoubleDoubleFunction(String function) { + if (function.equalsIgnoreCase("Minus_Squared")) { + return Functions.MINUS_SQUARED; + } else if (function.equalsIgnoreCase("Cross_Entropy")) { + return derivativeCrossEntropy; + } else { + throw new IllegalArgumentException("Function not supported."); + } + } + + /** + * Get the corresponding derivation of double double function by the name. + * Currently supports: "Minus_Squared", "Cross_Entropy". + * + * @param function The name of the function. + * @return The double-double-function. + */ + public static DoubleDoubleFunction getDerivativeDoubleDoubleFunction(String function) { + if (function.equalsIgnoreCase("Minus_Squared")) { + return derivativeMinusSquared; + } else if (function.equalsIgnoreCase("Cross_Entropy")) { + return derivativeCrossEntropy; + } else { + throw new IllegalArgumentException("Function not supported."); + } + } + +} \ No newline at end of file Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/mlp/TestMultilayerPerceptron.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/mlp/TestMultilayerPerceptron.java?rev=1552403&view=auto ============================================================================== --- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/mlp/TestMultilayerPerceptron.java (added) +++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/mlp/TestMultilayerPerceptron.java Thu Dec 19 19:29:02 2013 @@ -0,0 +1,88 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.classifier.mlp; + +import java.io.File; +import java.io.IOException; + +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.math.Arrays; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.junit.Test; + +/** + * Test the functionality of {@link MultilayerPerceptron} + */ +public class TestMultilayerPerceptron extends MahoutTestCase { + + @Test + public void testMLP() throws IOException { + testMLP("testMLPXORLocal", false, false, 8000); + testMLP("testMLPXORLocalWithMomentum", true, false, 4000); + testMLP("testMLPXORLocalWithRegularization", true, true, 2000); + } + + private void testMLP(String modelFilename, boolean useMomentum, + boolean useRegularization, int iterations) throws IOException { + MultilayerPerceptron mlp = new MultilayerPerceptron(); + mlp.addLayer(2, false, "Sigmoid"); + mlp.addLayer(3, false, "Sigmoid"); + mlp.addLayer(1, true, "Sigmoid"); + mlp.setCostFunction("Minus_Squared").setLearningRate(0.2); + if (useMomentum) { + mlp.setMomentumWeight(0.6); + } + + if (useRegularization) { + mlp.setRegularizationWeight(0.01); + } + + double[][] instances = { { 0, 1, 1 }, { 0, 0, 0 }, { 1, 0, 1 }, { 1, 1, 0 } }; + for (int i = 0; i < iterations; ++i) { + for (double[] instance : instances) { + Vector features = new DenseVector(Arrays.copyOf(instance, instance.length - 1)); + mlp.train((int) instance[2], features); + } + } + + for (double[] instance : instances) { + Vector input = new DenseVector(instance).viewPart(0, instance.length - 1); + // the expected output is the last element in array + double actual = instance[2]; + double expected = mlp.getOutput(input).get(0); + assertTrue(actual < 0.5 && expected < 0.5 || actual >= 0.5 && expected >= 0.5); + } + + // write model into file and read out + File modelFile = this.getTestTempFile(modelFilename); + mlp.setModelPath(modelFile.getAbsolutePath()); + mlp.writeModelToFile(); + mlp.close(); + + MultilayerPerceptron mlpCopy = new MultilayerPerceptron(modelFile.getAbsolutePath()); + // test on instances + for (double[] instance : instances) { + Vector input = new DenseVector(instance).viewPart(0, instance.length - 1); + // the expected output is the last element in array + double actual = instance[2]; + double expected = mlpCopy.getOutput(input).get(0); + assertTrue(actual < 0.5 && expected < 0.5 || actual >= 0.5 && expected >= 0.5); + } + mlpCopy.close(); + } +} Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/mlp/TestNeuralNetwork.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/mlp/TestNeuralNetwork.java?rev=1552403&view=auto ============================================================================== --- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/mlp/TestNeuralNetwork.java (added) +++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/mlp/TestNeuralNetwork.java Thu Dec 19 19:29:02 2013 @@ -0,0 +1,345 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.classifier.mlp; + +import java.io.File; +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import org.apache.mahout.classifier.mlp.NeuralNetwork.TrainingMethod; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.junit.Test; + +import com.google.common.base.Charsets; +import com.google.common.collect.Lists; +import com.google.common.io.Files; + +/** + * Test the functionality of {@link NeuralNetwork}. + */ +public class TestNeuralNetwork extends MahoutTestCase { + + @Test + public void testReadWrite() throws IOException { + NeuralNetwork ann = new MultilayerPerceptron(); + ann.addLayer(2, false, "Identity"); + ann.addLayer(5, false, "Identity"); + ann.addLayer(1, true, "Identity"); + ann.setCostFunction("Minus_Squared"); + double learningRate = 0.2; + double momentumWeight = 0.5; + double regularizationWeight = 0.05; + ann.setLearningRate(learningRate).setMomentumWeight(momentumWeight).setRegularizationWeight(regularizationWeight); + + // manually set weights + Matrix[] matrices = new DenseMatrix[2]; + matrices[0] = new DenseMatrix(5, 3); + matrices[0].assign(0.2); + matrices[1] = new DenseMatrix(1, 6); + matrices[1].assign(0.8); + ann.setWeightMatrices(matrices); + + // write to file + String modelFilename = "testNeuralNetworkReadWrite"; + File tmpModelFile = this.getTestTempFile(modelFilename); + ann.setModelPath(tmpModelFile.getAbsolutePath()); + ann.writeModelToFile(); + + // read from file + NeuralNetwork annCopy = new MultilayerPerceptron(tmpModelFile.getAbsolutePath()); + assertEquals(annCopy.getClass().getSimpleName(), annCopy.getModelType()); + assertEquals(tmpModelFile.getAbsolutePath(), annCopy.getModelPath()); + assertEquals(learningRate, annCopy.getLearningRate(), 0.000001); + assertEquals(momentumWeight, annCopy.getMomentumWeight(), 0.000001); + assertEquals(regularizationWeight, annCopy.getRegularizationWeight(), 0.000001); + assertEquals(TrainingMethod.GRADIENT_DESCENT, annCopy.getTrainingMethod()); + + // compare weights + Matrix[] weightsMatrices = annCopy.getWeightMatrices(); + for (int i = 0; i < weightsMatrices.length; ++i) { + Matrix expectMat = matrices[i]; + Matrix actualMat = weightsMatrices[i]; + for (int j = 0; j < expectMat.rowSize(); ++j) { + for (int k = 0; k < expectMat.columnSize(); ++k) { + assertEquals(expectMat.get(j, k), actualMat.get(j, k), 0.000001); + } + } + } + } + + /** + * Test the forward functionality. + */ + @Test + public void testOutput() { + // first network + NeuralNetwork ann = new MultilayerPerceptron(); + ann.addLayer(2, false, "Identity"); + ann.addLayer(5, false, "Identity"); + ann.addLayer(1, true, "Identity"); + ann.setCostFunction("Minus_Squared").setLearningRate(0.1); + + // intentionally initialize all weights to 0.5 + Matrix[] matrices = new Matrix[2]; + matrices[0] = new DenseMatrix(5, 3); + matrices[0].assign(0.5); + matrices[1] = new DenseMatrix(1, 6); + matrices[1].assign(0.5); + ann.setWeightMatrices(matrices); + + double[] arr = new double[]{0, 1}; + Vector training = new DenseVector(arr); + Vector result = ann.getOutput(training); + assertEquals(1, result.size()); + + // second network + NeuralNetwork ann2 = new MultilayerPerceptron(); + ann2.addLayer(2, false, "Sigmoid"); + ann2.addLayer(3, false, "Sigmoid"); + ann2.addLayer(1, true, "Sigmoid"); + ann2.setCostFunction("Minus_Squared"); + ann2.setLearningRate(0.3); + + // intentionally initialize all weights to 0.5 + Matrix[] matrices2 = new Matrix[2]; + matrices2[0] = new DenseMatrix(3, 3); + matrices2[0].assign(0.5); + matrices2[1] = new DenseMatrix(1, 4); + matrices2[1].assign(0.5); + ann2.setWeightMatrices(matrices2); + + double[] test = {0, 0}; + double[] result2 = {0.807476}; + + Vector vec = ann2.getOutput(new DenseVector(test)); + double[] arrVec = new double[vec.size()]; + for (int i = 0; i < arrVec.length; ++i) { + arrVec[i] = vec.getQuick(i); + } + assertArrayEquals(result2, arrVec, 0.000001); + + NeuralNetwork ann3 = new MultilayerPerceptron(); + ann3.addLayer(2, false, "Sigmoid"); + ann3.addLayer(3, false, "Sigmoid"); + ann3.addLayer(1, true, "Sigmoid"); + ann3.setCostFunction("Minus_Squared").setLearningRate(0.3); + + // intentionally initialize all weights to 0.5 + Matrix[] initMatrices = new Matrix[2]; + initMatrices[0] = new DenseMatrix(3, 3); + initMatrices[0].assign(0.5); + initMatrices[1] = new DenseMatrix(1, 4); + initMatrices[1].assign(0.5); + ann3.setWeightMatrices(initMatrices); + + double[] instance = {0, 1}; + Vector output = ann3.getOutput(new DenseVector(instance)); + assertEquals(0.8315410, output.get(0), 0.000001); + } + + @Test + public void testNeuralNetwork() throws IOException { + testNeuralNetwork("testNeuralNetworkXORLocal", false, false, 10000); + testNeuralNetwork("testNeuralNetworkXORWithMomentum", true, false, 5000); + testNeuralNetwork("testNeuralNetworkXORWithRegularization", true, true, 5000); + } + + private void testNeuralNetwork(String modelFilename, boolean useMomentum, + boolean useRegularization, int iterations) throws IOException { + NeuralNetwork ann = new MultilayerPerceptron(); + ann.addLayer(2, false, "Sigmoid"); + ann.addLayer(3, false, "Sigmoid"); + ann.addLayer(1, true, "Sigmoid"); + ann.setCostFunction("Minus_Squared").setLearningRate(0.1); + + if (useMomentum) { + ann.setMomentumWeight(0.6); + } + + if (useRegularization) { + ann.setRegularizationWeight(0.01); + } + + double[][] instances = {{0, 1, 1}, {0, 0, 0}, {1, 0, 1}, {1, 1, 0}}; + for (int i = 0; i < iterations; ++i) { + for (double[] instance : instances) { + ann.trainOnline(new DenseVector(instance)); + } + } + + for (double[] instance : instances) { + Vector input = new DenseVector(instance).viewPart(0, instance.length - 1); + // the expected output is the last element in array + double actual = instance[2]; + double expected = ann.getOutput(input).get(0); + assertTrue(actual < 0.5 && expected < 0.5 || actual >= 0.5 && expected >= 0.5); + } + + // write model into file and read out + File tmpModelFile = this.getTestTempFile(modelFilename); + ann.setModelPath(tmpModelFile.getAbsolutePath()); + ann.writeModelToFile(); + + NeuralNetwork annCopy = new MultilayerPerceptron(tmpModelFile.getAbsolutePath()); + // test on instances + for (double[] instance : instances) { + Vector input = new DenseVector(instance).viewPart(0, instance.length - 1); + // the expected output is the last element in array + double actual = instance[2]; + double expected = annCopy.getOutput(input).get(0); + assertTrue(actual < 0.5 && expected < 0.5 || actual >= 0.5 && expected >= 0.5); + } + } + + @Test + public void testWithCancerDataSet() throws IOException { + String dataSetPath = "src/test/resources/cancer.csv"; + List<Vector> records = Lists.newArrayList(); + // Returns a mutable list of the data + List<String> cancerDataSetList = Files.readLines(new File(dataSetPath), Charsets.UTF_8); + // skip the header line, hence remove the first element in the list + cancerDataSetList.remove(0); + for (String line : cancerDataSetList) { + String[] tokens = line.split(","); + double[] values = new double[tokens.length]; + for (int i = 0; i < tokens.length; ++i) { + values[i] = Double.parseDouble(tokens[i]); + } + records.add(new DenseVector(values)); + } + + int splitPoint = (int) (records.size() * 0.8); + List<Vector> trainingSet = records.subList(0, splitPoint); + List<Vector> testSet = records.subList(splitPoint, records.size()); + + // initialize neural network model + NeuralNetwork ann = new MultilayerPerceptron(); + int featureDimension = records.get(0).size() - 1; + ann.addLayer(featureDimension, false, "Sigmoid"); + ann.addLayer(featureDimension * 2, false, "Sigmoid"); + ann.addLayer(1, true, "Sigmoid"); + ann.setLearningRate(0.05).setMomentumWeight(0.5).setRegularizationWeight(0.001); + + int iteration = 2000; + for (int i = 0; i < iteration; ++i) { + for (Vector trainingInstance : trainingSet) { + ann.trainOnline(trainingInstance); + } + } + + int correctInstances = 0; + for (Vector testInstance : testSet) { + Vector res = ann.getOutput(testInstance.viewPart(0, testInstance.size() - 1)); + double actual = res.get(0); + double expected = testInstance.get(testInstance.size() - 1); + if (Math.abs(actual - expected) <= 0.1) { + ++correctInstances; + } + } + double accuracy = (double) correctInstances / testSet.size() * 100; + assertTrue("The classifier is even worse than a random guesser!", accuracy > 50); + System.out.printf("Cancer DataSet. Classification precision: %d/%d = %f%%\n", correctInstances, testSet.size(), accuracy); + } + + @Test + public void testWithIrisDataSet() throws IOException { + String dataSetPath = "src/test/resources/iris.csv"; + int numOfClasses = 3; + List<Vector> records = Lists.newArrayList(); + // Returns a mutable list of the data + List<String> irisDataSetList = Files.readLines(new File(dataSetPath), Charsets.UTF_8); + // skip the header line, hence remove the first element in the list + irisDataSetList.remove(0); + + for (String line : irisDataSetList) { + String[] tokens = line.split(","); + // last three dimensions represent the labels + double[] values = new double[tokens.length + numOfClasses - 1]; + Arrays.fill(values, 0.0); + for (int i = 0; i < tokens.length - 1; ++i) { + values[i] = Double.parseDouble(tokens[i]); + } + // add label values + String label = tokens[tokens.length - 1]; + if (label.equalsIgnoreCase("setosa")) { + values[values.length - 3] = 1; + } else if (label.equalsIgnoreCase("versicolor")) { + values[values.length - 2] = 1; + } else { // label 'virginica' + values[values.length - 1] = 1; + } + records.add(new DenseVector(values)); + } + + Collections.shuffle(records); + + int splitPoint = (int) (records.size() * 0.8); + List<Vector> trainingSet = records.subList(0, splitPoint); + List<Vector> testSet = records.subList(splitPoint, records.size()); + + // initialize neural network model + NeuralNetwork ann = new MultilayerPerceptron(); + int featureDimension = records.get(0).size() - numOfClasses; + ann.addLayer(featureDimension, false, "Sigmoid"); + ann.addLayer(featureDimension * 2, false, "Sigmoid"); + ann.addLayer(3, true, "Sigmoid"); // 3-class classification + ann.setLearningRate(0.05).setMomentumWeight(0.4).setRegularizationWeight(0.005); + + int iteration = 2000; + for (int i = 0; i < iteration; ++i) { + for (Vector trainingInstance : trainingSet) { + ann.trainOnline(trainingInstance); + } + } + + int correctInstances = 0; + for (Vector testInstance : testSet) { + Vector res = ann.getOutput(testInstance.viewPart(0, testInstance.size() - numOfClasses)); + double[] actualLabels = new double[numOfClasses]; + for (int i = 0; i < numOfClasses; ++i) { + actualLabels[i] = res.get(i); + } + double[] expectedLabels = new double[numOfClasses]; + for (int i = 0; i < numOfClasses; ++i) { + expectedLabels[i] = testInstance.get(testInstance.size() - numOfClasses + i); + } + + boolean allCorrect = true; + for (int i = 0; i < numOfClasses; ++i) { + if (Math.abs(expectedLabels[i] - actualLabels[i]) >= 0.1) { + allCorrect = false; + break; + } + } + if (allCorrect) { + ++correctInstances; + } + } + + double accuracy = (double) correctInstances / testSet.size() * 100; + assertTrue("The model is even worse than a random guesser.", accuracy > 50); + + System.out.printf("Iris DataSet. Classification precision: %d/%d = %f%%\n", correctInstances, testSet.size(), accuracy); + } + +}