Repository: ignite
Updated Branches:
  refs/heads/master edcc1089a -> 4800b8729


IGNITE-10605: [ML] Add binary metrics calculations to Cross-Validation

This closes #5712


Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/4800b872
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/4800b872
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/4800b872

Branch: refs/heads/master
Commit: 4800b8729565728877d31b6161a19fe1632225d8
Parents: edcc108
Author: zaleslaw <[email protected]>
Authored: Fri Dec 21 14:43:21 2018 +0300
Committer: Yury Babak <[email protected]>
Committed: Fri Dec 21 14:43:21 2018 +0300

----------------------------------------------------------------------
 .../ml/selection/cv/CrossValidationExample.java |  23 ++-
 .../Step_8_CV_with_Param_Grid_and_metrics.java  | 192 +++++++++++++++++++
 .../BinaryClassificationEvaluator.java          |   8 +-
 .../metric/BinaryClassificationMetrics.java     |  25 ++-
 .../ml/selection/cv/CrossValidationTest.java    |  42 ++++
 5 files changed, 281 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/ignite/blob/4800b872/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java
----------------------------------------------------------------------
diff --git 
a/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java
 
b/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java
index 552bcd2..462186c 100644
--- 
a/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java
+++ 
b/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java
@@ -27,6 +27,8 @@ import org.apache.ignite.configuration.CacheConfiguration;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.apache.ignite.ml.selection.cv.CrossValidation;
 import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import 
org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetricValues;
+import 
org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetrics;
 import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
 import org.apache.ignite.ml.tree.DecisionTreeNode;
 
