Repository: ignite Updated Branches: refs/heads/ignite-10639 ffad2ac6a -> 9298b2dec
IGNITE-10371: [ML] Add multiple metrics calculation for Binary Classification Evaluation process This closes #5612 Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/f2d6e436 Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/f2d6e436 Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/f2d6e436 Branch: refs/heads/ignite-10639 Commit: f2d6e43605f8daaa12c3cb729a1ddd56bf36f564 Parents: ad9d561 Author: zaleslaw <[email protected]> Authored: Tue Dec 18 21:53:32 2018 +0300 Committer: Yury Babak <[email protected]> Committed: Tue Dec 18 21:53:32 2018 +0300 ---------------------------------------------------------------------- .../ml/selection/scoring/EvaluatorExample.java | 83 +++++++++ .../scoring/MultipleMetricsExample.java | 82 ++++++++ .../ml/tutorial/Step_1_Read_and_Learn.java | 6 +- .../examples/ml/tutorial/Step_2_Imputing.java | 6 +- .../examples/ml/tutorial/Step_3_Categorial.java | 6 +- .../Step_3_Categorial_with_One_Hot_Encoder.java | 6 +- .../ml/tutorial/Step_4_Add_age_fare.java | 6 +- .../examples/ml/tutorial/Step_5_Scaling.java | 6 +- .../tutorial/Step_5_Scaling_with_Pipeline.java | 6 +- .../ignite/examples/ml/tutorial/Step_6_KNN.java | 6 +- .../ml/tutorial/Step_7_Split_train_test.java | 6 +- .../ignite/examples/ml/tutorial/Step_8_CV.java | 6 +- .../ml/tutorial/Step_8_CV_with_Param_Grid.java | 6 +- .../ml/tutorial/Step_9_Go_to_LogReg.java | 6 +- .../cursor/CacheBasedLabelPairCursor.java | 6 +- .../BinaryClassificationEvaluator.java | 184 ++++++++++++++++++ .../selection/scoring/evaluator/Evaluator.java | 104 ----------- .../ml/selection/scoring/metric/Accuracy.java | 5 + .../BinaryClassificationMetricValues.java | 185 +++++++++++++++++++ .../metric/BinaryClassificationMetrics.java | 86 +++++++++ .../selection/scoring/metric/ClassMetric.java | 37 ++++ .../ml/selection/scoring/metric/Fmeasure.java | 22 ++- .../ml/selection/scoring/metric/Metric.java | 9 + .../ml/selection/scoring/metric/Precision.java | 23 ++- .../ml/selection/scoring/metric/Recall.java | 22 ++- .../metric/UnknownClassLabelException.java | 38 ++++ .../scoring/evaluator/EvaluatorTest.java | 6 +- 27 files changed, 789 insertions(+), 175 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/f2d6e436/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/EvaluatorExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/EvaluatorExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/EvaluatorExample.java new file mode 100644 index 0000000..a6a989b --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/EvaluatorExample.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.selection.scoring; + +import java.io.FileNotFoundException; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator; +import org.apache.ignite.ml.selection.scoring.metric.Accuracy; +import org.apache.ignite.ml.svm.SVMLinearClassificationModel; +import org.apache.ignite.ml.svm.SVMLinearClassificationTrainer; +import org.apache.ignite.ml.util.MLSandboxDatasets; +import org.apache.ignite.ml.util.SandboxMLCache; + +/** + * Run SVM classification trainer ({@link SVMLinearClassificationTrainer}) over distributed dataset. + * <p> + * Code in this example launches Ignite grid and fills the cache with test data points (based on the + * <a href="https://en.wikipedia.org/wiki/Iris_flower_data_set"></a>Iris dataset</a>).</p> + * <p> + * After that it trains the model based on the specified data using + * <a href="https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm">kNN</a> algorithm.</p> + * <p> + * Finally, this example loops over the test set of data points, applies the trained model to predict what cluster + * does this point belong to, and compares prediction to expected outcome (ground truth).</p> + * <p> + * You can change the test data used in this example and re-run it to explore this algorithm further.</p> + */ +public class EvaluatorExample { + /** Run example. */ + public static void main(String[] args) throws FileNotFoundException { + System.out.println(); + System.out.println(">>> kNN multi-class classification algorithm over cached dataset usage example started."); + // Start ignite grid. + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + System.out.println(">>> Ignite grid started."); + + IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite) + .fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS); + + SVMLinearClassificationTrainer trainer = new SVMLinearClassificationTrainer(); + + IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size()); + IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0); + + SVMLinearClassificationModel mdl = trainer.fit( + ignite, + dataCache, + featureExtractor, + lbExtractor + ); + + double accuracy = BinaryClassificationEvaluator.evaluate( + dataCache, + mdl, + featureExtractor, + lbExtractor, + new Accuracy<>() + ); + + System.out.println("\n>>> Accuracy " + accuracy); + System.out.println("\n>>> Test Error " + (1 - accuracy)); + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/f2d6e436/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/MultipleMetricsExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/MultipleMetricsExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/MultipleMetricsExample.java new file mode 100644 index 0000000..b8c76e0 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/MultipleMetricsExample.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.selection.scoring; + +import java.io.FileNotFoundException; +import java.util.Map; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator; +import org.apache.ignite.ml.svm.SVMLinearClassificationModel; +import org.apache.ignite.ml.svm.SVMLinearClassificationTrainer; +import org.apache.ignite.ml.util.MLSandboxDatasets; +import org.apache.ignite.ml.util.SandboxMLCache; + +/** + * Run kNN multi-class classification trainer ({@link KNNClassificationTrainer}) over distributed dataset. + * <p> + * Code in this example launches Ignite grid and fills the cache with test data points (based on the + * <a href="https://en.wikipedia.org/wiki/Iris_flower_data_set"></a>Iris dataset</a>).</p> + * <p> + * After that it trains the model based on the specified data using + * <a href="https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm">kNN</a> algorithm.</p> + * <p> + * Finally, this example loops over the test set of data points, applies the trained model to predict what cluster + * does this point belong to, and compares prediction to expected outcome (ground truth).</p> + * <p> + * You can change the test data used in this example and re-run it to explore this algorithm further.</p> + */ +public class MultipleMetricsExample { + /** Run example. */ + public static void main(String[] args) throws FileNotFoundException { + // Start ignite grid. + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + System.out.println(">>> Ignite grid started."); + + IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite) + .fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS); + + SVMLinearClassificationTrainer trainer = new SVMLinearClassificationTrainer(); + + IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size()); + IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0); + + SVMLinearClassificationModel mdl = trainer.fit( + ignite, + dataCache, + featureExtractor, + lbExtractor + ); + + Map<String, Double> scores = BinaryClassificationEvaluator.evaluate( + dataCache, + mdl, + featureExtractor, + lbExtractor + ).toMap(); + + scores.forEach( + (metricName, score) -> System.out.println("\n>>>" + metricName + ": " + score) + ); + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/f2d6e436/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java index 481fa1d..0cbde9c 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java @@ -24,7 +24,7 @@ import org.apache.ignite.Ignition; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; -import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator; +import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator; import org.apache.ignite.ml.selection.scoring.metric.Accuracy; import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; import org.apache.ignite.ml.tree.DecisionTreeNode; @@ -38,7 +38,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode; * <p> * After that it trains the model based on the specified data using decision tree classification.</p> * <p> - * Finally, this example uses {@link Evaluator} functionality to compute metrics from predictions.</p> + * Finally, this example uses {@link BinaryClassificationEvaluator} functionality to compute metrics from predictions.</p> */ public class Step_1_Read_and_Learn { /** Run example. */ @@ -66,7 +66,7 @@ public class Step_1_Read_and_Learn { System.out.println("\n>>> Trained model: " + mdl); - double accuracy = Evaluator.evaluate( + double accuracy = BinaryClassificationEvaluator.evaluate( dataCache, mdl, featureExtractor, http://git-wip-us.apache.org/repos/asf/ignite/blob/f2d6e436/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_2_Imputing.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_2_Imputing.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_2_Imputing.java index d60dc4b..6fe41ab 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_2_Imputing.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_2_Imputing.java @@ -25,7 +25,7 @@ import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer; -import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator; +import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator; import org.apache.ignite.ml.selection.scoring.metric.Accuracy; import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; import org.apache.ignite.ml.tree.DecisionTreeNode; @@ -40,7 +40,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode; * <p> * Then, it trains the model based on the processed data using decision tree classification.</p> * <p> - * Finally, this example uses {@link Evaluator} functionality to compute metrics from predictions.</p> + * Finally, this example uses {@link BinaryClassificationEvaluator} functionality to compute metrics from predictions.</p> */ public class Step_2_Imputing { /** Run example. */ @@ -75,7 +75,7 @@ public class Step_2_Imputing { System.out.println("\n>>> Trained model: " + mdl); - double accuracy = Evaluator.evaluate( + double accuracy = BinaryClassificationEvaluator.evaluate( dataCache, mdl, imputingPreprocessor, http://git-wip-us.apache.org/repos/asf/ignite/blob/f2d6e436/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java index ac2fe08..f9bd014 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java @@ -26,7 +26,7 @@ import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer; import org.apache.ignite.ml.preprocessing.encoding.EncoderType; import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer; -import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator; +import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator; import org.apache.ignite.ml.selection.scoring.metric.Accuracy; import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; import org.apache.ignite.ml.tree.DecisionTreeNode; @@ -43,7 +43,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode; * <p> * Then, it trains the model based on the processed data using decision tree classification.</p> * <p> - * Finally, this example uses {@link Evaluator} functionality to compute metrics from predictions.</p> + * Finally, this example uses {@link BinaryClassificationEvaluator} functionality to compute metrics from predictions.</p> */ public class Step_3_Categorial { /** Run example. */ @@ -88,7 +88,7 @@ public class Step_3_Categorial { System.out.println("\n>>> Trained model: " + mdl); - double accuracy = Evaluator.evaluate( + double accuracy = BinaryClassificationEvaluator.evaluate( dataCache, mdl, imputingPreprocessor, http://git-wip-us.apache.org/repos/asf/ignite/blob/f2d6e436/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial_with_One_Hot_Encoder.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial_with_One_Hot_Encoder.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial_with_One_Hot_Encoder.java index f0b6efe..0b3e235 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial_with_One_Hot_Encoder.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial_with_One_Hot_Encoder.java @@ -26,7 +26,7 @@ import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer; import org.apache.ignite.ml.preprocessing.encoding.EncoderType; import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer; -import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator; +import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator; import org.apache.ignite.ml.selection.scoring.metric.Accuracy; import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; import org.apache.ignite.ml.tree.DecisionTreeNode; @@ -44,7 +44,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode; * <p> * Then, it trains the model based on the processed data using decision tree classification.</p> * <p> - * Finally, this example uses {@link Evaluator} functionality to compute metrics from predictions.</p> + * Finally, this example uses {@link BinaryClassificationEvaluator} functionality to compute metrics from predictions.</p> */ public class Step_3_Categorial_with_One_Hot_Encoder { /** Run example. */ @@ -91,7 +91,7 @@ public class Step_3_Categorial_with_One_Hot_Encoder { System.out.println("\n>>> Trained model: " + mdl); - double accuracy = Evaluator.evaluate( + double accuracy = BinaryClassificationEvaluator.evaluate( dataCache, mdl, imputingPreprocessor, http://git-wip-us.apache.org/repos/asf/ignite/blob/f2d6e436/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java index 71e9efd..7576cd6 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java @@ -26,7 +26,7 @@ import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer; import org.apache.ignite.ml.preprocessing.encoding.EncoderType; import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer; -import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator; +import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator; import org.apache.ignite.ml.selection.scoring.metric.Accuracy; import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; import org.apache.ignite.ml.tree.DecisionTreeNode; @@ -41,7 +41,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode; * <p> * Then, it trains the model based on the processed data using decision tree classification.</p> * <p> - * Finally, this example uses {@link Evaluator} functionality to compute metrics from predictions.</p> + * Finally, this example uses {@link BinaryClassificationEvaluator} functionality to compute metrics from predictions.</p> */ public class Step_4_Add_age_fare { /** Run example. */ @@ -87,7 +87,7 @@ public class Step_4_Add_age_fare { System.out.println("\n>>> Trained model: " + mdl); - double accuracy = Evaluator.evaluate( + double accuracy = BinaryClassificationEvaluator.evaluate( dataCache, mdl, imputingPreprocessor, http://git-wip-us.apache.org/repos/asf/ignite/blob/f2d6e436/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java index fe7bf91..065eb90 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java @@ -28,7 +28,7 @@ import org.apache.ignite.ml.preprocessing.encoding.EncoderType; import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer; import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer; import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer; -import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator; +import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator; import org.apache.ignite.ml.selection.scoring.metric.Accuracy; import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; import org.apache.ignite.ml.tree.DecisionTreeNode; @@ -44,7 +44,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode; * <p> * Then, it trains the model based on the processed data using decision tree classification.</p> * <p> - * Finally, this example uses {@link Evaluator} functionality to compute metrics from predictions.</p> + * Finally, this example uses {@link BinaryClassificationEvaluator} functionality to compute metrics from predictions.</p> */ public class Step_5_Scaling { /** Run example. */ @@ -105,7 +105,7 @@ public class Step_5_Scaling { System.out.println("\n>>> Trained model: " + mdl); - double accuracy = Evaluator.evaluate( + double accuracy = BinaryClassificationEvaluator.evaluate( dataCache, mdl, normalizationPreprocessor, http://git-wip-us.apache.org/repos/asf/ignite/blob/f2d6e436/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling_with_Pipeline.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling_with_Pipeline.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling_with_Pipeline.java index bd7cc21..2e97ccb 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling_with_Pipeline.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling_with_Pipeline.java @@ -29,7 +29,7 @@ import org.apache.ignite.ml.preprocessing.encoding.EncoderType; import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer; import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer; import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer; -import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator; +import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator; import org.apache.ignite.ml.selection.scoring.metric.Accuracy; import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; @@ -44,7 +44,7 @@ import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; * <p> * Then, it trains the model based on the processed data using decision tree classification.</p> * <p> - * Finally, this example uses {@link Evaluator} functionality to compute metrics from predictions.</p> + * Finally, this example uses {@link BinaryClassificationEvaluator} functionality to compute metrics from predictions.</p> */ public class Step_5_Scaling_with_Pipeline { /** Run example. */ @@ -79,7 +79,7 @@ public class Step_5_Scaling_with_Pipeline { System.out.println("\n>>> Trained model: " + mdl); - double accuracy = Evaluator.evaluate( + double accuracy = BinaryClassificationEvaluator.evaluate( dataCache, mdl, mdl.getFeatureExtractor(), http://git-wip-us.apache.org/repos/asf/ignite/blob/f2d6e436/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java index a35b841..d7a0b88 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java @@ -31,7 +31,7 @@ import org.apache.ignite.ml.preprocessing.encoding.EncoderType; import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer; import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer; import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer; -import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator; +import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator; import org.apache.ignite.ml.selection.scoring.metric.Accuracy; /** @@ -45,7 +45,7 @@ import org.apache.ignite.ml.selection.scoring.metric.Accuracy; * <p> * Then, it trains the model based on the processed data using kNN classification.</p> * <p> - * Finally, this example uses {@link Evaluator} functionality to compute metrics from predictions.</p> + * Finally, this example uses {@link BinaryClassificationEvaluator} functionality to compute metrics from predictions.</p> */ public class Step_6_KNN { /** Run example. */ @@ -106,7 +106,7 @@ public class Step_6_KNN { System.out.println("\n>>> Trained model: " + mdl); - double accuracy = Evaluator.evaluate( + double accuracy = BinaryClassificationEvaluator.evaluate( dataCache, mdl, normalizationPreprocessor, http://git-wip-us.apache.org/repos/asf/ignite/blob/f2d6e436/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java index 53d4d0a..a988abe 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java @@ -28,7 +28,7 @@ import org.apache.ignite.ml.preprocessing.encoding.EncoderType; import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer; import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer; import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer; -import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator; +import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator; import org.apache.ignite.ml.selection.scoring.metric.Accuracy; import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter; import org.apache.ignite.ml.selection.split.TrainTestSplit; @@ -47,7 +47,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode; * <p> * Then, it trains the model based on the processed data using decision tree classification.</p> * <p> - * Finally, this example uses {@link Evaluator} functionality to compute metrics from predictions.</p> + * Finally, this example uses {@link BinaryClassificationEvaluator} functionality to compute metrics from predictions.</p> */ public class Step_7_Split_train_test { /** Run example. */ @@ -112,7 +112,7 @@ public class Step_7_Split_train_test { System.out.println("\n>>> Trained model: " + mdl); - double accuracy = Evaluator.evaluate( + double accuracy = BinaryClassificationEvaluator.evaluate( dataCache, split.getTestFilter(), mdl, http://git-wip-us.apache.org/repos/asf/ignite/blob/f2d6e436/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java index feedccf..8a962f3 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java @@ -30,7 +30,7 @@ import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer; import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer; import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer; import org.apache.ignite.ml.selection.cv.CrossValidation; -import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator; +import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator; import org.apache.ignite.ml.selection.scoring.metric.Accuracy; import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter; import org.apache.ignite.ml.selection.split.TrainTestSplit; @@ -48,7 +48,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode; * Then, it tunes hyperparams with K-fold Cross-Validation on the split training set and trains the model based on * the processed data using decision tree classification and the obtained hyperparams.</p> * <p> - * Finally, this example uses {@link Evaluator} functionality to compute metrics from predictions.</p> + * Finally, this example uses {@link BinaryClassificationEvaluator} functionality to compute metrics from predictions.</p> * <p> * The purpose of cross-validation is model checking, not model building.</p> * <p> @@ -175,7 +175,7 @@ public class Step_8_CV { System.out.println("\n>>> Trained model: " + bestMdl); - double accuracy = Evaluator.evaluate( + double accuracy = BinaryClassificationEvaluator.evaluate( dataCache, split.getTestFilter(), bestMdl, http://git-wip-us.apache.org/repos/asf/ignite/blob/f2d6e436/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java index 670f025..19951a2 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java @@ -32,7 +32,7 @@ import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer; import org.apache.ignite.ml.selection.cv.CrossValidation; import org.apache.ignite.ml.selection.cv.CrossValidationResult; import org.apache.ignite.ml.selection.paramgrid.ParamGrid; -import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator; +import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator; import org.apache.ignite.ml.selection.scoring.metric.Accuracy; import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter; import org.apache.ignite.ml.selection.split.TrainTestSplit; @@ -50,7 +50,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode; * Then, it tunes hyperparams with K-fold Cross-Validation on the split training set and trains the model based on * the processed data using decision tree classification and the obtained hyperparams.</p> * <p> - * Finally, this example uses {@link Evaluator} functionality to compute metrics from predictions.</p> + * Finally, this example uses {@link BinaryClassificationEvaluator} functionality to compute metrics from predictions.</p> * <p> * The purpose of cross-validation is model checking, not model building.</p> * <p> @@ -163,7 +163,7 @@ public class Step_8_CV_with_Param_Grid { System.out.println("\n>>> Trained model: " + bestMdl); - double accuracy = Evaluator.evaluate( + double accuracy = BinaryClassificationEvaluator.evaluate( dataCache, split.getTestFilter(), bestMdl, http://git-wip-us.apache.org/repos/asf/ignite/blob/f2d6e436/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java index 58466bd..eb12e58 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java @@ -35,7 +35,7 @@ import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer; import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel; import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer; import org.apache.ignite.ml.selection.cv.CrossValidation; -import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator; +import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator; import org.apache.ignite.ml.selection.scoring.metric.Accuracy; import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter; import org.apache.ignite.ml.selection.split.TrainTestSplit; @@ -52,7 +52,7 @@ import org.apache.ignite.ml.selection.split.TrainTestSplit; * Then, it tunes hyperparams with K-fold Cross-Validation on the split training set and trains the model based on * the processed data using logistic regression and the obtained hyperparams.</p> * <p> - * Finally, this example uses {@link Evaluator} functionality to compute metrics from predictions.</p> + * Finally, this example uses {@link BinaryClassificationEvaluator} functionality to compute metrics from predictions.</p> */ public class Step_9_Go_to_LogReg { /** Run example. */ @@ -207,7 +207,7 @@ public class Step_9_Go_to_LogReg { System.out.println("\n>>> Trained model: " + bestMdl); - double accuracy = Evaluator.evaluate( + double accuracy = BinaryClassificationEvaluator.evaluate( dataCache, split.getTestFilter(), bestMdl, http://git-wip-us.apache.org/repos/asf/ignite/blob/f2d6e436/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/cursor/CacheBasedLabelPairCursor.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/cursor/CacheBasedLabelPairCursor.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/cursor/CacheBasedLabelPairCursor.java index 037cf45..06be96b 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/cursor/CacheBasedLabelPairCursor.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/cursor/CacheBasedLabelPairCursor.java @@ -98,12 +98,14 @@ public class CacheBasedLabelPairCursor<L, K, V> implements LabelPairCursor<L> { * Queries the specified cache using the specified filter. * * @param upstreamCache Ignite cache with {@code upstream} data. - * @param filter Filter for {@code upstream} data. + * @param filter Filter for {@code upstream} data. If {@code null} then all entries will be returned. * @return Query cursor. */ private QueryCursor<Cache.Entry<K, V>> query(IgniteCache<K, V> upstreamCache, IgniteBiPredicate<K, V> filter) { ScanQuery<K, V> qry = new ScanQuery<>(); - qry.setFilter(filter); + + if (filter != null) // This section was added to keep code correct of qry.setFilter(null) behaviour will changed. + qry.setFilter(filter); return upstreamCache.query(qry); } http://git-wip-us.apache.org/repos/asf/ignite/blob/f2d6e436/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/BinaryClassificationEvaluator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/BinaryClassificationEvaluator.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/BinaryClassificationEvaluator.java new file mode 100644 index 0000000..30adc5c --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/BinaryClassificationEvaluator.java @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.selection.scoring.evaluator; + +import org.apache.ignite.IgniteCache; +import org.apache.ignite.lang.IgniteBiPredicate; +import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.selection.scoring.cursor.CacheBasedLabelPairCursor; +import org.apache.ignite.ml.selection.scoring.cursor.LabelPairCursor; +import org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetricValues; +import org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetrics; +import org.apache.ignite.ml.selection.scoring.metric.Metric; + +/** + * Binary classification evaluator that computes metrics from predictions and ground truth values. + */ +public class BinaryClassificationEvaluator { + /** + * Computes the given metric on the given cache. + * + * @param dataCache The given cache. + * @param mdl The model. + * @param featureExtractor The feature extractor. + * @param lbExtractor The label extractor. + * @param metric The binary classification metric. + * @param <K> The type of cache entry key. + * @param <V> The type of cache entry value. + * @return Computed metric. + */ + public static <L, K, V> double evaluate(IgniteCache<K, V> dataCache, + Model<Vector, L> mdl, + IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, L> lbExtractor, + Metric<L> metric) { + + return calculateMetric(dataCache, null, mdl, featureExtractor, lbExtractor, metric); + } + + /** + * Computes the given metric on the given cache. + * + * @param dataCache The given cache. + * @param filter The given filter. + * @param mdl The model. + * @param featureExtractor The feature extractor. + * @param lbExtractor The label extractor. + * @param metric The binary classification metric. + * @param <L> The type of label. + * @param <K> The type of cache entry key. + * @param <V> The type of cache entry value. + * @return Computed metric. + */ + public static <L, K, V> double evaluate(IgniteCache<K, V> dataCache, IgniteBiPredicate<K, V> filter, + Model<Vector, L> mdl, + IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, L> lbExtractor, + Metric<L> metric) { + + return calculateMetric(dataCache, filter, mdl, featureExtractor, lbExtractor, metric); + } + + /** + * Computes the given metrics on the given cache. + * + * @param dataCache The given cache. + * @param mdl The model. + * @param featureExtractor The feature extractor. + * @param lbExtractor The label extractor. + * @param <K> The type of cache entry key. + * @param <V> The type of cache entry value. + * @return Computed metric. + */ + public static <K, V> BinaryClassificationMetricValues evaluate(IgniteCache<K, V> dataCache, + Model<Vector, Double> mdl, + IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor) { + return calcMetricValues(dataCache, null, mdl, featureExtractor, lbExtractor); + } + + /** + * Computes the given metrics on the given cache. + * + * @param dataCache The given cache. + * @param filter The given filter. + * @param mdl The model. + * @param featureExtractor The feature extractor. + * @param lbExtractor The label extractor. + * @param <K> The type of cache entry key. + * @param <V> The type of cache entry value. + * @return Computed metric. + */ + public static <K, V> BinaryClassificationMetricValues evaluate(IgniteCache<K, V> dataCache, IgniteBiPredicate<K, V> filter, + Model<Vector, Double> mdl, + IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor) { + return calcMetricValues(dataCache, filter, mdl, featureExtractor, lbExtractor); + } + + /** + * Computes the given metrics on the given cache. + * + * @param dataCache The given cache. + * @param filter The given filter. + * @param mdl The model. + * @param featureExtractor The feature extractor. + * @param lbExtractor The label extractor. + * @param <K> The type of cache entry key. + * @param <V> The type of cache entry value. + * @return Computed metric. + */ + private static <K, V> BinaryClassificationMetricValues calcMetricValues(IgniteCache<K, V> dataCache, + IgniteBiPredicate<K, V> filter, + Model<Vector, Double> mdl, + IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor) { + BinaryClassificationMetricValues metricValues; + BinaryClassificationMetrics binaryMetrics = new BinaryClassificationMetrics(); + + try (LabelPairCursor<Double> cursor = new CacheBasedLabelPairCursor<>( + dataCache, + filter, + featureExtractor, + lbExtractor, + mdl + )) { + metricValues = binaryMetrics.score(cursor.iterator()); + } catch (Exception e) { + throw new RuntimeException(e); + } + + return metricValues; + } + + /** + * Computes the given metric on the given cache. + * + * @param dataCache The given cache. + * @param filter The given filter. + * @param mdl The model. + * @param featureExtractor The feature extractor. + * @param lbExtractor The label extractor. + * @param metric The binary classification metric. + * @param <L> The type of label. + * @param <K> The type of cache entry key. + * @param <V> The type of cache entry value. + * @return Computed metric. + */ + private static <L, K, V> double calculateMetric(IgniteCache<K, V> dataCache, IgniteBiPredicate<K, V> filter, + Model<Vector, L> mdl, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, + Metric<L> metric) { + double metricRes; + + try (LabelPairCursor<L> cursor = new CacheBasedLabelPairCursor<>( + dataCache, + filter, + featureExtractor, + lbExtractor, + mdl + )) { + metricRes = metric.score(cursor.iterator()); + } catch (Exception e) { + throw new RuntimeException(e); + } + + return metricRes; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/f2d6e436/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/Evaluator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/Evaluator.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/Evaluator.java deleted file mode 100644 index 4535831..0000000 --- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/Evaluator.java +++ /dev/null @@ -1,104 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.ml.selection.scoring.evaluator; - -import org.apache.ignite.IgniteCache; -import org.apache.ignite.lang.IgniteBiPredicate; -import org.apache.ignite.ml.Model; -import org.apache.ignite.ml.math.functions.IgniteBiFunction; -import org.apache.ignite.ml.math.primitives.vector.Vector; -import org.apache.ignite.ml.selection.scoring.cursor.CacheBasedLabelPairCursor; -import org.apache.ignite.ml.selection.scoring.cursor.LabelPairCursor; -import org.apache.ignite.ml.selection.scoring.metric.Accuracy; - -/** - * Binary classification evaluator that compute metrics from predictions. - */ -public class Evaluator { - /** - * Computes the given metric on the given cache. - * - * @param dataCache The given cache. - * @param mdl The model. - * @param featureExtractor The feature extractor. - * @param lbExtractor The label extractor. - * @param metric The binary classification metric. - * @param <L> The type of label. - * @param <K> The type of cache entry key. - * @param <V> The type of cache entry value. - * @return Computed metric. - */ - public static <L, K, V> double evaluate(IgniteCache<K, V> dataCache, - Model<Vector, L> mdl, - IgniteBiFunction<K, V, Vector> featureExtractor, - IgniteBiFunction<K, V, L> lbExtractor, - Accuracy<L> metric) { - double metricRes; - - try (LabelPairCursor<L> cursor = new CacheBasedLabelPairCursor<L, K, V>( - dataCache, - featureExtractor, - lbExtractor, - mdl - )) { - metricRes = metric.score(cursor.iterator()); - } - catch (Exception e) { - throw new RuntimeException(e); - } - - return metricRes; - } - - /** - * Computes the given metric on the given cache. - * - * @param dataCache The given cache. - * @param filter The given filter. - * @param mdl The model. - * @param featureExtractor The feature extractor. - * @param lbExtractor The label extractor. - * @param metric The binary classification metric. - * @param <L> The type of label. - * @param <K> The type of cache entry key. - * @param <V> The type of cache entry value. - * @return Computed metric. - */ - public static <L, K, V> double evaluate(IgniteCache<K, V> dataCache, IgniteBiPredicate<K, V> filter, - Model<Vector, L> mdl, - IgniteBiFunction<K, V, Vector> featureExtractor, - IgniteBiFunction<K, V, L> lbExtractor, - Accuracy<L> metric) { - double metricRes; - - try (LabelPairCursor<L> cursor = new CacheBasedLabelPairCursor<L, K, V>( - dataCache, - filter, - featureExtractor, - lbExtractor, - mdl - )) { - metricRes = metric.score(cursor.iterator()); - } - catch (Exception e) { - throw new RuntimeException(e); - } - - return metricRes; - } -} http://git-wip-us.apache.org/repos/asf/ignite/blob/f2d6e436/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Accuracy.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Accuracy.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Accuracy.java index 30e6299..fd0656c 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Accuracy.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Accuracy.java @@ -45,4 +45,9 @@ public class Accuracy<L> implements Metric<L> { return 1.0 * correctCnt / totalCnt; } + + /** {@inheritDoc} */ + @Override public String name() { + return "accuracy"; + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/f2d6e436/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetricValues.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetricValues.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetricValues.java new file mode 100644 index 0000000..04cd981 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetricValues.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.selection.scoring.metric; + +import java.lang.reflect.Field; +import java.util.HashMap; +import java.util.Map; + +/** + * Provides access to binary metric values. + */ +public class BinaryClassificationMetricValues { + /** True Positive (TP). */ + private double tp; + + /** True Negative (TN). */ + private double tn; + + /** False Positive (FP). */ + private double fp; + + /** False Negative (FN). */ + private double fn; + + /** Sensitivity or True Positive Rate (TPR). */ + private double recall; + + /** Specificity (SPC) or True Negative Rate (TNR). */ + private double specificity; + + /** Precision or Positive Predictive Value (PPV). */ + private double precision; + + /** Negative Predictive Value (NPV). */ + private double npv; + + /** Fall-out or False Positive Rate (FPR). */ + private double fallOut; + + /** False Discovery Rate (FDR). */ + private double fdr; + + /** Miss Rate or False Negative Rate (FNR). */ + private double missRate; + + /** Accuracy. */ + private double accuracy; + + /** Balanced accuracy. */ + private double balancedAccuracy; + + /** F1-Score is the harmonic mean of Precision and Sensitivity. */ + private double f1Score; + + /** + * Initialize an example by 4 metrics. + * + * @param tp True Positive (TP). + * @param tn True Negative (TN). + * @param fp False Positive (FP). + * @param fn False Negative (FN). + */ + public BinaryClassificationMetricValues(long tp, long tn, long fp, long fn) { + this.tp = tp; + this.tn = tn; + this.fp = fp; + this.fn = fn; + + long p = tp + fn; + long n = tn + fp; + long positivePredictions = tp + fp; + long negativePredictions = tn + fn; + + // according to https://github.com/dice-group/gerbil/wiki/Precision,-Recall-and-F1-measure + recall = p == 0 ? 1 : (double)tp / p; + precision = positivePredictions == 0 ? 1 : (double)tp / positivePredictions; + specificity = n == 0 ? 1 : (double)tn / n; + npv = negativePredictions == 0 ? 1 : (double)tn / negativePredictions; + fallOut = n == 0 ? 1 : (double)fp / n; + fdr = positivePredictions == 0 ? 1 : (double)fp / positivePredictions; + missRate = p == 0 ? 1 : (double)fn / p; + + f1Score = 2 * (recall * precision) / (recall + precision); + + accuracy = (p + n) == 0 ? 1 : (double)(tp + tn) / (p + n); // multiplication on 1.0 to make double + balancedAccuracy = p == 0 && n == 0 ? 1 : ((double)tp / p + (double)tn / n) / 2; + } + + /** */ + public double tp() { + return tp; + } + + /** */ + public double tn() { + return tn; + } + + /** */ + public double fp() { + return fp; + } + + /** */ + public double fn() { + return fn; + } + + /** Returns Sensitivity or True Positive Rate (TPR). */ + public double recall() { + return recall; + } + + /** Returns Specificity (SPC) or True Negative Rate (TNR). */ + public double specificity() { + return specificity; + } + + /** Returns Precision or Positive Predictive Value (PPV). */ + public double precision() { + return precision; + } + + /** Returns Negative Predictive Value (NPV). */ + public double npv() { + return npv; + } + + /** Returns Fall-out or False Positive Rate (FPR). */ + public double fallOut() { + return fallOut; + } + + /** Returns False Discovery Rate (FDR). */ + public double fdr() { + return fdr; + } + + /** Returns Miss Rate or False Negative Rate (FNR). */ + public double missRate() { + return missRate; + } + + /** Returns Accuracy. */ + public double accuracy() { + return accuracy; + } + + /** Returns Balanced accuracy. */ + public double balancedAccuracy() { + return balancedAccuracy; + } + + /** Returns F1-Score is the harmonic mean of Precision and Sensitivity. */ + public double f1Score() { + return f1Score; + } + + /** Returns the pair of metric name and metric value. */ + public Map<String, Double> toMap() { + Map<String, Double> metricValues = new HashMap<>(); + for (Field field : getClass().getDeclaredFields()) + try { + metricValues.put(field.getName(), field.getDouble(this)); + } catch (IllegalAccessException e) { + e.printStackTrace(); + } + return metricValues; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/f2d6e436/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetrics.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetrics.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetrics.java new file mode 100644 index 0000000..0b15d04 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetrics.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.selection.scoring.metric; + +import java.util.Iterator; +import org.apache.ignite.ml.selection.scoring.LabelPair; + +/** + * Binary classification metrics calculator. + */ +public class BinaryClassificationMetrics { + /** Positive class label. */ + private double positiveClsLb = 1.0; + + /** Negative class label. Default value is 0.0. */ + private double negativeClsLb; + + /** + * Calculates binary metrics values. + * + * @param iter Iterator that supplies pairs of truth values and predicated. + * @return Scores for all binary metrics. + */ + public BinaryClassificationMetricValues score(Iterator<LabelPair<Double>> iter) { + long tp = 0; + long tn = 0; + long fp = 0; + long fn = 0; + + while (iter.hasNext()) { + LabelPair<Double> e = iter.next(); + + double prediction = e.getPrediction(); + double truth = e.getTruth(); + + if (prediction != negativeClsLb && prediction != positiveClsLb) + throw new UnknownClassLabelException(prediction, positiveClsLb, negativeClsLb); + if (truth != negativeClsLb && truth != positiveClsLb) + throw new UnknownClassLabelException(truth, positiveClsLb, negativeClsLb); + + if (truth == positiveClsLb && prediction == positiveClsLb) tp++; + else if (truth == positiveClsLb && prediction == negativeClsLb) fn++; + else if (truth == negativeClsLb && prediction == negativeClsLb) tn++; + else if (truth == negativeClsLb && prediction == positiveClsLb) fp++; + } + + return new BinaryClassificationMetricValues(tp, tn, fp, fn); + } + + /** */ + public double positiveClsLb() { + return positiveClsLb; + } + + /** */ + public BinaryClassificationMetrics withPositiveClsLb(double positiveClsLb) { + this.positiveClsLb = positiveClsLb; + return this; + } + + /** */ + public double negativeClsLb() { + return negativeClsLb; + } + + /** */ + public BinaryClassificationMetrics withNegativeClsLb(double negativeClsLb) { + this.negativeClsLb = negativeClsLb; + return this; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/f2d6e436/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/ClassMetric.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/ClassMetric.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/ClassMetric.java new file mode 100644 index 0000000..f89e683 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/ClassMetric.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.selection.scoring.metric; + +/** + * Metric calculator for one class label. + * + * @param <L> Type of a label (truth or prediction). + */ +public abstract class ClassMetric<L> implements Metric<L> { + /** Class label. */ + protected L clsLb; + + /** + * The class of interest or positive class. + * + * @param clsLb The label. + */ + public ClassMetric(L clsLb) { + this.clsLb = clsLb; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/f2d6e436/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Fmeasure.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Fmeasure.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Fmeasure.java index 1267584..fe36f51 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Fmeasure.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Fmeasure.java @@ -25,9 +25,15 @@ import org.apache.ignite.ml.selection.scoring.LabelPair; * * @param <L> Type of a label (truth or prediction). */ -public class Fmeasure<L> implements Metric<L> { - /** Class label. */ - private L clsLb; +public class Fmeasure<L> extends ClassMetric<L> { + /** + * The class of interest or positive class. + * + * @param clsLb The label. + */ + public Fmeasure(L clsLb) { + super(clsLb); + } /** {@inheritDoc} */ @Override public double score(Iterator<LabelPair<L>> it) { @@ -68,12 +74,8 @@ public class Fmeasure<L> implements Metric<L> { return Double.NaN; } - /** - * The class of interest or positive class. - * - * @param clsLb The label. - */ - public Fmeasure(L clsLb) { - this.clsLb = clsLb; + /** {@inheritDoc} */ + @Override public String name() { + return "F-measure for class with label " + clsLb; } } http://git-wip-us.apache.org/repos/asf/ignite/blob/f2d6e436/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Metric.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Metric.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Metric.java index 60d1e41..783d7fc 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Metric.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Metric.java @@ -33,4 +33,13 @@ public interface Metric<L> { * @return Score. */ public double score(Iterator<LabelPair<L>> iter); + + /** + * Returns the metric's name. + * + * NOTE: Should be unique to calculate multiple metrics correctly. + * + * @return String name representation. + */ + public String name(); } http://git-wip-us.apache.org/repos/asf/ignite/blob/f2d6e436/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Precision.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Precision.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Precision.java index 0a583c8..482a027 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Precision.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Precision.java @@ -25,9 +25,15 @@ import org.apache.ignite.ml.selection.scoring.LabelPair; * * @param <L> Type of a label (truth or prediction). */ -public class Precision<L> implements Metric<L> { - /** Class label. */ - private L clsLb; +public class Precision<L> extends ClassMetric<L> { + /** + * The class of interest or positive class. + * + * @param clsLb The label. + */ + public Precision(L clsLb) { + super(clsLb); + } /** {@inheritDoc} */ @Override public double score(Iterator<LabelPair<L>> it) { @@ -59,12 +65,9 @@ public class Precision<L> implements Metric<L> { return Double.NaN; } - /** - * The class of interest or positive class. - * - * @param clsLb The label. - */ - public Precision(L clsLb) { - this.clsLb = clsLb; + /** {@inheritDoc} */ + @Override public String name() { + return "precision for class with label " + clsLb; } + } http://git-wip-us.apache.org/repos/asf/ignite/blob/f2d6e436/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Recall.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Recall.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Recall.java index 60324e7..d459e94 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Recall.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/Recall.java @@ -25,9 +25,15 @@ import org.apache.ignite.ml.selection.scoring.LabelPair; * * @param <L> Type of a label (truth or prediction). */ -public class Recall<L> implements Metric<L> { - /** Class label. */ - private L clsLb; +public class Recall<L> extends ClassMetric<L> { + /** + * The class of interest or positive class. + * + * @param clsLb The label. + */ + public Recall(L clsLb) { + super(clsLb); + } /** {@inheritDoc} */ @Override public double score(Iterator<LabelPair<L>> it) { @@ -59,12 +65,8 @@ public class Recall<L> implements Metric<L> { return Double.NaN; } - /** - * The class of interest or positive class. - * - * @param clsLb The label. - */ - public Recall(L clsLb) { - this.clsLb = clsLb; + /** {@inheritDoc} */ + @Override public String name() { + return "recall for class with label " + clsLb; } } http://git-wip-us.apache.org/repos/asf/ignite/blob/f2d6e436/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/UnknownClassLabelException.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/UnknownClassLabelException.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/UnknownClassLabelException.java new file mode 100644 index 0000000..0531f2e --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/UnknownClassLabelException.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.selection.scoring.metric; + +import org.apache.ignite.IgniteException; + +/** + * Indicates an unknown class label for metric calculator. + */ +public class UnknownClassLabelException extends IgniteException { + /** */ + private static final long serialVersionUID = 0L; + + + /** + * @param incorrectVal Incorrect value. + * @param positiveClsLb Positive class label. + * @param negativeClsLb Negative class label. + */ + public UnknownClassLabelException(double incorrectVal, double positiveClsLb, double negativeClsLb) { + super("The next class label: " + incorrectVal + " is not positive class label: " + positiveClsLb + " or negative class label: " + negativeClsLb); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/f2d6e436/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/EvaluatorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/EvaluatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/EvaluatorTest.java index 1abf7f0..f17ac73 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/EvaluatorTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/EvaluatorTest.java @@ -52,7 +52,7 @@ import static org.apache.ignite.ml.TestUtils.testEnvBuilder; import static org.junit.Assert.assertArrayEquals; /** - * Tests for {@link Evaluator} that require to start the whole Ignite infrastructure. IMPL NOTE based on + * Tests for {@link BinaryClassificationEvaluator} that require to start the whole Ignite infrastructure. IMPL NOTE based on * Step_8_CV_with_Param_Grid example. */ public class EvaluatorTest extends GridCommonAbstractTest { @@ -260,7 +260,7 @@ public class EvaluatorTest extends GridCommonAbstractTest { lbExtractor ); - actualAccuracy.set(Evaluator.evaluate( + actualAccuracy.set(BinaryClassificationEvaluator.evaluate( cache, split.getTestFilter(), bestMdl, @@ -269,7 +269,7 @@ public class EvaluatorTest extends GridCommonAbstractTest { new Accuracy<>() )); - actualAccuracy2.set(Evaluator.evaluate( + actualAccuracy2.set(BinaryClassificationEvaluator.evaluate( cache, bestMdl, preprocessor,
