IGNITE-10606: [ML] Add new tests for BinaryClassificationMetrics and Evaluator
This closes #5751 Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/05fef60d Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/05fef60d Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/05fef60d Branch: refs/heads/ignite-601 Commit: 05fef60d00800608c11a9552ae42dac399719a83 Parents: 85bfcc7 Author: zaleslaw <zaleslaw....@gmail.com> Authored: Thu Dec 27 16:23:21 2018 +0300 Committer: Yury Babak <yba...@gridgain.com> Committed: Thu Dec 27 16:23:21 2018 +0300 ---------------------------------------------------------------------- .../scoring/cursor/LocalLabelPairCursor.java | 9 +- .../BinaryClassificationEvaluator.java | 279 ++++++++++++++----- .../metric/BinaryClassificationMetrics.java | 9 +- .../ignite/ml/trainers/DatasetTrainer.java | 2 +- .../ignite/ml/knn/KNNClassificationTest.java | 6 +- .../ignite/ml/selection/SelectionTestSuite.java | 7 + .../BinaryClassificationEvaluatorTest.java | 96 +++++++ .../scoring/evaluator/EvaluatorTest.java | 2 +- .../metric/BinaryClassificationMetricsTest.java | 159 +++++++++++ .../BinaryClassificationMetricsValuesTest.java | 48 ++++ 10 files changed, 546 insertions(+), 71 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/05fef60d/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/cursor/LocalLabelPairCursor.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/cursor/LocalLabelPairCursor.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/cursor/LocalLabelPairCursor.java index f135450..d8c2240 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/cursor/LocalLabelPairCursor.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/cursor/LocalLabelPairCursor.java @@ -100,7 +100,14 @@ public class LocalLabelPairCursor<L, K, V, T> implements LabelPairCursor<L> { /** {@inheritDoc} */ @Override public boolean hasNext() { - findNext(); + if (filter == null) { + Map.Entry<K, V> entry = iter.next(); + this.nextEntry = entry; + return iter.hasNext(); + } + + else + findNext(); return nextEntry != null; } http://git-wip-us.apache.org/repos/asf/ignite/blob/05fef60d/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/BinaryClassificationEvaluator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/BinaryClassificationEvaluator.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/BinaryClassificationEvaluator.java index 9642bce..5cbe10f 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/BinaryClassificationEvaluator.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/BinaryClassificationEvaluator.java @@ -17,6 +17,7 @@ package org.apache.ignite.ml.selection.scoring.evaluator; +import java.util.Map; import org.apache.ignite.IgniteCache; import org.apache.ignite.lang.IgniteBiPredicate; import org.apache.ignite.ml.Model; @@ -24,6 +25,7 @@ import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.selection.scoring.cursor.CacheBasedLabelPairCursor; import org.apache.ignite.ml.selection.scoring.cursor.LabelPairCursor; +import org.apache.ignite.ml.selection.scoring.cursor.LocalLabelPairCursor; import org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetricValues; import org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetrics; import org.apache.ignite.ml.selection.scoring.metric.Metric; @@ -35,99 +37,179 @@ public class BinaryClassificationEvaluator { /** * Computes the given metric on the given cache. * - * @param dataCache The given cache. - * @param mdl The model. + * @param dataCache The given cache. + * @param mdl The model. * @param featureExtractor The feature extractor. - * @param lbExtractor The label extractor. - * @param metric The binary classification metric. - * @param <K> The type of cache entry key. - * @param <V> The type of cache entry value. + * @param lbExtractor The label extractor. + * @param metric The binary classification metric. + * @param <K> The type of cache entry key. + * @param <V> The type of cache entry value. * @return Computed metric. */ public static <L, K, V> double evaluate(IgniteCache<K, V> dataCache, - Model<Vector, L> mdl, - IgniteBiFunction<K, V, Vector> featureExtractor, - IgniteBiFunction<K, V, L> lbExtractor, - Metric<L> metric) { + Model<Vector, L> mdl, + IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, L> lbExtractor, + Metric<L> metric) { return calculateMetric(dataCache, null, mdl, featureExtractor, lbExtractor, metric); } /** * Computes the given metric on the given cache. * - * @param dataCache The given cache. - * @param filter The given filter. - * @param mdl The model. + * @param dataCache The given local data. + * @param mdl The model. * @param featureExtractor The feature extractor. - * @param lbExtractor The label extractor. - * @param metric The binary classification metric. - * @param <L> The type of label. - * @param <K> The type of cache entry key. - * @param <V> The type of cache entry value. + * @param lbExtractor The label extractor. + * @param metric The binary classification metric. + * @param <K> The type of cache entry key. + * @param <V> The type of cache entry value. + * @return Computed metric. + */ + public static <L, K, V> double evaluate(Map<K, V> dataCache, + Model<Vector, L> mdl, + IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, L> lbExtractor, + Metric<L> metric) { + return calculateMetric(dataCache, null, mdl, featureExtractor, lbExtractor, metric); + } + + /** + * Computes the given metric on the given cache. + * + * @param dataCache The given cache. + * @param filter The given filter. + * @param mdl The model. + * @param featureExtractor The feature extractor. + * @param lbExtractor The label extractor. + * @param metric The binary classification metric. + * @param <L> The type of label. + * @param <K> The type of cache entry key. + * @param <V> The type of cache entry value. * @return Computed metric. */ public static <L, K, V> double evaluate(IgniteCache<K, V> dataCache, IgniteBiPredicate<K, V> filter, - Model<Vector, L> mdl, - IgniteBiFunction<K, V, Vector> featureExtractor, - IgniteBiFunction<K, V, L> lbExtractor, - Metric<L> metric) { + Model<Vector, L> mdl, + IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, L> lbExtractor, + Metric<L> metric) { + return calculateMetric(dataCache, filter, mdl, featureExtractor, lbExtractor, metric); + } + + /** + * Computes the given metric on the given cache. + * + * @param dataCache The given cache. + * @param filter The given filter. + * @param mdl The model. + * @param featureExtractor The feature extractor. + * @param lbExtractor The label extractor. + * @param metric The binary classification metric. + * @param <L> The type of label. + * @param <K> The type of cache entry key. + * @param <V> The type of cache entry value. + * @return Computed metric. + */ + public static <L, K, V> double evaluate(Map<K, V> dataCache, IgniteBiPredicate<K, V> filter, + Model<Vector, L> mdl, + IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, L> lbExtractor, + Metric<L> metric) { return calculateMetric(dataCache, filter, mdl, featureExtractor, lbExtractor, metric); } /** * Computes the given metrics on the given cache. * - * @param dataCache The given cache. - * @param mdl The model. + * @param dataCache The given cache. + * @param mdl The model. * @param featureExtractor The feature extractor. - * @param lbExtractor The label extractor. - * @param <K> The type of cache entry key. - * @param <V> The type of cache entry value. + * @param lbExtractor The label extractor. + * @param <K> The type of cache entry key. + * @param <V> The type of cache entry value. * @return Computed metric. */ public static <K, V> BinaryClassificationMetricValues evaluate(IgniteCache<K, V> dataCache, - Model<Vector, Double> mdl, - IgniteBiFunction<K, V, Vector> featureExtractor, - IgniteBiFunction<K, V, Double> lbExtractor) { + Model<Vector, Double> mdl, + IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor) { + return calcMetricValues(dataCache, null, mdl, featureExtractor, lbExtractor); + } + + /** + * Computes the given metrics on the given cache. + * + * @param dataCache The given cache. + * @param mdl The model. + * @param featureExtractor The feature extractor. + * @param lbExtractor The label extractor. + * @param <K> The type of cache entry key. + * @param <V> The type of cache entry value. + * @return Computed metric. + */ + public static <K, V> BinaryClassificationMetricValues evaluate(Map<K, V> dataCache, + Model<Vector, Double> mdl, + IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor) { return calcMetricValues(dataCache, null, mdl, featureExtractor, lbExtractor); } /** * Computes the given metrics on the given cache. * - * @param dataCache The given cache. - * @param filter The given filter. - * @param mdl The model. + * @param dataCache The given cache. + * @param filter The given filter. + * @param mdl The model. * @param featureExtractor The feature extractor. - * @param lbExtractor The label extractor. - * @param <K> The type of cache entry key. - * @param <V> The type of cache entry value. + * @param lbExtractor The label extractor. + * @param <K> The type of cache entry key. + * @param <V> The type of cache entry value. * @return Computed metric. */ - public static <K, V> BinaryClassificationMetricValues evaluate(IgniteCache<K, V> dataCache, IgniteBiPredicate<K, V> filter, - Model<Vector, Double> mdl, - IgniteBiFunction<K, V, Vector> featureExtractor, - IgniteBiFunction<K, V, Double> lbExtractor) { + public static <K, V> BinaryClassificationMetricValues evaluate(IgniteCache<K, V> dataCache, + IgniteBiPredicate<K, V> filter, + Model<Vector, Double> mdl, + IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor) { return calcMetricValues(dataCache, filter, mdl, featureExtractor, lbExtractor); } /** * Computes the given metrics on the given cache. * - * @param dataCache The given cache. - * @param filter The given filter. - * @param mdl The model. + * @param dataCache The given cache. + * @param filter The given filter. + * @param mdl The model. * @param featureExtractor The feature extractor. - * @param lbExtractor The label extractor. - * @param <K> The type of cache entry key. - * @param <V> The type of cache entry value. + * @param lbExtractor The label extractor. + * @param <K> The type of cache entry key. + * @param <V> The type of cache entry value. + * @return Computed metric. + */ + public static <K, V> BinaryClassificationMetricValues evaluate(Map<K, V> dataCache, IgniteBiPredicate<K, V> filter, + Model<Vector, Double> mdl, + IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor) { + return calcMetricValues(dataCache, filter, mdl, featureExtractor, lbExtractor); + } + + /** + * Computes the given metrics on the given cache. + * + * @param dataCache The given cache. + * @param filter The given filter. + * @param mdl The model. + * @param featureExtractor The feature extractor. + * @param lbExtractor The label extractor. + * @param <K> The type of cache entry key. + * @param <V> The type of cache entry value. * @return Computed metric. */ private static <K, V> BinaryClassificationMetricValues calcMetricValues(IgniteCache<K, V> dataCache, - IgniteBiPredicate<K, V> filter, - Model<Vector, Double> mdl, - IgniteBiFunction<K, V, Vector> featureExtractor, - IgniteBiFunction<K, V, Double> lbExtractor) { + IgniteBiPredicate<K, V> filter, + Model<Vector, Double> mdl, + IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor) { BinaryClassificationMetricValues metricValues; BinaryClassificationMetrics binaryMetrics = new BinaryClassificationMetrics(); @@ -139,7 +221,44 @@ public class BinaryClassificationEvaluator { mdl )) { metricValues = binaryMetrics.scoreAll(cursor.iterator()); - } catch (Exception e) { + } + catch (Exception e) { + throw new RuntimeException(e); + } + + return metricValues; + } + + /** + * Computes the given metrics on the given cache. + * + * @param dataCache The given cache. + * @param filter The given filter. + * @param mdl The model. + * @param featureExtractor The feature extractor. + * @param lbExtractor The label extractor. + * @param <K> The type of cache entry key. + * @param <V> The type of cache entry value. + * @return Computed metric. + */ + private static <K, V> BinaryClassificationMetricValues calcMetricValues(Map<K, V> dataCache, + IgniteBiPredicate<K, V> filter, + Model<Vector, Double> mdl, + IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor) { + BinaryClassificationMetricValues metricValues; + BinaryClassificationMetrics binaryMetrics = new BinaryClassificationMetrics(); + + try (LabelPairCursor<Double> cursor = new LocalLabelPairCursor<>( + dataCache, + filter, + featureExtractor, + lbExtractor, + mdl + )) { + metricValues = binaryMetrics.scoreAll(cursor.iterator()); + } + catch (Exception e) { throw new RuntimeException(e); } @@ -149,20 +268,20 @@ public class BinaryClassificationEvaluator { /** * Computes the given metric on the given cache. * - * @param dataCache The given cache. - * @param filter The given filter. - * @param mdl The model. + * @param dataCache The given cache. + * @param filter The given filter. + * @param mdl The model. * @param featureExtractor The feature extractor. - * @param lbExtractor The label extractor. - * @param metric The binary classification metric. - * @param <L> The type of label. - * @param <K> The type of cache entry key. - * @param <V> The type of cache entry value. + * @param lbExtractor The label extractor. + * @param metric The binary classification metric. + * @param <L> The type of label. + * @param <K> The type of cache entry key. + * @param <V> The type of cache entry value. * @return Computed metric. */ private static <L, K, V> double calculateMetric(IgniteCache<K, V> dataCache, IgniteBiPredicate<K, V> filter, - Model<Vector, L> mdl, IgniteBiFunction<K, V, Vector> featureExtractor, - IgniteBiFunction<K, V, L> lbExtractor, Metric<L> metric) { + Model<Vector, L> mdl, IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, L> lbExtractor, Metric<L> metric) { double metricRes; try (LabelPairCursor<L> cursor = new CacheBasedLabelPairCursor<>( @@ -173,7 +292,43 @@ public class BinaryClassificationEvaluator { mdl )) { metricRes = metric.score(cursor.iterator()); - } catch (Exception e) { + } + catch (Exception e) { + throw new RuntimeException(e); + } + + return metricRes; + } + + /** + * Computes the given metric on the given cache. + * + * @param dataCache The given cache. + * @param filter The given filter. + * @param mdl The model. + * @param featureExtractor The feature extractor. + * @param lbExtractor The label extractor. + * @param metric The binary classification metric. + * @param <L> The type of label. + * @param <K> The type of cache entry key. + * @param <V> The type of cache entry value. + * @return Computed metric. + */ + private static <L, K, V> double calculateMetric(Map<K, V> dataCache, IgniteBiPredicate<K, V> filter, + Model<Vector, L> mdl, IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, L> lbExtractor, Metric<L> metric) { + double metricRes; + + try (LabelPairCursor<L> cursor = new LocalLabelPairCursor<>( + dataCache, + filter, + featureExtractor, + lbExtractor, + mdl + )) { + metricRes = metric.score(cursor.iterator()); + } + catch (Exception e) { throw new RuntimeException(e); } http://git-wip-us.apache.org/repos/asf/ignite/blob/05fef60d/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetrics.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetrics.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetrics.java index bd4067a..35da9fa 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetrics.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetrics.java @@ -74,7 +74,8 @@ public class BinaryClassificationMetrics implements Metric<Double> { /** */ public BinaryClassificationMetrics withPositiveClsLb(double positiveClsLb) { - this.positiveClsLb = positiveClsLb; + if (Double.isFinite(positiveClsLb)) + this.positiveClsLb = positiveClsLb; return this; } @@ -85,13 +86,15 @@ public class BinaryClassificationMetrics implements Metric<Double> { /** */ public BinaryClassificationMetrics withNegativeClsLb(double negativeClsLb) { - this.negativeClsLb = negativeClsLb; + if (Double.isFinite(negativeClsLb)) + this.negativeClsLb = negativeClsLb; return this; } /** */ public BinaryClassificationMetrics withMetric(Function<BinaryClassificationMetricValues, Double> metric) { - this.metric = metric; + if (metric != null) + this.metric = metric; return this; } http://git-wip-us.apache.org/repos/asf/ignite/blob/05fef60d/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java index 161a40c..3f715dc 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java @@ -302,7 +302,7 @@ public abstract class DatasetTrainer<M extends Model, L> { // TODO: IGNITE-10441 Think about more elegant ways to perform fluent API. public DatasetTrainer<M, L> withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) { this.envBuilder = envBuilder; - this.environment = envBuilder.buildForTrainer(); + environment = envBuilder.buildForTrainer(); return this; } http://git-wip-us.apache.org/repos/asf/ignite/blob/05fef60d/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java index 748123a..6fe8a63 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java @@ -85,9 +85,9 @@ public class KNNClassificationTest { .withDistanceMeasure(new EuclideanDistance()) .withStrategy(NNStrategy.SIMPLE); - assertTrue(knnMdl.toString().length() > 0); - assertTrue(knnMdl.toString(true).length() > 0); - assertTrue(knnMdl.toString(false).length() > 0); + assertTrue(!knnMdl.toString().isEmpty()); + assertTrue(!knnMdl.toString(true).isEmpty()); + assertTrue(!knnMdl.toString(false).isEmpty()); Vector firstVector = new DenseVector(new double[] {2.0, 2.0}); assertEquals(knnMdl.apply(firstVector), 1.0); http://git-wip-us.apache.org/repos/asf/ignite/blob/05fef60d/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java index e2f8feb..0f62c92 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java @@ -23,8 +23,11 @@ import org.apache.ignite.ml.selection.cv.CrossValidationTest; import org.apache.ignite.ml.selection.paramgrid.ParameterSetGeneratorTest; import org.apache.ignite.ml.selection.scoring.cursor.CacheBasedLabelPairCursorTest; import org.apache.ignite.ml.selection.scoring.cursor.LocalLabelPairCursorTest; +import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluatorTest; import org.apache.ignite.ml.selection.scoring.evaluator.EvaluatorTest; import org.apache.ignite.ml.selection.scoring.metric.AccuracyTest; +import org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetricsTest; +import org.apache.ignite.ml.selection.scoring.metric.BinaryClassificationMetricsValuesTest; import org.apache.ignite.ml.selection.scoring.metric.FmeasureTest; import org.apache.ignite.ml.selection.scoring.metric.PrecisionTest; import org.apache.ignite.ml.selection.scoring.metric.RecallTest; @@ -53,6 +56,10 @@ public class SelectionTestSuite { suite.addTest(new JUnit4TestAdapter(TrainTestDatasetSplitterTest.class)); suite.addTest(new JUnit4TestAdapter(EvaluatorTest.class)); suite.addTest(new JUnit4TestAdapter(CacheBasedLabelPairCursorTest.class)); + suite.addTest(new JUnit4TestAdapter(BinaryClassificationMetricsTest.class)); + suite.addTest(new JUnit4TestAdapter(BinaryClassificationMetricsValuesTest.class)); + suite.addTest(new JUnit4TestAdapter(BinaryClassificationEvaluatorTest.class)); + return suite; } http://git-wip-us.apache.org/repos/asf/ignite/blob/05fef60d/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/BinaryClassificationEvaluatorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/BinaryClassificationEvaluatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/BinaryClassificationEvaluatorTest.java new file mode 100644 index 0000000..c6222c8 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/BinaryClassificationEvaluatorTest.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.selection.scoring.evaluator; + +import java.util.HashMap; +import java.util.Map; +import org.apache.ignite.ml.common.TrainerTest; +import org.apache.ignite.ml.knn.NNClassificationModel; +import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.apache.ignite.ml.selection.scoring.metric.Accuracy; +import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter; +import org.apache.ignite.ml.selection.split.TrainTestSplit; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * Tests for {@link BinaryClassificationEvaluator}. + */ +public class BinaryClassificationEvaluatorTest extends TrainerTest { + /** + * Test evalutor and trainer on classification model y = x. + */ + @Test + public void testEvaluatorWithoutFilter() { + Map<Integer, Vector> cacheMock = new HashMap<>(); + + for (int i = 0; i < twoLinearlySeparableClasses.length; i++) + cacheMock.put(i, VectorUtils.of(twoLinearlySeparableClasses[i])); + + KNNClassificationTrainer trainer = new KNNClassificationTrainer(); + + IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size()); + IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0); + + NNClassificationModel mdl = trainer.fit( + cacheMock, + parts, + featureExtractor, + lbExtractor + ).withK(3); + + double score = BinaryClassificationEvaluator.evaluate(cacheMock, mdl, featureExtractor, lbExtractor, new Accuracy<>()); + + assertEquals(0.9839357429718876, score, 1e-12); + } + + /** + * Test evalutor and trainer on classification model y = x. + */ + @Test + public void testEvaluatorWithFilter() { + Map<Integer, Vector> cacheMock = new HashMap<>(); + + for (int i = 0; i < twoLinearlySeparableClasses.length; i++) + cacheMock.put(i, VectorUtils.of(twoLinearlySeparableClasses[i])); + + KNNClassificationTrainer trainer = new KNNClassificationTrainer(); + + IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size()); + IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0); + + TrainTestSplit<Integer, Vector> split = new TrainTestDatasetSplitter<Integer, Vector>() + .split(0.75); + + NNClassificationModel mdl = trainer.fit( + cacheMock, + split.getTrainFilter(), + parts, + featureExtractor, + lbExtractor + ).withK(3); + + double score = BinaryClassificationEvaluator.evaluate(cacheMock, mdl, featureExtractor, lbExtractor, new Accuracy<>()); + + assertEquals(0.9, score, 1); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/05fef60d/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/EvaluatorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/EvaluatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/EvaluatorTest.java index 5025460..9ce35a0 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/EvaluatorTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/evaluator/EvaluatorTest.java @@ -203,7 +203,7 @@ public class EvaluatorTest extends GridCommonAbstractTest { /** */ private void assertResults(CrossValidationResult res, List<double[]> scores, double accuracy, double accuracy2) { - assertTrue(res.toString().length() > 0); + assertTrue(!res.toString().isEmpty()); assertEquals("Best maxDeep", 1.0, res.getBest("maxDeep")); assertEquals("Best minImpurityDecrease", 0.0, res.getBest("minImpurityDecrease")); assertArrayEquals("Best score", new double[] {0.6666666666666666, 0.6, 0}, res.getBestScore(), 0); http://git-wip-us.apache.org/repos/asf/ignite/blob/05fef60d/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetricsTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetricsTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetricsTest.java new file mode 100644 index 0000000..a173f5e --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetricsTest.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.selection.scoring.metric; + +import java.util.Arrays; +import org.apache.ignite.ml.selection.scoring.TestLabelPairCursor; +import org.apache.ignite.ml.selection.scoring.cursor.LabelPairCursor; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * Tests for {@link BinaryClassificationMetrics}. + */ +public class BinaryClassificationMetricsTest { + /** */ + @Test + public void testDefaultBehaviour() { + Metric scoreCalculator = new BinaryClassificationMetrics(); + + LabelPairCursor<Double> cursor = new TestLabelPairCursor<>( + Arrays.asList(1.0, 1.0, 1.0, 1.0), + Arrays.asList(1.0, 1.0, 0.0, 1.0) + ); + + double score = scoreCalculator.score(cursor.iterator()); + + assertEquals(0.75, score, 1e-12); + } + + /** */ + @Test + public void testDefaultBehaviourForScoreAll() { + BinaryClassificationMetrics scoreCalculator = new BinaryClassificationMetrics(); + + LabelPairCursor<Double> cursor = new TestLabelPairCursor<>( + Arrays.asList(1.0, 1.0, 1.0, 1.0), + Arrays.asList(1.0, 1.0, 0.0, 1.0) + ); + + BinaryClassificationMetricValues metricValues = scoreCalculator.scoreAll(cursor.iterator()); + + assertEquals(0.75, metricValues.accuracy(), 1e-12); + } + + /** */ + @Test + public void testAccuracy() { + Metric scoreCalculator = new BinaryClassificationMetrics() + .withNegativeClsLb(1.0) + .withPositiveClsLb(2.0); + + LabelPairCursor<Double> cursor = new TestLabelPairCursor<>( + Arrays.asList(2.0, 2.0, 2.0, 2.0), + Arrays.asList(2.0, 2.0, 1.0, 2.0) + ); + + double score = scoreCalculator.score(cursor.iterator()); + + assertEquals(0.75, score, 1e-12); + } + + /** */ + @Test + public void testCustomMetric() { + Metric scoreCalculator = new BinaryClassificationMetrics() + .withNegativeClsLb(1.0) + .withPositiveClsLb(2.0) + .withMetric(BinaryClassificationMetricValues::tp); + + LabelPairCursor<Double> cursor = new TestLabelPairCursor<>( + Arrays.asList(2.0, 2.0, 2.0, 2.0), + Arrays.asList(2.0, 2.0, 1.0, 2.0) + ); + + double score = scoreCalculator.score(cursor.iterator()); + + assertEquals(3, score, 1e-12); + } + + /** */ + @Test + public void testNullCustomMetric() { + Metric scoreCalculator = new BinaryClassificationMetrics() + .withNegativeClsLb(1.0) + .withPositiveClsLb(2.0) + .withMetric(null); + + LabelPairCursor<Double> cursor = new TestLabelPairCursor<>( + Arrays.asList(2.0, 2.0, 2.0, 2.0), + Arrays.asList(2.0, 2.0, 1.0, 2.0) + ); + + double score = scoreCalculator.score(cursor.iterator()); + + // accuracy as default metric + assertEquals(0.75, score, 1e-12); + } + + /** */ + @Test + public void testNaNinClassLabels() { + Metric scoreCalculator = new BinaryClassificationMetrics() + .withNegativeClsLb(Double.NaN) + .withPositiveClsLb(Double.POSITIVE_INFINITY); + + LabelPairCursor<Double> cursor = new TestLabelPairCursor<>( + Arrays.asList(1.0, 1.0, 1.0, 1.0), + Arrays.asList(1.0, 1.0, 0.0, 1.0) + ); + + double score = scoreCalculator.score(cursor.iterator()); + + // accuracy as default metric + assertEquals(0.75, score, 1e-12); + } + + /** */ + @Test(expected = org.apache.ignite.ml.selection.scoring.metric.UnknownClassLabelException.class) + public void testFailWithIncorrectClassLabelsInData() { + Metric scoreCalculator = new BinaryClassificationMetrics(); + + LabelPairCursor<Double> cursor = new TestLabelPairCursor<>( + Arrays.asList(2.0, 2.0, 2.0, 2.0), + Arrays.asList(2.0, 2.0, 1.0, 2.0) + ); + + scoreCalculator.score(cursor.iterator()); + } + + /** */ + @Test(expected = org.apache.ignite.ml.selection.scoring.metric.UnknownClassLabelException.class) + public void testFailWithIncorrectClassLabelsInMetrics() { + Metric scoreCalculator = new BinaryClassificationMetrics() + .withPositiveClsLb(42); + + LabelPairCursor<Double> cursor = new TestLabelPairCursor<>( + Arrays.asList(1.0, 1.0, 1.0, 1.0), + Arrays.asList(1.0, 1.0, 0.0, 1.0) + ); + + scoreCalculator.score(cursor.iterator()); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/05fef60d/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetricsValuesTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetricsValuesTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetricsValuesTest.java new file mode 100644 index 0000000..75a8183 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/BinaryClassificationMetricsValuesTest.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.selection.scoring.metric; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * Tests for {@link BinaryClassificationMetrics}. + */ +public class BinaryClassificationMetricsValuesTest { + /** */ + @Test + public void testDefaultBehaviour() { + BinaryClassificationMetricValues metricValues = new BinaryClassificationMetricValues(10, 10, 5, 5); + + assertEquals(10, metricValues.tp(), 1e-2); + assertEquals(10, metricValues.tn(), 1e-2); + assertEquals(5, metricValues.fn(), 1e-2); + assertEquals(5, metricValues.fp(), 1e-2); + assertEquals(0.66, metricValues.accuracy(), 1e-2); + assertEquals(0.66, metricValues.balancedAccuracy(), 1e-2); + assertEquals(0.66, metricValues.f1Score(), 1e-2); + assertEquals(0.33, metricValues.fallOut(), 1e-2); + assertEquals(0.33, metricValues.fdr(), 1e-2); + assertEquals(0.33, metricValues.missRate(), 1e-2); + assertEquals(0.66, metricValues.npv(), 1e-2); + assertEquals(0.66, metricValues.precision(), 1e-2); + assertEquals(0.66, metricValues.recall(), 1e-2); + assertEquals(0.66, metricValues.specificity(), 1e-2); + } +}