artemmalykh commented on a change in pull request #5767: [ML] IGNITE-10573: 
Consistent API for Ensemble training
URL: https://github.com/apache/ignite/pull/5767#discussion_r247890104
 
 

 ##########
 File path: 
modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedDatasetTrainer.java
 ##########
 @@ -322,59 +262,63 @@ public StackedDatasetTrainer() {
         if (aggregatingInputMerger == null)
             throw new IllegalStateException("Binary operator used to convert 
outputs of submodels is not specified");
 
-        List<IgniteSupplier<IgniteModel<IS, IA>>> mdlSuppliers = 
taskSupplier.apply(submodelsTrainers);
+        List<DatasetTrainer<IgniteModel<IS, IA>, L>> subs = new ArrayList<>();
+        if (submodelInput2AggregatingInputConverter != null) {
+            DatasetTrainer<IgniteModel<IS, IS>, L> id = 
DatasetTrainer.identityTrainer();
+            DatasetTrainer<IgniteModel<IS, IA>, L> mappedId = 
CompositionUtils.unsafeCoerce(
+                
AdaptableDatasetTrainer.of(id).afterTrainedModel(submodelInput2AggregatingInputConverter));
+            subs.add(mappedId);
+        }
 
-        List<IgniteModel<IS, IA>> subMdls = 
environment.parallelismStrategy().submit(mdlSuppliers).stream()
-            .map(Promise::unsafeGet)
-            .collect(Collectors.toList());
+        subs.addAll(submodelsTrainers);
 
-        // Add new columns consisting in submodels output in features.
-        IgniteBiFunction<K, V, Vector> augmentedExtractor = 
getFeatureExtractorForAggregator(featureExtractor,
-            subMdls,
-            submodelInput2AggregatingInputConverter,
+        TrainersParallelComposition<IS, IA, L> composition = new 
TrainersParallelComposition<>(subs);
+
+        IgniteBiFunction<List<IgniteModel<IS, IA>>, Vector, Vector> 
featureMapper = getFeatureExtractorForAggregator(
             submodelOutput2VectorConverter,
             vector2SubmodelInputConverter);
 
-        AM aggregator = aggregatorProcessor.apply(aggregatorTrainer, 
augmentedExtractor);
+        return AdaptableDatasetTrainer
+            .of(composition)
+            .afterTrainedModel(lst -> 
lst.stream().reduce(aggregatingInputMerger).get())
+            .andThen(aggregatorTrainer, model -> new DatasetMapping<L, L>() {
+                @Override public Vector mapFeatures(Vector v) {
+                    List<IgniteModel<IS, IA>> models = 
((ModelsParallelComposition<IS, IA>)model.innerModel()).submodels();
+                    return featureMapper.apply(models, v);
+                }
 
-        StackedModel<IS, IA, O, AM> res = new StackedModel<>(
-            aggregator,
-            aggregatingInputMerger,
-            submodelInput2AggregatingInputConverter);
+                @Override public L mapLabels(L lbl) {
+                    return lbl;
+                }
+            }).unsafeSimplyTyped();
+    }
 
-        for (IgniteModel<IS, IA> subMdl : subMdls)
-            res.addSubmodel(subMdl);
+    /** {@inheritDoc} */
+    @Override public StackedDatasetTrainer<IS, IA, O, AM, L> 
withEnvironmentBuilder(
+        LearningEnvironmentBuilder envBuilder) {
+        submodelsTrainers =
+            submodelsTrainers.stream().map(x -> 
x.withEnvironmentBuilder(envBuilder)).collect(Collectors.toList());
+        aggregatorTrainer = 
aggregatorTrainer.withEnvironmentBuilder(envBuilder);
 
-        return res;
+        return this;
     }
 
     /**
      * Get feature extractor which will be used for aggregator trainer from 
original feature extractor.
      * This method is static to make sure that we will not grab context of 
instance in serialization.
      *
-     * @param featureExtractor Original feature extractor.
-     * @param subMdls Submodels.
      * @param <K> Type of upstream keys.
      * @param <V> Type of upstream values.
      * @return Feature extractor which will be used for aggregator trainer 
from original feature extractor.
      */
-    private static <IS, IA, K, V> IgniteBiFunction<K, V, Vector> 
getFeatureExtractorForAggregator(
-        IgniteBiFunction<K, V, Vector> featureExtractor, List<IgniteModel<IS, 
IA>> subMdls,
-        IgniteFunction<IS, IA> submodelInput2AggregatingInputConverter,
+    private static <IS, IA, K, V> IgniteBiFunction<List<IgniteModel<IS, IA>>, 
Vector, Vector> getFeatureExtractorForAggregator(
 
 Review comment:
   Fixed.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to