http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/mlp/Datasets.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/classifier/mlp/Datasets.java b/mr/src/test/java/org/apache/mahout/classifier/mlp/Datasets.java new file mode 100644 index 0000000..76b1d3f --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/classifier/mlp/Datasets.java @@ -0,0 +1,866 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.mlp; + +public class Datasets { + + public static final String[] IRIS = new String[] { + "Sepal.Length,Sepal.Width,Petal.Length,Petal.Width,Species", + "5.1,3.5,1.4,0.2,setosa", + "4.9,3.0,1.4,0.2,setosa", + "4.7,3.2,1.3,0.2,setosa", + "4.6,3.1,1.5,0.2,setosa", + "5.0,3.6,1.4,0.2,setosa", + "5.4,3.9,1.7,0.4,setosa", + "4.6,3.4,1.4,0.3,setosa", + "5.0,3.4,1.5,0.2,setosa", + "4.4,2.9,1.4,0.2,setosa", + "4.9,3.1,1.5,0.1,setosa", + "5.4,3.7,1.5,0.2,setosa", + "4.8,3.4,1.6,0.2,setosa", + "4.8,3.0,1.4,0.1,setosa", + "4.3,3.0,1.1,0.1,setosa", + "5.8,4.0,1.2,0.2,setosa", + "5.7,4.4,1.5,0.4,setosa", + "5.4,3.9,1.3,0.4,setosa", + "5.1,3.5,1.4,0.3,setosa", + "5.7,3.8,1.7,0.3,setosa", + "5.1,3.8,1.5,0.3,setosa", + "5.4,3.4,1.7,0.2,setosa", + "5.1,3.7,1.5,0.4,setosa", + "4.6,3.6,1.0,0.2,setosa", + "5.1,3.3,1.7,0.5,setosa", + "4.8,3.4,1.9,0.2,setosa", + "5.0,3.0,1.6,0.2,setosa", + "5.0,3.4,1.6,0.4,setosa", + "5.2,3.5,1.5,0.2,setosa", + "5.2,3.4,1.4,0.2,setosa", + "4.7,3.2,1.6,0.2,setosa", + "4.8,3.1,1.6,0.2,setosa", + "5.4,3.4,1.5,0.4,setosa", + "5.2,4.1,1.5,0.1,setosa", + "5.5,4.2,1.4,0.2,setosa", + "4.9,3.1,1.5,0.2,setosa", + "5.0,3.2,1.2,0.2,setosa", + "5.5,3.5,1.3,0.2,setosa", + "4.9,3.6,1.4,0.1,setosa", + "4.4,3.0,1.3,0.2,setosa", + "5.1,3.4,1.5,0.2,setosa", + "5.0,3.5,1.3,0.3,setosa", + "4.5,2.3,1.3,0.3,setosa", + "4.4,3.2,1.3,0.2,setosa", + "5.0,3.5,1.6,0.6,setosa", + "5.1,3.8,1.9,0.4,setosa", + "4.8,3.0,1.4,0.3,setosa", + "5.1,3.8,1.6,0.2,setosa", + "4.6,3.2,1.4,0.2,setosa", + "5.3,3.7,1.5,0.2,setosa", + "5.0,3.3,1.4,0.2,setosa", + "7.0,3.2,4.7,1.4,versicolor", + "6.4,3.2,4.5,1.5,versicolor", + "6.9,3.1,4.9,1.5,versicolor", + "5.5,2.3,4.0,1.3,versicolor", + "6.5,2.8,4.6,1.5,versicolor", + "5.7,2.8,4.5,1.3,versicolor", + "6.3,3.3,4.7,1.6,versicolor", + "4.9,2.4,3.3,1.0,versicolor", + "6.6,2.9,4.6,1.3,versicolor", + "5.2,2.7,3.9,1.4,versicolor", + "5.0,2.0,3.5,1.0,versicolor", + "5.9,3.0,4.2,1.5,versicolor", + "6.0,2.2,4.0,1.0,versicolor", + "6.1,2.9,4.7,1.4,versicolor", + "5.6,2.9,3.6,1.3,versicolor", + "6.7,3.1,4.4,1.4,versicolor", + "5.6,3.0,4.5,1.5,versicolor", + "5.8,2.7,4.1,1.0,versicolor", + "6.2,2.2,4.5,1.5,versicolor", + "5.6,2.5,3.9,1.1,versicolor", + "5.9,3.2,4.8,1.8,versicolor", + "6.1,2.8,4.0,1.3,versicolor", + "6.3,2.5,4.9,1.5,versicolor", + "6.1,2.8,4.7,1.2,versicolor", + "6.4,2.9,4.3,1.3,versicolor", + "6.6,3.0,4.4,1.4,versicolor", + "6.8,2.8,4.8,1.4,versicolor", + "6.7,3.0,5.0,1.7,versicolor", + "6.0,2.9,4.5,1.5,versicolor", + "5.7,2.6,3.5,1.0,versicolor", + "5.5,2.4,3.8,1.1,versicolor", + "5.5,2.4,3.7,1.0,versicolor", + "5.8,2.7,3.9,1.2,versicolor", + "6.0,2.7,5.1,1.6,versicolor", + "5.4,3.0,4.5,1.5,versicolor", + "6.0,3.4,4.5,1.6,versicolor", + "6.7,3.1,4.7,1.5,versicolor", + "6.3,2.3,4.4,1.3,versicolor", + "5.6,3.0,4.1,1.3,versicolor", + "5.5,2.5,4.0,1.3,versicolor", + "5.5,2.6,4.4,1.2,versicolor", + "6.1,3.0,4.6,1.4,versicolor", + "5.8,2.6,4.0,1.2,versicolor", + "5.0,2.3,3.3,1.0,versicolor", + "5.6,2.7,4.2,1.3,versicolor", + "5.7,3.0,4.2,1.2,versicolor", + "5.7,2.9,4.2,1.3,versicolor", + "6.2,2.9,4.3,1.3,versicolor", + "5.1,2.5,3.0,1.1,versicolor", + "5.7,2.8,4.1,1.3,versicolor", + "6.3,3.3,6.0,2.5,virginica", + "5.8,2.7,5.1,1.9,virginica", + "7.1,3.0,5.9,2.1,virginica", + "6.3,2.9,5.6,1.8,virginica", + "6.5,3.0,5.8,2.2,virginica", + "7.6,3.0,6.6,2.1,virginica", + "4.9,2.5,4.5,1.7,virginica", + "7.3,2.9,6.3,1.8,virginica", + "6.7,2.5,5.8,1.8,virginica", + "7.2,3.6,6.1,2.5,virginica", + "6.5,3.2,5.1,2.0,virginica", + "6.4,2.7,5.3,1.9,virginica", + "6.8,3.0,5.5,2.1,virginica", + "5.7,2.5,5.0,2.0,virginica", + "5.8,2.8,5.1,2.4,virginica", + "6.4,3.2,5.3,2.3,virginica", + "6.5,3.0,5.5,1.8,virginica", + "7.7,3.8,6.7,2.2,virginica", + "7.7,2.6,6.9,2.3,virginica", + "6.0,2.2,5.0,1.5,virginica", + "6.9,3.2,5.7,2.3,virginica", + "5.6,2.8,4.9,2.0,virginica", + "7.7,2.8,6.7,2.0,virginica", + "6.3,2.7,4.9,1.8,virginica", + "6.7,3.3,5.7,2.1,virginica", + "7.2,3.2,6.0,1.8,virginica", + "6.2,2.8,4.8,1.8,virginica", + "6.1,3.0,4.9,1.8,virginica", + "6.4,2.8,5.6,2.1,virginica", + "7.2,3.0,5.8,1.6,virginica", + "7.4,2.8,6.1,1.9,virginica", + "7.9,3.8,6.4,2.0,virginica", + "6.4,2.8,5.6,2.2,virginica", + "6.3,2.8,5.1,1.5,virginica", + "6.1,2.6,5.6,1.4,virginica", + "7.7,3.0,6.1,2.3,virginica", + "6.3,3.4,5.6,2.4,virginica", + "6.4,3.1,5.5,1.8,virginica", + "6.0,3.0,4.8,1.8,virginica", + "6.9,3.1,5.4,2.1,virginica", + "6.7,3.1,5.6,2.4,virginica", + "6.9,3.1,5.1,2.3,virginica", + "5.8,2.7,5.1,1.9,virginica", + "6.8,3.2,5.9,2.3,virginica", + "6.7,3.3,5.7,2.5,virginica", + "6.7,3.0,5.2,2.3,virginica", + "6.3,2.5,5.0,1.9,virginica", + "6.5,3.0,5.2,2.0,virginica", + "6.2,3.4,5.4,2.3,virginica", + "5.9,3.0,5.1,1.8,virginica" + }; + + public static final String[] CANCER = new String[] { + "\"V1\",\"V2\",\"V3\",\"V4\",\"V5\",\"V6\",\"V7\",\"V8\",\"V9\",\"target\"", + "5,1,1,1,2,1,3,1,1,0", + "5,4,4,5,7,10,3,2,1,0", + "3,1,1,1,2,2,3,1,1,0", + "6,8,8,1,3,4,3,7,1,0", + "4,1,1,3,2,1,3,1,1,0", + "8,10,10,8,7,10,9,7,1,1", + "1,1,1,1,2,10,3,1,1,0", + "2,1,2,1,2,1,3,1,1,0", + "2,1,1,1,2,1,1,1,5,0", + "4,2,1,1,2,1,2,1,1,0", + "1,1,1,1,1,1,3,1,1,0", + "2,1,1,1,2,1,2,1,1,0", + "5,3,3,3,2,3,4,4,1,1", + "1,1,1,1,2,3,3,1,1,0", + "8,7,5,10,7,9,5,5,4,1", + "7,4,6,4,6,1,4,3,1,1", + "4,1,1,1,2,1,2,1,1,0", + "4,1,1,1,2,1,3,1,1,0", + "10,7,7,6,4,10,4,1,2,1", + "6,1,1,1,2,1,3,1,1,0", + "7,3,2,10,5,10,5,4,4,1", + "10,5,5,3,6,7,7,10,1,1", + "3,1,1,1,2,1,2,1,1,0", + "1,1,1,1,2,1,3,1,1,0", + "5,2,3,4,2,7,3,6,1,1", + "3,2,1,1,1,1,2,1,1,0", + "5,1,1,1,2,1,2,1,1,0", + "2,1,1,1,2,1,2,1,1,0", + "1,1,3,1,2,1,1,1,1,0", + "3,1,1,1,1,1,2,1,1,0", + "2,1,1,1,2,1,3,1,1,0", + "10,7,7,3,8,5,7,4,3,1", + "2,1,1,2,2,1,3,1,1,0", + "3,1,2,1,2,1,2,1,1,0", + "2,1,1,1,2,1,2,1,1,0", + "10,10,10,8,6,1,8,9,1,1", + "6,2,1,1,1,1,7,1,1,0", + "5,4,4,9,2,10,5,6,1,1", + "2,5,3,3,6,7,7,5,1,1", + "10,4,3,1,3,3,6,5,2,1", + "6,10,10,2,8,10,7,3,3,1", + "5,6,5,6,10,1,3,1,1,1", + "10,10,10,4,8,1,8,10,1,1", + "1,1,1,1,2,1,2,1,2,0", + "3,7,7,4,4,9,4,8,1,1", + "1,1,1,1,2,1,2,1,1,0", + "4,1,1,3,2,1,3,1,1,0", + "7,8,7,2,4,8,3,8,2,1", + "9,5,8,1,2,3,2,1,5,1", + "5,3,3,4,2,4,3,4,1,1", + "10,3,6,2,3,5,4,10,2,1", + "5,5,5,8,10,8,7,3,7,1", + "10,5,5,6,8,8,7,1,1,1", + "10,6,6,3,4,5,3,6,1,1", + "8,10,10,1,3,6,3,9,1,1", + "8,2,4,1,5,1,5,4,4,1", + "5,2,3,1,6,10,5,1,1,1", + "9,5,5,2,2,2,5,1,1,1", + "5,3,5,5,3,3,4,10,1,1", + "1,1,1,1,2,2,2,1,1,0", + "9,10,10,1,10,8,3,3,1,1", + "6,3,4,1,5,2,3,9,1,1", + "1,1,1,1,2,1,2,1,1,0", + "10,4,2,1,3,2,4,3,10,1", + "4,1,1,1,2,1,3,1,1,0", + "5,3,4,1,8,10,4,9,1,1", + "8,3,8,3,4,9,8,9,8,1", + "1,1,1,1,2,1,3,2,1,0", + "5,1,3,1,2,1,2,1,1,0", + "6,10,2,8,10,2,7,8,10,1", + "1,3,3,2,2,1,7,2,1,0", + "9,4,5,10,6,10,4,8,1,1", + "10,6,4,1,3,4,3,2,3,1", + "1,1,2,1,2,2,4,2,1,0", + "1,1,4,1,2,1,2,1,1,0", + "5,3,1,2,2,1,2,1,1,0", + "3,1,1,1,2,3,3,1,1,0", + "2,1,1,1,3,1,2,1,1,0", + "2,2,2,1,1,1,7,1,1,0", + "4,1,1,2,2,1,2,1,1,0", + "5,2,1,1,2,1,3,1,1,0", + "3,1,1,1,2,2,7,1,1,0", + "3,5,7,8,8,9,7,10,7,1", + "5,10,6,1,10,4,4,10,10,1", + "3,3,6,4,5,8,4,4,1,1", + "3,6,6,6,5,10,6,8,3,1", + "4,1,1,1,2,1,3,1,1,0", + "2,1,1,2,3,1,2,1,1,0", + "1,1,1,1,2,1,3,1,1,0", + "3,1,1,2,2,1,1,1,1,0", + "4,1,1,1,2,1,3,1,1,0", + "1,1,1,1,2,1,2,1,1,0", + "2,1,1,1,2,1,3,1,1,0", + "1,1,1,1,2,1,3,1,1,0", + "2,1,1,2,2,1,1,1,1,0", + "5,1,1,1,2,1,3,1,1,0", + "9,6,9,2,10,6,2,9,10,1", + "7,5,6,10,5,10,7,9,4,1", + "10,3,5,1,10,5,3,10,2,1", + "2,3,4,4,2,5,2,5,1,1", + "4,1,2,1,2,1,3,1,1,0", + "8,2,3,1,6,3,7,1,1,1", + "10,10,10,10,10,1,8,8,8,1", + "7,3,4,4,3,3,3,2,7,1", + "10,10,10,8,2,10,4,1,1,1", + "1,6,8,10,8,10,5,7,1,1", + "1,1,1,1,2,1,2,3,1,0", + "6,5,4,4,3,9,7,8,3,1", + "1,3,1,2,2,2,5,3,2,0", + "8,6,4,3,5,9,3,1,1,1", + "10,3,3,10,2,10,7,3,3,1", + "10,10,10,3,10,8,8,1,1,1", + "3,3,2,1,2,3,3,1,1,0", + "1,1,1,1,2,5,1,1,1,0", + "8,3,3,1,2,2,3,2,1,0", + "4,5,5,10,4,10,7,5,8,1", + "1,1,1,1,4,3,1,1,1,0", + "3,2,1,1,2,2,3,1,1,0", + "1,1,2,2,2,1,3,1,1,0", + "4,2,1,1,2,2,3,1,1,0", + "10,10,10,2,10,10,5,3,3,1", + "5,3,5,1,8,10,5,3,1,1", + "5,4,6,7,9,7,8,10,1,1", + "1,1,1,1,2,1,2,1,1,0", + "7,5,3,7,4,10,7,5,5,1", + "3,1,1,1,2,1,3,1,1,0", + "8,3,5,4,5,10,1,6,2,1", + "1,1,1,1,10,1,1,1,1,0", + "5,1,3,1,2,1,2,1,1,0", + "2,1,1,1,2,1,3,1,1,0", + "5,10,8,10,8,10,3,6,3,1", + "3,1,1,1,2,1,2,2,1,0", + "3,1,1,1,3,1,2,1,1,0", + "5,1,1,1,2,2,3,3,1,0", + "4,1,1,1,2,1,2,1,1,0", + "3,1,1,1,2,1,1,1,1,0", + "4,1,2,1,2,1,2,1,1,0", + "3,1,1,1,2,1,1,1,1,0", + "2,1,1,1,2,1,1,1,1,0", + "9,5,5,4,4,5,4,3,3,1", + "1,1,1,1,2,5,1,1,1,0", + "2,1,1,1,2,1,2,1,1,0", + "3,4,5,2,6,8,4,1,1,1", + "1,1,1,1,3,2,2,1,1,0", + "3,1,1,3,8,1,5,8,1,0", + "8,8,7,4,10,10,7,8,7,1", + "1,1,1,1,1,1,3,1,1,0", + "7,2,4,1,6,10,5,4,3,1", + "10,10,8,6,4,5,8,10,1,1", + "4,1,1,1,2,3,1,1,1,0", + "1,1,1,1,2,1,1,1,1,0", + "5,5,5,6,3,10,3,1,1,1", + "1,2,2,1,2,1,2,1,1,0", + "2,1,1,1,2,1,3,1,1,0", + "9,9,10,3,6,10,7,10,6,1", + "10,7,7,4,5,10,5,7,2,1", + "4,1,1,1,2,1,3,2,1,0", + "3,1,1,1,2,1,3,1,1,0", + "1,1,1,2,1,3,1,1,7,0", + "4,1,1,1,2,2,3,2,1,0", + "5,6,7,8,8,10,3,10,3,1", + "10,8,10,10,6,1,3,1,10,1", + "3,1,1,1,2,1,3,1,1,0", + "1,1,1,2,1,1,1,1,1,0", + "3,1,1,1,2,1,1,1,1,0", + "1,1,1,1,2,1,3,1,1,0", + "1,1,1,1,2,1,2,1,1,0", + "6,10,10,10,8,10,10,10,7,1", + "8,6,5,4,3,10,6,1,1,1", + "5,8,7,7,10,10,5,7,1,1", + "2,1,1,1,2,1,3,1,1,0", + "5,10,10,3,8,1,5,10,3,1", + "4,1,1,1,2,1,3,1,1,0", + "5,3,3,3,6,10,3,1,1,1", + "1,1,1,1,1,1,3,1,1,0", + "1,1,1,1,2,1,1,1,1,0", + "6,1,1,1,2,1,3,1,1,0", + "5,8,8,8,5,10,7,8,1,1", + "8,7,6,4,4,10,5,1,1,1", + "2,1,1,1,1,1,3,1,1,0", + "1,5,8,6,5,8,7,10,1,1", + "10,5,6,10,6,10,7,7,10,1", + "5,8,4,10,5,8,9,10,1,1", + "1,2,3,1,2,1,3,1,1,0", + "10,10,10,8,6,8,7,10,1,1", + "7,5,10,10,10,10,4,10,3,1", + "5,1,1,1,2,1,2,1,1,0", + "1,1,1,1,2,1,3,1,1,0", + "3,1,1,1,2,1,3,1,1,0", + "4,1,1,1,2,1,3,1,1,0", + "8,4,4,5,4,7,7,8,2,0", + "5,1,1,4,2,1,3,1,1,0", + "1,1,1,1,2,1,1,1,1,0", + "3,1,1,1,2,1,2,1,1,0", + "9,7,7,5,5,10,7,8,3,1", + "10,8,8,4,10,10,8,1,1,1", + "1,1,1,1,2,1,3,1,1,0", + "5,1,1,1,2,1,3,1,1,0", + "1,1,1,1,2,1,3,1,1,0", + "5,10,10,9,6,10,7,10,5,1", + "10,10,9,3,7,5,3,5,1,1", + "1,1,1,1,1,1,3,1,1,0", + "1,1,1,1,1,1,3,1,1,0", + "5,1,1,1,1,1,3,1,1,0", + "8,10,10,10,5,10,8,10,6,1", + "8,10,8,8,4,8,7,7,1,1", + "1,1,1,1,2,1,3,1,1,0", + "10,10,10,10,7,10,7,10,4,1", + "10,10,10,10,3,10,10,6,1,1", + "8,7,8,7,5,5,5,10,2,1", + "1,1,1,1,2,1,2,1,1,0", + "1,1,1,1,2,1,3,1,1,0", + "6,10,7,7,6,4,8,10,2,1", + "6,1,3,1,2,1,3,1,1,0", + "1,1,1,2,2,1,3,1,1,0", + "10,6,4,3,10,10,9,10,1,1", + "4,1,1,3,1,5,2,1,1,1", + "7,5,6,3,3,8,7,4,1,1", + "10,5,5,6,3,10,7,9,2,1", + "1,1,1,1,2,1,2,1,1,0", + "10,5,7,4,4,10,8,9,1,1", + "8,9,9,5,3,5,7,7,1,1", + "1,1,1,1,1,1,3,1,1,0", + "10,10,10,3,10,10,9,10,1,1", + "7,4,7,4,3,7,7,6,1,1", + "6,8,7,5,6,8,8,9,2,1", + "8,4,6,3,3,1,4,3,1,0", + "10,4,5,5,5,10,4,1,1,1", + "3,3,2,1,3,1,3,6,1,0", + "10,8,8,2,8,10,4,8,10,1", + "9,8,8,5,6,2,4,10,4,1", + "8,10,10,8,6,9,3,10,10,1", + "10,4,3,2,3,10,5,3,2,1", + "5,1,3,3,2,2,2,3,1,0", + "3,1,1,3,1,1,3,1,1,0", + "2,1,1,1,2,1,3,1,1,0", + "1,1,1,1,2,5,5,1,1,0", + "1,1,1,1,2,1,3,1,1,0", + "5,1,1,2,2,2,3,1,1,0", + "8,10,10,8,5,10,7,8,1,1", + "8,4,4,1,2,9,3,3,1,1", + "4,1,1,1,2,1,3,6,1,0", + "1,2,2,1,2,1,1,1,1,0", + "10,4,4,10,2,10,5,3,3,1", + "6,3,3,5,3,10,3,5,3,0", + "6,10,10,2,8,10,7,3,3,1", + "9,10,10,1,10,8,3,3,1,1", + "5,6,6,2,4,10,3,6,1,1", + "3,1,1,1,2,1,1,1,1,0", + "3,1,1,1,2,1,2,1,1,0", + "3,1,1,1,2,1,3,1,1,0", + "5,7,7,1,5,8,3,4,1,0", + "10,5,8,10,3,10,5,1,3,1", + "5,10,10,6,10,10,10,6,5,1", + "8,8,9,4,5,10,7,8,1,1", + "10,4,4,10,6,10,5,5,1,1", + "7,9,4,10,10,3,5,3,3,1", + "5,1,4,1,2,1,3,2,1,0", + "10,10,6,3,3,10,4,3,2,1", + "3,3,5,2,3,10,7,1,1,1", + "10,8,8,2,3,4,8,7,8,1", + "1,1,1,1,2,1,3,1,1,0", + "8,4,7,1,3,10,3,9,2,1", + "5,1,1,1,2,1,3,1,1,0", + "3,3,5,2,3,10,7,1,1,1", + "7,2,4,1,3,4,3,3,1,1", + "3,1,1,1,2,1,3,2,1,0", + "3,1,1,1,2,1,2,1,1,0", + "1,1,1,1,2,1,2,1,1,0", + "1,1,1,1,2,1,3,1,1,0", + "10,5,7,3,3,7,3,3,8,1", + "3,1,1,1,2,1,3,1,1,0", + "2,1,1,2,2,1,3,1,1,0", + "1,4,3,10,4,10,5,6,1,1", + "10,4,6,1,2,10,5,3,1,1", + "7,4,5,10,2,10,3,8,2,1", + "8,10,10,10,8,10,10,7,3,1", + "10,10,10,10,10,10,4,10,10,1", + "3,1,1,1,3,1,2,1,1,0", + "6,1,3,1,4,5,5,10,1,1", + "5,6,6,8,6,10,4,10,4,1", + "1,1,1,1,2,1,1,1,1,0", + "1,1,1,1,2,1,3,1,1,0", + "10,4,4,6,2,10,2,3,1,1", + "5,5,7,8,6,10,7,4,1,1", + "5,3,4,3,4,5,4,7,1,0", + "8,2,1,1,5,1,1,1,1,0", + "9,1,2,6,4,10,7,7,2,1", + "8,4,10,5,4,4,7,10,1,1", + "1,1,1,1,2,1,3,1,1,0", + "10,10,10,7,9,10,7,10,10,1", + "1,1,1,1,2,1,3,1,1,0", + "8,3,4,9,3,10,3,3,1,1", + "10,8,4,4,4,10,3,10,4,1", + "1,1,1,1,2,1,3,1,1,0", + "1,1,1,1,2,1,3,1,1,0", + "7,8,7,6,4,3,8,8,4,1", + "3,1,1,1,2,5,5,1,1,0", + "2,1,1,1,3,1,2,1,1,0", + "1,1,1,1,2,1,1,1,1,0", + "8,6,4,10,10,1,3,5,1,1", + "1,1,1,1,2,1,1,1,1,0", + "1,1,1,1,1,1,2,1,1,0", + "5,5,5,2,5,10,4,3,1,1", + "6,8,7,8,6,8,8,9,1,1", + "1,1,1,1,5,1,3,1,1,0", + "4,4,4,4,6,5,7,3,1,0", + "7,6,3,2,5,10,7,4,6,1", + "3,1,1,1,2,1,3,1,1,0", + "5,4,6,10,2,10,4,1,1,1", + "1,1,1,1,2,1,3,1,1,0", + "3,2,2,1,2,1,2,3,1,0", + "10,1,1,1,2,10,5,4,1,1", + "1,1,1,1,2,1,2,1,1,0", + "8,10,3,2,6,4,3,10,1,1", + "10,4,6,4,5,10,7,1,1,1", + "10,4,7,2,2,8,6,1,1,1", + "5,1,1,1,2,1,3,1,2,0", + "5,2,2,2,2,1,2,2,1,0", + "5,4,6,6,4,10,4,3,1,1", + "8,6,7,3,3,10,3,4,2,1", + "1,1,1,1,2,1,1,1,1,0", + "6,5,5,8,4,10,3,4,1,1", + "1,1,1,1,2,1,3,1,1,0", + "1,1,1,1,1,1,2,1,1,0", + "8,5,5,5,2,10,4,3,1,1", + "10,3,3,1,2,10,7,6,1,1", + "1,1,1,1,2,1,3,1,1,0", + "2,1,1,1,2,1,1,1,1,0", + "1,1,1,1,2,1,1,1,1,0", + "7,6,4,8,10,10,9,5,3,1", + "1,1,1,1,2,1,1,1,1,0", + "5,2,2,2,3,1,1,3,1,0", + "1,1,1,1,1,1,1,3,1,0", + "3,4,4,10,5,1,3,3,1,1", + "4,2,3,5,3,8,7,6,1,1", + "5,1,1,3,2,1,1,1,1,0", + "2,1,1,1,2,1,3,1,1,0", + "3,4,5,3,7,3,4,6,1,0", + "2,7,10,10,7,10,4,9,4,1", + "1,1,1,1,2,1,2,1,1,0", + "4,1,1,1,3,1,2,2,1,0", + "5,3,3,1,3,3,3,3,3,1", + "8,10,10,7,10,10,7,3,8,1", + "8,10,5,3,8,4,4,10,3,1", + "10,3,5,4,3,7,3,5,3,1", + "6,10,10,10,10,10,8,10,10,1", + "3,10,3,10,6,10,5,1,4,1", + "3,2,2,1,4,3,2,1,1,0", + "4,4,4,2,2,3,2,1,1,0", + "2,1,1,1,2,1,3,1,1,0", + "2,1,1,1,2,1,2,1,1,0", + "6,10,10,10,8,10,7,10,7,1", + "5,8,8,10,5,10,8,10,3,1", + "1,1,3,1,2,1,1,1,1,0", + "1,1,3,1,1,1,2,1,1,0", + "4,3,2,1,3,1,2,1,1,0", + "1,1,3,1,2,1,1,1,1,0", + "4,1,2,1,2,1,2,1,1,0", + "5,1,1,2,2,1,2,1,1,0", + "3,1,2,1,2,1,2,1,1,0", + "1,1,1,1,2,1,1,1,1,0", + "1,1,1,1,2,1,2,1,1,0", + "1,1,1,1,1,1,2,1,1,0", + "3,1,1,4,3,1,2,2,1,0", + "5,3,4,1,4,1,3,1,1,0", + "1,1,1,1,2,1,1,1,1,0", + "10,6,3,6,4,10,7,8,4,1", + "3,2,2,2,2,1,3,2,1,0", + "2,1,1,1,2,1,1,1,1,0", + "2,1,1,1,2,1,1,1,1,0", + "3,3,2,2,3,1,1,2,3,0", + "7,6,6,3,2,10,7,1,1,1", + "5,3,3,2,3,1,3,1,1,0", + "2,1,1,1,2,1,2,2,1,0", + "5,1,1,1,3,2,2,2,1,0", + "1,1,1,2,2,1,2,1,1,0", + "10,8,7,4,3,10,7,9,1,1", + "3,1,1,1,2,1,2,1,1,0", + "1,1,1,1,1,1,1,1,1,0", + "1,2,3,1,2,1,2,1,1,0", + "3,1,1,1,2,1,2,1,1,0", + "3,1,1,1,2,1,3,1,1,0", + "4,1,1,1,2,1,1,1,1,0", + "3,2,1,1,2,1,2,2,1,0", + "1,2,3,1,2,1,1,1,1,0", + "3,10,8,7,6,9,9,3,8,1", + "3,1,1,1,2,1,1,1,1,0", + "5,3,3,1,2,1,2,1,1,0", + "3,1,1,1,2,4,1,1,1,0", + "1,2,1,3,2,1,1,2,1,0", + "1,1,1,1,2,1,2,1,1,0", + "4,2,2,1,2,1,2,1,1,0", + "1,1,1,1,2,1,2,1,1,0", + "2,3,2,2,2,2,3,1,1,0", + "3,1,2,1,2,1,2,1,1,0", + "1,1,1,1,2,1,2,1,1,0", + "10,10,10,6,8,4,8,5,1,1", + "5,1,2,1,2,1,3,1,1,0", + "8,5,6,2,3,10,6,6,1,1", + "3,3,2,6,3,3,3,5,1,0", + "8,7,8,5,10,10,7,2,1,1", + "1,1,1,1,2,1,2,1,1,0", + "5,2,2,2,2,2,3,2,2,0", + "2,3,1,1,5,1,1,1,1,0", + "3,2,2,3,2,3,3,1,1,0", + "10,10,10,7,10,10,8,2,1,1", + "4,3,3,1,2,1,3,3,1,0", + "5,1,3,1,2,1,2,1,1,0", + "3,1,1,1,2,1,1,1,1,0", + "9,10,10,10,10,10,10,10,1,1", + "5,3,6,1,2,1,1,1,1,0", + "8,7,8,2,4,2,5,10,1,1", + "1,1,1,1,2,1,2,1,1,0", + "2,1,1,1,2,1,2,1,1,0", + "1,3,1,1,2,1,2,2,1,0", + "5,1,1,3,4,1,3,2,1,0", + "5,1,1,1,2,1,2,2,1,0", + "3,2,2,3,2,1,1,1,1,0", + "6,9,7,5,5,8,4,2,1,0", + "10,8,10,1,3,10,5,1,1,1", + "10,10,10,1,6,1,2,8,1,1", + "4,1,1,1,2,1,1,1,1,0", + "4,1,3,3,2,1,1,1,1,0", + "5,1,1,1,2,1,1,1,1,0", + "10,4,3,10,4,10,10,1,1,1", + "5,2,2,4,2,4,1,1,1,0", + "1,1,1,3,2,3,1,1,1,0", + "1,1,1,1,2,2,1,1,1,0", + "5,1,1,6,3,1,2,1,1,0", + "2,1,1,1,2,1,1,1,1,0", + "1,1,1,1,2,1,1,1,1,0", + "5,1,1,1,2,1,1,1,1,0", + "1,1,1,1,1,1,1,1,1,0", + "5,7,9,8,6,10,8,10,1,1", + "4,1,1,3,1,1,2,1,1,0", + "5,1,1,1,2,1,1,1,1,0", + "3,1,1,3,2,1,1,1,1,0", + "4,5,5,8,6,10,10,7,1,1", + "2,3,1,1,3,1,1,1,1,0", + "10,2,2,1,2,6,1,1,2,1", + "10,6,5,8,5,10,8,6,1,1", + "8,8,9,6,6,3,10,10,1,1", + "5,1,2,1,2,1,1,1,1,0", + "5,1,3,1,2,1,1,1,1,0", + "5,1,1,3,2,1,1,1,1,0", + "3,1,1,1,2,5,1,1,1,0", + "6,1,1,3,2,1,1,1,1,0", + "4,1,1,1,2,1,1,2,1,0", + "4,1,1,1,2,1,1,1,1,0", + "10,9,8,7,6,4,7,10,3,1", + "10,6,6,2,4,10,9,7,1,1", + "6,6,6,5,4,10,7,6,2,1", + "4,1,1,1,2,1,1,1,1,0", + "1,1,2,1,2,1,2,1,1,0", + "3,1,1,1,1,1,2,1,1,0", + "6,1,1,3,2,1,1,1,1,0", + "6,1,1,1,1,1,1,1,1,0", + "4,1,1,1,2,1,1,1,1,0", + "5,1,1,1,2,1,1,1,1,0", + "3,1,1,1,2,1,1,1,1,0", + "4,1,2,1,2,1,1,1,1,0", + "4,1,1,1,2,1,1,1,1,0", + "5,2,1,1,2,1,1,1,1,0", + "4,8,7,10,4,10,7,5,1,1", + "5,1,1,1,1,1,1,1,1,0", + "5,3,2,4,2,1,1,1,1,0", + "9,10,10,10,10,5,10,10,10,1", + "8,7,8,5,5,10,9,10,1,1", + "5,1,2,1,2,1,1,1,1,0", + "1,1,1,3,1,3,1,1,1,0", + "3,1,1,1,1,1,2,1,1,0", + "10,10,10,10,6,10,8,1,5,1", + "3,6,4,10,3,3,3,4,1,1", + "6,3,2,1,3,4,4,1,1,1", + "1,1,1,1,2,1,1,1,1,0", + "5,8,9,4,3,10,7,1,1,1", + "4,1,1,1,1,1,2,1,1,0", + "5,10,10,10,6,10,6,5,2,1", + "5,1,2,10,4,5,2,1,1,0", + "3,1,1,1,1,1,2,1,1,0", + "1,1,1,1,1,1,1,1,1,0", + "4,2,1,1,2,1,1,1,1,0", + "4,1,1,1,2,1,2,1,1,0", + "4,1,1,1,2,1,2,1,1,0", + "6,1,1,1,2,1,3,1,1,0", + "4,1,1,1,2,1,2,1,1,0", + "4,1,1,2,2,1,2,1,1,0", + "4,1,1,1,2,1,3,1,1,0", + "1,1,1,1,2,1,1,1,1,0", + "3,3,1,1,2,1,1,1,1,0", + "8,10,10,10,7,5,4,8,7,1", + "1,1,1,1,2,4,1,1,1,0", + "5,1,1,1,2,1,1,1,1,0", + "2,1,1,1,2,1,1,1,1,0", + "1,1,1,1,2,1,1,1,1,0", + "5,1,1,1,2,1,2,1,1,0", + "5,1,1,1,2,1,1,1,1,0", + "3,1,1,1,1,1,2,1,1,0", + "6,6,7,10,3,10,8,10,2,1", + "4,10,4,7,3,10,9,10,1,1", + "1,1,1,1,1,1,1,1,1,0", + "1,1,1,1,1,1,2,1,1,0", + "3,1,2,2,2,1,1,1,1,0", + "4,7,8,3,4,10,9,1,1,1", + "1,1,1,1,3,1,1,1,1,0", + "4,1,1,1,3,1,1,1,1,0", + "10,4,5,4,3,5,7,3,1,1", + "7,5,6,10,4,10,5,3,1,1", + "3,1,1,1,2,1,2,1,1,0", + "3,1,1,2,2,1,1,1,1,0", + "4,1,1,1,2,1,1,1,1,0", + "4,1,1,1,2,1,3,1,1,0", + "6,1,3,2,2,1,1,1,1,0", + "4,1,1,1,1,1,2,1,1,0", + "7,4,4,3,4,10,6,9,1,1", + "4,2,2,1,2,1,2,1,1,0", + "1,1,1,1,1,1,3,1,1,0", + "3,1,1,1,2,1,2,1,1,0", + "2,1,1,1,2,1,2,1,1,0", + "1,1,3,2,2,1,3,1,1,0", + "5,1,1,1,2,1,3,1,1,0", + "5,1,2,1,2,1,3,1,1,0", + "4,1,1,1,2,1,2,1,1,0", + "6,1,1,1,2,1,2,1,1,0", + "5,1,1,1,2,2,2,1,1,0", + "3,1,1,1,2,1,1,1,1,0", + "5,3,1,1,2,1,1,1,1,0", + "4,1,1,1,2,1,2,1,1,0", + "2,1,3,2,2,1,2,1,1,0", + "5,1,1,1,2,1,2,1,1,0", + "6,10,10,10,4,10,7,10,1,1", + "2,1,1,1,1,1,1,1,1,0", + "3,1,1,1,1,1,1,1,1,0", + "7,8,3,7,4,5,7,8,2,1", + "3,1,1,1,2,1,2,1,1,0", + "1,1,1,1,2,1,3,1,1,0", + "3,2,2,2,2,1,4,2,1,0", + "4,4,2,1,2,5,2,1,2,0", + "3,1,1,1,2,1,1,1,1,0", + "4,3,1,1,2,1,4,8,1,0", + "5,2,2,2,1,1,2,1,1,0", + "5,1,1,3,2,1,1,1,1,0", + "2,1,1,1,2,1,2,1,1,0", + "5,1,1,1,2,1,2,1,1,0", + "5,1,1,1,2,1,3,1,1,0", + "5,1,1,1,2,1,3,1,1,0", + "1,1,1,1,2,1,3,1,1,0", + "3,1,1,1,2,1,2,1,1,0", + "4,1,1,1,2,1,3,2,1,0", + "5,7,10,10,5,10,10,10,1,1", + "3,1,2,1,2,1,3,1,1,0", + "4,1,1,1,2,3,2,1,1,0", + "8,4,4,1,6,10,2,5,2,1", + "10,10,8,10,6,5,10,3,1,1", + "8,10,4,4,8,10,8,2,1,1", + "7,6,10,5,3,10,9,10,2,1", + "3,1,1,1,2,1,2,1,1,0", + "1,1,1,1,2,1,2,1,1,0", + "10,9,7,3,4,2,7,7,1,1", + "5,1,2,1,2,1,3,1,1,0", + "5,1,1,1,2,1,2,1,1,0", + "1,1,1,1,2,1,2,1,1,0", + "1,1,1,1,2,1,2,1,1,0", + "1,1,1,1,2,1,3,1,1,0", + "5,1,2,1,2,1,2,1,1,0", + "5,7,10,6,5,10,7,5,1,1", + "6,10,5,5,4,10,6,10,1,1", + "3,1,1,1,2,1,1,1,1,0", + "5,1,1,6,3,1,1,1,1,0", + "1,1,1,1,2,1,1,1,1,0", + "8,10,10,10,6,10,10,10,1,1", + "5,1,1,1,2,1,2,2,1,0", + "9,8,8,9,6,3,4,1,1,1", + "5,1,1,1,2,1,1,1,1,0", + "4,10,8,5,4,1,10,1,1,1", + "2,5,7,6,4,10,7,6,1,1", + "10,3,4,5,3,10,4,1,1,1", + "5,1,2,1,2,1,1,1,1,0", + "4,8,6,3,4,10,7,1,1,1", + "5,1,1,1,2,1,2,1,1,0", + "4,1,2,1,2,1,2,1,1,0", + "5,1,3,1,2,1,3,1,1,0", + "3,1,1,1,2,1,2,1,1,0", + "5,2,4,1,1,1,1,1,1,0", + "3,1,1,1,2,1,2,1,1,0", + "1,1,1,1,1,1,2,1,1,0", + "4,1,1,1,2,1,2,1,1,0", + "5,4,6,8,4,1,8,10,1,1", + "5,3,2,8,5,10,8,1,2,1", + "10,5,10,3,5,8,7,8,3,1", + "4,1,1,2,2,1,1,1,1,0", + "1,1,1,1,2,1,1,1,1,0", + "5,10,10,10,10,10,10,1,1,1", + "5,1,1,1,2,1,1,1,1,0", + "10,4,3,10,3,10,7,1,2,1", + "5,10,10,10,5,2,8,5,1,1", + "8,10,10,10,6,10,10,10,10,1", + "2,3,1,1,2,1,2,1,1,0", + "2,1,1,1,1,1,2,1,1,0", + "4,1,3,1,2,1,2,1,1,0", + "3,1,1,1,2,1,2,1,1,0", + "4,1,1,1,2,1,2,1,1,0", + "5,1,1,1,2,1,2,1,1,0", + "3,1,1,1,2,1,2,1,1,0", + "6,3,3,3,3,2,6,1,1,0", + "7,1,2,3,2,1,2,1,1,0", + "1,1,1,1,2,1,1,1,1,0", + "5,1,1,2,1,1,2,1,1,0", + "3,1,3,1,3,4,1,1,1,0", + "4,6,6,5,7,6,7,7,3,1", + "2,1,1,1,2,5,1,1,1,0", + "2,1,1,1,2,1,1,1,1,0", + "4,1,1,1,2,1,1,1,1,0", + "6,2,3,1,2,1,1,1,1,0", + "5,1,1,1,2,1,2,1,1,0", + "1,1,1,1,2,1,1,1,1,0", + "8,7,4,4,5,3,5,10,1,1", + "3,1,1,1,2,1,1,1,1,0", + "3,1,4,1,2,1,1,1,1,0", + "10,10,7,8,7,1,10,10,3,1", + "4,2,4,3,2,2,2,1,1,0", + "4,1,1,1,2,1,1,1,1,0", + "5,1,1,3,2,1,1,1,1,0", + "4,1,1,3,2,1,1,1,1,0", + "3,1,1,1,2,1,2,1,1,0", + "3,1,1,1,2,1,2,1,1,0", + "1,1,1,1,2,1,1,1,1,0", + "2,1,1,1,2,1,1,1,1,0", + "3,1,1,1,2,1,2,1,1,0", + "1,2,2,1,2,1,1,1,1,0", + "1,1,1,3,2,1,1,1,1,0", + "5,10,10,10,10,2,10,10,10,1", + "3,1,1,1,2,1,2,1,1,0", + "3,1,1,2,3,4,1,1,1,0", + "1,2,1,3,2,1,2,1,1,0", + "5,1,1,1,2,1,2,2,1,0", + "4,1,1,1,2,1,2,1,1,0", + "3,1,1,1,2,1,3,1,1,0", + "3,1,1,1,2,1,2,1,1,0", + "5,1,1,1,2,1,2,1,1,0", + "5,4,5,1,8,1,3,6,1,0", + "7,8,8,7,3,10,7,2,3,1", + "1,1,1,1,2,1,1,1,1,0", + "1,1,1,1,2,1,2,1,1,0", + "4,1,1,1,2,1,3,1,1,0", + "1,1,3,1,2,1,2,1,1,0", + "1,1,3,1,2,1,2,1,1,0", + "3,1,1,3,2,1,2,1,1,0", + "1,1,1,1,2,1,1,1,1,0", + "5,2,2,2,2,1,1,1,2,0", + "3,1,1,1,2,1,3,1,1,0", + "5,7,4,1,6,1,7,10,3,1", + "5,10,10,8,5,5,7,10,1,1", + "3,10,7,8,5,8,7,4,1,1", + "3,2,1,2,2,1,3,1,1,0", + "2,1,1,1,2,1,3,1,1,0", + "5,3,2,1,3,1,1,1,1,0", + "1,1,1,1,2,1,2,1,1,0", + "4,1,4,1,2,1,1,1,1,0", + "1,1,2,1,2,1,2,1,1,0", + "5,1,1,1,2,1,1,1,1,0", + "1,1,1,1,2,1,1,1,1,0", + "2,1,1,1,2,1,1,1,1,0", + "10,10,10,10,5,10,10,10,7,1", + "5,10,10,10,4,10,5,6,3,1", + "5,1,1,1,2,1,3,2,1,0", + "1,1,1,1,2,1,1,1,1,0", + "1,1,1,1,2,1,1,1,1,0", + "1,1,1,1,2,1,1,1,1,0", + "1,1,1,1,2,1,1,1,1,0", + "3,1,1,1,2,1,2,3,1,0", + "4,1,1,1,2,1,1,1,1,0", + "1,1,1,1,2,1,1,1,8,0", + "1,1,1,3,2,1,1,1,1,0", + "5,10,10,5,4,5,4,4,1,1", + "3,1,1,1,2,1,1,1,1,0", + "3,1,1,1,2,1,2,1,2,0", + "3,1,1,1,3,2,1,1,1,0", + "2,1,1,1,2,1,1,1,1,0", + "5,10,10,3,7,3,8,10,2,1", + "4,8,6,4,3,4,10,6,1,1", + "4,8,8,5,4,5,10,4,1,1" + }; + + private Datasets() {} + +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/mlp/RunMultilayerPerceptronTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/classifier/mlp/RunMultilayerPerceptronTest.java b/mr/src/test/java/org/apache/mahout/classifier/mlp/RunMultilayerPerceptronTest.java new file mode 100644 index 0000000..522ac4a --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/classifier/mlp/RunMultilayerPerceptronTest.java @@ -0,0 +1,66 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.classifier.mlp; + +import java.io.File; + +import org.apache.mahout.common.MahoutTestCase; +import org.junit.Test; + +public class RunMultilayerPerceptronTest extends MahoutTestCase { + + @Test + public void runMultilayerPerceptron() throws Exception { + + // Train a model first + String modelFileName = "mlp.model"; + File modelFile = getTestTempFile(modelFileName); + + File irisDataset = getTestTempFile("iris.csv"); + writeLines(irisDataset, Datasets.IRIS); + + String[] argsTrain = { + "-i", irisDataset.getAbsolutePath(), + "-sh", + "-labels", "setosa", "versicolor", "virginica", + "-mo", modelFile.getAbsolutePath(), + "-u", + "-ls", "4", "8", "3" + }; + + TrainMultilayerPerceptron.main(argsTrain); + + assertTrue(modelFile.exists()); + + String outputFileName = "labelResult.txt"; + File outputFile = getTestTempFile(outputFileName); + + String[] argsLabeling = { + "-i", irisDataset.getAbsolutePath(), + "-sh", + "-cr", "0", "3", + "-mo", modelFile.getAbsolutePath(), + "-o", outputFile.getAbsolutePath() + }; + + RunMultilayerPerceptron.main(argsLabeling); + + assertTrue(outputFile.exists()); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/mlp/TestMultilayerPerceptron.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/classifier/mlp/TestMultilayerPerceptron.java b/mr/src/test/java/org/apache/mahout/classifier/mlp/TestMultilayerPerceptron.java new file mode 100644 index 0000000..93013b6 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/classifier/mlp/TestMultilayerPerceptron.java @@ -0,0 +1,88 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.classifier.mlp; + +import java.io.File; +import java.io.IOException; + +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.math.Arrays; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.junit.Test; + +/** + * Test the functionality of {@link MultilayerPerceptron} + */ +public class TestMultilayerPerceptron extends MahoutTestCase { + + @Test + public void testMLP() throws IOException { + testMLP("testMLPXORLocal", false, false, 8000); + testMLP("testMLPXORLocalWithMomentum", true, false, 4000); + testMLP("testMLPXORLocalWithRegularization", true, true, 2000); + } + + private void testMLP(String modelFilename, boolean useMomentum, + boolean useRegularization, int iterations) throws IOException { + MultilayerPerceptron mlp = new MultilayerPerceptron(); + mlp.addLayer(2, false, "Sigmoid"); + mlp.addLayer(3, false, "Sigmoid"); + mlp.addLayer(1, true, "Sigmoid"); + mlp.setCostFunction("Minus_Squared").setLearningRate(0.2); + if (useMomentum) { + mlp.setMomentumWeight(0.6); + } + + if (useRegularization) { + mlp.setRegularizationWeight(0.01); + } + + double[][] instances = { { 0, 1, 1 }, { 0, 0, 0 }, { 1, 0, 1 }, { 1, 1, 0 } }; + for (int i = 0; i < iterations; ++i) { + for (double[] instance : instances) { + Vector features = new DenseVector(Arrays.copyOf(instance, instance.length - 1)); + mlp.train((int) instance[2], features); + } + } + + for (double[] instance : instances) { + Vector input = new DenseVector(instance).viewPart(0, instance.length - 1); + // the expected output is the last element in array + double actual = instance[2]; + double expected = mlp.getOutput(input).get(0); + assertTrue(actual < 0.5 && expected < 0.5 || actual >= 0.5 && expected >= 0.5); + } + + // write model into file and read out + File modelFile = this.getTestTempFile(modelFilename); + mlp.setModelPath(modelFile.getAbsolutePath()); + mlp.writeModelToFile(); + mlp.close(); + + MultilayerPerceptron mlpCopy = new MultilayerPerceptron(modelFile.getAbsolutePath()); + // test on instances + for (double[] instance : instances) { + Vector input = new DenseVector(instance).viewPart(0, instance.length - 1); + // the expected output is the last element in array + double actual = instance[2]; + double expected = mlpCopy.getOutput(input).get(0); + assertTrue(actual < 0.5 && expected < 0.5 || actual >= 0.5 && expected >= 0.5); + } + mlpCopy.close(); + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/mlp/TestNeuralNetwork.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/classifier/mlp/TestNeuralNetwork.java b/mr/src/test/java/org/apache/mahout/classifier/mlp/TestNeuralNetwork.java new file mode 100644 index 0000000..ebe5424 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/classifier/mlp/TestNeuralNetwork.java @@ -0,0 +1,353 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.mlp; + +import java.io.File; +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import org.apache.commons.csv.CSVUtils; +import org.apache.mahout.classifier.mlp.NeuralNetwork.TrainingMethod; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.junit.Test; + +import com.google.common.base.Charsets; +import com.google.common.collect.Lists; +import com.google.common.io.Files; + +/** Test the functionality of {@link NeuralNetwork}. */ +public class TestNeuralNetwork extends MahoutTestCase { + + + @Test + public void testReadWrite() throws IOException { + NeuralNetwork ann = new MultilayerPerceptron(); + ann.addLayer(2, false, "Identity"); + ann.addLayer(5, false, "Identity"); + ann.addLayer(1, true, "Identity"); + ann.setCostFunction("Minus_Squared"); + double learningRate = 0.2; + double momentumWeight = 0.5; + double regularizationWeight = 0.05; + ann.setLearningRate(learningRate) + .setMomentumWeight(momentumWeight) + .setRegularizationWeight(regularizationWeight); + + // Manually set weights + Matrix[] matrices = new DenseMatrix[2]; + matrices[0] = new DenseMatrix(5, 3); + matrices[0].assign(0.2); + matrices[1] = new DenseMatrix(1, 6); + matrices[1].assign(0.8); + ann.setWeightMatrices(matrices); + + // Write to file + String modelFilename = "testNeuralNetworkReadWrite"; + File tmpModelFile = this.getTestTempFile(modelFilename); + ann.setModelPath(tmpModelFile.getAbsolutePath()); + ann.writeModelToFile(); + + // Read from file + NeuralNetwork annCopy = new MultilayerPerceptron(tmpModelFile.getAbsolutePath()); + assertEquals(annCopy.getClass().getSimpleName(), annCopy.getModelType()); + assertEquals(tmpModelFile.getAbsolutePath(), annCopy.getModelPath()); + assertEquals(learningRate, annCopy.getLearningRate(), EPSILON); + assertEquals(momentumWeight, annCopy.getMomentumWeight(), EPSILON); + assertEquals(regularizationWeight, annCopy.getRegularizationWeight(), EPSILON); + assertEquals(TrainingMethod.GRADIENT_DESCENT, annCopy.getTrainingMethod()); + + // Compare weights + Matrix[] weightsMatrices = annCopy.getWeightMatrices(); + for (int i = 0; i < weightsMatrices.length; ++i) { + Matrix expectMat = matrices[i]; + Matrix actualMat = weightsMatrices[i]; + for (int j = 0; j < expectMat.rowSize(); ++j) { + for (int k = 0; k < expectMat.columnSize(); ++k) { + assertEquals(expectMat.get(j, k), actualMat.get(j, k), EPSILON); + } + } + } + } + + /** Test the forward functionality. */ + @Test + public void testOutput() { + // First network + NeuralNetwork ann = new MultilayerPerceptron(); + ann.addLayer(2, false, "Identity"); + ann.addLayer(5, false, "Identity"); + ann.addLayer(1, true, "Identity"); + ann.setCostFunction("Minus_Squared").setLearningRate(0.1); + + // Intentionally initialize all weights to 0.5 + Matrix[] matrices = new Matrix[2]; + matrices[0] = new DenseMatrix(5, 3); + matrices[0].assign(0.5); + matrices[1] = new DenseMatrix(1, 6); + matrices[1].assign(0.5); + ann.setWeightMatrices(matrices); + + double[] arr = new double[] { 0, 1 }; + Vector training = new DenseVector(arr); + Vector result = ann.getOutput(training); + assertEquals(1, result.size()); + + // Second network + NeuralNetwork ann2 = new MultilayerPerceptron(); + ann2.addLayer(2, false, "Sigmoid"); + ann2.addLayer(3, false, "Sigmoid"); + ann2.addLayer(1, true, "Sigmoid"); + ann2.setCostFunction("Minus_Squared"); + ann2.setLearningRate(0.3); + + // Intentionally initialize all weights to 0.5 + Matrix[] matrices2 = new Matrix[2]; + matrices2[0] = new DenseMatrix(3, 3); + matrices2[0].assign(0.5); + matrices2[1] = new DenseMatrix(1, 4); + matrices2[1].assign(0.5); + ann2.setWeightMatrices(matrices2); + + double[] test = { 0, 0 }; + double[] result2 = { 0.807476 }; + + Vector vec = ann2.getOutput(new DenseVector(test)); + double[] arrVec = new double[vec.size()]; + for (int i = 0; i < arrVec.length; ++i) { + arrVec[i] = vec.getQuick(i); + } + assertArrayEquals(result2, arrVec, EPSILON); + + NeuralNetwork ann3 = new MultilayerPerceptron(); + ann3.addLayer(2, false, "Sigmoid"); + ann3.addLayer(3, false, "Sigmoid"); + ann3.addLayer(1, true, "Sigmoid"); + ann3.setCostFunction("Minus_Squared").setLearningRate(0.3); + + // Intentionally initialize all weights to 0.5 + Matrix[] initMatrices = new Matrix[2]; + initMatrices[0] = new DenseMatrix(3, 3); + initMatrices[0].assign(0.5); + initMatrices[1] = new DenseMatrix(1, 4); + initMatrices[1].assign(0.5); + ann3.setWeightMatrices(initMatrices); + + double[] instance = {0, 1}; + Vector output = ann3.getOutput(new DenseVector(instance)); + assertEquals(0.8315410, output.get(0), EPSILON); + } + + @Test + public void testNeuralNetwork() throws IOException { + testNeuralNetwork("testNeuralNetworkXORLocal", false, false, 10000); + testNeuralNetwork("testNeuralNetworkXORWithMomentum", true, false, 5000); + testNeuralNetwork("testNeuralNetworkXORWithRegularization", true, true, 5000); + } + + private void testNeuralNetwork(String modelFilename, boolean useMomentum, + boolean useRegularization, int iterations) throws IOException { + NeuralNetwork ann = new MultilayerPerceptron(); + ann.addLayer(2, false, "Sigmoid"); + ann.addLayer(3, false, "Sigmoid"); + ann.addLayer(1, true, "Sigmoid"); + ann.setCostFunction("Minus_Squared").setLearningRate(0.1); + + if (useMomentum) { + ann.setMomentumWeight(0.6); + } + + if (useRegularization) { + ann.setRegularizationWeight(0.01); + } + + double[][] instances = { { 0, 1, 1 }, { 0, 0, 0 }, { 1, 0, 1 }, { 1, 1, 0 } }; + for (int i = 0; i < iterations; ++i) { + for (double[] instance : instances) { + ann.trainOnline(new DenseVector(instance)); + } + } + + for (double[] instance : instances) { + Vector input = new DenseVector(instance).viewPart(0, instance.length - 1); + // The expected output is the last element in array + double actual = instance[2]; + double expected = ann.getOutput(input).get(0); + assertTrue(actual < 0.5 && expected < 0.5 || actual >= 0.5 && expected >= 0.5); + } + + // Write model into file and read out + File tmpModelFile = this.getTestTempFile(modelFilename); + ann.setModelPath(tmpModelFile.getAbsolutePath()); + ann.writeModelToFile(); + + NeuralNetwork annCopy = new MultilayerPerceptron(tmpModelFile.getAbsolutePath()); + // Test on instances + for (double[] instance : instances) { + Vector input = new DenseVector(instance).viewPart(0, instance.length - 1); + // The expected output is the last element in array + double actual = instance[2]; + double expected = annCopy.getOutput(input).get(0); + assertTrue(actual < 0.5 && expected < 0.5 || actual >= 0.5 && expected >= 0.5); + } + } + + @Test + public void testWithCancerDataSet() throws IOException { + + File cancerDataset = getTestTempFile("cancer.csv"); + writeLines(cancerDataset, Datasets.CANCER); + + List<Vector> records = Lists.newArrayList(); + // Returns a mutable list of the data + List<String> cancerDataSetList = Files.readLines(cancerDataset, Charsets.UTF_8); + // Skip the header line, hence remove the first element in the list + cancerDataSetList.remove(0); + for (String line : cancerDataSetList) { + String[] tokens = CSVUtils.parseLine(line); + double[] values = new double[tokens.length]; + for (int i = 0; i < tokens.length; ++i) { + values[i] = Double.parseDouble(tokens[i]); + } + records.add(new DenseVector(values)); + } + + int splitPoint = (int) (records.size() * 0.8); + List<Vector> trainingSet = records.subList(0, splitPoint); + List<Vector> testSet = records.subList(splitPoint, records.size()); + + // initialize neural network model + NeuralNetwork ann = new MultilayerPerceptron(); + int featureDimension = records.get(0).size() - 1; + ann.addLayer(featureDimension, false, "Sigmoid"); + ann.addLayer(featureDimension * 2, false, "Sigmoid"); + ann.addLayer(1, true, "Sigmoid"); + ann.setLearningRate(0.05).setMomentumWeight(0.5).setRegularizationWeight(0.001); + + int iteration = 2000; + for (int i = 0; i < iteration; ++i) { + for (Vector trainingInstance : trainingSet) { + ann.trainOnline(trainingInstance); + } + } + + int correctInstances = 0; + for (Vector testInstance : testSet) { + Vector res = ann.getOutput(testInstance.viewPart(0, testInstance.size() - 1)); + double actual = res.get(0); + double expected = testInstance.get(testInstance.size() - 1); + if (Math.abs(actual - expected) <= 0.1) { + ++correctInstances; + } + } + double accuracy = (double) correctInstances / testSet.size() * 100; + assertTrue("The classifier is even worse than a random guesser!", accuracy > 50); + System.out.printf("Cancer DataSet. Classification precision: %d/%d = %f%%\n", correctInstances, testSet.size(), accuracy); + } + + @Test + public void testWithIrisDataSet() throws IOException { + + File irisDataset = getTestTempFile("iris.csv"); + writeLines(irisDataset, Datasets.IRIS); + + int numOfClasses = 3; + List<Vector> records = Lists.newArrayList(); + // Returns a mutable list of the data + List<String> irisDataSetList = Files.readLines(irisDataset, Charsets.UTF_8); + // Skip the header line, hence remove the first element in the list + irisDataSetList.remove(0); + + for (String line : irisDataSetList) { + String[] tokens = CSVUtils.parseLine(line); + // Last three dimensions represent the labels + double[] values = new double[tokens.length + numOfClasses - 1]; + Arrays.fill(values, 0.0); + for (int i = 0; i < tokens.length - 1; ++i) { + values[i] = Double.parseDouble(tokens[i]); + } + // Add label values + String label = tokens[tokens.length - 1]; + if (label.equalsIgnoreCase("setosa")) { + values[values.length - 3] = 1; + } else if (label.equalsIgnoreCase("versicolor")) { + values[values.length - 2] = 1; + } else { // label 'virginica' + values[values.length - 1] = 1; + } + records.add(new DenseVector(values)); + } + + Collections.shuffle(records); + + int splitPoint = (int) (records.size() * 0.8); + List<Vector> trainingSet = records.subList(0, splitPoint); + List<Vector> testSet = records.subList(splitPoint, records.size()); + + // Initialize neural network model + NeuralNetwork ann = new MultilayerPerceptron(); + int featureDimension = records.get(0).size() - numOfClasses; + ann.addLayer(featureDimension, false, "Sigmoid"); + ann.addLayer(featureDimension * 2, false, "Sigmoid"); + ann.addLayer(3, true, "Sigmoid"); // 3-class classification + ann.setLearningRate(0.05).setMomentumWeight(0.4).setRegularizationWeight(0.005); + + int iteration = 2000; + for (int i = 0; i < iteration; ++i) { + for (Vector trainingInstance : trainingSet) { + ann.trainOnline(trainingInstance); + } + } + + int correctInstances = 0; + for (Vector testInstance : testSet) { + Vector res = ann.getOutput(testInstance.viewPart(0, testInstance.size() - numOfClasses)); + double[] actualLabels = new double[numOfClasses]; + for (int i = 0; i < numOfClasses; ++i) { + actualLabels[i] = res.get(i); + } + double[] expectedLabels = new double[numOfClasses]; + for (int i = 0; i < numOfClasses; ++i) { + expectedLabels[i] = testInstance.get(testInstance.size() - numOfClasses + i); + } + + boolean allCorrect = true; + for (int i = 0; i < numOfClasses; ++i) { + if (Math.abs(expectedLabels[i] - actualLabels[i]) >= 0.1) { + allCorrect = false; + break; + } + } + if (allCorrect) { + ++correctInstances; + } + } + + double accuracy = (double) correctInstances / testSet.size() * 100; + assertTrue("The model is even worse than a random guesser.", accuracy > 50); + + System.out.printf("Iris DataSet. Classification precision: %d/%d = %f%%\n", + correctInstances, testSet.size(), accuracy); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/mlp/TrainMultilayerPerceptronTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/classifier/mlp/TrainMultilayerPerceptronTest.java b/mr/src/test/java/org/apache/mahout/classifier/mlp/TrainMultilayerPerceptronTest.java new file mode 100644 index 0000000..b905509 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/classifier/mlp/TrainMultilayerPerceptronTest.java @@ -0,0 +1,105 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.mlp; + +import java.io.File; + +import org.apache.mahout.common.MahoutTestCase; +import org.junit.Test; + +public class TrainMultilayerPerceptronTest extends MahoutTestCase { + + @Test + public void testIrisDataset() throws Exception { + String modelFileName = "mlp.model"; + File modelFile = getTestTempFile(modelFileName); + + File irisDataset = getTestTempFile("iris.csv"); + writeLines(irisDataset, Datasets.IRIS); + + String[] args = { + "-i", irisDataset.getAbsolutePath(), + "-sh", + "-labels", "setosa", "versicolor", "virginica", + "-mo", modelFile.getAbsolutePath(), + "-u", + "-ls", "4", "8", "3" + }; + + TrainMultilayerPerceptron.main(args); + + assertTrue(modelFile.exists()); + } + + @Test + public void initializeModelWithDifferentParameters() throws Exception { + String modelFileName = "mlp.model"; + File modelFile1 = getTestTempFile(modelFileName); + + File irisDataset = getTestTempFile("iris.csv"); + writeLines(irisDataset, Datasets.IRIS); + + String[] args1 = { + "-i", irisDataset.getAbsolutePath(), + "-sh", + "-labels", "setosa", "versicolor", "virginica", + "-mo", modelFile1.getAbsolutePath(), + "-u", + "-ls", "4", "8", "3", + "-l", "0.2", "-m", "0.35", "-r", "0.0001" + }; + + MultilayerPerceptron mlp1 = trainModel(args1, modelFile1); + assertEquals(0.2, mlp1.getLearningRate(), EPSILON); + assertEquals(0.35, mlp1.getMomentumWeight(), EPSILON); + assertEquals(0.0001, mlp1.getRegularizationWeight(), EPSILON); + + assertEquals(4, mlp1.getLayerSize(0) - 1); + assertEquals(8, mlp1.getLayerSize(1) - 1); + assertEquals(3, mlp1.getLayerSize(2)); // Final layer has no bias neuron + + // MLP with default learning rate, momemtum weight, and regularization weight + File modelFile2 = this.getTestTempFile(modelFileName); + + String[] args2 = { + "-i", irisDataset.getAbsolutePath(), + "-sh", + "-labels", "setosa", "versicolor", "virginica", + "-mo", modelFile2.getAbsolutePath(), + "-ls", "4", "10", "18", "3" + }; + + MultilayerPerceptron mlp2 = trainModel(args2, modelFile2); + assertEquals(0.5, mlp2.getLearningRate(), EPSILON); + assertEquals(0.1, mlp2.getMomentumWeight(), EPSILON); + assertEquals(0, mlp2.getRegularizationWeight(), EPSILON); + + assertEquals(4, mlp2.getLayerSize(0) - 1); + assertEquals(10, mlp2.getLayerSize(1) - 1); + assertEquals(18, mlp2.getLayerSize(2) - 1); + assertEquals(3, mlp2.getLayerSize(3)); // Final layer has no bias neuron + + } + + private MultilayerPerceptron trainModel(String[] args, File modelFile) throws Exception { + TrainMultilayerPerceptron.main(args); + return new MultilayerPerceptron(modelFile.getAbsolutePath()); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java b/mr/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java new file mode 100644 index 0000000..f658738 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java @@ -0,0 +1,47 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes; + +import org.apache.mahout.math.DenseVector; +import org.junit.Before; +import org.junit.Test; + +public final class ComplementaryNaiveBayesClassifierTest extends NaiveBayesTestBase { + + private ComplementaryNaiveBayesClassifier classifier; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + NaiveBayesModel model = createComplementaryNaiveBayesModel(); + classifier = new ComplementaryNaiveBayesClassifier(model); + } + + @Test + public void testNaiveBayes() throws Exception { + assertEquals(4, classifier.numCategories()); + assertEquals(0, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 1.0, 0.0, 0.0, 0.0 })))); + assertEquals(1, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 0.0, 1.0, 0.0, 0.0 })))); + assertEquals(2, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 0.0, 0.0, 1.0, 0.0 })))); + assertEquals(3, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 0.0, 0.0, 0.0, 1.0 })))); + + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java b/mr/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java new file mode 100644 index 0000000..3b83492 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java @@ -0,0 +1,36 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes; + +import org.junit.Test; + +public class NaiveBayesModelTest extends NaiveBayesTestBase { + + @Test + public void testRandomModelGeneration() { + // make sure we generate a valid random model + NaiveBayesModel standardModel = getStandardModel(); + // check whether the model is valid + standardModel.validate(); + + // same for Complementary + NaiveBayesModel complementaryModel = getComplementaryModel(); + complementaryModel.validate(); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java b/mr/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java new file mode 100644 index 0000000..974b90c --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes; + +import java.io.File; + +import com.google.common.io.Closeables; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.Text; +import org.apache.mahout.classifier.AbstractVectorClassifier; +import org.apache.mahout.classifier.naivebayes.training.TrainNaiveBayesJob; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.hadoop.MathHelper; +import org.junit.Before; +import org.junit.Test; + +public class NaiveBayesTest extends MahoutTestCase { + + private Configuration conf; + private File inputFile; + private File outputDir; + private File tempDir; + + static final Text LABEL_STOLEN = new Text("/stolen/"); + static final Text LABEL_NOT_STOLEN = new Text("/not_stolen/"); + + static final Vector.Element COLOR_RED = MathHelper.elem(0, 1); + static final Vector.Element COLOR_YELLOW = MathHelper.elem(1, 1); + static final Vector.Element TYPE_SPORTS = MathHelper.elem(2, 1); + static final Vector.Element TYPE_SUV = MathHelper.elem(3, 1); + static final Vector.Element ORIGIN_DOMESTIC = MathHelper.elem(4, 1); + static final Vector.Element ORIGIN_IMPORTED = MathHelper.elem(5, 1); + + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + + conf = getConfiguration(); + + inputFile = getTestTempFile("trainingInstances.seq"); + outputDir = getTestTempDir("output"); + outputDir.delete(); + tempDir = getTestTempDir("tmp"); + + SequenceFile.Writer writer = new SequenceFile.Writer(FileSystem.get(conf), conf, + new Path(inputFile.getAbsolutePath()), Text.class, VectorWritable.class); + + try { + writer.append(LABEL_STOLEN, trainingInstance(COLOR_RED, TYPE_SPORTS, ORIGIN_DOMESTIC)); + writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_RED, TYPE_SPORTS, ORIGIN_DOMESTIC)); + writer.append(LABEL_STOLEN, trainingInstance(COLOR_RED, TYPE_SPORTS, ORIGIN_DOMESTIC)); + writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SPORTS, ORIGIN_DOMESTIC)); + writer.append(LABEL_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SPORTS, ORIGIN_IMPORTED)); + writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SUV, ORIGIN_IMPORTED)); + writer.append(LABEL_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SUV, ORIGIN_IMPORTED)); + writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SUV, ORIGIN_DOMESTIC)); + writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_RED, TYPE_SUV, ORIGIN_IMPORTED)); + writer.append(LABEL_STOLEN, trainingInstance(COLOR_RED, TYPE_SPORTS, ORIGIN_IMPORTED)); + } finally { + Closeables.close(writer, false); + } + } + + @Test + public void toyData() throws Exception { + TrainNaiveBayesJob trainNaiveBayes = new TrainNaiveBayesJob(); + trainNaiveBayes.setConf(conf); + trainNaiveBayes.run(new String[] { "--input", inputFile.getAbsolutePath(), "--output", outputDir.getAbsolutePath(), + "-el", "--tempDir", tempDir.getAbsolutePath() }); + + NaiveBayesModel naiveBayesModel = NaiveBayesModel.materialize(new Path(outputDir.getAbsolutePath()), conf); + + AbstractVectorClassifier classifier = new StandardNaiveBayesClassifier(naiveBayesModel); + + assertEquals(2, classifier.numCategories()); + + Vector prediction = classifier.classifyFull(trainingInstance(COLOR_RED, TYPE_SUV, ORIGIN_DOMESTIC).get()); + + // should be classified as not stolen + assertTrue(prediction.get(0) < prediction.get(1)); + } + + @Test + public void toyDataComplementary() throws Exception { + TrainNaiveBayesJob trainNaiveBayes = new TrainNaiveBayesJob(); + trainNaiveBayes.setConf(conf); + trainNaiveBayes.run(new String[] { "--input", inputFile.getAbsolutePath(), "--output", outputDir.getAbsolutePath(), + "-el", "--trainComplementary", + "--tempDir", tempDir.getAbsolutePath() }); + + NaiveBayesModel naiveBayesModel = NaiveBayesModel.materialize(new Path(outputDir.getAbsolutePath()), conf); + + AbstractVectorClassifier classifier = new ComplementaryNaiveBayesClassifier(naiveBayesModel); + + assertEquals(2, classifier.numCategories()); + + Vector prediction = classifier.classifyFull(trainingInstance(COLOR_RED, TYPE_SUV, ORIGIN_DOMESTIC).get()); + + // should be classified as not stolen + assertTrue(prediction.get(0) < prediction.get(1)); + } + + static VectorWritable trainingInstance(Vector.Element... elems) { + DenseVector trainingInstance = new DenseVector(6); + for (Vector.Element elem : elems) { + trainingInstance.set(elem.index(), elem.get()); + } + return new VectorWritable(trainingInstance); + } + + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java b/mr/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java new file mode 100644 index 0000000..a943b7b --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java @@ -0,0 +1,135 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes; + +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.Vector.Element; + +public abstract class NaiveBayesTestBase extends MahoutTestCase { + + private NaiveBayesModel standardModel; + private NaiveBayesModel complementaryModel; + + @Override + public void setUp() throws Exception { + super.setUp(); + standardModel = createStandardNaiveBayesModel(); + standardModel.validate(); + complementaryModel = createComplementaryNaiveBayesModel(); + complementaryModel.validate(); + } + + protected NaiveBayesModel getStandardModel() { + return standardModel; + } + protected NaiveBayesModel getComplementaryModel() { + return complementaryModel; + } + + protected static double complementaryNaiveBayesThetaWeight(int label, + Matrix weightMatrix, + Vector labelSum, + Vector featureSum) { + double weight = 0.0; + double alpha = 1.0; + for (int i = 0; i < featureSum.size(); i++) { + double score = weightMatrix.get(i, label); + double lSum = labelSum.get(label); + double fSum = featureSum.get(i); + double totalSum = featureSum.zSum(); + double numerator = fSum - score + alpha; + double denominator = totalSum - lSum + featureSum.size(); + weight += Math.abs(Math.log(numerator / denominator)); + } + return weight; + } + + protected static double naiveBayesThetaWeight(int label, + Matrix weightMatrix, + Vector labelSum, + Vector featureSum) { + double weight = 0.0; + double alpha = 1.0; + for (int feature = 0; feature < featureSum.size(); feature++) { + double score = weightMatrix.get(feature, label); + double lSum = labelSum.get(label); + double numerator = score + alpha; + double denominator = lSum + featureSum.size(); + weight += Math.abs(Math.log(numerator / denominator)); + } + return weight; + } + + protected static NaiveBayesModel createStandardNaiveBayesModel() { + double[][] matrix = { + { 0.7, 0.1, 0.1, 0.3 }, + { 0.4, 0.4, 0.1, 0.1 }, + { 0.1, 0.0, 0.8, 0.1 }, + { 0.1, 0.1, 0.1, 0.7 } }; + + double[] labelSumArray = { 1.2, 1.0, 1.0, 1.0 }; + double[] featureSumArray = { 1.3, 0.6, 1.1, 1.2 }; + + DenseMatrix weightMatrix = new DenseMatrix(matrix); + DenseVector labelSum = new DenseVector(labelSumArray); + DenseVector featureSum = new DenseVector(featureSumArray); + + // now generate the model + return new NaiveBayesModel(weightMatrix, featureSum, labelSum, null, 1.0f, false); + } + + protected static NaiveBayesModel createComplementaryNaiveBayesModel() { + double[][] matrix = { + { 0.7, 0.1, 0.1, 0.3 }, + { 0.4, 0.4, 0.1, 0.1 }, + { 0.1, 0.0, 0.8, 0.1 }, + { 0.1, 0.1, 0.1, 0.7 } }; + + double[] labelSumArray = { 1.2, 1.0, 1.0, 1.0 }; + double[] featureSumArray = { 1.3, 0.6, 1.1, 1.2 }; + + DenseMatrix weightMatrix = new DenseMatrix(matrix); + DenseVector labelSum = new DenseVector(labelSumArray); + DenseVector featureSum = new DenseVector(featureSumArray); + + double[] thetaNormalizerSum = { + complementaryNaiveBayesThetaWeight(0, weightMatrix, labelSum, featureSum), + complementaryNaiveBayesThetaWeight(1, weightMatrix, labelSum, featureSum), + complementaryNaiveBayesThetaWeight(2, weightMatrix, labelSum, featureSum), + complementaryNaiveBayesThetaWeight(3, weightMatrix, labelSum, featureSum) }; + + // now generate the model + return new NaiveBayesModel(weightMatrix, featureSum, labelSum, new DenseVector(thetaNormalizerSum), 1.0f, true); + } + + protected static int maxIndex(Vector instance) { + int maxIndex = -1; + double maxScore = Integer.MIN_VALUE; + for (Element label : instance.all()) { + if (label.get() >= maxScore) { + maxIndex = label.index(); + maxScore = label.get(); + } + } + return maxIndex; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java b/mr/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java new file mode 100644 index 0000000..a432ac9 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java @@ -0,0 +1,47 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes; + +import org.apache.mahout.math.DenseVector; +import org.junit.Before; +import org.junit.Test; + + +public final class StandardNaiveBayesClassifierTest extends NaiveBayesTestBase { + + private StandardNaiveBayesClassifier classifier; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + NaiveBayesModel model = createStandardNaiveBayesModel(); + classifier = new StandardNaiveBayesClassifier(model); + } + + @Test + public void testNaiveBayes() throws Exception { + assertEquals(4, classifier.numCategories()); + assertEquals(0, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 1.0, 0.0, 0.0, 0.0 })))); + assertEquals(1, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 0.0, 1.0, 0.0, 0.0 })))); + assertEquals(2, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 0.0, 0.0, 1.0, 0.0 })))); + assertEquals(3, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 0.0, 0.0, 0.0, 1.0 })))); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java b/mr/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java new file mode 100644 index 0000000..a9541c9 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java @@ -0,0 +1,85 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes.training; + +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Counter; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.map.OpenObjectIntHashMap; +import org.easymock.EasyMock; +import org.junit.Before; +import org.junit.Test; + +public class IndexInstancesMapperTest extends MahoutTestCase { + + private Mapper.Context ctx; + private OpenObjectIntHashMap<String> labelIndex; + private VectorWritable instance; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + + ctx = EasyMock.createMock(Mapper.Context.class); + instance = new VectorWritable(new DenseVector(new double[] { 1, 0, 1, 1, 0 })); + + labelIndex = new OpenObjectIntHashMap<String>(); + labelIndex.put("bird", 0); + labelIndex.put("cat", 1); + } + + + @Test + public void index() throws Exception { + + ctx.write(new IntWritable(0), instance); + + EasyMock.replay(ctx); + + IndexInstancesMapper indexInstances = new IndexInstancesMapper(); + setField(indexInstances, "labelIndex", labelIndex); + + indexInstances.map(new Text("/bird/"), instance, ctx); + + EasyMock.verify(ctx); + } + + @Test + public void skip() throws Exception { + + Counter skippedInstances = EasyMock.createMock(Counter.class); + + EasyMock.expect(ctx.getCounter(IndexInstancesMapper.Counter.SKIPPED_INSTANCES)).andReturn(skippedInstances); + skippedInstances.increment(1); + + EasyMock.replay(ctx, skippedInstances); + + IndexInstancesMapper indexInstances = new IndexInstancesMapper(); + setField(indexInstances, "labelIndex", labelIndex); + + indexInstances.map(new Text("/fish/"), instance, ctx); + + EasyMock.verify(ctx, skippedInstances); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapperTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapperTest.java b/mr/src/test/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapperTest.java new file mode 100644 index 0000000..746ae0d --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapperTest.java @@ -0,0 +1,61 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes.training; + +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.easymock.EasyMock; +import org.junit.Test; + +public class ThetaMapperTest extends MahoutTestCase { + + @Test + public void standard() throws Exception { + + Mapper.Context ctx = EasyMock.createMock(Mapper.Context.class); + ComplementaryThetaTrainer trainer = EasyMock.createMock(ComplementaryThetaTrainer.class); + + Vector instance1 = new DenseVector(new double[] { 1, 2, 3 }); + Vector instance2 = new DenseVector(new double[] { 4, 5, 6 }); + + Vector perLabelThetaNormalizer = new DenseVector(new double[] { 7, 8 }); + + ThetaMapper thetaMapper = new ThetaMapper(); + setField(thetaMapper, "trainer", trainer); + + trainer.train(0, instance1); + trainer.train(1, instance2); + EasyMock.expect(trainer.retrievePerLabelThetaNormalizer()).andReturn(perLabelThetaNormalizer); + ctx.write(new Text(TrainNaiveBayesJob.LABEL_THETA_NORMALIZER), new VectorWritable(perLabelThetaNormalizer)); + + EasyMock.replay(ctx, trainer); + + thetaMapper.map(new IntWritable(0), new VectorWritable(instance1), ctx); + thetaMapper.map(new IntWritable(1), new VectorWritable(instance2), ctx); + thetaMapper.cleanup(ctx); + + EasyMock.verify(ctx, trainer); + } + + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapperTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapperTest.java b/mr/src/test/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapperTest.java new file mode 100644 index 0000000..af0b464 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapperTest.java @@ -0,0 +1,60 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes.training; + +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.easymock.EasyMock; +import org.junit.Test; + +public class WeightsMapperTest extends MahoutTestCase { + + @Test + public void scores() throws Exception { + + Mapper.Context ctx = EasyMock.createMock(Mapper.Context.class); + Vector instance1 = new DenseVector(new double[] { 1, 0, 0.5, 0.5, 0 }); + Vector instance2 = new DenseVector(new double[] { 0, 0.5, 0, 0, 0 }); + Vector instance3 = new DenseVector(new double[] { 1, 0.5, 1, 1.5, 1 }); + + Vector weightsPerLabel = new DenseVector(new double[] { 0, 0 }); + + ctx.write(new Text(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE), + new VectorWritable(new DenseVector(new double[] { 2, 1, 1.5, 2, 1 }))); + ctx.write(new Text(TrainNaiveBayesJob.WEIGHTS_PER_LABEL), + new VectorWritable(new DenseVector(new double[] { 2.5, 5 }))); + + EasyMock.replay(ctx); + + WeightsMapper weights = new WeightsMapper(); + setField(weights, "weightsPerLabel", weightsPerLabel); + + weights.map(new IntWritable(0), new VectorWritable(instance1), ctx); + weights.map(new IntWritable(0), new VectorWritable(instance2), ctx); + weights.map(new IntWritable(1), new VectorWritable(instance3), ctx); + + weights.cleanup(ctx); + + EasyMock.verify(ctx); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMAlgorithmsTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMAlgorithmsTest.java b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMAlgorithmsTest.java new file mode 100644 index 0000000..ade25b8 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMAlgorithmsTest.java @@ -0,0 +1,164 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import org.apache.mahout.math.Matrix; +import org.junit.Test; + +public class HMMAlgorithmsTest extends HMMTestBase { + + /** + * Test the forward algorithm by comparing the alpha values with the values + * obtained from HMM R model. We test the test observation sequence "O1" "O0" + * "O2" "O2" "O0" "O0" "O1" by comparing the generated alpha values to the + * R-generated "reference". + */ + @Test + public void testForwardAlgorithm() { + // intialize the expected alpha values + double[][] alphaExpectedA = { + {0.02, 0.0392, 0.002438, 0.00035456, 0.0011554672, 7.158497e-04, + 4.614927e-05}, + {0.01, 0.0054, 0.001824, 0.00069486, 0.0007586904, 2.514137e-04, + 1.721505e-05}, + {0.32, 0.0262, 0.002542, 0.00038026, 0.0001360234, 3.002345e-05, + 9.659608e-05}, + {0.03, 0.0000, 0.013428, 0.00951084, 0.0000000000, 0.000000e+00, + 2.428986e-05},}; + // fetch the alpha matrix using the forward algorithm + Matrix alpha = HmmAlgorithms.forwardAlgorithm(getModel(), getSequence(), false); + // first do some basic checking + assertNotNull(alpha); + assertEquals(4, alpha.numCols()); + assertEquals(7, alpha.numRows()); + // now compare the resulting matrices + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 7; ++j) { + assertEquals(alphaExpectedA[i][j], alpha.get(j, i), EPSILON); + } + } + } + + @Test + public void testLogScaledForwardAlgorithm() { + // intialize the expected alpha values + double[][] alphaExpectedA = { + {0.02, 0.0392, 0.002438, 0.00035456, 0.0011554672, 7.158497e-04, + 4.614927e-05}, + {0.01, 0.0054, 0.001824, 0.00069486, 0.0007586904, 2.514137e-04, + 1.721505e-05}, + {0.32, 0.0262, 0.002542, 0.00038026, 0.0001360234, 3.002345e-05, + 9.659608e-05}, + {0.03, 0.0000, 0.013428, 0.00951084, 0.0000000000, 0.000000e+00, + 2.428986e-05},}; + // fetch the alpha matrix using the forward algorithm + Matrix alpha = HmmAlgorithms.forwardAlgorithm(getModel(), getSequence(), true); + // first do some basic checking + assertNotNull(alpha); + assertEquals(4, alpha.numCols()); + assertEquals(7, alpha.numRows()); + // now compare the resulting matrices + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 7; ++j) { + assertEquals(Math.log(alphaExpectedA[i][j]), alpha.get(j, i), EPSILON); + } + } + } + + /** + * Test the backward algorithm by comparing the beta values with the values + * obtained from HMM R model. We test the following observation sequence "O1" + * "O0" "O2" "O2" "O0" "O0" "O1" by comparing the generated beta values to the + * R-generated "reference". + */ + @Test + public void testBackwardAlgorithm() { + // intialize the expected beta values + double[][] betaExpectedA = { + {0.0015730559, 0.003543656, 0.00738264, 0.040692, 0.0848, 0.17, 1}, + {0.0017191865, 0.002386795, 0.00923652, 0.052232, 0.1018, 0.17, 1}, + {0.0003825772, 0.001238558, 0.00259464, 0.012096, 0.0664, 0.66, 1}, + {0.0004390858, 0.007076994, 0.01063512, 0.013556, 0.0304, 0.17, 1}}; + // fetch the beta matrix using the backward algorithm + Matrix beta = HmmAlgorithms.backwardAlgorithm(getModel(), getSequence(), false); + // first do some basic checking + assertNotNull(beta); + assertEquals(4, beta.numCols()); + assertEquals(7, beta.numRows()); + // now compare the resulting matrices + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 7; ++j) { + assertEquals(betaExpectedA[i][j], beta.get(j, i), EPSILON); + } + } + } + + @Test + public void testLogScaledBackwardAlgorithm() { + // intialize the expected beta values + double[][] betaExpectedA = { + {0.0015730559, 0.003543656, 0.00738264, 0.040692, 0.0848, 0.17, 1}, + {0.0017191865, 0.002386795, 0.00923652, 0.052232, 0.1018, 0.17, 1}, + {0.0003825772, 0.001238558, 0.00259464, 0.012096, 0.0664, 0.66, 1}, + {0.0004390858, 0.007076994, 0.01063512, 0.013556, 0.0304, 0.17, 1}}; + // fetch the beta matrix using the backward algorithm + Matrix beta = HmmAlgorithms.backwardAlgorithm(getModel(), getSequence(), true); + // first do some basic checking + assertNotNull(beta); + assertEquals(4, beta.numCols()); + assertEquals(7, beta.numRows()); + // now compare the resulting matrices + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 7; ++j) { + assertEquals(Math.log(betaExpectedA[i][j]), beta.get(j, i), EPSILON); + } + } + } + + @Test + public void testViterbiAlgorithm() { + // initialize the expected hidden sequence + int[] expected = {2, 0, 3, 3, 0, 0, 2}; + // fetch the viterbi generated sequence + int[] computed = HmmAlgorithms.viterbiAlgorithm(getModel(), getSequence(), false); + // first make sure we return the correct size + assertNotNull(computed); + assertEquals(computed.length, getSequence().length); + // now check the contents + for (int i = 0; i < getSequence().length; ++i) { + assertEquals(expected[i], computed[i]); + } + } + + @Test + public void testLogScaledViterbiAlgorithm() { + // initialize the expected hidden sequence + int[] expected = {2, 0, 3, 3, 0, 0, 2}; + // fetch the viterbi generated sequence + int[] computed = HmmAlgorithms.viterbiAlgorithm(getModel(), getSequence(), true); + // first make sure we return the correct size + assertNotNull(computed); + assertEquals(computed.length, getSequence().length); + // now check the contents + for (int i = 0; i < getSequence().length; ++i) { + assertEquals(expected[i], computed[i]); + } + + } + +}
