Repository: ignite Updated Branches: refs/heads/master f2d6e4360 -> 388f7ffc4
IGNITE-10719: [ML] LearningEnvironmentBuilder is not passed in makeBagged This closes #5684 Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/388f7ffc Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/388f7ffc Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/388f7ffc Branch: refs/heads/master Commit: 388f7ffc448ed44f1a00b03a7bb7f57d57bc117a Parents: f2d6e43 Author: Artem Malykh <[email protected]> Authored: Tue Dec 18 22:00:10 2018 +0300 Committer: Yury Babak <[email protected]> Committed: Tue Dec 18 22:00:10 2018 +0300 ---------------------------------------------------------------------- .../org/apache/ignite/ml/trainers/TrainerTransformers.java | 2 +- .../ml/tree/randomforest/RandomForestClassifierTrainer.java | 6 ++++++ .../tree/randomforest/RandomForestClassifierTrainerTest.java | 4 +++- 3 files changed, 10 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/388f7ffc/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java index 1019a39..80a57e0 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java @@ -122,7 +122,7 @@ public class TrainerTransformers { aggregator, environment); } - }; + }.withEnvironmentBuilder(trainer.envBuilder); } /** http://git-wip-us.apache.org/repos/asf/ignite/blob/388f7ffc/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java index 7832584..a76a941 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java @@ -30,6 +30,7 @@ import org.apache.ignite.ml.dataset.feature.ObjectHistogram; import org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedDatasetPartition; import org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector; import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; import org.apache.ignite.ml.tree.randomforest.data.TreeRoot; import org.apache.ignite.ml.tree.randomforest.data.impurity.GiniHistogram; import org.apache.ignite.ml.tree.randomforest.data.impurity.GiniHistogramsComputer; @@ -110,4 +111,9 @@ public class RandomForestClassifierTrainer @Override protected LeafValuesComputer<ObjectHistogram<BootstrappedVector>> createLeafStatisticsAggregator() { return new ClassifierLeafValuesComputer(lblMapping); } + + /** {@inheritDoc} */ + @Override public RandomForestClassifierTrainer withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) { + return (RandomForestClassifierTrainer)super.withEnvironmentBuilder(envBuilder); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/388f7ffc/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java index 3a038ff..7d282df 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java @@ -20,6 +20,7 @@ package org.apache.ignite.ml.tree.randomforest; import java.util.ArrayList; import java.util.HashMap; import java.util.Map; +import org.apache.ignite.ml.TestUtils; import org.apache.ignite.ml.common.TrainerTest; import org.apache.ignite.ml.composition.ModelsComposition; import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator; @@ -81,7 +82,8 @@ public class RandomForestClassifierTrainerTest extends TrainerTest { meta.add(new FeatureMeta("", i, false)); RandomForestClassifierTrainer trainer = new RandomForestClassifierTrainer(meta) .withAmountOfTrees(100) - .withFeaturesCountSelectionStrgy(x -> 2); + .withFeaturesCountSelectionStrgy(x -> 2) + .withEnvironmentBuilder(TestUtils.testEnvBuilder()); ModelsComposition originalMdl = trainer.fit(sample, parts, (k, v) -> VectorUtils.of(k), (k, v) -> v); ModelsComposition updatedOnSameDS = trainer.update(originalMdl, sample, parts, (k, v) -> VectorUtils.of(k), (k, v) -> v);
