Repository: ignite
Updated Branches:
  refs/heads/master f2d6e4360 -> 388f7ffc4


IGNITE-10719: [ML] LearningEnvironmentBuilder is not passed in
makeBagged

This closes #5684


Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/388f7ffc
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/388f7ffc
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/388f7ffc

Branch: refs/heads/master
Commit: 388f7ffc448ed44f1a00b03a7bb7f57d57bc117a
Parents: f2d6e43
Author: Artem Malykh <[email protected]>
Authored: Tue Dec 18 22:00:10 2018 +0300
Committer: Yury Babak <[email protected]>
Committed: Tue Dec 18 22:00:10 2018 +0300

----------------------------------------------------------------------
 .../org/apache/ignite/ml/trainers/TrainerTransformers.java     | 2 +-
 .../ml/tree/randomforest/RandomForestClassifierTrainer.java    | 6 ++++++
 .../tree/randomforest/RandomForestClassifierTrainerTest.java   | 4 +++-
 3 files changed, 10 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/ignite/blob/388f7ffc/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java
index 1019a39..80a57e0 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java
@@ -122,7 +122,7 @@ public class TrainerTransformers {
                     aggregator,
                     environment);
             }
-        };
+        }.withEnvironmentBuilder(trainer.envBuilder);
     }
 
     /**

http://git-wip-us.apache.org/repos/asf/ignite/blob/388f7ffc/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java
index 7832584..a76a941 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java
@@ -30,6 +30,7 @@ import org.apache.ignite.ml.dataset.feature.ObjectHistogram;
 import 
org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedDatasetPartition;
 import org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector;
 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
 import org.apache.ignite.ml.tree.randomforest.data.TreeRoot;
 import org.apache.ignite.ml.tree.randomforest.data.impurity.GiniHistogram;
 import 
org.apache.ignite.ml.tree.randomforest.data.impurity.GiniHistogramsComputer;
@@ -110,4 +111,9 @@ public class RandomForestClassifierTrainer
     @Override protected 
LeafValuesComputer<ObjectHistogram<BootstrappedVector>> 
createLeafStatisticsAggregator() {
         return new ClassifierLeafValuesComputer(lblMapping);
     }
+
+    /** {@inheritDoc} */
+    @Override public RandomForestClassifierTrainer 
withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
+        return 
(RandomForestClassifierTrainer)super.withEnvironmentBuilder(envBuilder);
+    }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/388f7ffc/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java
index 3a038ff..7d282df 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java
@@ -20,6 +20,7 @@ package org.apache.ignite.ml.tree.randomforest;
 import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.Map;
+import org.apache.ignite.ml.TestUtils;
 import org.apache.ignite.ml.common.TrainerTest;
 import org.apache.ignite.ml.composition.ModelsComposition;
 import 
org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator;
@@ -81,7 +82,8 @@ public class RandomForestClassifierTrainerTest extends 
TrainerTest {
             meta.add(new FeatureMeta("", i, false));
         RandomForestClassifierTrainer trainer = new 
RandomForestClassifierTrainer(meta)
             .withAmountOfTrees(100)
-            .withFeaturesCountSelectionStrgy(x -> 2);
+            .withFeaturesCountSelectionStrgy(x -> 2)
+            .withEnvironmentBuilder(TestUtils.testEnvBuilder());
 
         ModelsComposition originalMdl = trainer.fit(sample, parts, (k, v) -> 
VectorUtils.of(k), (k, v) -> v);
         ModelsComposition updatedOnSameDS = trainer.update(originalMdl, 
sample, parts, (k, v) -> VectorUtils.of(k), (k, v) -> v);

Reply via email to