artemmalykh commented on a change in pull request #5767: [ML] IGNITE-10573: 
Consistent API for Ensemble training
URL: https://github.com/apache/ignite/pull/5767#discussion_r247901728
 
 

 ##########
 File path: 
modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedDatasetTrainer.java
 ##########
 @@ -254,62 +233,23 @@ public StackedDatasetTrainer() {
         IgniteBiFunction<K, V, Vector> featureExtractor,
         IgniteBiFunction<K, V, L> lbExtractor) {
 
-        return update(null, datasetBuilder, featureExtractor, lbExtractor);
+        return new StackedModel<>(getTrainer().fit(datasetBuilder, 
featureExtractor, lbExtractor));
     }
 
     /** {@inheritDoc} */
     @Override public <K, V> StackedModel<IS, IA, O, AM> 
update(StackedModel<IS, IA, O, AM> mdl,
         DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> 
featureExtractor,
         IgniteBiFunction<K, V, L> lbExtractor) {
-        return runOnSubmodels(
-            ensemble -> {
-                List<IgniteSupplier<IgniteModel<IS, IA>>> res = new 
ArrayList<>();
-                for (int i = 0; i < ensemble.size(); i++) {
-                    final int j = i;
-                    res.add(() -> {
-                        DatasetTrainer<IgniteModel<IS, IA>, L> trainer = 
ensemble.get(j);
-                        return mdl == null ?
-                            trainer.fit(datasetBuilder, featureExtractor, 
lbExtractor) :
-                            trainer.update(mdl.submodels().get(j), 
datasetBuilder, featureExtractor, lbExtractor);
-                    });
-                }
-                return res;
-            },
-            (at, extr) -> mdl == null ?
-                at.fit(datasetBuilder, extr, lbExtractor) :
-                at.update(mdl.aggregatorModel(), datasetBuilder, extr, 
lbExtractor),
-            featureExtractor
-        );
-    }
 
-    /** {@inheritDoc} */
-    @Override public StackedDatasetTrainer<IS, IA, O, AM, L> 
withEnvironmentBuilder(
-        LearningEnvironmentBuilder envBuilder) {
-        submodelsTrainers =
-            submodelsTrainers.stream().map(x -> 
x.withEnvironmentBuilder(envBuilder)).collect(Collectors.toList());
-        aggregatorTrainer = 
aggregatorTrainer.withEnvironmentBuilder(envBuilder);
-
-        return this;
+        return new StackedModel<>(getTrainer().update(mdl, datasetBuilder, 
featureExtractor, lbExtractor));
     }
 
     /**
-     * <pre>
-     * 1. Obtain models produced by running specified tasks;
-     * 2. run other specified task on dataset augmented with results of models 
from step 2.
-     * </pre>
+     * Get the trainer for stacking.
      *
-     * @param taskSupplier Function used to generate tasks for first step.
-     * @param aggregatorProcessor Function used
-     * @param featureExtractor Feature extractor.
-     * @param <K> Type of keys in upstream.
-     * @param <V> Type of values in upstream.
-     * @return {@link StackedModel}.
+     * @return Trainer for stacking.
      */
-    private <K, V> StackedModel<IS, IA, O, AM> runOnSubmodels(
-        IgniteFunction<List<DatasetTrainer<IgniteModel<IS, IA>, L>>, 
List<IgniteSupplier<IgniteModel<IS, IA>>>> taskSupplier,
-        IgniteBiFunction<DatasetTrainer<AM, L>, IgniteBiFunction<K, V, 
Vector>, AM> aggregatorProcessor,
-        IgniteBiFunction<K, V, Vector> featureExtractor) {
-
+    private DatasetTrainer<IgniteModel<IS, O>, L> getTrainer() {
 
 Review comment:
   Separated consistency checking into a separate method.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to