http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/AbstractVectorClassifier.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/AbstractVectorClassifier.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/AbstractVectorClassifier.java new file mode 100644 index 0000000..efd233f --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/AbstractVectorClassifier.java @@ -0,0 +1,248 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier; + +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; + +import com.google.common.base.Preconditions; + +/** + * Defines the interface for classifiers that take a vector as input. This is + * implemented as an abstract class so that it can implement a number of handy + * convenience methods related to classification of vectors. + * + * <p> + * A classifier takes an input vector and calculates the scores (usually + * probabilities) that the input vector belongs to one of {@code n} + * categories. In {@code AbstractVectorClassifier} each category is denoted + * by an integer {@code c} between {@code 0} and {@code n-1} + * (inclusive). + * + * <p> + * New users should start by looking at {@link #classifyFull} (not {@link #classify}). + * + */ +public abstract class AbstractVectorClassifier { + + /** Minimum allowable log likelihood value. */ + public static final double MIN_LOG_LIKELIHOOD = -100.0; + + /** + * Returns the number of categories that a target variable can be assigned to. + * A vector classifier will encode it's output as an integer from + * {@code 0} to {@code numCategories()-1} (inclusive). + * + * @return The number of categories. + */ + public abstract int numCategories(); + + /** + * Compute and return a vector containing {@code n-1} scores, where + * {@code n} is equal to {@code numCategories()}, given an input + * vector {@code instance}. Higher scores indicate that the input vector + * is more likely to belong to that category. The categories are denoted by + * the integers {@code 0} through {@code n-1} (inclusive), and the + * scores in the returned vector correspond to categories 1 through + * {@code n-1} (leaving out category 0). It is assumed that the score for + * category 0 is one minus the sum of the scores in the returned vector. + * + * @param instance A feature vector to be classified. + * @return A vector of probabilities in 1 of {@code n-1} encoding. + */ + public abstract Vector classify(Vector instance); + + /** + * Compute and return a vector of scores before applying the inverse link + * function. For logistic regression and other generalized linear models, this + * is just the linear part of the classification. + * + * <p> + * The implementation of this method provided by {@code AbstractVectorClassifier} throws an + * {@link UnsupportedOperationException}. Your subclass must explicitly override this method to support + * this operation. + * + * @param features A feature vector to be classified. + * @return A vector of scores. If transformed by the link function, these will become probabilities. + */ + public Vector classifyNoLink(Vector features) { + throw new UnsupportedOperationException(this.getClass().getName() + + " doesn't support classification without a link"); + } + + /** + * Classifies a vector in the special case of a binary classifier where + * {@link #classify(Vector)} would return a vector with only one element. As + * such, using this method can avoid the allocation of a vector. + * + * @param instance The feature vector to be classified. + * @return The score for category 1. + * + * @see #classify(Vector) + */ + public abstract double classifyScalar(Vector instance); + + /** + * Computes and returns a vector containing {@code n} scores, where + * {@code n} is {@code numCategories()}, given an input vector + * {@code instance}. Higher scores indicate that the input vector is more + * likely to belong to the corresponding category. The categories are denoted + * by the integers {@code 0} through {@code n-1} (inclusive). + * + * <p> + * Using this method it is possible to classify an input vector, for example, + * by selecting the category with the largest score. If + * {@code classifier} is an instance of + * {@code AbstractVectorClassifier} and {@code input} is a + * {@code Vector} of features describing an element to be classified, + * then the following code could be used to classify {@code input}.<br> + * {@code + * Vector scores = classifier.classifyFull(input);<br> + * int assignedCategory = scores.maxValueIndex();<br> + * } Here {@code assignedCategory} is the index of the category + * with the maximum score. + * + * <p> + * If an {@code n-1} encoding is acceptable, and allocation performance + * is an issue, then the {@link #classify(Vector)} method is probably better + * to use. + * + * @see #classify(Vector) + * @see #classifyFull(Vector r, Vector instance) + * + * @param instance A vector of features to be classified. + * @return A vector of probabilities, one for each category. + */ + public Vector classifyFull(Vector instance) { + return classifyFull(new DenseVector(numCategories()), instance); + } + + /** + * Computes and returns a vector containing {@code n} scores, where + * {@code n} is {@code numCategories()}, given an input vector + * {@code instance}. Higher scores indicate that the input vector is more + * likely to belong to the corresponding category. The categories are denoted + * by the integers {@code 0} through {@code n-1} (inclusive). The + * main difference between this method and {@link #classifyFull(Vector)} is + * that this method allows a user to provide a previously allocated + * {@code Vector r} to store the returned scores. + * + * <p> + * Using this method it is possible to classify an input vector, for example, + * by selecting the category with the largest score. If + * {@code classifier} is an instance of + * {@code AbstractVectorClassifier}, {@code result} is a non-null + * {@code Vector}, and {@code input} is a {@code Vector} of + * features describing an element to be classified, then the following code + * could be used to classify {@code input}.<br> + * {@code + * Vector scores = classifier.classifyFull(result, input); // Notice that scores == result<br> + * int assignedCategory = scores.maxValueIndex();<br> + * } Here {@code assignedCategory} is the index of the category + * with the maximum score. + * + * @param r Where to put the results. + * @param instance A vector of features to be classified. + * @return A vector of scores/probabilities, one for each category. + */ + public Vector classifyFull(Vector r, Vector instance) { + r.viewPart(1, numCategories() - 1).assign(classify(instance)); + r.setQuick(0, 1.0 - r.zSum()); + return r; + } + + + /** + * Returns n-1 probabilities, one for each categories 1 through + * {@code n-1}, for each row of a matrix, where {@code n} is equal + * to {@code numCategories()}. The probability of the missing 0-th + * category is 1 - rowSum(this result). + * + * @param data The matrix whose rows are the input vectors to classify + * @return A matrix of scores, one row per row of the input matrix, one column for each but the last category. + */ + public Matrix classify(Matrix data) { + Matrix r = new DenseMatrix(data.numRows(), numCategories() - 1); + for (int row = 0; row < data.numRows(); row++) { + r.assignRow(row, classify(data.viewRow(row))); + } + return r; + } + + /** + * Returns a matrix where the rows of the matrix each contain {@code n} probabilities, one for each category. + * + * @param data The matrix whose rows are the input vectors to classify + * @return A matrix of scores, one row per row of the input matrix, one column for each but the last category. + */ + public Matrix classifyFull(Matrix data) { + Matrix r = new DenseMatrix(data.numRows(), numCategories()); + for (int row = 0; row < data.numRows(); row++) { + classifyFull(r.viewRow(row), data.viewRow(row)); + } + return r; + } + + /** + * Returns a vector of probabilities of category 1, one for each row + * of a matrix. This only makes sense if there are exactly two categories, but + * calling this method in that case can save a number of vector allocations. + * + * @param data The matrix whose rows are vectors to classify + * @return A vector of scores, with one value per row of the input matrix. + */ + public Vector classifyScalar(Matrix data) { + Preconditions.checkArgument(numCategories() == 2, "Can only call classifyScalar with two categories"); + + Vector r = new DenseVector(data.numRows()); + for (int row = 0; row < data.numRows(); row++) { + r.set(row, classifyScalar(data.viewRow(row))); + } + return r; + } + + /** + * Returns a measure of how good the classification for a particular example + * actually is. + * + * @param actual The correct category for the example. + * @param data The vector to be classified. + * @return The log likelihood of the correct answer as estimated by the current model. This will always be <= 0 + * and larger (closer to 0) indicates better accuracy. In order to simplify code that maintains eunning averages, + * we bound this value at -100. + */ + public double logLikelihood(int actual, Vector data) { + if (numCategories() == 2) { + double p = classifyScalar(data); + if (actual > 0) { + return Math.max(MIN_LOG_LIKELIHOOD, Math.log(p)); + } else { + return Math.max(MIN_LOG_LIKELIHOOD, Math.log1p(-p)); + } + } else { + Vector p = classify(data); + if (actual > 0) { + return Math.max(MIN_LOG_LIKELIHOOD, Math.log(p.get(actual - 1))); + } else { + return Math.max(MIN_LOG_LIKELIHOOD, Math.log1p(-p.zSum())); + } + } + } +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/ClassifierResult.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/ClassifierResult.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/ClassifierResult.java new file mode 100644 index 0000000..29eaa0d --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/ClassifierResult.java @@ -0,0 +1,74 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier; + +/** + * Result of a document classification. The label and the associated score (usually probabilty) + */ +public class ClassifierResult { + + private String label; + private double score; + private double logLikelihood = Double.MAX_VALUE; + + public ClassifierResult() { } + + public ClassifierResult(String label, double score) { + this.label = label; + this.score = score; + } + + public ClassifierResult(String label) { + this.label = label; + } + + public ClassifierResult(String label, double score, double logLikelihood) { + this.label = label; + this.score = score; + this.logLikelihood = logLikelihood; + } + + public double getLogLikelihood() { + return logLikelihood; + } + + public void setLogLikelihood(double logLikelihood) { + this.logLikelihood = logLikelihood; + } + + public String getLabel() { + return label; + } + + public double getScore() { + return score; + } + + public void setLabel(String label) { + this.label = label; + } + + public void setScore(double score) { + this.score = score; + } + + @Override + public String toString() { + return "ClassifierResult{" + "category='" + label + '\'' + ", score=" + score + '}'; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java new file mode 100644 index 0000000..73ba521 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java @@ -0,0 +1,444 @@ +/** + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * Licensed to the Apache Software Foundation (ASF) under one or more + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; + +import com.google.common.base.Preconditions; +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.math3.stat.descriptive.moment.Mean; +import org.apache.mahout.cf.taste.impl.common.FullRunningAverageAndStdDev; +import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.Matrix; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * The ConfusionMatrix Class stores the result of Classification of a Test Dataset. + * + * The fact of whether there is a default is not stored. A row of zeros is the only indicator that there is no default. + * + * See http://en.wikipedia.org/wiki/Confusion_matrix for background + */ +public class ConfusionMatrix { + private static final Logger LOG = LoggerFactory.getLogger(ConfusionMatrix.class); + private final Map<String,Integer> labelMap = new LinkedHashMap<>(); + private final int[][] confusionMatrix; + private int samples = 0; + private String defaultLabel = "unknown"; + + public ConfusionMatrix(Collection<String> labels, String defaultLabel) { + confusionMatrix = new int[labels.size() + 1][labels.size() + 1]; + this.defaultLabel = defaultLabel; + int i = 0; + for (String label : labels) { + labelMap.put(label, i++); + } + labelMap.put(defaultLabel, i); + } + + public ConfusionMatrix(Matrix m) { + confusionMatrix = new int[m.numRows()][m.numRows()]; + setMatrix(m); + } + + public int[][] getConfusionMatrix() { + return confusionMatrix; + } + + public Collection<String> getLabels() { + return Collections.unmodifiableCollection(labelMap.keySet()); + } + + private int numLabels() { + return labelMap.size(); + } + + public double getAccuracy(String label) { + int labelId = labelMap.get(label); + int labelTotal = 0; + int correct = 0; + for (int i = 0; i < numLabels(); i++) { + labelTotal += confusionMatrix[labelId][i]; + if (i == labelId) { + correct += confusionMatrix[labelId][i]; + } + } + return 100.0 * correct / labelTotal; + } + + // Producer accuracy + public double getAccuracy() { + int total = 0; + int correct = 0; + for (int i = 0; i < numLabels(); i++) { + for (int j = 0; j < numLabels(); j++) { + total += confusionMatrix[i][j]; + if (i == j) { + correct += confusionMatrix[i][j]; + } + } + } + return 100.0 * correct / total; + } + + /** Sum of true positives and false negatives */ + private int getActualNumberOfTestExamplesForClass(String label) { + int labelId = labelMap.get(label); + int sum = 0; + for (int i = 0; i < numLabels(); i++) { + sum += confusionMatrix[labelId][i]; + } + return sum; + } + + public double getPrecision(String label) { + int labelId = labelMap.get(label); + int truePositives = confusionMatrix[labelId][labelId]; + int falsePositives = 0; + for (int i = 0; i < numLabels(); i++) { + if (i == labelId) { + continue; + } + falsePositives += confusionMatrix[i][labelId]; + } + + if (truePositives + falsePositives == 0) { + return 0; + } + + return ((double) truePositives) / (truePositives + falsePositives); + } + + public double getWeightedPrecision() { + double[] precisions = new double[numLabels()]; + double[] weights = new double[numLabels()]; + + int index = 0; + for (String label : labelMap.keySet()) { + precisions[index] = getPrecision(label); + weights[index] = getActualNumberOfTestExamplesForClass(label); + index++; + } + return new Mean().evaluate(precisions, weights); + } + + public double getRecall(String label) { + int labelId = labelMap.get(label); + int truePositives = confusionMatrix[labelId][labelId]; + int falseNegatives = 0; + for (int i = 0; i < numLabels(); i++) { + if (i == labelId) { + continue; + } + falseNegatives += confusionMatrix[labelId][i]; + } + if (truePositives + falseNegatives == 0) { + return 0; + } + return ((double) truePositives) / (truePositives + falseNegatives); + } + + public double getWeightedRecall() { + double[] recalls = new double[numLabels()]; + double[] weights = new double[numLabels()]; + + int index = 0; + for (String label : labelMap.keySet()) { + recalls[index] = getRecall(label); + weights[index] = getActualNumberOfTestExamplesForClass(label); + index++; + } + return new Mean().evaluate(recalls, weights); + } + + public double getF1score(String label) { + double precision = getPrecision(label); + double recall = getRecall(label); + if (precision + recall == 0) { + return 0; + } + return 2 * precision * recall / (precision + recall); + } + + public double getWeightedF1score() { + double[] f1Scores = new double[numLabels()]; + double[] weights = new double[numLabels()]; + + int index = 0; + for (String label : labelMap.keySet()) { + f1Scores[index] = getF1score(label); + weights[index] = getActualNumberOfTestExamplesForClass(label); + index++; + } + return new Mean().evaluate(f1Scores, weights); + } + + // User accuracy + public double getReliability() { + int count = 0; + double accuracy = 0; + for (String label: labelMap.keySet()) { + if (!label.equals(defaultLabel)) { + accuracy += getAccuracy(label); + } + count++; + } + return accuracy / count; + } + + /** + * Accuracy v.s. randomly classifying all samples. + * kappa() = (totalAccuracy() - randomAccuracy()) / (1 - randomAccuracy()) + * Cohen, Jacob. 1960. A coefficient of agreement for nominal scales. + * Educational And Psychological Measurement 20:37-46. + * + * Formula and variable names from: + * http://www.yale.edu/ceo/OEFS/Accuracy.pdf + * + * @return double + */ + public double getKappa() { + double a = 0.0; + double b = 0.0; + for (int i = 0; i < confusionMatrix.length; i++) { + a += confusionMatrix[i][i]; + double br = 0; + for (int j = 0; j < confusionMatrix.length; j++) { + br += confusionMatrix[i][j]; + } + double bc = 0; + for (int[] vec : confusionMatrix) { + bc += vec[i]; + } + b += br * bc; + } + return (samples * a - b) / (samples * samples - b); + } + + /** + * Standard deviation of normalized producer accuracy + * Not a standard score + * @return double + */ + public RunningAverageAndStdDev getNormalizedStats() { + RunningAverageAndStdDev summer = new FullRunningAverageAndStdDev(); + for (int d = 0; d < confusionMatrix.length; d++) { + double total = 0; + for (int j = 0; j < confusionMatrix.length; j++) { + total += confusionMatrix[d][j]; + } + summer.addDatum(confusionMatrix[d][d] / (total + 0.000001)); + } + + return summer; + } + + public int getCorrect(String label) { + int labelId = labelMap.get(label); + return confusionMatrix[labelId][labelId]; + } + + public int getTotal(String label) { + int labelId = labelMap.get(label); + int labelTotal = 0; + for (int i = 0; i < labelMap.size(); i++) { + labelTotal += confusionMatrix[labelId][i]; + } + return labelTotal; + } + + public void addInstance(String correctLabel, ClassifierResult classifiedResult) { + samples++; + incrementCount(correctLabel, classifiedResult.getLabel()); + } + + public void addInstance(String correctLabel, String classifiedLabel) { + samples++; + incrementCount(correctLabel, classifiedLabel); + } + + public int getCount(String correctLabel, String classifiedLabel) { + if(!labelMap.containsKey(correctLabel)) { + LOG.warn("Label {} did not appear in the training examples", correctLabel); + return 0; + } + Preconditions.checkArgument(labelMap.containsKey(classifiedLabel), "Label not found: " + classifiedLabel); + int correctId = labelMap.get(correctLabel); + int classifiedId = labelMap.get(classifiedLabel); + return confusionMatrix[correctId][classifiedId]; + } + + public void putCount(String correctLabel, String classifiedLabel, int count) { + if(!labelMap.containsKey(correctLabel)) { + LOG.warn("Label {} did not appear in the training examples", correctLabel); + return; + } + Preconditions.checkArgument(labelMap.containsKey(classifiedLabel), "Label not found: " + classifiedLabel); + int correctId = labelMap.get(correctLabel); + int classifiedId = labelMap.get(classifiedLabel); + if (confusionMatrix[correctId][classifiedId] == 0.0 && count != 0) { + samples++; + } + confusionMatrix[correctId][classifiedId] = count; + } + + public String getDefaultLabel() { + return defaultLabel; + } + + public void incrementCount(String correctLabel, String classifiedLabel, int count) { + putCount(correctLabel, classifiedLabel, count + getCount(correctLabel, classifiedLabel)); + } + + public void incrementCount(String correctLabel, String classifiedLabel) { + incrementCount(correctLabel, classifiedLabel, 1); + } + + public ConfusionMatrix merge(ConfusionMatrix b) { + Preconditions.checkArgument(labelMap.size() == b.getLabels().size(), "The label sizes do not match"); + for (String correctLabel : this.labelMap.keySet()) { + for (String classifiedLabel : this.labelMap.keySet()) { + incrementCount(correctLabel, classifiedLabel, b.getCount(correctLabel, classifiedLabel)); + } + } + return this; + } + + public Matrix getMatrix() { + int length = confusionMatrix.length; + Matrix m = new DenseMatrix(length, length); + for (int r = 0; r < length; r++) { + for (int c = 0; c < length; c++) { + m.set(r, c, confusionMatrix[r][c]); + } + } + Map<String,Integer> labels = new HashMap<>(); + for (Map.Entry<String, Integer> entry : labelMap.entrySet()) { + labels.put(entry.getKey(), entry.getValue()); + } + m.setRowLabelBindings(labels); + m.setColumnLabelBindings(labels); + return m; + } + + public void setMatrix(Matrix m) { + int length = confusionMatrix.length; + if (m.numRows() != m.numCols()) { + throw new IllegalArgumentException( + "ConfusionMatrix: matrix(" + m.numRows() + ',' + m.numCols() + ") must be square"); + } + for (int r = 0; r < length; r++) { + for (int c = 0; c < length; c++) { + confusionMatrix[r][c] = (int) Math.round(m.get(r, c)); + } + } + Map<String,Integer> labels = m.getRowLabelBindings(); + if (labels == null) { + labels = m.getColumnLabelBindings(); + } + if (labels != null) { + String[] sorted = sortLabels(labels); + verifyLabels(length, sorted); + labelMap.clear(); + for (int i = 0; i < length; i++) { + labelMap.put(sorted[i], i); + } + } + } + + private static String[] sortLabels(Map<String,Integer> labels) { + String[] sorted = new String[labels.size()]; + for (Map.Entry<String,Integer> entry : labels.entrySet()) { + sorted[entry.getValue()] = entry.getKey(); + } + return sorted; + } + + private static void verifyLabels(int length, String[] sorted) { + Preconditions.checkArgument(sorted.length == length, "One label, one row"); + for (int i = 0; i < length; i++) { + if (sorted[i] == null) { + Preconditions.checkArgument(false, "One label, one row"); + } + } + } + + /** + * This is overloaded. toString() is not a formatted report you print for a manager :) + * Assume that if there are no default assignments, the default feature was not used + */ + @Override + public String toString() { + StringBuilder returnString = new StringBuilder(200); + returnString.append("=======================================================").append('\n'); + returnString.append("Confusion Matrix\n"); + returnString.append("-------------------------------------------------------").append('\n'); + + int unclassified = getTotal(defaultLabel); + for (Map.Entry<String,Integer> entry : this.labelMap.entrySet()) { + if (entry.getKey().equals(defaultLabel) && unclassified == 0) { + continue; + } + + returnString.append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 5)).append('\t'); + } + + returnString.append("<--Classified as").append('\n'); + for (Map.Entry<String,Integer> entry : this.labelMap.entrySet()) { + if (entry.getKey().equals(defaultLabel) && unclassified == 0) { + continue; + } + String correctLabel = entry.getKey(); + int labelTotal = 0; + for (String classifiedLabel : this.labelMap.keySet()) { + if (classifiedLabel.equals(defaultLabel) && unclassified == 0) { + continue; + } + returnString.append( + StringUtils.rightPad(Integer.toString(getCount(correctLabel, classifiedLabel)), 5)).append('\t'); + labelTotal += getCount(correctLabel, classifiedLabel); + } + returnString.append(" | ").append(StringUtils.rightPad(String.valueOf(labelTotal), 6)).append('\t') + .append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 5)) + .append(" = ").append(correctLabel).append('\n'); + } + if (unclassified > 0) { + returnString.append("Default Category: ").append(defaultLabel).append(": ").append(unclassified).append('\n'); + } + returnString.append('\n'); + return returnString.toString(); + } + + static String getSmallLabel(int i) { + int val = i; + StringBuilder returnString = new StringBuilder(); + do { + int n = val % 26; + returnString.insert(0, (char) ('a' + n)); + val /= 26; + } while (val > 0); + return returnString.toString(); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/OnlineLearner.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/OnlineLearner.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/OnlineLearner.java new file mode 100644 index 0000000..af1d5e7 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/OnlineLearner.java @@ -0,0 +1,96 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier; + +import org.apache.mahout.math.Vector; + +import java.io.Closeable; + +/** + * The simplest interface for online learning algorithms. + */ +public interface OnlineLearner extends Closeable { + /** + * Updates the model using a particular target variable value and a feature vector. + * <p/> + * There may an assumption that if multiple passes through the training data are necessary, then + * the training examples will be presented in the same order. This is because the order of + * training examples may be used to assign records to different data splits for evaluation by + * cross-validation. Without the order invariance, records might be assigned to training and test + * splits and error estimates could be seriously affected. + * <p/> + * If re-ordering is necessary, then using the alternative API which allows a tracking key to be + * added to the training example can be used. + * + * @param actual The value of the target variable. This value should be in the half-open + * interval [0..n) where n is the number of target categories. + * @param instance The feature vector for this example. + */ + void train(int actual, Vector instance); + + /** + * Updates the model using a particular target variable value and a feature vector. + * <p/> + * There may an assumption that if multiple passes through the training data are necessary that + * the tracking key for a record will be the same for each pass and that there will be a + * relatively large number of distinct tracking keys and that the low-order bits of the tracking + * keys will not correlate with any of the input variables. This tracking key is used to assign + * training examples to different test/training splits. + * <p/> + * Examples of useful tracking keys include id-numbers for the training records derived from + * a database id for the base table from the which the record is derived, or the offset of + * the original data record in a data file. + * + * @param trackingKey The tracking key for this training example. + * @param groupKey An optional value that allows examples to be grouped in the computation of + * the update to the model. + * @param actual The value of the target variable. This value should be in the half-open + * interval [0..n) where n is the number of target categories. + * @param instance The feature vector for this example. + */ + void train(long trackingKey, String groupKey, int actual, Vector instance); + + /** + * Updates the model using a particular target variable value and a feature vector. + * <p/> + * There may an assumption that if multiple passes through the training data are necessary that + * the tracking key for a record will be the same for each pass and that there will be a + * relatively large number of distinct tracking keys and that the low-order bits of the tracking + * keys will not correlate with any of the input variables. This tracking key is used to assign + * training examples to different test/training splits. + * <p/> + * Examples of useful tracking keys include id-numbers for the training records derived from + * a database id for the base table from the which the record is derived, or the offset of + * the original data record in a data file. + * + * @param trackingKey The tracking key for this training example. + * @param actual The value of the target variable. This value should be in the half-open + * interval [0..n) where n is the number of target categories. + * @param instance The feature vector for this example. + */ + void train(long trackingKey, int actual, Vector instance); + + /** + * Prepares the classifier for classification and deallocates any temporary data structures. + * + * An online classifier should be able to accept more training after being closed, but + * closing the classifier may make classification more efficient. + */ + @Override + void close(); +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/RegressionResultAnalyzer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/RegressionResultAnalyzer.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/RegressionResultAnalyzer.java new file mode 100644 index 0000000..35c11ee --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/RegressionResultAnalyzer.java @@ -0,0 +1,144 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier; + +import java.text.DecimalFormat; +import java.text.NumberFormat; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; + +import org.apache.commons.lang3.StringUtils; + +/** + * ResultAnalyzer captures the classification statistics and displays in a tabular manner + */ +public class RegressionResultAnalyzer { + + private static class Result { + private final double actual; + private final double result; + Result(double actual, double result) { + this.actual = actual; + this.result = result; + } + double getActual() { + return actual; + } + double getResult() { + return result; + } + } + + private List<Result> results; + + /** + * + * @param actual + * The actual answer + * @param result + * The regression result + */ + public void addInstance(double actual, double result) { + if (results == null) { + results = new ArrayList<>(); + } + results.add(new Result(actual, result)); + } + + /** + * + * @param results + * The results table + */ + public void setInstances(double[][] results) { + for (double[] res : results) { + addInstance(res[0], res[1]); + } + } + + @Override + public String toString() { + double sumActual = 0.0; + double sumActualSquared = 0.0; + double sumResult = 0.0; + double sumResultSquared = 0.0; + double sumActualResult = 0.0; + double sumAbsolute = 0.0; + double sumAbsoluteSquared = 0.0; + int predictable = 0; + int unpredictable = 0; + + for (Result res : results) { + double actual = res.getActual(); + double result = res.getResult(); + if (Double.isNaN(result)) { + unpredictable++; + } else { + sumActual += actual; + sumActualSquared += actual * actual; + sumResult += result; + sumResultSquared += result * result; + sumActualResult += actual * result; + double absolute = Math.abs(actual - result); + sumAbsolute += absolute; + sumAbsoluteSquared += absolute * absolute; + predictable++; + } + } + + StringBuilder returnString = new StringBuilder(); + + returnString.append("=======================================================\n"); + returnString.append("Summary\n"); + returnString.append("-------------------------------------------------------\n"); + + if (predictable > 0) { + double varActual = sumActualSquared - sumActual * sumActual / predictable; + double varResult = sumResultSquared - sumResult * sumResult / predictable; + double varCo = sumActualResult - sumActual * sumResult / predictable; + + double correlation; + if (varActual * varResult <= 0) { + correlation = 0.0; + } else { + correlation = varCo / Math.sqrt(varActual * varResult); + } + + Locale.setDefault(Locale.US); + NumberFormat decimalFormatter = new DecimalFormat("0.####"); + + returnString.append(StringUtils.rightPad("Correlation coefficient", 40)).append(": ").append( + StringUtils.leftPad(decimalFormatter.format(correlation), 10)).append('\n'); + returnString.append(StringUtils.rightPad("Mean absolute error", 40)).append(": ").append( + StringUtils.leftPad(decimalFormatter.format(sumAbsolute / predictable), 10)).append('\n'); + returnString.append(StringUtils.rightPad("Root mean squared error", 40)).append(": ").append( + StringUtils.leftPad(decimalFormatter.format(Math.sqrt(sumAbsoluteSquared / predictable)), + 10)).append('\n'); + } + returnString.append(StringUtils.rightPad("Predictable Instances", 40)).append(": ").append( + StringUtils.leftPad(Integer.toString(predictable), 10)).append('\n'); + returnString.append(StringUtils.rightPad("Unpredictable Instances", 40)).append(": ").append( + StringUtils.leftPad(Integer.toString(unpredictable), 10)).append('\n'); + returnString.append(StringUtils.rightPad("Total Regressed Instances", 40)).append(": ").append( + StringUtils.leftPad(Integer.toString(results.size()), 10)).append('\n'); + returnString.append('\n'); + + return returnString.toString(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java new file mode 100644 index 0000000..1711f19 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java @@ -0,0 +1,132 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier; + +import java.text.DecimalFormat; +import java.text.NumberFormat; +import java.util.Collection; + +import org.apache.commons.lang3.StringUtils; +import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev; +import org.apache.mahout.math.stats.OnlineSummarizer; + +/** ResultAnalyzer captures the classification statistics and displays in a tabular manner */ +public class ResultAnalyzer { + + private final ConfusionMatrix confusionMatrix; + private final OnlineSummarizer summarizer; + private boolean hasLL; + + /* + * === Summary === + * + * Correctly Classified Instances 635 92.9722 % Incorrectly Classified Instances 48 7.0278 % Kappa statistic + * 0.923 Mean absolute error 0.0096 Root mean squared error 0.0817 Relative absolute error 9.9344 % Root + * relative squared error 37.2742 % Total Number of Instances 683 + */ + private int correctlyClassified; + private int incorrectlyClassified; + + public ResultAnalyzer(Collection<String> labelSet, String defaultLabel) { + confusionMatrix = new ConfusionMatrix(labelSet, defaultLabel); + summarizer = new OnlineSummarizer(); + } + + public ConfusionMatrix getConfusionMatrix() { + return this.confusionMatrix; + } + + /** + * + * @param correctLabel + * The correct label + * @param classifiedResult + * The classified result + * @return whether the instance was correct or not + */ + public boolean addInstance(String correctLabel, ClassifierResult classifiedResult) { + boolean result = correctLabel.equals(classifiedResult.getLabel()); + if (result) { + correctlyClassified++; + } else { + incorrectlyClassified++; + } + confusionMatrix.addInstance(correctLabel, classifiedResult); + if (classifiedResult.getLogLikelihood() != Double.MAX_VALUE) { + summarizer.add(classifiedResult.getLogLikelihood()); + hasLL = true; + } + return result; + } + + @Override + public String toString() { + StringBuilder returnString = new StringBuilder(); + + returnString.append('\n'); + returnString.append("=======================================================\n"); + returnString.append("Summary\n"); + returnString.append("-------------------------------------------------------\n"); + int totalClassified = correctlyClassified + incorrectlyClassified; + double percentageCorrect = (double) 100 * correctlyClassified / totalClassified; + double percentageIncorrect = (double) 100 * incorrectlyClassified / totalClassified; + NumberFormat decimalFormatter = new DecimalFormat("0.####"); + + returnString.append(StringUtils.rightPad("Correctly Classified Instances", 40)).append(": ").append( + StringUtils.leftPad(Integer.toString(correctlyClassified), 10)).append('\t').append( + StringUtils.leftPad(decimalFormatter.format(percentageCorrect), 10)).append("%\n"); + returnString.append(StringUtils.rightPad("Incorrectly Classified Instances", 40)).append(": ").append( + StringUtils.leftPad(Integer.toString(incorrectlyClassified), 10)).append('\t').append( + StringUtils.leftPad(decimalFormatter.format(percentageIncorrect), 10)).append("%\n"); + returnString.append(StringUtils.rightPad("Total Classified Instances", 40)).append(": ").append( + StringUtils.leftPad(Integer.toString(totalClassified), 10)).append('\n'); + returnString.append('\n'); + + returnString.append(confusionMatrix); + returnString.append("=======================================================\n"); + returnString.append("Statistics\n"); + returnString.append("-------------------------------------------------------\n"); + + RunningAverageAndStdDev normStats = confusionMatrix.getNormalizedStats(); + returnString.append(StringUtils.rightPad("Kappa", 40)).append( + StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getKappa()), 10)).append('\n'); + returnString.append(StringUtils.rightPad("Accuracy", 40)).append( + StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getAccuracy()), 10)).append("%\n"); + returnString.append(StringUtils.rightPad("Reliability", 40)).append( + StringUtils.leftPad(decimalFormatter.format(normStats.getAverage() * 100.00000001), 10)).append("%\n"); + returnString.append(StringUtils.rightPad("Reliability (standard deviation)", 40)).append( + StringUtils.leftPad(decimalFormatter.format(normStats.getStandardDeviation()), 10)).append('\n'); + returnString.append(StringUtils.rightPad("Weighted precision", 40)).append( + StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getWeightedPrecision()), 10)).append('\n'); + returnString.append(StringUtils.rightPad("Weighted recall", 40)).append( + StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getWeightedRecall()), 10)).append('\n'); + returnString.append(StringUtils.rightPad("Weighted F1 score", 40)).append( + StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getWeightedF1score()), 10)).append('\n'); + + if (hasLL) { + returnString.append(StringUtils.rightPad("Log-likelihood", 30)).append("mean : ").append( + StringUtils.leftPad(decimalFormatter.format(summarizer.getMean()), 10)).append('\n'); + returnString.append(StringUtils.rightPad("", 30)).append(StringUtils.rightPad("25%-ile : ", 10)).append( + StringUtils.leftPad(decimalFormatter.format(summarizer.getQuartile(1)), 10)).append('\n'); + returnString.append(StringUtils.rightPad("", 30)).append(StringUtils.rightPad("75%-ile : ", 10)).append( + StringUtils.leftPad(decimalFormatter.format(summarizer.getQuartile(3)), 10)).append('\n'); + } + + return returnString.toString(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/Bagging.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/Bagging.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/Bagging.java new file mode 100644 index 0000000..f79a429 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/Bagging.java @@ -0,0 +1,61 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df; + +import org.apache.mahout.classifier.df.builder.TreeBuilder; +import org.apache.mahout.classifier.df.data.Data; +import org.apache.mahout.classifier.df.node.Node; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Arrays; +import java.util.Random; + +/** + * Builds a tree using bagging + */ +@Deprecated +public class Bagging { + + private static final Logger log = LoggerFactory.getLogger(Bagging.class); + + private final TreeBuilder treeBuilder; + + private final Data data; + + private final boolean[] sampled; + + public Bagging(TreeBuilder treeBuilder, Data data) { + this.treeBuilder = treeBuilder; + this.data = data; + sampled = new boolean[data.size()]; + } + + /** + * Builds one tree + */ + public Node build(Random rng) { + log.debug("Bagging..."); + Arrays.fill(sampled, false); + Data bag = data.bagging(rng, sampled); + + log.debug("Building..."); + return treeBuilder.build(rng, bag); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/DFUtils.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/DFUtils.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/DFUtils.java new file mode 100644 index 0000000..c94292c --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/DFUtils.java @@ -0,0 +1,174 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * <p/> + * http://www.apache.org/licenses/LICENSE-2.0 + * <p/> + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.List; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FSDataOutputStream; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.Writable; +import org.apache.mahout.classifier.df.node.Node; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; + +/** + * Utility class that contains various helper methods + */ +@Deprecated +public final class DFUtils { + + private DFUtils() { + } + + /** + * Writes an Node[] into a DataOutput + * @throws java.io.IOException + */ + public static void writeArray(DataOutput out, Node[] array) throws IOException { + out.writeInt(array.length); + for (Node w : array) { + w.write(out); + } + } + + /** + * Reads a Node[] from a DataInput + * @throws java.io.IOException + */ + public static Node[] readNodeArray(DataInput in) throws IOException { + int length = in.readInt(); + Node[] nodes = new Node[length]; + for (int index = 0; index < length; index++) { + nodes[index] = Node.read(in); + } + + return nodes; + } + + /** + * Writes a double[] into a DataOutput + * @throws java.io.IOException + */ + public static void writeArray(DataOutput out, double[] array) throws IOException { + out.writeInt(array.length); + for (double value : array) { + out.writeDouble(value); + } + } + + /** + * Reads a double[] from a DataInput + * @throws java.io.IOException + */ + public static double[] readDoubleArray(DataInput in) throws IOException { + int length = in.readInt(); + double[] array = new double[length]; + for (int index = 0; index < length; index++) { + array[index] = in.readDouble(); + } + + return array; + } + + /** + * Writes an int[] into a DataOutput + * @throws java.io.IOException + */ + public static void writeArray(DataOutput out, int[] array) throws IOException { + out.writeInt(array.length); + for (int value : array) { + out.writeInt(value); + } + } + + /** + * Reads an int[] from a DataInput + * @throws java.io.IOException + */ + public static int[] readIntArray(DataInput in) throws IOException { + int length = in.readInt(); + int[] array = new int[length]; + for (int index = 0; index < length; index++) { + array[index] = in.readInt(); + } + + return array; + } + + /** + * Return a list of all files in the output directory + * @throws IOException if no file is found + */ + public static Path[] listOutputFiles(FileSystem fs, Path outputPath) throws IOException { + List<Path> outputFiles = new ArrayList<>(); + for (FileStatus s : fs.listStatus(outputPath, PathFilters.logsCRCFilter())) { + if (!s.isDir() && !s.getPath().getName().startsWith("_")) { + outputFiles.add(s.getPath()); + } + } + if (outputFiles.isEmpty()) { + throw new IOException("No output found !"); + } + return outputFiles.toArray(new Path[outputFiles.size()]); + } + + /** + * Formats a time interval in milliseconds to a String in the form "hours:minutes:seconds:millis" + */ + public static String elapsedTime(long milli) { + long seconds = milli / 1000; + milli %= 1000; + + long minutes = seconds / 60; + seconds %= 60; + + long hours = minutes / 60; + minutes %= 60; + + return hours + "h " + minutes + "m " + seconds + "s " + milli; + } + + public static void storeWritable(Configuration conf, Path path, Writable writable) throws IOException { + FileSystem fs = path.getFileSystem(conf); + + try (FSDataOutputStream out = fs.create(path)) { + writable.write(out); + } + } + + /** + * Write a string to a path. + * @param conf From which the file system will be picked + * @param path Where the string will be written + * @param string The string to write + * @throws IOException if things go poorly + */ + public static void storeString(Configuration conf, Path path, String string) throws IOException { + try (DataOutputStream out = path.getFileSystem(conf).create(path)) { + out.write(string.getBytes(Charset.defaultCharset())); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/DecisionForest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/DecisionForest.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/DecisionForest.java new file mode 100644 index 0000000..c11cf34 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/DecisionForest.java @@ -0,0 +1,241 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df; + +import com.google.common.base.Preconditions; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FSDataInputStream; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.Writable; +import org.apache.mahout.classifier.df.data.Data; +import org.apache.mahout.classifier.df.data.DataUtils; +import org.apache.mahout.classifier.df.data.Dataset; +import org.apache.mahout.classifier.df.data.Instance; +import org.apache.mahout.classifier.df.node.Node; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +/** + * Represents a forest of decision trees. + */ +@Deprecated +public class DecisionForest implements Writable { + + private final List<Node> trees; + + private DecisionForest() { + trees = new ArrayList<>(); + } + + public DecisionForest(List<Node> trees) { + Preconditions.checkArgument(trees != null && !trees.isEmpty(), "trees argument must not be null or empty"); + + this.trees = trees; + } + + List<Node> getTrees() { + return trees; + } + + /** + * Classifies the data and calls callback for each classification + */ + public void classify(Data data, double[][] predictions) { + Preconditions.checkArgument(data.size() == predictions.length, "predictions.length must be equal to data.size()"); + + if (data.isEmpty()) { + return; // nothing to classify + } + + int treeId = 0; + for (Node tree : trees) { + for (int index = 0; index < data.size(); index++) { + if (predictions[index] == null) { + predictions[index] = new double[trees.size()]; + } + predictions[index][treeId] = tree.classify(data.get(index)); + } + treeId++; + } + } + + /** + * predicts the label for the instance + * + * @param rng + * Random number generator, used to break ties randomly + * @return NaN if the label cannot be predicted + */ + public double classify(Dataset dataset, Random rng, Instance instance) { + if (dataset.isNumerical(dataset.getLabelId())) { + double sum = 0; + int cnt = 0; + for (Node tree : trees) { + double prediction = tree.classify(instance); + if (!Double.isNaN(prediction)) { + sum += prediction; + cnt++; + } + } + + if (cnt > 0) { + return sum / cnt; + } else { + return Double.NaN; + } + } else { + int[] predictions = new int[dataset.nblabels()]; + for (Node tree : trees) { + double prediction = tree.classify(instance); + if (!Double.isNaN(prediction)) { + predictions[(int) prediction]++; + } + } + + if (DataUtils.sum(predictions) == 0) { + return Double.NaN; // no prediction available + } + + return DataUtils.maxindex(rng, predictions); + } + } + + /** + * @return Mean number of nodes per tree + */ + public long meanNbNodes() { + long sum = 0; + + for (Node tree : trees) { + sum += tree.nbNodes(); + } + + return sum / trees.size(); + } + + /** + * @return Total number of nodes in all the trees + */ + public long nbNodes() { + long sum = 0; + + for (Node tree : trees) { + sum += tree.nbNodes(); + } + + return sum; + } + + /** + * @return Mean maximum depth per tree + */ + public long meanMaxDepth() { + long sum = 0; + + for (Node tree : trees) { + sum += tree.maxDepth(); + } + + return sum / trees.size(); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof DecisionForest)) { + return false; + } + + DecisionForest rf = (DecisionForest) obj; + + return trees.size() == rf.getTrees().size() && trees.containsAll(rf.getTrees()); + } + + @Override + public int hashCode() { + return trees.hashCode(); + } + + @Override + public void write(DataOutput dataOutput) throws IOException { + dataOutput.writeInt(trees.size()); + for (Node tree : trees) { + tree.write(dataOutput); + } + } + + /** + * Reads the trees from the input and adds them to the existing trees + */ + @Override + public void readFields(DataInput dataInput) throws IOException { + int size = dataInput.readInt(); + for (int i = 0; i < size; i++) { + trees.add(Node.read(dataInput)); + } + } + + /** + * Read the forest from inputStream + * @param dataInput - input forest + * @return {@link org.apache.mahout.classifier.df.DecisionForest} + * @throws IOException + */ + public static DecisionForest read(DataInput dataInput) throws IOException { + DecisionForest forest = new DecisionForest(); + forest.readFields(dataInput); + return forest; + } + + /** + * Load the forest from a single file or a directory of files + * @throws java.io.IOException + */ + public static DecisionForest load(Configuration conf, Path forestPath) throws IOException { + FileSystem fs = forestPath.getFileSystem(conf); + Path[] files; + if (fs.getFileStatus(forestPath).isDir()) { + files = DFUtils.listOutputFiles(fs, forestPath); + } else { + files = new Path[]{forestPath}; + } + + DecisionForest forest = null; + for (Path path : files) { + try (FSDataInputStream dataInput = new FSDataInputStream(fs.open(path))) { + if (forest == null) { + forest = read(dataInput); + } else { + forest.readFields(dataInput); + } + } + } + + return forest; + + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/ErrorEstimate.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/ErrorEstimate.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/ErrorEstimate.java new file mode 100644 index 0000000..13cd386 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/ErrorEstimate.java @@ -0,0 +1,51 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df; + +import com.google.common.base.Preconditions; + +/** + * Various methods to compute from the output of a random forest + */ +@Deprecated +public final class ErrorEstimate { + + private ErrorEstimate() { + } + + public static double errorRate(double[] labels, double[] predictions) { + Preconditions.checkArgument(labels.length == predictions.length, "labels.length != predictions.length"); + double nberrors = 0; // number of instance that got bad predictions + double datasize = 0; // number of classified instances + + for (int index = 0; index < labels.length; index++) { + if (predictions[index] == -1) { + continue; // instance not classified + } + + if (predictions[index] != labels[index]) { + nberrors++; + } + + datasize++; + } + + return nberrors / datasize; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilder.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilder.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilder.java new file mode 100644 index 0000000..9f84e9c --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilder.java @@ -0,0 +1,422 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.builder; + +import org.apache.mahout.classifier.df.data.Data; +import org.apache.mahout.classifier.df.data.Dataset; +import org.apache.mahout.classifier.df.data.Instance; +import org.apache.mahout.classifier.df.data.conditions.Condition; +import org.apache.mahout.classifier.df.node.CategoricalNode; +import org.apache.mahout.classifier.df.node.Leaf; +import org.apache.mahout.classifier.df.node.Node; +import org.apache.mahout.classifier.df.node.NumericalNode; +import org.apache.mahout.classifier.df.split.IgSplit; +import org.apache.mahout.classifier.df.split.OptIgSplit; +import org.apache.mahout.classifier.df.split.RegressionSplit; +import org.apache.mahout.classifier.df.split.Split; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collection; +import java.util.HashSet; +import java.util.Random; + +/** + * Builds a classification tree or regression tree<br> + * A classification tree is built when the criterion variable is the categorical attribute.<br> + * A regression tree is built when the criterion variable is the numerical attribute. + */ +@Deprecated +public class DecisionTreeBuilder implements TreeBuilder { + + private static final Logger log = LoggerFactory.getLogger(DecisionTreeBuilder.class); + + private static final int[] NO_ATTRIBUTES = new int[0]; + private static final double EPSILON = 1.0e-6; + + /** + * indicates which CATEGORICAL attributes have already been selected in the parent nodes + */ + private boolean[] selected; + /** + * number of attributes to select randomly at each node + */ + private int m; + /** + * IgSplit implementation + */ + private IgSplit igSplit; + /** + * tree is complemented + */ + private boolean complemented = true; + /** + * minimum number for split + */ + private double minSplitNum = 2.0; + /** + * minimum proportion of the total variance for split + */ + private double minVarianceProportion = 1.0e-3; + /** + * full set data + */ + private Data fullSet; + /** + * minimum variance for split + */ + private double minVariance = Double.NaN; + + public void setM(int m) { + this.m = m; + } + + public void setIgSplit(IgSplit igSplit) { + this.igSplit = igSplit; + } + + public void setComplemented(boolean complemented) { + this.complemented = complemented; + } + + public void setMinSplitNum(int minSplitNum) { + this.minSplitNum = minSplitNum; + } + + public void setMinVarianceProportion(double minVarianceProportion) { + this.minVarianceProportion = minVarianceProportion; + } + + @Override + public Node build(Random rng, Data data) { + if (selected == null) { + selected = new boolean[data.getDataset().nbAttributes()]; + selected[data.getDataset().getLabelId()] = true; // never select the label + } + if (m == 0) { + // set default m + double e = data.getDataset().nbAttributes() - 1; + if (data.getDataset().isNumerical(data.getDataset().getLabelId())) { + // regression + m = (int) Math.ceil(e / 3.0); + } else { + // classification + m = (int) Math.ceil(Math.sqrt(e)); + } + } + + if (data.isEmpty()) { + return new Leaf(Double.NaN); + } + + double sum = 0.0; + if (data.getDataset().isNumerical(data.getDataset().getLabelId())) { + // regression + // sum and sum squared of a label is computed + double sumSquared = 0.0; + for (int i = 0; i < data.size(); i++) { + double label = data.getDataset().getLabel(data.get(i)); + sum += label; + sumSquared += label * label; + } + + // computes the variance + double var = sumSquared - (sum * sum) / data.size(); + + // computes the minimum variance + if (Double.compare(minVariance, Double.NaN) == 0) { + minVariance = var / data.size() * minVarianceProportion; + log.debug("minVariance:{}", minVariance); + } + + // variance is compared with minimum variance + if ((var / data.size()) < minVariance) { + log.debug("variance({}) < minVariance({}) Leaf({})", var / data.size(), minVariance, sum / data.size()); + return new Leaf(sum / data.size()); + } + } else { + // classification + if (isIdentical(data)) { + return new Leaf(data.majorityLabel(rng)); + } + if (data.identicalLabel()) { + return new Leaf(data.getDataset().getLabel(data.get(0))); + } + } + + // store full set data + if (fullSet == null) { + fullSet = data; + } + + int[] attributes = randomAttributes(rng, selected, m); + if (attributes == null || attributes.length == 0) { + // we tried all the attributes and could not split the data anymore + double label; + if (data.getDataset().isNumerical(data.getDataset().getLabelId())) { + // regression + label = sum / data.size(); + } else { + // classification + label = data.majorityLabel(rng); + } + log.warn("attribute which can be selected is not found Leaf({})", label); + return new Leaf(label); + } + + if (igSplit == null) { + if (data.getDataset().isNumerical(data.getDataset().getLabelId())) { + // regression + igSplit = new RegressionSplit(); + } else { + // classification + igSplit = new OptIgSplit(); + } + } + + // find the best split + Split best = null; + for (int attr : attributes) { + Split split = igSplit.computeSplit(data, attr); + if (best == null || best.getIg() < split.getIg()) { + best = split; + } + } + + // information gain is near to zero. + if (best.getIg() < EPSILON) { + double label; + if (data.getDataset().isNumerical(data.getDataset().getLabelId())) { + label = sum / data.size(); + } else { + label = data.majorityLabel(rng); + } + log.debug("ig is near to zero Leaf({})", label); + return new Leaf(label); + } + + log.debug("best split attr:{}, split:{}, ig:{}", best.getAttr(), best.getSplit(), best.getIg()); + + boolean alreadySelected = selected[best.getAttr()]; + if (alreadySelected) { + // attribute already selected + log.warn("attribute {} already selected in a parent node", best.getAttr()); + } + + Node childNode; + if (data.getDataset().isNumerical(best.getAttr())) { + boolean[] temp = null; + + Data loSubset = data.subset(Condition.lesser(best.getAttr(), best.getSplit())); + Data hiSubset = data.subset(Condition.greaterOrEquals(best.getAttr(), best.getSplit())); + + if (loSubset.isEmpty() || hiSubset.isEmpty()) { + // the selected attribute did not change the data, avoid using it in the child notes + selected[best.getAttr()] = true; + } else { + // the data changed, so we can unselect all previousely selected NUMERICAL attributes + temp = selected; + selected = cloneCategoricalAttributes(data.getDataset(), selected); + } + + // size of the subset is less than the minSpitNum + if (loSubset.size() < minSplitNum || hiSubset.size() < minSplitNum) { + // branch is not split + double label; + if (data.getDataset().isNumerical(data.getDataset().getLabelId())) { + label = sum / data.size(); + } else { + label = data.majorityLabel(rng); + } + log.debug("branch is not split Leaf({})", label); + return new Leaf(label); + } + + Node loChild = build(rng, loSubset); + Node hiChild = build(rng, hiSubset); + + // restore the selection state of the attributes + if (temp != null) { + selected = temp; + } else { + selected[best.getAttr()] = alreadySelected; + } + + childNode = new NumericalNode(best.getAttr(), best.getSplit(), loChild, hiChild); + } else { // CATEGORICAL attribute + double[] values = data.values(best.getAttr()); + + // tree is complemented + Collection<Double> subsetValues = null; + if (complemented) { + subsetValues = new HashSet<>(); + for (double value : values) { + subsetValues.add(value); + } + values = fullSet.values(best.getAttr()); + } + + int cnt = 0; + Data[] subsets = new Data[values.length]; + for (int index = 0; index < values.length; index++) { + if (complemented && !subsetValues.contains(values[index])) { + continue; + } + subsets[index] = data.subset(Condition.equals(best.getAttr(), values[index])); + if (subsets[index].size() >= minSplitNum) { + cnt++; + } + } + + // size of the subset is less than the minSpitNum + if (cnt < 2) { + // branch is not split + double label; + if (data.getDataset().isNumerical(data.getDataset().getLabelId())) { + label = sum / data.size(); + } else { + label = data.majorityLabel(rng); + } + log.debug("branch is not split Leaf({})", label); + return new Leaf(label); + } + + selected[best.getAttr()] = true; + + Node[] children = new Node[values.length]; + for (int index = 0; index < values.length; index++) { + if (complemented && (subsetValues == null || !subsetValues.contains(values[index]))) { + // tree is complemented + double label; + if (data.getDataset().isNumerical(data.getDataset().getLabelId())) { + label = sum / data.size(); + } else { + label = data.majorityLabel(rng); + } + log.debug("complemented Leaf({})", label); + children[index] = new Leaf(label); + continue; + } + children[index] = build(rng, subsets[index]); + } + + selected[best.getAttr()] = alreadySelected; + + childNode = new CategoricalNode(best.getAttr(), values, children); + } + + return childNode; + } + + /** + * checks if all the vectors have identical attribute values. Ignore selected attributes. + * + * @return true is all the vectors are identical or the data is empty<br> + * false otherwise + */ + private boolean isIdentical(Data data) { + if (data.isEmpty()) { + return true; + } + + Instance instance = data.get(0); + for (int attr = 0; attr < selected.length; attr++) { + if (selected[attr]) { + continue; + } + + for (int index = 1; index < data.size(); index++) { + if (data.get(index).get(attr) != instance.get(attr)) { + return false; + } + } + } + + return true; + } + + /** + * Make a copy of the selection state of the attributes, unselect all numerical attributes + * + * @param selected selection state to clone + * @return cloned selection state + */ + private static boolean[] cloneCategoricalAttributes(Dataset dataset, boolean[] selected) { + boolean[] cloned = new boolean[selected.length]; + + for (int i = 0; i < selected.length; i++) { + cloned[i] = !dataset.isNumerical(i) && selected[i]; + } + cloned[dataset.getLabelId()] = true; + + return cloned; + } + + /** + * Randomly selects m attributes to consider for split, excludes IGNORED and LABEL attributes + * + * @param rng random-numbers generator + * @param selected attributes' state (selected or not) + * @param m number of attributes to choose + * @return list of selected attributes' indices, or null if all attributes have already been selected + */ + private static int[] randomAttributes(Random rng, boolean[] selected, int m) { + int nbNonSelected = 0; // number of non selected attributes + for (boolean sel : selected) { + if (!sel) { + nbNonSelected++; + } + } + + if (nbNonSelected == 0) { + log.warn("All attributes are selected !"); + return NO_ATTRIBUTES; + } + + int[] result; + if (nbNonSelected <= m) { + // return all non selected attributes + result = new int[nbNonSelected]; + int index = 0; + for (int attr = 0; attr < selected.length; attr++) { + if (!selected[attr]) { + result[index++] = attr; + } + } + } else { + result = new int[m]; + for (int index = 0; index < m; index++) { + // randomly choose a "non selected" attribute + int rind; + do { + rind = rng.nextInt(selected.length); + } while (selected[rind]); + + result[index] = rind; + selected[rind] = true; // temporarily set the chosen attribute to be selected + } + + // the chosen attributes are not yet selected + for (int attr : result) { + selected[attr] = false; + } + } + + return result; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/builder/DefaultTreeBuilder.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/builder/DefaultTreeBuilder.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/builder/DefaultTreeBuilder.java new file mode 100644 index 0000000..3392fb1 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/builder/DefaultTreeBuilder.java @@ -0,0 +1,253 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.builder; + +import org.apache.mahout.classifier.df.data.Data; +import org.apache.mahout.classifier.df.data.Dataset; +import org.apache.mahout.classifier.df.data.Instance; +import org.apache.mahout.classifier.df.data.conditions.Condition; +import org.apache.mahout.classifier.df.node.CategoricalNode; +import org.apache.mahout.classifier.df.node.Leaf; +import org.apache.mahout.classifier.df.node.Node; +import org.apache.mahout.classifier.df.node.NumericalNode; +import org.apache.mahout.classifier.df.split.IgSplit; +import org.apache.mahout.classifier.df.split.OptIgSplit; +import org.apache.mahout.classifier.df.split.Split; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Random; + +/** + * Builds a Decision Tree <br> + * Based on the algorithm described in the "Decision Trees" tutorials by Andrew W. Moore, available at:<br> + * <br> + * http://www.cs.cmu.edu/~awm/tutorials + * <br><br> + * This class can be used when the criterion variable is the categorical attribute. + */ +@Deprecated +public class DefaultTreeBuilder implements TreeBuilder { + + private static final Logger log = LoggerFactory.getLogger(DefaultTreeBuilder.class); + + private static final int[] NO_ATTRIBUTES = new int[0]; + + /** + * indicates which CATEGORICAL attributes have already been selected in the parent nodes + */ + private boolean[] selected; + /** + * number of attributes to select randomly at each node + */ + private int m = 1; + /** + * IgSplit implementation + */ + private final IgSplit igSplit; + + public DefaultTreeBuilder() { + igSplit = new OptIgSplit(); + } + + public void setM(int m) { + this.m = m; + } + + @Override + public Node build(Random rng, Data data) { + + if (selected == null) { + selected = new boolean[data.getDataset().nbAttributes()]; + selected[data.getDataset().getLabelId()] = true; // never select the label + } + + if (data.isEmpty()) { + return new Leaf(-1); + } + if (isIdentical(data)) { + return new Leaf(data.majorityLabel(rng)); + } + if (data.identicalLabel()) { + return new Leaf(data.getDataset().getLabel(data.get(0))); + } + + int[] attributes = randomAttributes(rng, selected, m); + if (attributes == null || attributes.length == 0) { + // we tried all the attributes and could not split the data anymore + return new Leaf(data.majorityLabel(rng)); + } + + // find the best split + Split best = null; + for (int attr : attributes) { + Split split = igSplit.computeSplit(data, attr); + if (best == null || best.getIg() < split.getIg()) { + best = split; + } + } + + boolean alreadySelected = selected[best.getAttr()]; + if (alreadySelected) { + // attribute already selected + log.warn("attribute {} already selected in a parent node", best.getAttr()); + } + + Node childNode; + if (data.getDataset().isNumerical(best.getAttr())) { + boolean[] temp = null; + + Data loSubset = data.subset(Condition.lesser(best.getAttr(), best.getSplit())); + Data hiSubset = data.subset(Condition.greaterOrEquals(best.getAttr(), best.getSplit())); + + if (loSubset.isEmpty() || hiSubset.isEmpty()) { + // the selected attribute did not change the data, avoid using it in the child notes + selected[best.getAttr()] = true; + } else { + // the data changed, so we can unselect all previousely selected NUMERICAL attributes + temp = selected; + selected = cloneCategoricalAttributes(data.getDataset(), selected); + } + + Node loChild = build(rng, loSubset); + Node hiChild = build(rng, hiSubset); + + // restore the selection state of the attributes + if (temp != null) { + selected = temp; + } else { + selected[best.getAttr()] = alreadySelected; + } + + childNode = new NumericalNode(best.getAttr(), best.getSplit(), loChild, hiChild); + } else { // CATEGORICAL attribute + selected[best.getAttr()] = true; + + double[] values = data.values(best.getAttr()); + Node[] children = new Node[values.length]; + + for (int index = 0; index < values.length; index++) { + Data subset = data.subset(Condition.equals(best.getAttr(), values[index])); + children[index] = build(rng, subset); + } + + selected[best.getAttr()] = alreadySelected; + + childNode = new CategoricalNode(best.getAttr(), values, children); + } + + return childNode; + } + + /** + * checks if all the vectors have identical attribute values. Ignore selected attributes. + * + * @return true is all the vectors are identical or the data is empty<br> + * false otherwise + */ + private boolean isIdentical(Data data) { + if (data.isEmpty()) { + return true; + } + + Instance instance = data.get(0); + for (int attr = 0; attr < selected.length; attr++) { + if (selected[attr]) { + continue; + } + + for (int index = 1; index < data.size(); index++) { + if (data.get(index).get(attr) != instance.get(attr)) { + return false; + } + } + } + + return true; + } + + + /** + * Make a copy of the selection state of the attributes, unselect all numerical attributes + * + * @param selected selection state to clone + * @return cloned selection state + */ + private static boolean[] cloneCategoricalAttributes(Dataset dataset, boolean[] selected) { + boolean[] cloned = new boolean[selected.length]; + + for (int i = 0; i < selected.length; i++) { + cloned[i] = !dataset.isNumerical(i) && selected[i]; + } + + return cloned; + } + + /** + * Randomly selects m attributes to consider for split, excludes IGNORED and LABEL attributes + * + * @param rng random-numbers generator + * @param selected attributes' state (selected or not) + * @param m number of attributes to choose + * @return list of selected attributes' indices, or null if all attributes have already been selected + */ + protected static int[] randomAttributes(Random rng, boolean[] selected, int m) { + int nbNonSelected = 0; // number of non selected attributes + for (boolean sel : selected) { + if (!sel) { + nbNonSelected++; + } + } + + if (nbNonSelected == 0) { + log.warn("All attributes are selected !"); + return NO_ATTRIBUTES; + } + + int[] result; + if (nbNonSelected <= m) { + // return all non selected attributes + result = new int[nbNonSelected]; + int index = 0; + for (int attr = 0; attr < selected.length; attr++) { + if (!selected[attr]) { + result[index++] = attr; + } + } + } else { + result = new int[m]; + for (int index = 0; index < m; index++) { + // randomly choose a "non selected" attribute + int rind; + do { + rind = rng.nextInt(selected.length); + } while (selected[rind]); + + result[index] = rind; + selected[rind] = true; // temporarily set the chosen attribute to be selected + } + + // the chosen attributes are not yet selected + for (int attr : result) { + selected[attr] = false; + } + } + + return result; + } +}
