Repository: ignite Updated Branches: refs/heads/master edcc1089a -> 4800b8729
IGNITE-10605: [ML] Add binary metrics calculations to Cross-Validation This closes #5712 Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/4800b872 Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/4800b872 Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/4800b872 Branch: refs/heads/master Commit: 4800b8729565728877d31b6161a19fe1632225d8 Parents: edcc108 Author: zaleslaw <[email protected]> Authored: Fri Dec 21 14:43:21 2018 +0300 Committer: Yury Babak <[email protected]> Committed: Fri Dec 21 14:43:21 2018 +0300 ---------------------------------------------------------------------- .../ml/selection/cv/CrossValidationExample.java | 23 ++- .../Step_8_CV_with_Param_Grid_and_metrics.java | 192 +++++++++++++++++++ .../BinaryClassificationEvaluator.java | 8 +- .../metric/BinaryClassificationMetrics.java | 25 ++- .../ml/selection/cv/CrossValidationTest.java | 42 ++++ 5 files changed, 281 insertions(+), 9 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/4800b872/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java index 552bcd2..462186c 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java @@ -27,6 +27,8 @@ import org.apache.ignite.configuration.CacheConfiguration; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.ml.selection.cv.CrossValidation; import org.apache.ignite.ml.selection.scoring.metric.Accuracy; +import org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetricValues; +import org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetrics; import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; import org.apache.ignite.ml.tree.DecisionTreeNode; @@ -71,7 +73,7 @@ public class CrossValidationExample { CrossValidation<DecisionTreeNode, Double, Integer, LabeledPoint> scoreCalculator = new CrossValidation<>(); - double[] scores = scoreCalculator.score( + double[] accuracyScores = scoreCalculator.score( trainer, new Accuracy<>(), ignite, @@ -81,7 +83,24 @@ public class CrossValidationExample { 4 ); - System.out.println(">>> Accuracy: " + Arrays.toString(scores)); + System.out.println(">>> Accuracy: " + Arrays.toString(accuracyScores)); + + BinaryClassificationMetrics metrics = new BinaryClassificationMetrics() + .withNegativeClsLb(0.0) + .withPositiveClsLb(1.0) + .withMetric(BinaryClassificationMetricValues::balancedAccuracy); + + double[] balancedAccuracyScores = scoreCalculator.score( + trainer, + metrics, + ignite, + trainingSet, + (k, v) -> VectorUtils.of(v.x, v.y), + (k, v) -> v.lb, + 4 + ); + + System.out.println(">>> Balanced Accuracy: " + Arrays.toString(balancedAccuracyScores)); System.out.println(">>> Cross validation score calculator example completed."); } http://git-wip-us.apache.org/repos/asf/ignite/blob/4800b872/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_metrics.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_metrics.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_metrics.java new file mode 100644 index 0000000..0ea0ca2 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_metrics.java @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.tutorial; + +import java.io.FileNotFoundException; +import java.util.Arrays; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer; +import org.apache.ignite.ml.preprocessing.encoding.EncoderType; +import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer; +import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer; +import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer; +import org.apache.ignite.ml.selection.cv.CrossValidation; +import org.apache.ignite.ml.selection.cv.CrossValidationResult; +import org.apache.ignite.ml.selection.paramgrid.ParamGrid; +import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator; +import org.apache.ignite.ml.selection.scoring.metric.Accuracy; +import org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetricValues; +import org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetrics; +import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter; +import org.apache.ignite.ml.selection.split.TrainTestSplit; +import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; +import org.apache.ignite.ml.tree.DecisionTreeNode; + +/** + * To choose the best hyperparameters the cross-validation with {@link ParamGrid} will be used in this example. + * <p> + * Code in this example launches Ignite grid and fills the cache with test data (based on Titanic passengers data).</p> + * <p> + * After that it defines how to split the data to train and test sets and configures preprocessors that extract + * features from an upstream data and perform other desired changes over the extracted data.</p> + * <p> + * Then, it tunes hyperparams with K-fold Cross-Validation on the split training set and trains the model based on + * the processed data using decision tree classification and the obtained hyperparams.</p> + * <p> + * Finally, this example uses {@link BinaryClassificationEvaluator} functionality to compute metrics from predictions.</p> + * <p> + * The purpose of cross-validation is model checking, not model building.</p> + * <p> + * We train {@code k} different models.</p> + * <p> + * They differ in that {@code 1/(k-1)}th of the training data is exchanged against other cases.</p> + * <p> + * These models are sometimes called surrogate models because the (average) performance measured for these models + * is taken as a surrogate of the performance of the model trained on all cases.</p> + * <p> + * All scenarios are described there: https://sebastianraschka.com/faq/docs/evaluate-a-model.html</p> + */ +public class Step_8_CV_with_Param_Grid_and_metrics { + /** Run example. */ + public static void main(String[] args) { + System.out.println(); + System.out.println(">>> Tutorial step 8 (cross-validation with param grid) example started."); + + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + try { + IgniteCache<Integer, Object[]> dataCache = TitanicUtils.readPassengers(ignite); + + // Defines first preprocessor that extracts features from an upstream data. + // Extracts "pclass", "sibsp", "parch", "sex", "embarked", "age", "fare" . + IgniteBiFunction<Integer, Object[], Object[]> featureExtractor + = (k, v) -> new Object[] {v[0], v[3], v[4], v[5], v[6], v[8], v[10]}; + + IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double)v[1]; + + TrainTestSplit<Integer, Object[]> split = new TrainTestDatasetSplitter<Integer, Object[]>() + .split(0.75); + + IgniteBiFunction<Integer, Object[], Vector> strEncoderPreprocessor = new EncoderTrainer<Integer, Object[]>() + .withEncoderType(EncoderType.STRING_ENCODER) + .withEncodedFeature(1) + .withEncodedFeature(6) // <--- Changed index here. + .fit(ignite, + dataCache, + featureExtractor + ); + + IgniteBiFunction<Integer, Object[], Vector> imputingPreprocessor = new ImputerTrainer<Integer, Object[]>() + .fit(ignite, + dataCache, + strEncoderPreprocessor + ); + + IgniteBiFunction<Integer, Object[], Vector> minMaxScalerPreprocessor = new MinMaxScalerTrainer<Integer, Object[]>() + .fit( + ignite, + dataCache, + imputingPreprocessor + ); + + IgniteBiFunction<Integer, Object[], Vector> normalizationPreprocessor = new NormalizationTrainer<Integer, Object[]>() + .withP(2) + .fit( + ignite, + dataCache, + minMaxScalerPreprocessor + ); + + // Tune hyperparams with K-fold Cross-Validation on the split training set. + + DecisionTreeClassificationTrainer trainerCV = new DecisionTreeClassificationTrainer(); + + CrossValidation<DecisionTreeNode, Double, Integer, Object[]> scoreCalculator + = new CrossValidation<>(); + + ParamGrid paramGrid = new ParamGrid() + .addHyperParam("maxDeep", new Double[] {1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 10.0}) + .addHyperParam("minImpurityDecrease", new Double[] {0.0, 0.25, 0.5}); + + BinaryClassificationMetrics metrics = new BinaryClassificationMetrics() + .withNegativeClsLb(0.0) + .withPositiveClsLb(1.0) + .withMetric(BinaryClassificationMetricValues::accuracy); + + CrossValidationResult crossValidationRes = scoreCalculator.score( + trainerCV, + metrics, + ignite, + dataCache, + split.getTrainFilter(), + normalizationPreprocessor, + lbExtractor, + 3, + paramGrid + ); + + System.out.println("Train with maxDeep: " + crossValidationRes.getBest("maxDeep") + + " and minImpurityDecrease: " + crossValidationRes.getBest("minImpurityDecrease")); + + DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer() + .withMaxDeep(crossValidationRes.getBest("maxDeep")) + .withMinImpurityDecrease(crossValidationRes.getBest("minImpurityDecrease")); + + System.out.println(crossValidationRes); + + System.out.println("Best score: " + Arrays.toString(crossValidationRes.getBestScore())); + System.out.println("Best hyper params: " + crossValidationRes.getBestHyperParams()); + System.out.println("Best average score: " + crossValidationRes.getBestAvgScore()); + + crossValidationRes.getScoringBoard().forEach((hyperParams, score) + -> System.out.println("Score " + Arrays.toString(score) + " for hyper params " + hyperParams)); + + // Train decision tree model. + DecisionTreeNode bestMdl = trainer.fit( + ignite, + dataCache, + split.getTrainFilter(), + normalizationPreprocessor, + lbExtractor + ); + + System.out.println("\n>>> Trained model: " + bestMdl); + + double accuracy = BinaryClassificationEvaluator.evaluate( + dataCache, + split.getTestFilter(), + bestMdl, + normalizationPreprocessor, + lbExtractor, + new Accuracy<>() + ); + + System.out.println("\n>>> Accuracy " + accuracy); + System.out.println("\n>>> Test Error " + (1 - accuracy)); + + System.out.println(">>> Tutorial step 8 (cross-validation with param grid) example started."); + } + catch (FileNotFoundException e) { + e.printStackTrace(); + } + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/4800b872/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/BinaryClassificationEvaluator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/BinaryClassificationEvaluator.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/BinaryClassificationEvaluator.java index 30adc5c..9642bce 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/BinaryClassificationEvaluator.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/BinaryClassificationEvaluator.java @@ -49,7 +49,6 @@ public class BinaryClassificationEvaluator { IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, Metric<L> metric) { - return calculateMetric(dataCache, null, mdl, featureExtractor, lbExtractor, metric); } @@ -72,7 +71,6 @@ public class BinaryClassificationEvaluator { IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, Metric<L> metric) { - return calculateMetric(dataCache, filter, mdl, featureExtractor, lbExtractor, metric); } @@ -140,7 +138,7 @@ public class BinaryClassificationEvaluator { lbExtractor, mdl )) { - metricValues = binaryMetrics.score(cursor.iterator()); + metricValues = binaryMetrics.scoreAll(cursor.iterator()); } catch (Exception e) { throw new RuntimeException(e); } @@ -163,8 +161,8 @@ public class BinaryClassificationEvaluator { * @return Computed metric. */ private static <L, K, V> double calculateMetric(IgniteCache<K, V> dataCache, IgniteBiPredicate<K, V> filter, - Model<Vector, L> mdl, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, - Metric<L> metric) { + Model<Vector, L> mdl, IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, L> lbExtractor, Metric<L> metric) { double metricRes; try (LabelPairCursor<L> cursor = new CacheBasedLabelPairCursor<>( http://git-wip-us.apache.org/repos/asf/ignite/blob/4800b872/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetrics.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetrics.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetrics.java index 0b15d04..bd4067a 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetrics.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetrics.java @@ -18,25 +18,30 @@ package org.apache.ignite.ml.selection.scoring.metric; import java.util.Iterator; +import java.util.function.Function; import org.apache.ignite.ml.selection.scoring.LabelPair; /** * Binary classification metrics calculator. + * It could be used in two ways: to caculate all binary classification metrics or specific metric. */ -public class BinaryClassificationMetrics { +public class BinaryClassificationMetrics implements Metric<Double> { /** Positive class label. */ private double positiveClsLb = 1.0; /** Negative class label. Default value is 0.0. */ private double negativeClsLb; + /** The main metric to get individual score. */ + private Function<BinaryClassificationMetricValues, Double> metric = BinaryClassificationMetricValues::accuracy; + /** * Calculates binary metrics values. * * @param iter Iterator that supplies pairs of truth values and predicated. * @return Scores for all binary metrics. */ - public BinaryClassificationMetricValues score(Iterator<LabelPair<Double>> iter) { + public BinaryClassificationMetricValues scoreAll(Iterator<LabelPair<Double>> iter) { long tp = 0; long tn = 0; long fp = 0; @@ -83,4 +88,20 @@ public class BinaryClassificationMetrics { this.negativeClsLb = negativeClsLb; return this; } + + /** */ + public BinaryClassificationMetrics withMetric(Function<BinaryClassificationMetricValues, Double> metric) { + this.metric = metric; + return this; + } + + /** {@inheritDoc} */ + @Override public double score(Iterator<LabelPair<Double>> iter) { + return metric.apply(scoreAll(iter)); + } + + /** {@inheritDoc} */ + @Override public String name() { + return "Binary classification metrics"; + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/4800b872/modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationTest.java index 3e8b9dd..e64aa7a 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationTest.java @@ -21,6 +21,8 @@ import java.util.HashMap; import java.util.Map; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.ml.selection.scoring.metric.Accuracy; +import org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetricValues; +import org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetrics; import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; import org.apache.ignite.ml.tree.DecisionTreeNode; import org.junit.Test; @@ -71,6 +73,46 @@ public class CrossValidationTest { /** */ @Test + public void testScoreWithGoodDatasetAndBinaryMetrics() { + Map<Integer, Double> data = new HashMap<>(); + + for (int i = 0; i < 1000; i++) + data.put(i, i > 500 ? 1.0 : 0.0); + + DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(1, 0); + + CrossValidation<DecisionTreeNode, Double, Integer, Double> scoreCalculator = + new CrossValidation<>(); + + int folds = 4; + + BinaryClassificationMetrics metrics = new BinaryClassificationMetrics() + .withMetric(BinaryClassificationMetricValues::accuracy); + + verifyScores(folds, scoreCalculator.score( + trainer, + metrics, + data, + 1, + (k, v) -> VectorUtils.of(k), + (k, v) -> v, + folds + )); + + verifyScores(folds, scoreCalculator.score( + trainer, + new Accuracy<>(), + data, + (e1, e2) -> true, + 1, + (k, v) -> VectorUtils.of(k), + (k, v) -> v, + folds + )); + } + + /** */ + @Test public void testScoreWithBadDataset() { Map<Integer, Double> data = new HashMap<>();