@@ -71,7 +73,7 @@ public class CrossValidationExample {
             CrossValidation<DecisionTreeNode, Double, Integer, LabeledPoint> 
scoreCalculator
                 = new CrossValidation<>();
 
-            double[] scores = scoreCalculator.score(
+            double[] accuracyScores = scoreCalculator.score(
                 trainer,
                 new Accuracy<>(),
                 ignite,
@@ -81,7 +83,24 @@ public class CrossValidationExample {
                 4
             );
 
-            System.out.println(">>> Accuracy: " + Arrays.toString(scores));
+            System.out.println(">>> Accuracy: " + 
Arrays.toString(accuracyScores));
+
+            BinaryClassificationMetrics metrics = new 
BinaryClassificationMetrics()
+                .withNegativeClsLb(0.0)
+                .withPositiveClsLb(1.0)
+                
.withMetric(BinaryClassificationMetricValues::balancedAccuracy);
+
+            double[] balancedAccuracyScores = scoreCalculator.score(
+                trainer,
+                metrics,
+                ignite,
+                trainingSet,
+                (k, v) -> VectorUtils.of(v.x, v.y),
+                (k, v) -> v.lb,
+                4
+            );
+
+            System.out.println(">>> Balanced Accuracy: " + 
Arrays.toString(balancedAccuracyScores));
 
             System.out.println(">>> Cross validation score calculator example 
completed.");
         }

http://git-wip-us.apache.org/repos/asf/ignite/blob/4800b872/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_metrics.java
----------------------------------------------------------------------
diff --git 
a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_metrics.java
 
b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_metrics.java
new file mode 100644
index 0000000..0ea0ca2
--- /dev/null
+++ 
b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_metrics.java
@@ -0,0 +1,192 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.tutorial;
+
+import java.io.FileNotFoundException;
+import java.util.Arrays;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer;
+import org.apache.ignite.ml.preprocessing.encoding.EncoderType;
+import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer;
+import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer;
+import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer;
+import org.apache.ignite.ml.selection.cv.CrossValidation;
+import org.apache.ignite.ml.selection.cv.CrossValidationResult;
+import org.apache.ignite.ml.selection.paramgrid.ParamGrid;
+import 
org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
+import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import 
org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetricValues;
+import 
org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetrics;
+import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter;
+import org.apache.ignite.ml.selection.split.TrainTestSplit;
+import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
+import org.apache.ignite.ml.tree.DecisionTreeNode;
+
+/**
+ * To choose the best hyperparameters the cross-validation with {@link 
ParamGrid} will be used in this example.
+ * <p>
+ * Code in this example launches Ignite grid and fills the cache with test 
data (based on Titanic passengers data).</p>
+ * <p>
+ * After that it defines how to split the data to train and test sets and 
configures preprocessors that extract
+ * features from an upstream data and perform other desired changes over the 
extracted data.</p>
+ * <p>
+ * Then, it tunes hyperparams with K-fold Cross-Validation on the split 
training set and trains the model based on
+ * the processed data using decision tree classification and the obtained 
hyperparams.</p>
+ * <p>
+ * Finally, this example uses {@link BinaryClassificationEvaluator} 
functionality to compute metrics from predictions.</p>
+ * <p>
+ * The purpose of cross-validation is model checking, not model building.</p>
+ * <p>
+ * We train {@code k} different models.</p>
+ * <p>
+ * They differ in that {@code 1/(k-1)}th of the training data is exchanged 
against other cases.</p>
+ * <p>
+ * These models are sometimes called surrogate models because the (average) 
performance measured for these models
+ * is taken as a surrogate of the performance of the model trained on all 
cases.</p>
+ * <p>
+ * All scenarios are described there: 
https://sebastianraschka.com/faq/docs/evaluate-a-model.html</p>
+ */
+public class Step_8_CV_with_Param_Grid_and_metrics {
+    /** Run example. */
+    public static void main(String[] args) {
+        System.out.println();
+        System.out.println(">>> Tutorial step 8 (cross-validation with param 
grid) example started.");
+
+        try (Ignite ignite = 
Ignition.start("examples/config/example-ignite.xml")) {
+            try {
+                IgniteCache<Integer, Object[]> dataCache = 
TitanicUtils.readPassengers(ignite);
+
+                // Defines first preprocessor that extracts features from an 
upstream data.
+                // Extracts "pclass", "sibsp", "parch", "sex", "embarked", 
"age", "fare" .
+                IgniteBiFunction<Integer, Object[], Object[]> featureExtractor
+                    = (k, v) -> new Object[] {v[0], v[3], v[4], v[5], v[6], 
v[8], v[10]};
+
+                IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, 
v) -> (double)v[1];
+
+                TrainTestSplit<Integer, Object[]> split = new 
TrainTestDatasetSplitter<Integer, Object[]>()
+                    .split(0.75);
+
+                IgniteBiFunction<Integer, Object[], Vector> 
strEncoderPreprocessor = new EncoderTrainer<Integer, Object[]>()
+                    .withEncoderType(EncoderType.STRING_ENCODER)
+                    .withEncodedFeature(1)
+                    .withEncodedFeature(6) // <--- Changed index here.
+                    .fit(ignite,
+                        dataCache,
+                        featureExtractor
+                    );
+
+                IgniteBiFunction<Integer, Object[], Vector> 
imputingPreprocessor = new ImputerTrainer<Integer, Object[]>()
+                    .fit(ignite,
+                        dataCache,
+                        strEncoderPreprocessor
+                    );
+
+                IgniteBiFunction<Integer, Object[], Vector> 
minMaxScalerPreprocessor = new MinMaxScalerTrainer<Integer, Object[]>()
+                    .fit(
+                        ignite,
+                        dataCache,
+                        imputingPreprocessor
+                    );
+
+                IgniteBiFunction<Integer, Object[], Vector> 
normalizationPreprocessor = new NormalizationTrainer<Integer, Object[]>()
+                    .withP(2)
+                    .fit(
+                        ignite,
+                        dataCache,
+                        minMaxScalerPreprocessor
+                    );
+
+                // Tune hyperparams with K-fold Cross-Validation on the split 
training set.
+
+                DecisionTreeClassificationTrainer trainerCV = new 
DecisionTreeClassificationTrainer();
+
+                CrossValidation<DecisionTreeNode, Double, Integer, Object[]> 
scoreCalculator
+                    = new CrossValidation<>();
+
+                ParamGrid paramGrid = new ParamGrid()
+                    .addHyperParam("maxDeep", new Double[] {1.0, 2.0, 3.0, 
4.0, 5.0, 10.0, 10.0})
+                    .addHyperParam("minImpurityDecrease", new Double[] {0.0, 
0.25, 0.5});
+
+                BinaryClassificationMetrics metrics = new 
BinaryClassificationMetrics()
+                    .withNegativeClsLb(0.0)
+                    .withPositiveClsLb(1.0)
+                    .withMetric(BinaryClassificationMetricValues::accuracy);
+
+                CrossValidationResult crossValidationRes = 
scoreCalculator.score(
+                    trainerCV,
+                    metrics,
+                    ignite,
+                    dataCache,
+                    split.getTrainFilter(),
+                    normalizationPreprocessor,
+                    lbExtractor,
+                    3,
+                    paramGrid
+                );
+
+                System.out.println("Train with maxDeep: " + 
crossValidationRes.getBest("maxDeep")
+                    + " and minImpurityDecrease: " + 
crossValidationRes.getBest("minImpurityDecrease"));
+
+                DecisionTreeClassificationTrainer trainer = new 
DecisionTreeClassificationTrainer()
+                    .withMaxDeep(crossValidationRes.getBest("maxDeep"))
+                    
.withMinImpurityDecrease(crossValidationRes.getBest("minImpurityDecrease"));
+
+                System.out.println(crossValidationRes);
+
+                System.out.println("Best score: " + 
Arrays.toString(crossValidationRes.getBestScore()));
+                System.out.println("Best hyper params: " + 
crossValidationRes.getBestHyperParams());
+                System.out.println("Best average score: " + 
crossValidationRes.getBestAvgScore());
+
+                crossValidationRes.getScoringBoard().forEach((hyperParams, 
score)
+                    -> System.out.println("Score " + Arrays.toString(score) + 
" for hyper params " + hyperParams));
+
+                // Train decision tree model.
+                DecisionTreeNode bestMdl = trainer.fit(
+                    ignite,
+                    dataCache,
+                    split.getTrainFilter(),
+                    normalizationPreprocessor,
+                    lbExtractor
+                );
+
+                System.out.println("\n>>> Trained model: " + bestMdl);
+
+                double accuracy = BinaryClassificationEvaluator.evaluate(
+                    dataCache,
+                    split.getTestFilter(),
+                    bestMdl,
+                    normalizationPreprocessor,
+                    lbExtractor,
+                    new Accuracy<>()
+                );
+
+                System.out.println("\n>>> Accuracy " + accuracy);
+                System.out.println("\n>>> Test Error " + (1 - accuracy));
+
+                System.out.println(">>> Tutorial step 8 (cross-validation with 
param grid) example started.");
+            }
+            catch (FileNotFoundException e) {
+                e.printStackTrace();
+            }
+        }
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/4800b872/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/BinaryClassificationEvaluator.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/BinaryClassificationEvaluator.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/BinaryClassificationEvaluator.java
index 30adc5c..9642bce 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/BinaryClassificationEvaluator.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/BinaryClassificationEvaluator.java
@@ -49,7 +49,6 @@ public class BinaryClassificationEvaluator {
                                             IgniteBiFunction<K, V, Vector> 
featureExtractor,
                                             IgniteBiFunction<K, V, L> 
lbExtractor,
                                             Metric<L> metric) {
-
         return calculateMetric(dataCache, null, mdl, featureExtractor, 
lbExtractor, metric);
     }
 
@@ -72,7 +71,6 @@ public class BinaryClassificationEvaluator {
                                             IgniteBiFunction<K, V, Vector> 
featureExtractor,
                                             IgniteBiFunction<K, V, L> 
lbExtractor,
                                             Metric<L> metric) {
-
         return calculateMetric(dataCache, filter, mdl, featureExtractor, 
lbExtractor, metric);
     }
 
@@ -140,7 +138,7 @@ public class BinaryClassificationEvaluator {
             lbExtractor,
             mdl
         )) {
-            metricValues = binaryMetrics.score(cursor.iterator());
+            metricValues = binaryMetrics.scoreAll(cursor.iterator());
         } catch (Exception e) {
             throw new RuntimeException(e);
         }
@@ -163,8 +161,8 @@ public class BinaryClassificationEvaluator {
      * @return Computed metric.
      */
     private static <L, K, V> double calculateMetric(IgniteCache<K, V> 
dataCache, IgniteBiPredicate<K, V> filter,
-                                                    Model<Vector, L> mdl, 
IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> 
lbExtractor,
-                                                    Metric<L> metric) {
+                                                    Model<Vector, L> mdl, 
IgniteBiFunction<K, V, Vector> featureExtractor,
+                                                    IgniteBiFunction<K, V, L> 
lbExtractor, Metric<L> metric) {
         double metricRes;
 
         try (LabelPairCursor<L> cursor = new CacheBasedLabelPairCursor<>(

http://git-wip-us.apache.org/repos/asf/ignite/blob/4800b872/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetrics.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetrics.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetrics.java
index 0b15d04..bd4067a 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetrics.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetrics.java
@@ -18,25 +18,30 @@
 package org.apache.ignite.ml.selection.scoring.metric;
 
 import java.util.Iterator;
+import java.util.function.Function;
 import org.apache.ignite.ml.selection.scoring.LabelPair;
 
 /**
  * Binary classification metrics calculator.
+ * It could be used in two ways: to caculate all binary classification metrics 
or specific metric.
  */
-public class BinaryClassificationMetrics {
+public class BinaryClassificationMetrics implements Metric<Double> {
     /** Positive class label. */
     private double positiveClsLb = 1.0;
 
     /** Negative class label. Default value is 0.0. */
     private double negativeClsLb;
 
+    /** The main metric to get individual score. */
+    private Function<BinaryClassificationMetricValues, Double> metric = 
BinaryClassificationMetricValues::accuracy;
+
     /**
      * Calculates binary metrics values.
      *
      * @param iter Iterator that supplies pairs of truth values and predicated.
      * @return Scores for all binary metrics.
      */
-    public BinaryClassificationMetricValues score(Iterator<LabelPair<Double>> 
iter) {
+    public BinaryClassificationMetricValues 
scoreAll(Iterator<LabelPair<Double>> iter) {
         long tp = 0;
         long tn = 0;
         long fp = 0;
@@ -83,4 +88,20 @@ public class BinaryClassificationMetrics {
         this.negativeClsLb = negativeClsLb;
         return this;
     }
+
+    /** */
+    public BinaryClassificationMetrics 
withMetric(Function<BinaryClassificationMetricValues, Double> metric) {
+        this.metric = metric;
+        return this;
+    }
+
+    /** {@inheritDoc} */
+    @Override public double score(Iterator<LabelPair<Double>> iter) {
+        return metric.apply(scoreAll(iter));
+    }
+
+    /** {@inheritDoc} */
+    @Override public String name() {
+        return "Binary classification metrics";
+    }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/4800b872/modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationTest.java
index 3e8b9dd..e64aa7a 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationTest.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationTest.java
@@ -21,6 +21,8 @@ import java.util.HashMap;
 import java.util.Map;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import 
org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetricValues;
+import 
org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetrics;
 import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
 import org.apache.ignite.ml.tree.DecisionTreeNode;
 import org.junit.Test;
@@ -71,6 +73,46 @@ public class CrossValidationTest {
 
     /** */
     @Test
+    public void testScoreWithGoodDatasetAndBinaryMetrics() {
+        Map<Integer, Double> data = new HashMap<>();
+
+        for (int i = 0; i < 1000; i++)
+            data.put(i, i > 500 ? 1.0 : 0.0);
+
+        DecisionTreeClassificationTrainer trainer = new 
DecisionTreeClassificationTrainer(1, 0);
+
+        CrossValidation<DecisionTreeNode, Double, Integer, Double> 
scoreCalculator =
+            new CrossValidation<>();
+
+        int folds = 4;
+
+        BinaryClassificationMetrics metrics = new BinaryClassificationMetrics()
+            .withMetric(BinaryClassificationMetricValues::accuracy);
+
+        verifyScores(folds, scoreCalculator.score(
+            trainer,
+            metrics,
+            data,
+            1,
+            (k, v) -> VectorUtils.of(k),
+            (k, v) -> v,
+            folds
+        ));
+
+        verifyScores(folds, scoreCalculator.score(
+            trainer,
+            new Accuracy<>(),
+            data,
+            (e1, e2) -> true,
+            1,
+            (k, v) -> VectorUtils.of(k),
+            (k, v) -> v,
+            folds
+        ));
+    }
+
+    /** */
+    @Test
     public void testScoreWithBadDataset() {
         Map<Integer, Double> data = new HashMap<>();
 

Reply via email to