IGNITE-8542: [ML] Add OneVsRest Trainer to handle cases with multiple class labels in dataset.
This closes #5512 Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/c3fd4a93 Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/c3fd4a93 Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/c3fd4a93 Branch: refs/heads/ignite-9720 Commit: c3fd4a930cc1a76b4d1fbccc6d764bdfe88da941 Parents: 3885f3f Author: zaleslaw <[email protected]> Authored: Wed Nov 28 01:45:11 2018 +0300 Committer: Yury Babak <[email protected]> Committed: Wed Nov 28 01:45:11 2018 +0300 ---------------------------------------------------------------------- .../ignite/ml/multiclass/MultiClassModel.java | 115 +++++++++++++++ .../ignite/ml/multiclass/OneVsRestTrainer.java | 147 +++++++++++++++++++ .../org/apache/ignite/ml/IgniteMLTestSuite.java | 4 +- .../ml/multiclass/MultiClassTestSuite.java | 32 ++++ .../ml/multiclass/OneVsRestTrainerTest.java | 126 ++++++++++++++++ 5 files changed, 423 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/c3fd4a93/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/MultiClassModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/MultiClassModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/MultiClassModel.java new file mode 100644 index 0000000..8520aa9 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/MultiClassModel.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.multiclass; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.TreeMap; +import org.apache.ignite.ml.Exportable; +import org.apache.ignite.ml.Exporter; +import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.math.primitives.vector.Vector; + +/** Base class for multi-classification model for set of classifiers. */ +public class MultiClassModel<M extends Model<Vector, Double>> implements Model<Vector, Double>, Exportable<MultiClassModel>, Serializable { + /** */ + private static final long serialVersionUID = -114986533359917L; + + /** List of models associated with each class. */ + private Map<Double, M> models; + + /** */ + public MultiClassModel() { + this.models = new HashMap<>(); + } + + /** + * Adds a specific binary classifier to the bunch of same classifiers. + * + * @param clsLb The class label for the added model. + * @param mdl The model. + */ + public void add(double clsLb, M mdl) { + models.put(clsLb, mdl); + } + + /** + * @param clsLb Class label. + * @return model for class label if it exists. + */ + public Optional<M> getModel(Double clsLb) { + return Optional.ofNullable(models.get(clsLb)); + } + + /** {@inheritDoc} */ + @Override public Double apply(Vector input) { + TreeMap<Double, Double> maxMargins = new TreeMap<>(); + + models.forEach((k, v) -> maxMargins.put(v.apply(input), k)); + + // returns value the most closest to 1 + return maxMargins.lastEntry().getValue(); + } + + /** {@inheritDoc} */ + @Override public <P> void saveModel(Exporter<MultiClassModel, P> exporter, P path) { + exporter.save(this, path); + } + + /** {@inheritDoc} */ + @Override public boolean equals(Object o) { + if (this == o) + return true; + + if (o == null || getClass() != o.getClass()) + return false; + + MultiClassModel mdl = (MultiClassModel)o; + + return Objects.equals(models, mdl.models); + } + + /** {@inheritDoc} */ + @Override public int hashCode() { + return Objects.hash(models); + } + + /** {@inheritDoc} */ + @Override public String toString() { + StringBuilder wholeStr = new StringBuilder(); + + models.forEach((clsLb, mdl) -> + wholeStr + .append("The class with label ") + .append(clsLb) + .append(" has classifier: ") + .append(mdl.toString()) + .append(System.lineSeparator()) + ); + + return wholeStr.toString(); + } + + /** {@inheritDoc} */ + @Override public String toString(boolean pretty) { + return toString(); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/c3fd4a93/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/OneVsRestTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/OneVsRestTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/OneVsRestTrainer.java new file mode 100644 index 0000000..7426506 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/OneVsRestTrainer.java @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.multiclass; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.dataset.Dataset; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.dataset.PartitionDataBuilder; +import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.structures.partition.LabelPartitionDataBuilderOnHeap; +import org.apache.ignite.ml.structures.partition.LabelPartitionDataOnHeap; +import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer; + +/** + * This is a common heuristic trainer for multi-class labeled models. + * + * NOTE: The current implementation suffers from unbalanced training over the dataset due to unweighted approach + * during the process of reassign labels from all range of labels to 0,1. + */ +public class OneVsRestTrainer<M extends Model<Vector, Double>> + extends SingleLabelDatasetTrainer<MultiClassModel<M>> { + /** The common binary classifier with all hyper-parameters to spread them for all separate trainings . */ + private SingleLabelDatasetTrainer<M> classifier; + + /** */ + public OneVsRestTrainer(SingleLabelDatasetTrainer<M> classifier) { + this.classifier = classifier; + } + + /** + * Trains model based on the specified data. + * + * @param datasetBuilder Dataset builder. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + * @return Model. + */ + @Override public <K, V> MultiClassModel<M> fit(DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor) { + + return updateModel(null, datasetBuilder, featureExtractor, lbExtractor); + } + + /** {@inheritDoc} */ + @Override public <K, V> MultiClassModel<M> updateModel(MultiClassModel<M> newMdl, + DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor) { + + List<Double> classes = extractClassLabels(datasetBuilder, lbExtractor); + + if (classes.isEmpty()) + return getLastTrainedModelOrThrowEmptyDatasetException(newMdl); + + MultiClassModel<M> multiClsMdl = new MultiClassModel<>(); + + classes.forEach(clsLb -> { + IgniteBiFunction<K, V, Double> lbTransformer = (k, v) -> { + Double lb = lbExtractor.apply(k, v); + + if (lb.equals(clsLb)) + return 1.0; + else + return 0.0; + }; + + M mdl = Optional.ofNullable(newMdl) + .flatMap(multiClassModel -> multiClassModel.getModel(clsLb)) + .map(learnedModel -> classifier.update(learnedModel, datasetBuilder, featureExtractor, lbTransformer)) + .orElseGet(() -> classifier.fit(datasetBuilder, featureExtractor, lbTransformer)); + + multiClsMdl.add(clsLb, mdl); + }); + + return multiClsMdl; + } + + /** {@inheritDoc} */ + @Override protected boolean checkState(MultiClassModel<M> mdl) { + return true; + } + + /** Iterates among dataset and collects class labels. */ + private <K, V> List<Double> extractClassLabels(DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Double> lbExtractor) { + assert datasetBuilder != null; + + PartitionDataBuilder<K, V, EmptyContext, LabelPartitionDataOnHeap> partDataBuilder = new LabelPartitionDataBuilderOnHeap<>(lbExtractor); + + List<Double> res = new ArrayList<>(); + + try (Dataset<EmptyContext, LabelPartitionDataOnHeap> dataset = datasetBuilder.build( + (upstream, upstreamSize) -> new EmptyContext(), + partDataBuilder + )) { + final Set<Double> clsLabels = dataset.compute(data -> { + final Set<Double> locClsLabels = new HashSet<>(); + + final double[] lbs = data.getY(); + + for (double lb : lbs) + locClsLabels.add(lb); + + return locClsLabels; + }, (a, b) -> { + if (a == null) + return b == null ? new HashSet<>() : b; + if (b == null) + return a; + return Stream.of(a, b).flatMap(Collection::stream).collect(Collectors.toSet()); + }); + + if (clsLabels != null) + res.addAll(clsLabels); + + } + catch (Exception e) { + throw new RuntimeException(e); + } + return res; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/c3fd4a93/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java index f9645d8..78d6659 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java @@ -26,6 +26,7 @@ import org.apache.ignite.ml.genetic.GAGridTestSuite; import org.apache.ignite.ml.inference.InferenceTestSuite; import org.apache.ignite.ml.knn.KNNTestSuite; import org.apache.ignite.ml.math.MathImplMainTestSuite; +import org.apache.ignite.ml.multiclass.MultiClassTestSuite; import org.apache.ignite.ml.nn.MLPTestSuite; import org.apache.ignite.ml.pipeline.PipelineTestSuite; import org.apache.ignite.ml.preprocessing.PreprocessingTestSuite; @@ -61,7 +62,8 @@ import org.junit.runners.Suite; StructuresTestSuite.class, CommonTestSuite.class, InferenceTestSuite.class, - BaggingTest.class + BaggingTest.class, + MultiClassTestSuite.class }) public class IgniteMLTestSuite { // No-op. http://git-wip-us.apache.org/repos/asf/ignite/blob/c3fd4a93/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/MultiClassTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/MultiClassTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/MultiClassTestSuite.java new file mode 100644 index 0000000..551597f --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/MultiClassTestSuite.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.multiclass; + +import org.junit.runner.RunWith; +import org.junit.runners.Suite; + +/** + * Test suite for multilayer perceptrons. + */ +@RunWith(Suite.class) [email protected]({ + OneVsRestTrainerTest.class +}) +public class MultiClassTestSuite { + // No-op. +} http://git-wip-us.apache.org/repos/asf/ignite/blob/c3fd4a93/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java new file mode 100644 index 0000000..9842d92 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.multiclass; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.ignite.ml.TestUtils; +import org.apache.ignite.ml.common.TrainerTest; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.apache.ignite.ml.nn.UpdatesStrategy; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; +import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; +import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer; +import org.junit.Assert; +import org.junit.Test; + +/** + * Tests for {@link OneVsRestTrainer}. + */ +public class OneVsRestTrainerTest extends TrainerTest { + /** + * Test trainer on 2 linearly separable sets. + */ + @Test + public void testTrainWithTheLinearlySeparableCase() { + Map<Integer, double[]> cacheMock = new HashMap<>(); + + for (int i = 0; i < twoLinearlySeparableClasses.length; i++) + cacheMock.put(i, twoLinearlySeparableClasses[i]); + + LogisticRegressionSGDTrainer<?> binaryTrainer = new LogisticRegressionSGDTrainer<>() + .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2), + SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg)) + .withMaxIterations(1000) + .withLocIterations(10) + .withBatchSize(100) + .withSeed(123L); + + OneVsRestTrainer<LogisticRegressionModel> trainer = new OneVsRestTrainer<>(binaryTrainer); + + MultiClassModel mdl = trainer.fit( + cacheMock, + parts, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), + (k, v) -> v[0] + ); + + Assert.assertTrue(mdl.toString().length() > 0); + Assert.assertTrue(mdl.toString(true).length() > 0); + Assert.assertTrue(mdl.toString(false).length() > 0); + + TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(-100, 0)), PRECISION); + TestUtils.assertEquals(0, mdl.apply(VectorUtils.of(100, 0)), PRECISION); + } + + /** */ + @Test + public void testUpdate() { + Map<Integer, double[]> cacheMock = new HashMap<>(); + + for (int i = 0; i < twoLinearlySeparableClasses.length; i++) + cacheMock.put(i, twoLinearlySeparableClasses[i]); + + LogisticRegressionSGDTrainer<?> binaryTrainer = new LogisticRegressionSGDTrainer<>() + .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2), + SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg)) + .withMaxIterations(1000) + .withLocIterations(10) + .withBatchSize(100) + .withSeed(123L); + + OneVsRestTrainer<LogisticRegressionModel> trainer = new OneVsRestTrainer<>(binaryTrainer); + + MultiClassModel originalMdl = trainer.fit( + cacheMock, + parts, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), + (k, v) -> v[0] + ); + + MultiClassModel updatedOnSameDS = trainer.update( + originalMdl, + cacheMock, + parts, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), + (k, v) -> v[0] + ); + + MultiClassModel updatedOnEmptyDS = trainer.update( + originalMdl, + new HashMap<Integer, double[]>(), + parts, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), + (k, v) -> v[0] + ); + + List<Vector> vectors = Arrays.asList( + VectorUtils.of(-100, 0), + VectorUtils.of(100, 0) + ); + + for (Vector vec : vectors) { + TestUtils.assertEquals(originalMdl.apply(vec), updatedOnSameDS.apply(vec), PRECISION); + TestUtils.assertEquals(originalMdl.apply(vec), updatedOnEmptyDS.apply(vec), PRECISION); + } + } +}
