This is an automated email from the ASF dual-hosted git repository. chief pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/ignite.git
The following commit(s) were added to refs/heads/master by this push: new 4e47c6a IGNITE-10573: [ML] Consistent API for Ensemble training 4e47c6a is described below commit 4e47c6a6531449b5e6bd3593edcd6ddacc535a13 Author: Artem Malykh <amalyk...@gmail.com> AuthorDate: Tue Jan 15 20:19:48 2019 +0300 IGNITE-10573: [ML] Consistent API for Ensemble training This closes #5767 --- .../BaggedLogisticRegressionSGDTrainerExample.java | 8 +- .../ml/tutorial/Step_10_Scaling_With_Stacking.java | 142 ++++++++++++++ .../java/org/apache/ignite/ml/IgniteModel.java | 40 +++- .../ignite/ml/clustering/kmeans/KMeansTrainer.java | 2 +- .../ignite/ml/composition/CompositionUtils.java | 85 +++++++++ .../ignite/ml/composition/DatasetMapping.java | 68 +++++++ .../ignite/ml/composition/bagging/BaggedModel.java | 57 ++++++ .../ml/composition/bagging/BaggedTrainer.java | 212 +++++++++++++++++++++ .../ignite/ml/composition/boosting/GDBTrainer.java | 2 +- .../combinators/package-info.java} | 33 +--- .../parallel/ModelsParallelComposition.java | 67 +++++++ .../parallel/TrainersParallelComposition.java | 145 ++++++++++++++ .../combinators/parallel/package-info.java} | 33 +--- .../sequential/ModelsSequentialComposition.java | 100 ++++++++++ .../sequential/TrainersSequentialComposition.java | 139 ++++++++++++++ .../combinators/sequential/package-info.java} | 33 +--- .../stacking/StackedDatasetTrainer.java | 202 +++++++++----------- .../ml/composition/stacking/StackedModel.java | 72 ++----- .../stacking/StackedVectorDatasetTrainer.java | 1 + .../apache/ignite/ml/dataset/DatasetBuilder.java | 2 +- .../ignite/ml/dataset/UpstreamTransformer.java | 18 +- .../ml/dataset/UpstreamTransformerBuilder.java | 18 +- .../ml/dataset/impl/cache/CacheBasedDataset.java | 4 +- .../impl/cache/CacheBasedDatasetBuilder.java | 6 +- .../ml/dataset/impl/cache/util/ComputeUtils.java | 27 ++- .../ml/dataset/impl/local/LocalDatasetBuilder.java | 25 +-- .../org/apache/ignite/ml/genetic/Chromosome.java | 1 - .../org/apache/ignite/ml/genetic/MutateJob.java | 3 +- .../ignite/ml/genetic/cache/GeneCacheConfig.java | 1 - .../ml/genetic/cache/PopulationCacheConfig.java | 1 - .../ml/genetic/parameter/ChromosomeCriteria.java | 1 - .../ml/knn/ann/ANNClassificationTrainer.java | 2 +- .../classification/KNNClassificationTrainer.java | 2 +- .../ml/knn/regression/KNNRegressionTrainer.java | 2 +- .../ignite/ml/multiclass/OneVsRestTrainer.java | 2 +- .../discrete/DiscreteNaiveBayesTrainer.java | 4 +- .../gaussian/GaussianNaiveBayesTrainer.java | 2 +- .../java/org/apache/ignite/ml/nn/MLPTrainer.java | 2 +- .../linear/LinearRegressionLSQRTrainer.java | 2 +- .../linear/LinearRegressionSGDTrainer.java | 2 +- .../logistic/LogisticRegressionSGDTrainer.java | 2 +- .../ml/svm/SVMLinearClassificationTrainer.java | 2 +- .../ml/trainers/AdaptableDatasetTrainer.java | 177 +++++++++++++++-- .../apache/ignite/ml/trainers/DatasetTrainer.java | 40 +++- .../ignite/ml/trainers/TrainerTransformers.java | 59 +----- .../transformers/BaggingUpstreamTransformer.java | 11 +- .../org/apache/ignite/ml/tree/DecisionTree.java | 2 +- .../ml/tree/randomforest/RandomForestTrainer.java | 2 +- .../ml/util/generators/DataStreamGenerator.java | 2 +- .../ml/util/generators/DatasetBuilderAdapter.java | 2 +- .../test/java/org/apache/ignite/ml/TestUtils.java | 2 +- .../apache/ignite/ml/composition/BaggingTest.java | 31 +-- .../ml/environment/LearningEnvironmentTest.java | 2 +- .../apache/ignite/ml/trainers/StackingTest.java | 169 ++++++++++++++++ .../util/generators/DataStreamGeneratorTest.java | 12 +- 55 files changed, 1635 insertions(+), 448 deletions(-) diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java index 58f739d7..c9b10b1 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java @@ -22,7 +22,8 @@ import java.util.Arrays; import org.apache.ignite.Ignite; import org.apache.ignite.IgniteCache; import org.apache.ignite.Ignition; -import org.apache.ignite.ml.composition.ModelsComposition; +import org.apache.ignite.ml.composition.bagging.BaggedModel; +import org.apache.ignite.ml.composition.bagging.BaggedTrainer; import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.nn.UpdatesStrategy; @@ -31,7 +32,6 @@ import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalcula import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer; import org.apache.ignite.ml.selection.cv.CrossValidation; import org.apache.ignite.ml.selection.scoring.metric.Accuracy; -import org.apache.ignite.ml.trainers.DatasetTrainer; import org.apache.ignite.ml.trainers.TrainerTransformers; import org.apache.ignite.ml.util.MLSandboxDatasets; import org.apache.ignite.ml.util.SandboxMLCache; @@ -75,7 +75,7 @@ public class BaggedLogisticRegressionSGDTrainerExample { System.out.println(">>> Perform the training to get the model."); - DatasetTrainer< ModelsComposition, Double> baggedTrainer = TrainerTransformers.makeBagged( + BaggedTrainer<Double> baggedTrainer = TrainerTransformers.makeBagged( trainer, 10, 0.6, @@ -85,7 +85,7 @@ public class BaggedLogisticRegressionSGDTrainerExample { System.out.println(">>> Perform evaluation of the model."); - double[] score = new CrossValidation<ModelsComposition, Double, Integer, Vector>().score( + double[] score = new CrossValidation<BaggedModel, Double, Integer, Vector>().score( baggedTrainer, new Accuracy<>(), ignite, diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_10_Scaling_With_Stacking.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_10_Scaling_With_Stacking.java new file mode 100644 index 0000000..ec64764 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_10_Scaling_With_Stacking.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.tutorial; + +import java.io.FileNotFoundException; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.ml.composition.stacking.StackedModel; +import org.apache.ignite.ml.composition.stacking.StackedVectorDatasetTrainer; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.nn.UpdatesStrategy; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; +import org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer; +import org.apache.ignite.ml.preprocessing.encoding.EncoderType; +import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer; +import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer; +import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer; +import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel; +import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer; +import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator; +import org.apache.ignite.ml.selection.scoring.metric.Accuracy; +import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; + +/** + * {@link MinMaxScalerTrainer} and {@link NormalizationTrainer} are used in this example due to different values + * distribution in columns and rows. + * <p> + * Code in this example launches Ignite grid and fills the cache with test data (based on Titanic passengers data).</p> + * <p> + * After that it defines preprocessors that extract features from an upstream data and perform other desired changes + * over the extracted data, including the scaling.</p> + * <p> + * Then, it trains the model based on the processed data using decision tree classification.</p> + * <p> + * Finally, this example uses {@link BinaryClassificationEvaluator} functionality to compute metrics from predictions.</p> + */ +public class Step_10_Scaling_With_Stacking { + /** Run example. */ + public static void main(String[] args) { + System.out.println(); + System.out.println(">>> Tutorial step 5 (scaling) example started."); + + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + try { + IgniteCache<Integer, Object[]> dataCache = TitanicUtils.readPassengers(ignite); + + // Defines first preprocessor that extracts features from an upstream data. + // Extracts "pclass", "sibsp", "parch", "sex", "embarked", "age", "fare". + IgniteBiFunction<Integer, Object[], Object[]> featureExtractor + = (k, v) -> new Object[] {v[0], v[3], v[4], v[5], v[6], v[8], v[10]}; + + IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double)v[1]; + + IgniteBiFunction<Integer, Object[], Vector> strEncoderPreprocessor = new EncoderTrainer<Integer, Object[]>() + .withEncoderType(EncoderType.STRING_ENCODER) + .withEncodedFeature(1) + .withEncodedFeature(6) // <--- Changed index here. + .fit(ignite, + dataCache, + featureExtractor + ); + + IgniteBiFunction<Integer, Object[], Vector> imputingPreprocessor = new ImputerTrainer<Integer, Object[]>() + .fit(ignite, + dataCache, + strEncoderPreprocessor + ); + + IgniteBiFunction<Integer, Object[], Vector> minMaxScalerPreprocessor = new MinMaxScalerTrainer<Integer, Object[]>() + .fit( + ignite, + dataCache, + imputingPreprocessor + ); + + IgniteBiFunction<Integer, Object[], Vector> normalizationPreprocessor = new NormalizationTrainer<Integer, Object[]>() + .withP(1) + .fit( + ignite, + dataCache, + minMaxScalerPreprocessor + ); + + DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(5, 0); + DecisionTreeClassificationTrainer trainer1 = new DecisionTreeClassificationTrainer(3, 0); + DecisionTreeClassificationTrainer trainer2 = new DecisionTreeClassificationTrainer(4, 0); + + LogisticRegressionSGDTrainer aggregator = new LogisticRegressionSGDTrainer() + .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2), + SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg)); + + StackedModel<Vector, Vector, Double, LogisticRegressionModel> mdl = + new StackedVectorDatasetTrainer<>(aggregator) + .addTrainerWithDoubleOutput(trainer) + .addTrainerWithDoubleOutput(trainer1) + .addTrainerWithDoubleOutput(trainer2) + .fit( + ignite, + dataCache, + normalizationPreprocessor, + lbExtractor + ); + + System.out.println("\n>>> Trained model: " + mdl); + + double accuracy = BinaryClassificationEvaluator.evaluate( + dataCache, + mdl, + normalizationPreprocessor, + lbExtractor, + new Accuracy<>() + ); + + System.out.println("\n>>> Accuracy " + accuracy); + System.out.println("\n>>> Test Error " + (1 - accuracy)); + + System.out.println(">>> Tutorial step 5 (scaling) example completed."); + } + catch (FileNotFoundException e) { + e.printStackTrace(); + } + } + } +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/IgniteModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/IgniteModel.java index a1165e1..6268d06 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/IgniteModel.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/IgniteModel.java @@ -20,8 +20,10 @@ package org.apache.ignite.ml; import java.io.Serializable; import java.util.function.BiFunction; import org.apache.ignite.ml.inference.Model; +import org.apache.ignite.ml.math.functions.IgniteFunction; /** Basic interface for all models. */ +@FunctionalInterface public interface IgniteModel<T, V> extends Model<T, V>, Serializable { /** * Combines this model with other model via specified combiner @@ -37,12 +39,46 @@ public interface IgniteModel<T, V> extends Model<T, V>, Serializable { /** * Get a composition model of the form {@code x -> after(mdl(x))}. * - * @param after Function to apply after this model. + * @param after Model to apply after this model. * @param <V1> Type of input of function applied before this model. * @return Composition model of the form {@code x -> after(mdl(x))}. */ public default <V1> IgniteModel<T, V1> andThen(IgniteModel<V, V1> after) { - return t -> after.predict(predict(t)); + IgniteModel<T, V> self = this; + return new IgniteModel<T, V1>() { + /** {@inheritDoc} */ + @Override public V1 predict(T input) { + return after.predict(self.predict(input)); + } + + /** {@inheritDoc} */ + @Override public void close() { + self.close(); + after.close(); + } + }; + } + + /** + * Get a composition model of the form {@code x -> after(mdl(x))}. + * + * @param after Function to apply after this model. + * @param <V1> Type of input of function applied before this model. + * @return Composition model of the form {@code x -> after(mdl(x))}. + */ + public default <V1> IgniteModel<T, V1> andThen(IgniteFunction<V, V1> after) { + IgniteModel<T, V> self = this; + return new IgniteModel<T, V1>() { + /** {@inheritDoc} */ + @Override public V1 predict(T input) { + return after.apply(self.predict(input)); + } + + /** {@inheritDoc} */ + @Override public void close() { + self.close(); + } + }; } /** diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java index 88ea9b9..3206b5f 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java @@ -149,7 +149,7 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> { } /** {@inheritDoc} */ - @Override protected boolean checkState(KMeansModel mdl) { + @Override public boolean isUpdateable(KMeansModel mdl) { return mdl.getCenters().length == k && mdl.distanceMeasure().equals(distance); } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/CompositionUtils.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/CompositionUtils.java new file mode 100644 index 0000000..5a2f40a --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/CompositionUtils.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.composition; + +import org.apache.ignite.ml.IgniteModel; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.trainers.DatasetTrainer; + +/** + * Various utility functions for trainers composition. + */ +public class CompositionUtils { + /** + * Perform blurring of model type of given trainer to {@code IgniteModel<I, O>}, where I, O are input and output + * types of original model. + * + * @param trainer Trainer to coerce. + * @param <I> Type of input of model produced by coerced trainer. + * @param <O> Type of output of model produced by coerced trainer. + * @param <M> Type of model produced by coerced trainer. + * @param <L> Type of labels. + * @return Trainer coerced to {@code DatasetTrainer<IgniteModel<I, O>, L>}. + */ + public static <I, O, M extends IgniteModel<I, O>, L> DatasetTrainer<IgniteModel<I, O>, L> unsafeCoerce( + DatasetTrainer<? extends M, L> trainer) { + return new DatasetTrainer<IgniteModel<I, O>, L>() { + /** {@inheritDoc} */ + @Override public <K, V> IgniteModel<I, O> fit(DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) { + return trainer.fit(datasetBuilder, featureExtractor, lbExtractor); + } + + /** {@inheritDoc} */ + @Override public <K, V> IgniteModel<I, O> update(IgniteModel<I, O> mdl, DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) { + DatasetTrainer<IgniteModel<I, O>, L> trainer1 = (DatasetTrainer<IgniteModel<I, O>, L>)trainer; + return trainer1.update(mdl, datasetBuilder, featureExtractor, lbExtractor); + } + + /** + * This method is never called, instead of constructing logic of update from + * {@link DatasetTrainer#isUpdateable} and + * {@link DatasetTrainer#updateModel} + * in this class we explicitly override update method. + * + * @param mdl Model. + * @return True if current critical for training parameters correspond to parameters from last training. + */ + @Override public boolean isUpdateable(IgniteModel<I, O> mdl) { + throw new IllegalStateException(); + } + + /** + * This method is never called, instead of constructing logic of update from + * {@link DatasetTrainer#isUpdateable(IgniteModel)} and + * {@link DatasetTrainer#updateModel(IgniteModel, DatasetBuilder, IgniteBiFunction, IgniteBiFunction)} + * in this class we explicitly override update method. + * + * @param mdl Model. + * @return Updated model. + */ + @Override protected <K, V> IgniteModel<I, O> updateModel(IgniteModel<I, O> mdl, DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) { + throw new IllegalStateException(); + } + }; + } +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/DatasetMapping.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/DatasetMapping.java new file mode 100644 index 0000000..9547d54 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/DatasetMapping.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.composition; + +import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; + +/** + * This class represents dataset mapping. This is just a tuple of two mappings: one for features and one for labels. + * + * @param <L1> Type of labels before mapping. + * @param <L2> Type of labels after mapping. + */ +public interface DatasetMapping<L1, L2> { + /** + * Method used to map feature vectors. + * + * @param v Feature vector. + * @return Mapped feature vector. + */ + public default Vector mapFeatures(Vector v) { + return v; + } + + /** + * Method used to map labels. + * + * @param lbl Label. + * @return Mapped label. + */ + public L2 mapLabels(L1 lbl); + + /** + * Dataset mapping which maps features, leaving labels unaffected. + * + * @param mapper Function used to map features. + * @param <L> Type of labels. + * @return Dataset mapping which maps features, leaving labels unaffected. + */ + public static <L> DatasetMapping<L, L> mappingFeatures(IgniteFunction<Vector, Vector> mapper) { + return new DatasetMapping<L, L>() { + /** {@inheritDoc} */ + @Override public Vector mapFeatures(Vector v) { + return mapper.apply(v); + } + + /** {@inheritDoc} */ + @Override public L mapLabels(L lbl) { + return lbl; + } + }; + } +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/bagging/BaggedModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/bagging/BaggedModel.java new file mode 100644 index 0000000..c59a634 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/bagging/BaggedModel.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.composition.bagging; + +import org.apache.ignite.ml.IgniteModel; +import org.apache.ignite.ml.math.primitives.vector.Vector; + +/** + * This class represents model produced by {@link BaggedTrainer}. + * It is a wrapper around inner representation of model produced by {@link BaggedTrainer}. + */ +public class BaggedModel implements IgniteModel<Vector, Double> { + /** Inner representation of model produced by {@link BaggedTrainer}. */ + private IgniteModel<Vector, Double> mdl; + + /** + * Construct instance of this class given specified model. + * @param mdl Model to wrap. + */ + BaggedModel(IgniteModel<Vector, Double> mdl) { + this.mdl = mdl; + } + + /** + * Get wrapped model. + * + * @return Wrapped model. + */ + IgniteModel<Vector, Double> model() { + return mdl; + } + + /** {@inheritDoc} */ + @Override public Double predict(Vector i) { + return mdl.predict(i); + } + + /** {@inheritDoc} */ + @Override public void close() { + mdl.close(); + } +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/bagging/BaggedTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/bagging/BaggedTrainer.java new file mode 100644 index 0000000..5b0962a --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/bagging/BaggedTrainer.java @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.composition.bagging; + +import java.util.Collections; +import java.util.List; +import java.util.Random; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.apache.ignite.ml.IgniteModel; +import org.apache.ignite.ml.composition.CompositionUtils; +import org.apache.ignite.ml.composition.combinators.parallel.TrainersParallelComposition; +import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.apache.ignite.ml.trainers.AdaptableDatasetTrainer; +import org.apache.ignite.ml.trainers.DatasetTrainer; +import org.apache.ignite.ml.trainers.transformers.BaggingUpstreamTransformer; +import org.apache.ignite.ml.util.Utils; + +/** + * Trainer encapsulating logic of bootstrap aggregating (bagging). + * This trainer accepts some other trainer and returns bagged version of it. + * Resulting model consists of submodels results of which are aggregated by a specified aggregator. + * <p>Bagging is done + * on both samples and features (<a href="https://en.wikipedia.org/wiki/Bootstrap_aggregating"></a>Samples bagging</a>, + * <a href="https://en.wikipedia.org/wiki/Random_subspace_method"></a>Features bagging</a>).</p> + * + * @param <L> Type of labels. + */ +public class BaggedTrainer<L> extends + DatasetTrainer<BaggedModel, L> { + /** Trainer for which bagged version is created. */ + private final DatasetTrainer<? extends IgniteModel, L> tr; + + /** Aggregator of submodels results. */ + private final PredictionsAggregator aggregator; + + /** Count of submodels in the ensemble. */ + private final int ensembleSize; + + /** Ratio determining which part of dataset will be taken as subsample for each submodel training. */ + private final double subsampleRatio; + + /** Dimensionality of feature vectors. */ + private final int featuresVectorSize; + + /** Dimension of subspace on which all samples from subsample are projected. */ + private final int featureSubspaceDim; + + /** + * Construct instance of this class with given parameters. + * + * @param tr Trainer for making bagged. + * @param aggregator Aggregator of models. + * @param ensembleSize Size of ensemble. + * @param subsampleRatio Ratio (subsample size) / (initial dataset size). + * @param featuresVectorSize Dimensionality of feature vector. + * @param featureSubspaceDim Dimensionality of feature subspace. + */ + public BaggedTrainer(DatasetTrainer<? extends IgniteModel, L> tr, + PredictionsAggregator aggregator, int ensembleSize, double subsampleRatio, int featuresVectorSize, + int featureSubspaceDim) { + this.tr = tr; + this.aggregator = aggregator; + this.ensembleSize = ensembleSize; + this.subsampleRatio = subsampleRatio; + this.featuresVectorSize = featuresVectorSize; + this.featureSubspaceDim = featureSubspaceDim; + } + + /** + * Create trainer bagged trainer. + * + * @return Bagged trainer. + */ + private DatasetTrainer<IgniteModel<Vector, Double>, L> getTrainer() { + List<int[]> mappings = (featuresVectorSize > 0 && featureSubspaceDim != featuresVectorSize) ? + IntStream.range(0, ensembleSize).mapToObj( + modelIdx -> getMapping( + featuresVectorSize, + featureSubspaceDim, + environment.randomNumbersGenerator().nextLong())) + .collect(Collectors.toList()) : + null; + + List<DatasetTrainer<? extends IgniteModel, L>> trainers = Collections.nCopies(ensembleSize, tr); + + // Generate a list of trainers each each copy of original trainer but on its own subspace and subsample. + List<DatasetTrainer<IgniteModel<Vector, Double>, L>> subspaceTrainers = IntStream.range(0, ensembleSize) + .mapToObj(mdlIdx -> { + AdaptableDatasetTrainer<Vector, Double, Vector, Double, ? extends IgniteModel, L> tr = + AdaptableDatasetTrainer.of(trainers.get(mdlIdx)); + if (mappings != null) { + tr = tr.afterFeatureExtractor(featureValues -> { + int[] mapping = mappings.get(mdlIdx); + double[] newFeaturesValues = new double[mapping.length]; + for (int j = 0; j < mapping.length; j++) + newFeaturesValues[j] = featureValues.get(mapping[j]); + + return VectorUtils.of(newFeaturesValues); + }).beforeTrainedModel(getProjector(mappings.get(mdlIdx))); + } + return tr + .withUpstreamTransformerBuilder(BaggingUpstreamTransformer.builder(subsampleRatio, mdlIdx)) + .withEnvironmentBuilder(envBuilder); + }) + .map(CompositionUtils::unsafeCoerce) + .collect(Collectors.toList()); + + AdaptableDatasetTrainer<Vector, Double, Vector, List<Double>, IgniteModel<Vector, List<Double>>, L> finalTrainer = AdaptableDatasetTrainer.of( + new TrainersParallelComposition<>( + subspaceTrainers)).afterTrainedModel(l -> aggregator.apply(l.stream().mapToDouble(Double::valueOf).toArray())); + + return CompositionUtils.unsafeCoerce(finalTrainer); + } + + /** + * Get mapping R^featuresVectorSize -> R^maximumFeaturesCntPerMdl. + * + * @param featuresVectorSize Features vector size (Dimension of initial space). + * @param maximumFeaturesCntPerMdl Dimension of target space. + * @param seed Seed. + * @return Mapping R^featuresVectorSize -> R^maximumFeaturesCntPerMdl. + */ + public static int[] getMapping(int featuresVectorSize, int maximumFeaturesCntPerMdl, long seed) { + return Utils.selectKDistinct(featuresVectorSize, maximumFeaturesCntPerMdl, new Random(seed)); + } + + /** + * Get projector from index mapping. + * + * @param mapping Index mapping. + * @return Projector. + */ + public static IgniteFunction<Vector, Vector> getProjector(int[] mapping) { + return v -> { + Vector res = VectorUtils.zeroes(mapping.length); + for (int i = 0; i < mapping.length; i++) + res.set(i, v.get(mapping[i])); + + return res; + }; + } + + /** {@inheritDoc} */ + @Override public <K, V> BaggedModel fit(DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) { + IgniteModel<Vector, Double> fit = getTrainer().fit(datasetBuilder, featureExtractor, lbExtractor); + return new BaggedModel(fit); + } + + /** {@inheritDoc} */ + @Override public <K, V> BaggedModel update(BaggedModel mdl, DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) { + IgniteModel<Vector, Double> updated = getTrainer().update(mdl.model(), datasetBuilder, featureExtractor, lbExtractor); + return new BaggedModel(updated); + } + + /** {@inheritDoc} */ + @Override public BaggedTrainer<L> withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) { + return (BaggedTrainer<L>)super.withEnvironmentBuilder(envBuilder); + } + + /** + * This method is never called, instead of constructing logic of update from + * {@link DatasetTrainer#isUpdateable} and + * {@link DatasetTrainer#updateModel} + * in this class we explicitly override update method. + * + * @param mdl Model. + * @return True if current critical for training parameters correspond to parameters from last training. + */ + @Override public boolean isUpdateable(BaggedModel mdl) { + // Should be never called. + throw new IllegalStateException(); + } + + /** + * This method is never called, instead of constructing logic of update from + * {@link DatasetTrainer#isUpdateable} and + * {@link DatasetTrainer#updateModel} + * in this class we explicitly override update method. + * + * @param mdl Model. + * @return Updated model. + */ + @Override protected <K, V> BaggedModel updateModel(BaggedModel mdl, DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) { + // Should be never called. + throw new IllegalStateException(); + } +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java index 35502ab..7d88ddb 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java @@ -141,7 +141,7 @@ public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Doubl } /** {@inheritDoc} */ - @Override protected boolean checkState(ModelsComposition mdl) { + @Override public boolean isUpdateable(ModelsComposition mdl) { return mdl instanceof GDBModel; } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/genetic/parameter/ChromosomeCriteria.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/package-info.java similarity index 52% copy from modules/ml/src/main/java/org/apache/ignite/ml/genetic/parameter/ChromosomeCriteria.java copy to modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/package-info.java index bc4b839..b39067d 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/genetic/parameter/ChromosomeCriteria.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/package-info.java @@ -15,35 +15,8 @@ * limitations under the License. */ -package org.apache.ignite.ml.genetic.parameter; - -import java.util.ArrayList; - -import java.util.List; - /** - * Responsible for describing the characteristics of an individual Chromosome. + * <!-- Package description. --> + * Contains various combinators of trainers and models. */ -public class ChromosomeCriteria { - /** List of criteria for a Chromosome */ - private List<String> criteria = new ArrayList<String>(); - - /** - * Retrieve criteria - * - * @return List of strings - */ - public List<String> getCriteria() { - return criteria; - } - - /** - * Set criteria - * - * @param criteria List of criteria to be applied for a Chromosome ;Use format "name=value", ie: "coinType=QUARTER" - */ - public void setCriteria(List<String> criteria) { - this.criteria = criteria; - } - -} +package org.apache.ignite.ml.composition.combinators; diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/parallel/ModelsParallelComposition.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/parallel/ModelsParallelComposition.java new file mode 100644 index 0000000..601b639 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/parallel/ModelsParallelComposition.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.composition.combinators.parallel; + +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.ignite.ml.IgniteModel; + +/** + * Parallel composition of models. + * Parallel composition of models is a model which contains a list of submodels with same input and output types. + * Result of prediction in such model is a list of predictions of each of submodels. + * + * @param <I> Type of submodel input. + * @param <O> Type of submodel output. + */ +public class ModelsParallelComposition<I, O> implements IgniteModel<I, List<O>> { + /** List of submodels. */ + private final List<IgniteModel<I, O>> submodels; + + /** + * Construc an instance of this class from list of submodels. + * + * @param submodels List of submodels constituting this model. + */ + public ModelsParallelComposition(List<IgniteModel<I, O>> submodels) { + this.submodels = submodels; + } + + /** {@inheritDoc} */ + @Override public List<O> predict(I i) { + return submodels + .stream() + .map(m -> m.predict(i)) + .collect(Collectors.toList()); + } + + /** + * List of submodels constituting this model. + * + * @return List of submodels constituting this model. + */ + public List<IgniteModel<I, O>> submodels() { + return Collections.unmodifiableList(submodels); + } + + /** {@inheritDoc} */ + @Override public void close() { + submodels.forEach(IgniteModel::close); + } +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/parallel/TrainersParallelComposition.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/parallel/TrainersParallelComposition.java new file mode 100644 index 0000000..411ed17 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/parallel/TrainersParallelComposition.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.composition.combinators.parallel; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.ignite.ml.IgniteModel; +import org.apache.ignite.ml.composition.CompositionUtils; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.environment.parallelism.Promise; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.functions.IgniteSupplier; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.trainers.DatasetTrainer; + +/** + * This class represents a parallel composition of trainers. + * Parallel composition of trainers is a trainer itself which trains a list of trainers with same + * input and output. Training is done in following manner: + * <pre> + * 1. Independently train all trainers on the same dataset and get a list of models. + * 2. Combine models produced in step (1) into a {@link ModelsParallelComposition}. + * </pre> + * Updating is made in a similar fashion. + * Like in other trainers combinators we avoid to include type of contained trainers in type parameters + * because otherwise compositions of compositions would have a relatively complex generic type which will + * reduce readability. + * + * @param <I> Type of trainers inputs. + * @param <O> Type of trainers outputs. + * @param <L> Type of dataset labels. + */ +public class TrainersParallelComposition<I, O, L> extends DatasetTrainer<IgniteModel<I, List<O>>, L> { + /** List of trainers. */ + private final List<DatasetTrainer<IgniteModel<I, O>, L>> trainers; + + /** + * Construct an instance of this class from a list of trainers. + * + * @param trainers Trainers. + * @param <M> Type of model. + * @param <T> Type of trainer. + */ + public <M extends IgniteModel<I, O>, T extends DatasetTrainer<? extends IgniteModel<I, O>, L>> TrainersParallelComposition( + List<T> trainers) { + this.trainers = trainers.stream().map(CompositionUtils::unsafeCoerce).collect(Collectors.toList()); + } + + /** + * Create parallel composition of trainers contained in a given list. + * + * @param trainers List of trainers. + * @param <I> Type of input of model priduced by trainers. + * @param <O> Type of output of model priduced by trainers. + * @param <M> Type of model priduced by trainers. + * @param <T> Type of trainers. + * @param <L> Type of input of labels. + * @return Parallel composition of trainers contained in a given list. + */ + public static <I, O, M extends IgniteModel<I, O>, T extends DatasetTrainer<M, L>, L> TrainersParallelComposition<I, O, L> of(List<T> trainers) { + List<DatasetTrainer<IgniteModel<I, O>, L>> trs = + trainers.stream().map(CompositionUtils::unsafeCoerce).collect(Collectors.toList()); + + return new TrainersParallelComposition<>(trs); + } + + /** {@inheritDoc} */ + @Override public <K, V> IgniteModel<I, List<O>> fit(DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) { + List<IgniteSupplier<IgniteModel<I, O>>> tasks = trainers.stream() + .map(tr -> (IgniteSupplier<IgniteModel<I, O>>)(() -> tr.fit(datasetBuilder, featureExtractor, lbExtractor))) + .collect(Collectors.toList()); + + List<IgniteModel<I, O>> mdls = environment.parallelismStrategy().submit(tasks).stream() + .map(Promise::unsafeGet) + .collect(Collectors.toList()); + + return new ModelsParallelComposition<>(mdls); + } + + /** {@inheritDoc} */ + @Override public <K, V> IgniteModel<I, List<O>> update(IgniteModel<I, List<O>> mdl, DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) { + ModelsParallelComposition<I, O> typedMdl = (ModelsParallelComposition<I, O>)mdl; + + assert typedMdl.submodels().size() == trainers.size(); + List<IgniteSupplier<IgniteModel<I, O>>> tasks = new ArrayList<>(); + + for (int i = 0; i < trainers.size(); i++) { + int j = i; + tasks.add(() -> trainers.get(j).update(typedMdl.submodels().get(j), datasetBuilder, featureExtractor, lbExtractor)); + } + + List<IgniteModel<I, O>> mdls = environment.parallelismStrategy().submit(tasks).stream() + .map(Promise::unsafeGet) + .collect(Collectors.toList()); + + return new ModelsParallelComposition<>(mdls); + } + + /** + * This method is never called, instead of constructing logic of update from + * {@link DatasetTrainer#isUpdateable} and + * {@link DatasetTrainer#updateModel} + * in this class we explicitly override update method. + * + * @param mdl Model. + * @return True if current critical for training parameters correspond to parameters from last training. + */ + @Override public boolean isUpdateable(IgniteModel<I, List<O>> mdl) { + // Never called. + throw new IllegalStateException(); + } + + /** + * This method is never called, instead of constructing logic of update from + * {@link DatasetTrainer#isUpdateable(IgniteModel)} and + * {@link DatasetTrainer#updateModel(IgniteModel, DatasetBuilder, IgniteBiFunction, IgniteBiFunction)} + * in this class we explicitly override update method. + * + * @param mdl Model. + * @return Updated model. + */ + @Override protected <K, V> IgniteModel<I, List<O>> updateModel(IgniteModel<I, List<O>> mdl, DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) { + // Never called. + throw new IllegalStateException(); + } +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/genetic/parameter/ChromosomeCriteria.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/parallel/package-info.java similarity index 52% copy from modules/ml/src/main/java/org/apache/ignite/ml/genetic/parameter/ChromosomeCriteria.java copy to modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/parallel/package-info.java index bc4b839..cb24250 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/genetic/parameter/ChromosomeCriteria.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/parallel/package-info.java @@ -15,35 +15,8 @@ * limitations under the License. */ -package org.apache.ignite.ml.genetic.parameter; - -import java.util.ArrayList; - -import java.util.List; - /** - * Responsible for describing the characteristics of an individual Chromosome. + * <!-- Package description. --> + * Contains parallel combinators of trainers and models. */ -public class ChromosomeCriteria { - /** List of criteria for a Chromosome */ - private List<String> criteria = new ArrayList<String>(); - - /** - * Retrieve criteria - * - * @return List of strings - */ - public List<String> getCriteria() { - return criteria; - } - - /** - * Set criteria - * - * @param criteria List of criteria to be applied for a Chromosome ;Use format "name=value", ie: "coinType=QUARTER" - */ - public void setCriteria(List<String> criteria) { - this.criteria = criteria; - } - -} +package org.apache.ignite.ml.composition.combinators.parallel; diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/sequential/ModelsSequentialComposition.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/sequential/ModelsSequentialComposition.java new file mode 100644 index 0000000..78e2846 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/sequential/ModelsSequentialComposition.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.composition.combinators.sequential; + +import java.util.List; +import org.apache.ignite.ml.IgniteModel; +import org.apache.ignite.ml.math.functions.IgniteFunction; + +/** + * Sequential composition of models. + * Sequential composition is a model consisting of two models {@code mdl1 :: I -> O1, mdl2 :: O1 -> O2} with prediction + * corresponding to application of composition {@code mdl1 `andThen` mdl2} to input. + * + * @param <I> Type of input of the first model. + * @param <O1> Type of output of the first model (and input of second). + * @param <O2> Type of output of the second model. + */ +public class ModelsSequentialComposition<I, O1, O2> implements IgniteModel<I, O2> { + /** First model. */ + private IgniteModel<I, O1> mdl1; + + /** Second model. */ + private IgniteModel<O1, O2> mdl2; + + /** + * Get sequential composition of submodels with same type. + * + * @param lst List of submodels. + * @param output2Input Function for conversion output to input. + * @param <I> Type of input of submodel. + * @param <O> Type of output of submodel. + * @return Sequential composition of submodels with same type. + */ + public static <I, O> ModelsSequentialComposition<I, I, O> ofSame(List<? extends IgniteModel<I, O>> lst, + IgniteFunction<O, I> output2Input) { + assert lst.size() >= 2; + + if (lst.size() == 2) + return new ModelsSequentialComposition<>(lst.get(0).andThen(output2Input), + lst.get(1)); + + return new ModelsSequentialComposition<>(lst.get(0).andThen(output2Input), + ofSame(lst.subList(1, lst.size()), output2Input)); + } + + /** + * Construct instance of this class from two given models. + * + * @param mdl1 First model. + * @param mdl2 Second model. + */ + public ModelsSequentialComposition(IgniteModel<I, O1> mdl1, IgniteModel<O1, O2> mdl2) { + this.mdl1 = mdl1; + this.mdl2 = mdl2; + } + + /** + * Get first model. + * + * @return First model. + */ + public IgniteModel<I, O1> firstModel() { + return mdl1; + } + + /** + * Get second model. + * + * @return Second model. + */ + public IgniteModel<O1, O2> secondModel() { + return mdl2; + } + + /** {@inheritDoc} */ + @Override public O2 predict(I i1) { + return mdl1.andThen(mdl2).predict(i1); + } + + /** {@inheritDoc} */ + @Override public void close() { + mdl1.close(); + mdl2.close(); + } +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/sequential/TrainersSequentialComposition.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/sequential/TrainersSequentialComposition.java new file mode 100644 index 0000000..d36ff9c --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/sequential/TrainersSequentialComposition.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.composition.combinators.sequential; + +import org.apache.ignite.ml.IgniteModel; +import org.apache.ignite.ml.composition.CompositionUtils; +import org.apache.ignite.ml.composition.DatasetMapping; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.trainers.DatasetTrainer; + +/** + * Sequential composition of trainers. + * Sequential composition of trainers is itself trainer which produces {@link ModelsSequentialComposition}. + * Training is done in following fashion: + * <pre> + * 1. First trainer is trained and `mdl1` is produced. + * 2. From `mdl1` {@link DatasetMapping} is constructed. This mapping `dsM` encapsulates dependency between first + * training result and second trainer. + * 3. Second trainer is trained using dataset aquired from application `dsM` to original dataset; `mdl2` is produced. + * 4. `mdl1` and `mdl2` are composed into {@link ModelsSequentialComposition}. + * </pre> + * + * @param <I> Type of input of model produced by first trainer. + * @param <O1> Type of output of model produced by first trainer. + * @param <O2> Type of output of model produced by second trainer. + * @param <L> Type of labels. + */ +public class TrainersSequentialComposition<I, O1, O2, L> extends DatasetTrainer<ModelsSequentialComposition<I, O1, O2>, L> { + /** First trainer. */ + private DatasetTrainer<IgniteModel<I, O1>, L> tr1; + + /** Second trainer. */ + private DatasetTrainer<IgniteModel<O1, O2>, L> tr2; + + /** Dataset mapping. */ + private IgniteFunction<? super IgniteModel<I, O1>, DatasetMapping<L, L>> datasetMapping; + + /** + * Construct sequential composition of given two trainers. + * + * @param tr1 First trainer. + * @param tr2 Second trainer. + * @param datasetMapping Dataset mapping. + */ + public TrainersSequentialComposition(DatasetTrainer<? extends IgniteModel<I, O1>, L> tr1, + DatasetTrainer<? extends IgniteModel<O1, O2>, L> tr2, + IgniteFunction<? super IgniteModel<I, O1>, DatasetMapping<L, L>> datasetMapping) { + this.tr1 = CompositionUtils.unsafeCoerce(tr1); + this.tr2 = CompositionUtils.unsafeCoerce(tr2); + this.datasetMapping = datasetMapping; + } + + /** {@inheritDoc} */ + @Override public <K, V> ModelsSequentialComposition<I, O1, O2> fit(DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) { + + IgniteModel<I, O1> mdl1 = tr1.fit(datasetBuilder, featureExtractor, lbExtractor); + DatasetMapping<L, L> mapping = datasetMapping.apply(mdl1); + + IgniteModel<O1, O2> mdl2 = tr2.fit(datasetBuilder, + featureExtractor.andThen(mapping::mapFeatures), + lbExtractor.andThen(mapping::mapLabels)); + + return new ModelsSequentialComposition<>(mdl1, mdl2); + } + + /** {@inheritDoc} */ + @Override public <K, V> ModelsSequentialComposition<I, O1, O2> update( + ModelsSequentialComposition<I, O1, O2> mdl, DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) { + + IgniteModel<I, O1> firstUpdated = tr1.update(mdl.firstModel(), datasetBuilder, featureExtractor, lbExtractor); + DatasetMapping<L, L> mapping = datasetMapping.apply(firstUpdated); + + IgniteModel<O1, O2> secondUpdated = tr2.update(mdl.secondModel(), + datasetBuilder, + featureExtractor.andThen(mapping::mapFeatures), + lbExtractor.andThen(mapping::mapLabels)); + + return new ModelsSequentialComposition<>(firstUpdated, secondUpdated); + } + + /** + * This method is never called, instead of constructing logic of update from + * {@link DatasetTrainer#isUpdateable} and + * {@link DatasetTrainer#updateModel} + * in this class we explicitly override update method. + * + * @param mdl Model. + * @return True if current critical for training parameters correspond to parameters from last training. + */ + @Override public boolean isUpdateable(ModelsSequentialComposition<I, O1, O2> mdl) { + // Never called. + throw new IllegalStateException(); + } + + /** + * This method is never called, instead of constructing logic of update from + * {@link DatasetTrainer#isUpdateable(IgniteModel)} and + * {@link DatasetTrainer#updateModel(IgniteModel, DatasetBuilder, IgniteBiFunction, IgniteBiFunction)} + * in this class we explicitly override update method. + * + * @param mdl Model. + * @return Updated model. + */ + @Override protected <K, V> ModelsSequentialComposition<I, O1, O2> updateModel( + ModelsSequentialComposition<I, O1, O2> mdl, DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) { + // Never called. + throw new IllegalStateException(); + } + + /** + * Performs coersion of this trainer to {@code DatasetTrainer<IgniteModel<I, O2>, L>}. + * + * @return Trainer coerced to {@code DatasetTrainer<IgniteModel<I, O>, L>}. + */ + public DatasetTrainer<IgniteModel<I, O2>, L> unsafeSimplyTyped() { + return CompositionUtils.unsafeCoerce(this); + } +} diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/genetic/parameter/ChromosomeCriteria.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/sequential/package-info.java similarity index 52% copy from modules/ml/src/main/java/org/apache/ignite/ml/genetic/parameter/ChromosomeCriteria.java copy to modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/sequential/package-info.java index bc4b839..02ca2df 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/genetic/parameter/ChromosomeCriteria.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/sequential/package-info.java @@ -15,35 +15,8 @@ * limitations under the License. */ -package org.apache.ignite.ml.genetic.parameter; - -import java.util.ArrayList; - -import java.util.List; - /** - * Responsible for describing the characteristics of an individual Chromosome. + * <!-- Package description. --> + * Contains sequential combinators of trainers and models. */ -public class ChromosomeCriteria { - /** List of criteria for a Chromosome */ - private List<String> criteria = new ArrayList<String>(); - - /** - * Retrieve criteria - * - * @return List of strings - */ - public List<String> getCriteria() { - return criteria; - } - - /** - * Set criteria - * - * @param criteria List of criteria to be applied for a Chromosome ;Use format "name=value", ie: "coinType=QUARTER" - */ - public void setCriteria(List<String> criteria) { - this.criteria = criteria; - } - -} +package org.apache.ignite.ml.composition.combinators.sequential; diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedDatasetTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedDatasetTrainer.java index e58107d..45fcecc 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedDatasetTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedDatasetTrainer.java @@ -21,15 +21,18 @@ import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; import org.apache.ignite.ml.IgniteModel; +import org.apache.ignite.ml.composition.CompositionUtils; +import org.apache.ignite.ml.composition.DatasetMapping; +import org.apache.ignite.ml.composition.combinators.parallel.ModelsParallelComposition; +import org.apache.ignite.ml.composition.combinators.parallel.TrainersParallelComposition; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.environment.LearningEnvironmentBuilder; -import org.apache.ignite.ml.environment.parallelism.Promise; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.functions.IgniteBinaryOperator; import org.apache.ignite.ml.math.functions.IgniteFunction; -import org.apache.ignite.ml.math.functions.IgniteSupplier; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.apache.ignite.ml.trainers.AdaptableDatasetTrainer; import org.apache.ignite.ml.trainers.DatasetTrainer; /** @@ -220,31 +223,7 @@ public class StackedDatasetTrainer<IS, IA, O, AM extends IgniteModel<IA, O>, L> // Unsafely coerce DatasetTrainer<M1, L> to DatasetTrainer<Model<IS, IA>, L>, but we fully control // usages of this unsafely coerced object, on the other hand this makes work with // submodelTrainers easier. - submodelsTrainers.add(new DatasetTrainer<IgniteModel<IS, IA>, L>() { - /** {@inheritDoc} */ - @Override public <K, V> IgniteModel<IS, IA> fit(DatasetBuilder<K, V> datasetBuilder, - IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) { - return trainer.fit(datasetBuilder, featureExtractor, lbExtractor); - } - - /** {@inheritDoc} */ - @Override public <K, V> IgniteModel<IS, IA> update(IgniteModel<IS, IA> mdl, DatasetBuilder<K, V> datasetBuilder, - IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) { - DatasetTrainer<IgniteModel<IS, IA>, L> trainer1 = (DatasetTrainer<IgniteModel<IS, IA>, L>)trainer; - return trainer1.update(mdl, datasetBuilder, featureExtractor, lbExtractor); - } - - /** {@inheritDoc} */ - @Override protected boolean checkState(IgniteModel<IS, IA> mdl) { - return true; - } - - /** {@inheritDoc} */ - @Override protected <K, V> IgniteModel<IS, IA> updateModel(IgniteModel<IS, IA> mdl, DatasetBuilder<K, V> datasetBuilder, - IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) { - return null; - } - }); + submodelsTrainers.add(CompositionUtils.unsafeCoerce(trainer)); return this; } @@ -254,62 +233,60 @@ public class StackedDatasetTrainer<IS, IA, O, AM extends IgniteModel<IA, O>, L> IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) { - return update(null, datasetBuilder, featureExtractor, lbExtractor); + return new StackedModel<>(getTrainer().fit(datasetBuilder, featureExtractor, lbExtractor)); } /** {@inheritDoc} */ @Override public <K, V> StackedModel<IS, IA, O, AM> update(StackedModel<IS, IA, O, AM> mdl, DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) { - return runOnSubmodels( - ensemble -> { - List<IgniteSupplier<IgniteModel<IS, IA>>> res = new ArrayList<>(); - for (int i = 0; i < ensemble.size(); i++) { - final int j = i; - res.add(() -> { - DatasetTrainer<IgniteModel<IS, IA>, L> trainer = ensemble.get(j); - return mdl == null ? - trainer.fit(datasetBuilder, featureExtractor, lbExtractor) : - trainer.update(mdl.submodels().get(j), datasetBuilder, featureExtractor, lbExtractor); - }); - } - return res; - }, - (at, extr) -> mdl == null ? - at.fit(datasetBuilder, extr, lbExtractor) : - at.update(mdl.aggregatorModel(), datasetBuilder, extr, lbExtractor), - featureExtractor - ); - } - /** {@inheritDoc} */ - @Override public StackedDatasetTrainer<IS, IA, O, AM, L> withEnvironmentBuilder( - LearningEnvironmentBuilder envBuilder) { - submodelsTrainers = - submodelsTrainers.stream().map(x -> x.withEnvironmentBuilder(envBuilder)).collect(Collectors.toList()); - aggregatorTrainer = aggregatorTrainer.withEnvironmentBuilder(envBuilder); - - return this; + return new StackedModel<>(getTrainer().update(mdl, datasetBuilder, featureExtractor, lbExtractor)); } /** - * <pre> - * 1. Obtain models produced by running specified tasks; - * 2. run other specified task on dataset augmented with results of models from step 2. - * </pre> + * Get the trainer for stacking. * - * @param taskSupplier Function used to generate tasks for first step. - * @param aggregatorProcessor Function used - * @param featureExtractor Feature extractor. - * @param <K> Type of keys in upstream. - * @param <V> Type of values in upstream. - * @return {@link StackedModel}. + * @return Trainer for stacking. */ - private <K, V> StackedModel<IS, IA, O, AM> runOnSubmodels( - IgniteFunction<List<DatasetTrainer<IgniteModel<IS, IA>, L>>, List<IgniteSupplier<IgniteModel<IS, IA>>>> taskSupplier, - IgniteBiFunction<DatasetTrainer<AM, L>, IgniteBiFunction<K, V, Vector>, AM> aggregatorProcessor, - IgniteBiFunction<K, V, Vector> featureExtractor) { + private DatasetTrainer<IgniteModel<IS, O>, L> getTrainer() { + checkConsistency(); + + List<DatasetTrainer<IgniteModel<IS, IA>, L>> subs = new ArrayList<>(); + if (submodelInput2AggregatingInputConverter != null) { + DatasetTrainer<IgniteModel<IS, IS>, L> id = DatasetTrainer.identityTrainer(); + DatasetTrainer<IgniteModel<IS, IA>, L> mappedId = CompositionUtils.unsafeCoerce( + AdaptableDatasetTrainer.of(id).afterTrainedModel(submodelInput2AggregatingInputConverter)); + subs.add(mappedId); + } + + subs.addAll(submodelsTrainers); + + TrainersParallelComposition<IS, IA, L> composition = new TrainersParallelComposition<>(subs); + IgniteBiFunction<List<IgniteModel<IS, IA>>, Vector, Vector> featureMapper = getFeatureExtractorForAggregator( + submodelOutput2VectorConverter, + vector2SubmodelInputConverter); + + return AdaptableDatasetTrainer + .of(composition) + .afterTrainedModel(lst -> lst.stream().reduce(aggregatingInputMerger).get()) + .andThen(aggregatorTrainer, model -> new DatasetMapping<L, L>() { + @Override public Vector mapFeatures(Vector v) { + List<IgniteModel<IS, IA>> models = ((ModelsParallelComposition<IS, IA>)model.innerModel()).submodels(); + return featureMapper.apply(models, v); + } + + @Override public L mapLabels(L lbl) { + return lbl; + } + }).unsafeSimplyTyped(); + } + + /** + * Method checking consistency of this trainer. + */ + private void checkConsistency() { // Make sure there is at least one way for submodel input to propagate to aggregator. if (submodelInput2AggregatingInputConverter == null && submodelsTrainers.isEmpty()) throw new IllegalStateException("There should be at least one way for submodels " + @@ -321,60 +298,36 @@ public class StackedDatasetTrainer<IS, IA, O, AM extends IgniteModel<IA, O>, L> if (aggregatingInputMerger == null) throw new IllegalStateException("Binary operator used to convert outputs of submodels is not specified"); + } - List<IgniteSupplier<IgniteModel<IS, IA>>> mdlSuppliers = taskSupplier.apply(submodelsTrainers); - - List<IgniteModel<IS, IA>> subMdls = environment.parallelismStrategy().submit(mdlSuppliers).stream() - .map(Promise::unsafeGet) - .collect(Collectors.toList()); - - // Add new columns consisting in submodels output in features. - IgniteBiFunction<K, V, Vector> augmentedExtractor = getFeatureExtractorForAggregator(featureExtractor, - subMdls, - submodelInput2AggregatingInputConverter, - submodelOutput2VectorConverter, - vector2SubmodelInputConverter); - - AM aggregator = aggregatorProcessor.apply(aggregatorTrainer, augmentedExtractor); - - StackedModel<IS, IA, O, AM> res = new StackedModel<>( - aggregator, - aggregatingInputMerger, - submodelInput2AggregatingInputConverter); - - for (IgniteModel<IS, IA> subMdl : subMdls) - res.addSubmodel(subMdl); + /** {@inheritDoc} */ + @Override public StackedDatasetTrainer<IS, IA, O, AM, L> withEnvironmentBuilder( + LearningEnvironmentBuilder envBuilder) { + submodelsTrainers = + submodelsTrainers.stream().map(x -> x.withEnvironmentBuilder(envBuilder)).collect(Collectors.toList()); + aggregatorTrainer = aggregatorTrainer.withEnvironmentBuilder(envBuilder); - return res; + return this; } /** * Get feature extractor which will be used for aggregator trainer from original feature extractor. * This method is static to make sure that we will not grab context of instance in serialization. * - * @param featureExtractor Original feature extractor. - * @param subMdls Submodels. + * @param <IS> Type of submodels input. + * @param <IA> Type of aggregator input. * @param <K> Type of upstream keys. - * @param <V> Type of upstream values. + * @param <V> Type of upstream values * @return Feature extractor which will be used for aggregator trainer from original feature extractor. */ - private static <IS, IA, K, V> IgniteBiFunction<K, V, Vector> getFeatureExtractorForAggregator( - IgniteBiFunction<K, V, Vector> featureExtractor, List<IgniteModel<IS, IA>> subMdls, - IgniteFunction<IS, IA> submodelInput2AggregatingInputConverter, + private static <IS, IA, K, V> IgniteBiFunction<List<IgniteModel<IS, IA>>, Vector, Vector> getFeatureExtractorForAggregator( IgniteFunction<IA, Vector> submodelOutput2VectorConverter, IgniteFunction<Vector, IS> vector2SubmodelInputConverter) { - if (submodelInput2AggregatingInputConverter != null) - return featureExtractor.andThen((Vector v) -> { - Vector[] vs = subMdls.stream().map(sm -> - applyToVector(sm, submodelOutput2VectorConverter, vector2SubmodelInputConverter, v)).toArray(Vector[]::new); - return VectorUtils.concat(v, vs); - }); - else - return featureExtractor.andThen((Vector v) -> { - Vector[] vs = subMdls.stream().map(sm -> - applyToVector(sm, submodelOutput2VectorConverter, vector2SubmodelInputConverter, v)).toArray(Vector[]::new); - return VectorUtils.concat(vs); - }); + return (List<IgniteModel<IS, IA>> subMdls, Vector v) -> { + Vector[] vs = subMdls.stream().map(sm -> + applyToVector(sm, submodelOutput2VectorConverter, vector2SubmodelInputConverter, v)).toArray(Vector[]::new); + return VectorUtils.concat(vs); + }; } /** @@ -396,17 +349,34 @@ public class StackedDatasetTrainer<IS, IA, O, AM extends IgniteModel<IA, O>, L> return vector2SubmodelInputConverter.andThen(mdl::predict).andThen(submodelOutput2VectorConverter).apply(v); } - /** {@inheritDoc} */ + /** + * This method is never called, instead of constructing logic of update from + * {@link DatasetTrainer#isUpdateable(IgniteModel)} and + * {@link DatasetTrainer#updateModel(IgniteModel, DatasetBuilder, IgniteBiFunction, IgniteBiFunction)} + * in this class we explicitly override update method. + * + * @param mdl Model. + * @return Updated model. + */ @Override protected <K, V> StackedModel<IS, IA, O, AM> updateModel(StackedModel<IS, IA, O, AM> mdl, DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) { // This method is never called, we override "update" instead. - return null; + throw new IllegalStateException(); } - /** {@inheritDoc} */ - @Override protected boolean checkState(StackedModel<IS, IA, O, AM> mdl) { - return true; + /** + * This method is never called, instead of constructing logic of update from + * {@link DatasetTrainer#isUpdateable} and + * {@link DatasetTrainer#updateModel} + * in this class we explicitly override update method. + * + * @param mdl Model. + * @return True if current critical for training parameters correspond to parameters from last training. + */ + @Override public boolean isUpdateable(StackedModel<IS, IA, O, AM> mdl) { + // Should be never called. + throw new IllegalStateException(); } } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedModel.java index a9be8f8..34e1a97 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedModel.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedModel.java @@ -17,19 +17,17 @@ package org.apache.ignite.ml.composition.stacking; -import java.util.ArrayList; -import java.util.List; import org.apache.ignite.ml.IgniteModel; -import org.apache.ignite.ml.math.functions.IgniteBinaryOperator; -import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.composition.combinators.parallel.ModelsParallelComposition; /** + * This is a wrapper for model produced by {@link StackedDatasetTrainer}. * Model consisting of two layers: * <pre> * 1. Submodels layer {@code (IS -> IA)}. * 2. Aggregator layer {@code (IA -> O)}. * </pre> - * Submodels layer is a "parallel" composition of several models {@code IS -> IA} each of them getting same input + * Submodels layer is a {@link ModelsParallelComposition} of several models {@code IS -> IA} each of them getting same input * {@code IS} and produce own output, these outputs outputs {@code [IA]} * are combined into a single output with a given binary "merger" operator {@code IA -> IA -> IA}. Result of merge * is then passed to the aggregator layer. @@ -41,66 +39,24 @@ import org.apache.ignite.ml.math.functions.IgniteFunction; * @param <AM> Type of aggregator model. */ public class StackedModel<IS, IA, O, AM extends IgniteModel<IA, O>> implements IgniteModel<IS, O> { - /** Submodels layer. */ - private IgniteModel<IS, IA> subModelsLayer; - - /** Aggregator model. */ - private final AM aggregatorMdl; - - /** Models constituting submodels layer. */ - private List<IgniteModel<IS, IA>> submodels; - - /** Binary operator merging submodels outputs. */ - private final IgniteBinaryOperator<IA> aggregatingInputMerger; - - /** - * Constructs instance of this class. - * - * @param aggregatorMdl Aggregator model. - * @param aggregatingInputMerger Binary operator used to merge submodels outputs. - * @param subMdlInput2AggregatingInput Function converting submodels input to aggregator input. (This function - * is needed when in {@link StackedDatasetTrainer} option to keep original features is chosen). - */ - StackedModel(AM aggregatorMdl, - IgniteBinaryOperator<IA> aggregatingInputMerger, - IgniteFunction<IS, IA> subMdlInput2AggregatingInput) { - this.aggregatorMdl = aggregatorMdl; - this.aggregatingInputMerger = aggregatingInputMerger; - this.subModelsLayer = subMdlInput2AggregatingInput != null ? subMdlInput2AggregatingInput::apply : null; - submodels = new ArrayList<>(); - } - - /** - * Get submodels constituting first layer of this model. - * - * @return Submodels constituting first layer of this model. - */ - List<IgniteModel<IS, IA>> submodels() { - return submodels; - } + /** Model to wrap. */ + private IgniteModel<IS, O> mdl; /** - * Get aggregator model. - * - * @return Aggregator model. + * Construct instance of this class from {@link IgniteModel}. + * @param mdl */ - AM aggregatorModel() { - return aggregatorMdl; + StackedModel(IgniteModel<IS, O> mdl) { + this.mdl = mdl; } - /** - * Add submodel into first layer. - * - * @param subMdl Submodel to add. - */ - void addSubmodel(IgniteModel<IS, IA> subMdl) { - submodels.add(subMdl); - subModelsLayer = subModelsLayer != null ? subModelsLayer.combine(subMdl, aggregatingInputMerger) - : subMdl; + /** {@inheritDoc} */ + @Override public O predict(IS is) { + return mdl.predict(is); } /** {@inheritDoc} */ - @Override public O predict(IS is) { - return subModelsLayer.andThen(aggregatorMdl).predict(is); + @Override public void close() { + mdl.close(); } } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedVectorDatasetTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedVectorDatasetTrainer.java index 7230e3c..c25b721 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedVectorDatasetTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedVectorDatasetTrainer.java @@ -81,6 +81,7 @@ public class StackedVectorDatasetTrainer<O, AM extends IgniteModel<Vector, O>, L } /** {@inheritDoc} */ + // TODO: IGNITE-10843 Add possibility to keep features with specific indices. @Override public StackedVectorDatasetTrainer<O, AM, L> withOriginalFeaturesKept( IgniteFunction<Vector, Vector> submodelInput2AggregatingInputConverter) { return (StackedVectorDatasetTrainer<O, AM, L>)super.withOriginalFeaturesKept( diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetBuilder.java index 9900659..c826a40 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetBuilder.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetBuilder.java @@ -67,7 +67,7 @@ public interface DatasetBuilder<K, V> { * @return Returns new instance of {@link DatasetBuilder} with new {@link UpstreamTransformerBuilder} added * to chain of upstream transformer builders. */ - public DatasetBuilder<K, V> withUpstreamTransformer(UpstreamTransformerBuilder<K, V> builder); + public DatasetBuilder<K, V> withUpstreamTransformer(UpstreamTransformerBuilder builder); /** * Returns new instance of DatasetBuilder using conjunction of internal filter and {@code filterToAdd}. diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformer.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformer.java index 9c0e281..c7fb92f 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformer.java @@ -22,29 +22,15 @@ import java.util.stream.Stream; /** * Interface of transformer of upstream. - * - * @param <K> Type of keys in the upstream. - * @param <V> Type of values in the upstream. */ // TODO: IGNITE-10297: Investigate possibility of API change. @FunctionalInterface -public interface UpstreamTransformer<K, V> extends Serializable { +public interface UpstreamTransformer extends Serializable { /** * Transform upstream. * * @param upstream Upstream to transform. * @return Transformed upstream. */ - public Stream<UpstreamEntry<K, V>> transform(Stream<UpstreamEntry<K, V>> upstream); - - /** - * Get composition of this transformer and other transformer which is - * itself is {@link UpstreamTransformer} applying this transformer and then other transformer. - * - * @param other Other transformer. - * @return Composition of this and other transformer. - */ - public default UpstreamTransformer<K, V> andThen(UpstreamTransformer<K, V> other) { - return upstream -> other.transform(transform(upstream)); - } + public Stream<UpstreamEntry> transform(Stream<UpstreamEntry> upstream); } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformerBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformerBuilder.java index 9adfab5..ea9f126 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformerBuilder.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformerBuilder.java @@ -21,19 +21,17 @@ import java.io.Serializable; import org.apache.ignite.ml.environment.LearningEnvironment; /** - * Builder of {@link UpstreamTransformerBuilder}. - * @param <K> Type of keys in upstream. - * @param <V> Type of values in upstream. + * Builder of {@link UpstreamTransformer}. */ @FunctionalInterface -public interface UpstreamTransformerBuilder<K, V> extends Serializable { +public interface UpstreamTransformerBuilder extends Serializable { /** * Create {@link UpstreamTransformer} based on learning environment. * * @param env Learning environment. * @return Upstream transformer. */ - public UpstreamTransformer<K, V> build(LearningEnvironment env); + public UpstreamTransformer build(LearningEnvironment env); /** * Combunes two builders (this and other respectfully) @@ -49,11 +47,11 @@ public interface UpstreamTransformerBuilder<K, V> extends Serializable { * @param other Builder to combine with. * @return Compositional builder. */ - public default UpstreamTransformerBuilder<K, V> andThen(UpstreamTransformerBuilder<K, V> other) { - UpstreamTransformerBuilder<K, V> self = this; + public default UpstreamTransformerBuilder andThen(UpstreamTransformerBuilder other) { + UpstreamTransformerBuilder self = this; return env -> { - UpstreamTransformer<K, V> transformer1 = self.build(env); - UpstreamTransformer<K, V> transformer2 = other.build(env); + UpstreamTransformer transformer1 = self.build(env); + UpstreamTransformer transformer2 = other.build(env); return upstream -> transformer2.transform(transformer1.transform(upstream)); }; @@ -66,7 +64,7 @@ public interface UpstreamTransformerBuilder<K, V> extends Serializable { * @param <V> Type of values in upstream. * @return Identity upstream transformer. */ - public static <K, V> UpstreamTransformerBuilder<K, V> identity() { + public static <K, V> UpstreamTransformerBuilder identity() { return env -> upstream -> upstream; } } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java index bde4bb6..b2aa00b 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java @@ -64,7 +64,7 @@ public class CacheBasedDataset<K, V, C extends Serializable, D extends AutoClose private final IgniteBiPredicate<K, V> filter; /** Builder of transformation applied to upstream. */ - private final UpstreamTransformerBuilder<K, V> upstreamTransformerBuilder; + private final UpstreamTransformerBuilder upstreamTransformerBuilder; /** Ignite Cache with partition {@code context}. */ private final IgniteCache<Integer, C> datasetCache; @@ -94,7 +94,7 @@ public class CacheBasedDataset<K, V, C extends Serializable, D extends AutoClose Ignite ignite, IgniteCache<K, V> upstreamCache, IgniteBiPredicate<K, V> filter, - UpstreamTransformerBuilder<K, V> upstreamTransformerBuilder, + UpstreamTransformerBuilder upstreamTransformerBuilder, IgniteCache<Integer, C> datasetCache, LearningEnvironmentBuilder envBuilder, PartitionDataBuilder<K, V, C, D> partDataBuilder, diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilder.java index be40158..b85bfc2 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilder.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilder.java @@ -59,7 +59,7 @@ public class CacheBasedDatasetBuilder<K, V> implements DatasetBuilder<K, V> { private final IgniteBiPredicate<K, V> filter; /** Upstream transformer builder. */ - private final UpstreamTransformerBuilder<K, V> transformerBuilder; + private final UpstreamTransformerBuilder transformerBuilder; /** * Constructs a new instance of cache based dataset builder that makes {@link CacheBasedDataset} with default @@ -93,7 +93,7 @@ public class CacheBasedDatasetBuilder<K, V> implements DatasetBuilder<K, V> { public CacheBasedDatasetBuilder(Ignite ignite, IgniteCache<K, V> upstreamCache, IgniteBiPredicate<K, V> filter, - UpstreamTransformerBuilder<K, V> transformerBuilder) { + UpstreamTransformerBuilder transformerBuilder) { this.ignite = ignite; this.upstreamCache = upstreamCache; this.filter = filter; @@ -136,7 +136,7 @@ public class CacheBasedDatasetBuilder<K, V> implements DatasetBuilder<K, V> { } /** {@inheritDoc} */ - @Override public DatasetBuilder<K, V> withUpstreamTransformer(UpstreamTransformerBuilder<K, V> builder) { + @Override public DatasetBuilder<K, V> withUpstreamTransformer(UpstreamTransformerBuilder builder) { return new CacheBasedDatasetBuilder<>(ignite, upstreamCache, filter, transformerBuilder.andThen(builder)); } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java index 7fa1efa..f12977c 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java @@ -185,7 +185,7 @@ public class ComputeUtils { public static <K, V, C extends Serializable, D extends AutoCloseable> D getData( Ignite ignite, String upstreamCacheName, IgniteBiPredicate<K, V> filter, - UpstreamTransformerBuilder<K, V> transformerBuilder, + UpstreamTransformerBuilder transformerBuilder, String datasetCacheName, UUID datasetId, PartitionDataBuilder<K, V, C, D> partDataBuilder, LearningEnvironment env) { @@ -208,8 +208,8 @@ public class ComputeUtils { qry.setPartition(part); qry.setFilter(filter); - UpstreamTransformer<K, V> transformer = transformerBuilder.build(env); - UpstreamTransformer<K, V> transformerCp = Utils.copy(transformer); + UpstreamTransformer transformer = transformerBuilder.build(env); + UpstreamTransformer transformerCp = Utils.copy(transformer); long cnt = computeCount(upstreamCache, qry, transformer); @@ -218,9 +218,8 @@ public class ComputeUtils { e -> new UpstreamEntry<>(e.getKey(), e.getValue()))) { Iterator<UpstreamEntry<K, V>> it = cursor.iterator(); - Stream<UpstreamEntry<K, V>> transformedStream = transformerCp.transform(Utils.asStream(it, cnt)); - it = transformedStream.iterator(); - + Stream<UpstreamEntry> transformedStream = transformerCp.transform(Utils.asStream(it, cnt).map(x -> (UpstreamEntry)x)); + it = Utils.asStream(transformedStream.iterator()).map(x -> (UpstreamEntry<K, V>)x).iterator(); Iterator<UpstreamEntry<K, V>> iter = new IteratorWithConcurrentModificationChecker<>(it, cnt, "Cache expected to be not modified during dataset data building [partition=" + part + ']'); @@ -268,7 +267,7 @@ public class ComputeUtils { public static <K, V, C extends Serializable> void initContext( Ignite ignite, String upstreamCacheName, - UpstreamTransformerBuilder<K, V> transformerBuilder, + UpstreamTransformerBuilder transformerBuilder, IgniteBiPredicate<K, V> filter, String datasetCacheName, PartitionContextBuilder<K, V, C> ctxBuilder, @@ -287,8 +286,8 @@ public class ComputeUtils { qry.setFilter(filter); C ctx; - UpstreamTransformer<K, V> transformer = transformerBuilder.build(env); - UpstreamTransformer<K, V> transformerCp = Utils.copy(transformer); + UpstreamTransformer transformer = transformerBuilder.build(env); + UpstreamTransformer transformerCp = Utils.copy(transformer); long cnt = computeCount(locUpstreamCache, qry, transformer); @@ -296,8 +295,8 @@ public class ComputeUtils { e -> new UpstreamEntry<>(e.getKey(), e.getValue()))) { Iterator<UpstreamEntry<K, V>> it = cursor.iterator(); - Stream<UpstreamEntry<K, V>> transformedStream = transformerCp.transform(Utils.asStream(it, cnt)); - it = transformedStream.iterator(); + Stream<UpstreamEntry> transformedStream = transformerCp.transform(Utils.asStream(it, cnt).map(x -> (UpstreamEntry)x)); + it = Utils.asStream(transformedStream.iterator()).map(x -> (UpstreamEntry<K, V>)x).iterator(); Iterator<UpstreamEntry<K, V>> iter = new IteratorWithConcurrentModificationChecker<>( it, @@ -334,7 +333,7 @@ public class ComputeUtils { Ignite ignite, String upstreamCacheName, IgniteBiPredicate<K, V> filter, - UpstreamTransformerBuilder<K, V> transformerBuilder, + UpstreamTransformerBuilder transformerBuilder, String datasetCacheName, PartitionContextBuilder<K, V, C> ctxBuilder, LearningEnvironmentBuilder envBuilder, @@ -382,11 +381,11 @@ public class ComputeUtils { private static <K, V> long computeCount( IgniteCache<K, V> cache, ScanQuery<K, V> qry, - UpstreamTransformer<K, V> transformer) { + UpstreamTransformer transformer) { try (QueryCursor<UpstreamEntry<K, V>> cursor = cache.query(qry, e -> new UpstreamEntry<>(e.getKey(), e.getValue()))) { - return computeCount(transformer.transform(Utils.asStream(cursor.iterator())).iterator()); + return computeCount(transformer.transform(Utils.asStream(cursor.iterator()).map(x -> (UpstreamEntry<K, V>)x)).iterator()); } } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java index b8cd8dc..84f3e08 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java @@ -54,7 +54,7 @@ public class LocalDatasetBuilder<K, V> implements DatasetBuilder<K, V> { private final IgniteBiPredicate<K, V> filter; /** Upstream transformers. */ - private final UpstreamTransformerBuilder<K, V> upstreamTransformerBuilder; + private final UpstreamTransformerBuilder upstreamTransformerBuilder; /** * Constructs a new instance of local dataset builder that makes {@link LocalDataset} with default predicate that @@ -78,7 +78,7 @@ public class LocalDatasetBuilder<K, V> implements DatasetBuilder<K, V> { public LocalDatasetBuilder(Map<K, V> upstreamMap, IgniteBiPredicate<K, V> filter, int partitions, - UpstreamTransformerBuilder<K, V> upstreamTransformerBuilder) { + UpstreamTransformerBuilder upstreamTransformerBuilder) { this.upstreamMap = upstreamMap; this.filter = filter; this.partitions = partitions; @@ -129,23 +129,26 @@ public class LocalDatasetBuilder<K, V> implements DatasetBuilder<K, V> { int cntBeforeTransform = part == partitions - 1 ? entriesList.size() - ptr : Math.min(partSize, entriesList.size() - ptr); LearningEnvironment env = envs.get(part); - UpstreamTransformer<K, V> transformer1 = upstreamTransformerBuilder.build(env); - UpstreamTransformer<K, V> transformer2 = Utils.copy(transformer1); - UpstreamTransformer<K, V> transformer3 = Utils.copy(transformer1); + UpstreamTransformer transformer1 = upstreamTransformerBuilder.build(env); + UpstreamTransformer transformer2 = Utils.copy(transformer1); + UpstreamTransformer transformer3 = Utils.copy(transformer1); int cnt = (int)transformer1.transform(Utils.asStream(new IteratorWindow<>(thirdKeysIter, k -> k, cntBeforeTransform))).count(); - Iterator<UpstreamEntry<K, V>> iter = - transformer2.transform(Utils.asStream(new IteratorWindow<>(firstKeysIter, k -> k, cntBeforeTransform))).iterator(); + Iterator<UpstreamEntry> iter = + transformer2.transform(Utils.asStream(new IteratorWindow<>(firstKeysIter, k -> k, cntBeforeTransform)).map(x -> (UpstreamEntry)x)).iterator(); + Iterator<UpstreamEntry<K, V>> convertedBack = Utils.asStream(iter).map(x -> (UpstreamEntry<K, V>)x).iterator(); - C ctx = cntBeforeTransform > 0 ? partCtxBuilder.build(env, iter, cnt) : null; + C ctx = cntBeforeTransform > 0 ? partCtxBuilder.build(env, convertedBack, cnt) : null; - Iterator<UpstreamEntry<K, V>> iter1 = transformer3.transform( + Iterator<UpstreamEntry> iter1 = transformer3.transform( Utils.asStream(new IteratorWindow<>(secondKeysIter, k -> k, cntBeforeTransform))).iterator(); + Iterator<UpstreamEntry<K, V>> convertedBack1 = Utils.asStream(iter1).map(x -> (UpstreamEntry<K, V>)x).iterator(); + D data = cntBeforeTransform > 0 ? partDataBuilder.build( env, - iter1, + convertedBack1, cnt, ctx ) : null; @@ -160,7 +163,7 @@ public class LocalDatasetBuilder<K, V> implements DatasetBuilder<K, V> { } /** {@inheritDoc} */ - @Override public DatasetBuilder<K, V> withUpstreamTransformer(UpstreamTransformerBuilder<K, V> builder) { + @Override public DatasetBuilder<K, V> withUpstreamTransformer(UpstreamTransformerBuilder builder) { return new LocalDatasetBuilder<>(upstreamMap, filter, partitions, upstreamTransformerBuilder.andThen(builder)); } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/genetic/Chromosome.java b/modules/ml/src/main/java/org/apache/ignite/ml/genetic/Chromosome.java index ed78e85..0552036 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/genetic/Chromosome.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/genetic/Chromosome.java @@ -19,7 +19,6 @@ package org.apache.ignite.ml.genetic; import java.util.Arrays; import java.util.concurrent.atomic.AtomicLong; - import org.apache.ignite.cache.query.annotations.QuerySqlField; /** diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/genetic/MutateJob.java b/modules/ml/src/main/java/org/apache/ignite/ml/genetic/MutateJob.java index b03e7ca..d69911a 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/genetic/MutateJob.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/genetic/MutateJob.java @@ -22,11 +22,10 @@ import org.apache.ignite.Ignite; import org.apache.ignite.IgniteCache; import org.apache.ignite.IgniteException; import org.apache.ignite.compute.ComputeJobAdapter; +import org.apache.ignite.ml.genetic.parameter.GAGridConstants; import org.apache.ignite.resources.IgniteInstanceResource; import org.apache.ignite.transactions.Transaction; -import org.apache.ignite.ml.genetic.parameter.GAGridConstants; - /** * Responsible for applying mutation on respective Chromosome based on mutation Rate */ diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/genetic/cache/GeneCacheConfig.java b/modules/ml/src/main/java/org/apache/ignite/ml/genetic/cache/GeneCacheConfig.java index c5302ee..f980e22 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/genetic/cache/GeneCacheConfig.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/genetic/cache/GeneCacheConfig.java @@ -20,7 +20,6 @@ package org.apache.ignite.ml.genetic.cache; import org.apache.ignite.cache.CacheMode; import org.apache.ignite.cache.CacheRebalanceMode; import org.apache.ignite.configuration.CacheConfiguration; - import org.apache.ignite.ml.genetic.Gene; import org.apache.ignite.ml.genetic.functions.GAGridFunction; import org.apache.ignite.ml.genetic.parameter.GAGridConstants; diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/genetic/cache/PopulationCacheConfig.java b/modules/ml/src/main/java/org/apache/ignite/ml/genetic/cache/PopulationCacheConfig.java index cae7c1a..6a8b2b4 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/genetic/cache/PopulationCacheConfig.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/genetic/cache/PopulationCacheConfig.java @@ -21,7 +21,6 @@ import org.apache.ignite.cache.CacheAtomicityMode; import org.apache.ignite.cache.CacheMode; import org.apache.ignite.cache.CacheRebalanceMode; import org.apache.ignite.configuration.CacheConfiguration; - import org.apache.ignite.ml.genetic.Chromosome; import org.apache.ignite.ml.genetic.parameter.GAGridConstants; diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/genetic/parameter/ChromosomeCriteria.java b/modules/ml/src/main/java/org/apache/ignite/ml/genetic/parameter/ChromosomeCriteria.java index bc4b839..745847a 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/genetic/parameter/ChromosomeCriteria.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/genetic/parameter/ChromosomeCriteria.java @@ -18,7 +18,6 @@ package org.apache.ignite.ml.genetic.parameter; import java.util.ArrayList; - import java.util.List; /** diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java index c32ca56..0cdfc52 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java @@ -102,7 +102,7 @@ public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClass } /** {@inheritDoc} */ - @Override protected boolean checkState(ANNClassificationModel mdl) { + @Override public boolean isUpdateable(ANNClassificationModel mdl) { return mdl.getDistanceMeasure().equals(distance) && mdl.getCandidates().rowSize() == k; } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationTrainer.java index c52ad2b..16bf186 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationTrainer.java @@ -60,7 +60,7 @@ public class KNNClassificationTrainer extends SingleLabelDatasetTrainer<KNNClass } /** {@inheritDoc} */ - @Override protected boolean checkState(KNNClassificationModel mdl) { + @Override public boolean isUpdateable(KNNClassificationModel mdl) { return true; } } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionTrainer.java index 9b348f3..e621801 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionTrainer.java @@ -56,7 +56,7 @@ public class KNNRegressionTrainer extends SingleLabelDatasetTrainer<KNNRegressio } /** {@inheritDoc} */ - @Override protected boolean checkState(KNNRegressionModel mdl) { + @Override public boolean isUpdateable(KNNRegressionModel mdl) { return true; } } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/OneVsRestTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/OneVsRestTrainer.java index 4eca27f..a44b5b4 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/OneVsRestTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/OneVsRestTrainer.java @@ -101,7 +101,7 @@ public class OneVsRestTrainer<M extends IgniteModel<Vector, Double>> } /** {@inheritDoc} */ - @Override protected boolean checkState(MultiClassModel<M> mdl) { + @Override public boolean isUpdateable(MultiClassModel<M> mdl) { return true; } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainer.java index 0779b84..0179b31 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainer.java @@ -59,7 +59,7 @@ public class DiscreteNaiveBayesTrainer extends SingleLabelDatasetTrainer<Discret } /** {@inheritDoc} */ - @Override protected boolean checkState(DiscreteNaiveBayesModel mdl) { + @Override public boolean isUpdateable(DiscreteNaiveBayesModel mdl) { if (mdl.getBucketThresholds().length != bucketThresholds.length) return false; @@ -124,7 +124,7 @@ public class DiscreteNaiveBayesTrainer extends SingleLabelDatasetTrainer<Discret return a.merge(b); }); - if (mdl != null && checkState(mdl)) { + if (mdl != null && isUpdateable(mdl)) { if (checkSumsHolder(sumsHolder, mdl.getSumsHolder())) sumsHolder = sumsHolder.merge(mdl.getSumsHolder()); } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainer.java index cdaac5a..c4ef1bd 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainer.java @@ -55,7 +55,7 @@ public class GaussianNaiveBayesTrainer extends SingleLabelDatasetTrainer<Gaussia } /** {@inheritDoc} */ - @Override protected boolean checkState(GaussianNaiveBayesModel mdl) { + @Override public boolean isUpdateable(GaussianNaiveBayesModel mdl) { return true; } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java index ea0bb6c..cf511ec 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java @@ -354,7 +354,7 @@ public class MLPTrainer<P extends Serializable> extends MultiLabelDatasetTrainer } /** {@inheritDoc} */ - @Override protected boolean checkState(MultilayerPerceptron mdl) { + @Override public boolean isUpdateable(MultilayerPerceptron mdl) { return true; } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java index 6b2b11e..e273633 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java @@ -79,7 +79,7 @@ public class LinearRegressionLSQRTrainer extends SingleLabelDatasetTrainer<Linea } /** {@inheritDoc} */ - @Override public boolean checkState(LinearRegressionModel mdl) { + @Override public boolean isUpdateable(LinearRegressionModel mdl) { return true; } } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java index 4132d35..7dc4df6 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java @@ -160,7 +160,7 @@ public class LinearRegressionSGDTrainer<P extends Serializable> extends SingleLa } /** {@inheritDoc} */ - @Override protected boolean checkState(LinearRegressionModel mdl) { + @Override public boolean isUpdateable(LinearRegressionModel mdl) { return true; } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java index 864187d..16ffac3 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java @@ -139,7 +139,7 @@ public class LogisticRegressionSGDTrainer extends SingleLabelDatasetTrainer<Logi } /** {@inheritDoc} */ - @Override protected boolean checkState(LogisticRegressionModel mdl) { + @Override public boolean isUpdateable(LogisticRegressionModel mdl) { return true; } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationTrainer.java index 67484ea..90bbe37 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationTrainer.java @@ -121,7 +121,7 @@ public class SVMLinearClassificationTrainer extends SingleLabelDatasetTrainer<SV } /** {@inheritDoc} */ - @Override protected boolean checkState(SVMLinearClassificationModel mdl) { + @Override public boolean isUpdateable(SVMLinearClassificationModel mdl) { return true; } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/AdaptableDatasetTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/AdaptableDatasetTrainer.java index 4205286..4695946 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/AdaptableDatasetTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/AdaptableDatasetTrainer.java @@ -18,7 +18,10 @@ package org.apache.ignite.ml.trainers; import org.apache.ignite.ml.IgniteModel; +import org.apache.ignite.ml.composition.DatasetMapping; +import org.apache.ignite.ml.composition.combinators.sequential.TrainersSequentialComposition; import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.dataset.UpstreamTransformerBuilder; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.functions.IgniteFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; @@ -46,6 +49,15 @@ public class AdaptableDatasetTrainer<I, O, IW, OW, M extends IgniteModel<IW, OW> /** Function used to convert output type of wrapped trainer. */ private final IgniteFunction<OW, O> after; + /** Function which is applied after feature extractor. */ + private final IgniteFunction<Vector, Vector> afterFeatureExtractor; + + /** Function which is applied after label extractor. */ + private final IgniteFunction<L, L> afterLabelExtractor; + + /** Upstream transformer builder which will be used in dataset builder. */ + private final UpstreamTransformerBuilder upstreamTransformerBuilder; + /** * Construct instance of this class from a given {@link DatasetTrainer}. * @@ -56,39 +68,65 @@ public class AdaptableDatasetTrainer<I, O, IW, OW, M extends IgniteModel<IW, OW> * @param <L> Type of labels. * @return Instance of this class. */ - public static <I, O, M extends IgniteModel<I, O>, L> AdaptableDatasetTrainer<I, O, I, O, M, L> of(DatasetTrainer<M, L> wrapped) { - return new AdaptableDatasetTrainer<>(IgniteFunction.identity(), wrapped, IgniteFunction.identity()); + public static <I, O, M extends IgniteModel<I, O>, L> AdaptableDatasetTrainer<I, O, I, O, M, L> of( + DatasetTrainer<M, L> wrapped) { + return new AdaptableDatasetTrainer<>(IgniteFunction.identity(), + wrapped, + IgniteFunction.identity(), + IgniteFunction.identity(), + IgniteFunction.identity(), + UpstreamTransformerBuilder.identity()); } /** * Construct instance of this class with specified wrapped trainer and converter functions. * * @param before Function used to convert input type of wrapped trainer. - * @param wrapped Wrapped trainer. + * @param wrapped Wrapped trainer. * @param after Function used to convert output type of wrapped trainer. + * @param extractor Function which is applied after label extractor. + * @param builder Upstream transformer builder which will be used in dataset builder. */ - private AdaptableDatasetTrainer(IgniteFunction<I, IW> before, DatasetTrainer<M, L> wrapped, IgniteFunction<OW, O> after) { + private AdaptableDatasetTrainer(IgniteFunction<I, IW> before, DatasetTrainer<M, L> wrapped, + IgniteFunction<OW, O> after, + IgniteFunction<Vector, Vector> afterFeatureExtractor, + IgniteFunction<L, L> extractor, UpstreamTransformerBuilder builder) { this.before = before; this.wrapped = wrapped; this.after = after; + this.afterFeatureExtractor = afterFeatureExtractor; + afterLabelExtractor = extractor; + upstreamTransformerBuilder = builder; } /** {@inheritDoc} */ @Override public <K, V> AdaptableDatasetModel<I, O, IW, OW, M> fit(DatasetBuilder<K, V> datasetBuilder, - IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) { - M fit = wrapped.fit(datasetBuilder, featureExtractor, lbExtractor); + IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, L> lbExtractor) { + M fit = wrapped.fit( + datasetBuilder.withUpstreamTransformer(upstreamTransformerBuilder), + featureExtractor.andThen(afterFeatureExtractor), + lbExtractor.andThen(afterLabelExtractor)); + return new AdaptableDatasetModel<>(before, fit, after); } /** {@inheritDoc} */ - @Override protected boolean checkState(AdaptableDatasetModel<I, O, IW, OW, M> mdl) { - return wrapped.checkState(mdl.innerModel()); + @Override public boolean isUpdateable(AdaptableDatasetModel<I, O, IW, OW, M> mdl) { + return wrapped.isUpdateable(mdl.innerModel()); } /** {@inheritDoc} */ - @Override protected <K, V> AdaptableDatasetModel<I, O, IW, OW, M> updateModel(AdaptableDatasetModel<I, O, IW, OW, M> mdl, DatasetBuilder<K, V> datasetBuilder, + @Override protected <K, V> AdaptableDatasetModel<I, O, IW, OW, M> updateModel( + AdaptableDatasetModel<I, O, IW, OW, M> mdl, DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) { - return mdl.withInnerModel(wrapped.updateModel(mdl.innerModel(), datasetBuilder, featureExtractor, lbExtractor)); + M updated = wrapped.updateModel( + mdl.innerModel(), + datasetBuilder.withUpstreamTransformer(upstreamTransformerBuilder), + featureExtractor.andThen(afterFeatureExtractor), + lbExtractor.andThen(afterLabelExtractor)); + + return mdl.withInnerModel(updated); } /** @@ -101,7 +139,12 @@ public class AdaptableDatasetTrainer<I, O, IW, OW, M extends IgniteModel<IW, OW> * original trainer. */ public <O1> AdaptableDatasetTrainer<I, O1, IW, OW, M, L> afterTrainedModel(IgniteFunction<O, O1> after) { - return new AdaptableDatasetTrainer<>(before, wrapped, i -> after.apply(this.after.apply(i))); + return new AdaptableDatasetTrainer<>(before, + wrapped, + i -> after.apply(this.after.apply(i)), + afterFeatureExtractor, + afterLabelExtractor, + upstreamTransformerBuilder); } /** @@ -115,6 +158,116 @@ public class AdaptableDatasetTrainer<I, O, IW, OW, M extends IgniteModel<IW, OW> */ public <I1> AdaptableDatasetTrainer<I1, O, IW, OW, M, L> beforeTrainedModel(IgniteFunction<I1, I> before) { IgniteFunction<I1, IW> function = i -> this.before.apply(before.apply(i)); - return new AdaptableDatasetTrainer<>(function, wrapped, after); + return new AdaptableDatasetTrainer<>(function, + wrapped, + after, + afterFeatureExtractor, + afterLabelExtractor, + upstreamTransformerBuilder); + } + + /** + * Specify {@link DatasetMapping} which will be applied to dataset before fitting and updating. + * + * @param mapping {@link DatasetMapping} which will be applied to dataset before fitting and updating. + * @return New trainer of the same type, but with specified mapping applied to dataset before fitting and updating. + */ + public AdaptableDatasetTrainer<I, O, IW, OW, M, L> withDatasetMapping(DatasetMapping<L, L> mapping) { + return of(new DatasetTrainer<M, L>() { + @Override public <K, V> M fit( + DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, L> lbExtractor) { + IgniteBiFunction<K, V, Vector> fe = featureExtractor.andThen(mapping::mapFeatures); + IgniteBiFunction<K, V, L> le = lbExtractor.andThen(mapping::mapLabels); + + return wrapped.fit(datasetBuilder, + fe, + le); + } + + @Override public <K, V> M update(M mdl, DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) { + return wrapped.update(mdl, datasetBuilder, + featureExtractor.andThen(mapping::mapFeatures), + lbExtractor.andThen((IgniteFunction<L, L>)mapping::mapLabels)); + } + + @Override public boolean isUpdateable(M mdl) { + return false; + } + + @Override protected <K, V> M updateModel(M mdl, DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) { + return null; + } + }).beforeTrainedModel(before).afterTrainedModel(after); + } + + /** + * Create a {@link TrainersSequentialComposition} of whis trainer and specified trainer. + * + * @param tr Trainer to compose with. + * @param datasetMappingProducer {@link DatasetMapping} producer specifying dependency between this trainer and + * trainer to compose with. + * @param <O1> Type of output of trainer to compose with. + * @param <M1> Type of model produced by the trainer to compose with. + * @return A {@link TrainersSequentialComposition} of whis trainer and specified trainer. + */ + public <O1, M1 extends IgniteModel<O, O1>> TrainersSequentialComposition<I, O, O1, L> andThen( + DatasetTrainer<M1, L> tr, + IgniteFunction<AdaptableDatasetModel<I, O, IW, OW, M>, DatasetMapping<L, L>> datasetMappingProducer) { + IgniteFunction<IgniteModel<I, O>, DatasetMapping<L, L>> coercedMapping = mdl -> + datasetMappingProducer.apply((AdaptableDatasetModel<I, O, IW, OW, M>)mdl); + return new TrainersSequentialComposition<>(this, + tr, + coercedMapping); + } + + /** + * Specify function which will be applied after feature extractor. + * + * @param after Function which will be applied after feature extractor. + * @return New trainer with same parameters as this trainer except that specified function will be applied + * after feature extractor. + */ + public AdaptableDatasetTrainer<I, O, IW, OW, M, L> afterFeatureExtractor(IgniteFunction<Vector, Vector> after) { + return new AdaptableDatasetTrainer<>(before, + wrapped, + this.after, + after, + afterLabelExtractor, + upstreamTransformerBuilder); + } + + /** + * Specify function which will be applied after label extractor. + * + * @param after Function which will be applied after label extractor. + * @return New trainer with same parameters as this trainer has except that specified function will be applied + * after label extractor. + */ + public AdaptableDatasetTrainer<I, O, IW, OW, M, L> afterLabelExtractor(IgniteFunction<L, L> after) { + return new AdaptableDatasetTrainer<>(before, + wrapped, + this.after, + afterFeatureExtractor, + after, + upstreamTransformerBuilder); + } + + /** + * Specify which {@link UpstreamTransformerBuilder} will be used. + * + * @param upstreamTransformerBuilder {@link UpstreamTransformerBuilder} to use. + * @return New trainer with same parameters as this trainer has except that specified {@link UpstreamTransformerBuilder} will be used. + */ + public AdaptableDatasetTrainer<I, O, IW, OW, M, L> withUpstreamTransformerBuilder( + UpstreamTransformerBuilder upstreamTransformerBuilder) { + return new AdaptableDatasetTrainer<>(before, + wrapped, + after, + afterFeatureExtractor, + afterLabelExtractor, + upstreamTransformerBuilder); } } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java index 88c4bcd..42cac07 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java @@ -71,12 +71,11 @@ public abstract class DatasetTrainer<M extends IgniteModel, L> { * @param <V> Type of a value in {@code upstream} data. * @return Updated model. */ - // public <K,V> M update(M mdl, DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) { if(mdl != null) { - if (checkState(mdl)) + if (isUpdateable(mdl)) return updateModel(mdl, datasetBuilder, featureExtractor, lbExtractor); else { environment.logger(getClass()).log( @@ -94,7 +93,7 @@ public abstract class DatasetTrainer<M extends IgniteModel, L> { * @param mdl Model. * @return true if current critical for training parameters correspond to parameters from last training. */ - protected abstract boolean checkState(M mdl); + public abstract boolean isUpdateable(M mdl); /** * Used on update phase when given dataset is empty. @@ -308,12 +307,12 @@ public abstract class DatasetTrainer<M extends IgniteModel, L> { } /** - * Creates {@code DatasetTrainer} with same training logic, but able to accept labels of given new type + * Creates {@link DatasetTrainer} with same training logic, but able to accept labels of given new type * of labels. * * @param new2Old Converter of new labels to old labels. * @param <L1> New labels type. - * @return {@code DatasetTrainer} with same training logic, but able to accept labels of given new type + * @return {@link DatasetTrainer} with same training logic, but able to accept labels of given new type * of labels. */ public <L1> DatasetTrainer<M, L1> withConvertedLabels(IgniteFunction<L1, L> new2Old) { @@ -326,8 +325,8 @@ public abstract class DatasetTrainer<M extends IgniteModel, L> { } /** {@inheritDoc} */ - @Override protected boolean checkState(M mdl) { - return old.checkState(mdl); + @Override public boolean isUpdateable(M mdl) { + return old.isUpdateable(mdl); } /** {@inheritDoc} */ @@ -362,4 +361,31 @@ public abstract class DatasetTrainer<M extends IgniteModel, L> { } } + /** + * Returns the trainer which returns identity model. + * + * @param <I> Type of model input. + * @param <L> Type of labels in dataset. + * @return Trainer which returns identity model. + */ + public static <I, L> DatasetTrainer<IgniteModel<I, I>, L> identityTrainer() { + return new DatasetTrainer<IgniteModel<I, I>, L>() { + @Override public <K, V> IgniteModel<I, I> fit(DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, L> lbExtractor) { + return x -> x; + } + + /** {@inheritDoc} */ + @Override public boolean isUpdateable(IgniteModel<I, I> mdl) { + return true; + } + + /** {@inheritDoc} */ + @Override protected <K, V> IgniteModel<I, I> updateModel(IgniteModel<I, I> mdl, DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) { + return x -> x; + } + }; + } } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java index 43c1600..db5522e 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java @@ -24,6 +24,7 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; import org.apache.ignite.ml.IgniteModel; import org.apache.ignite.ml.composition.ModelsComposition; +import org.apache.ignite.ml.composition.bagging.BaggedTrainer; import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.environment.LearningEnvironment; @@ -48,12 +49,11 @@ public class TrainerTransformers { * @param ensembleSize Size of ensemble. * @param subsampleRatio Subsample ratio to whole dataset. * @param aggregator Aggregator. - * @param <M> Type of one model in ensemble. * @param <L> Type of labels. * @return Bagged trainer. */ - public static <M extends IgniteModel<Vector, Double>, L> DatasetTrainer<ModelsComposition, L> makeBagged( - DatasetTrainer<M, L> trainer, + public static <L> BaggedTrainer<L> makeBagged( + DatasetTrainer<? extends IgniteModel, L> trainer, int ensembleSize, double subsampleRatio, PredictionsAggregator aggregator) { @@ -71,58 +71,19 @@ public class TrainerTransformers { * @param <L> Type of labels. * @return Bagged trainer. */ - public static <M extends IgniteModel<Vector, Double>, L> DatasetTrainer<ModelsComposition, L> makeBagged( + public static <M extends IgniteModel<Vector, Double>, L> BaggedTrainer<L> makeBagged( DatasetTrainer<M, L> trainer, int ensembleSize, double subsampleRatio, int featureVectorSize, int featuresSubspaceDim, PredictionsAggregator aggregator) { - return new DatasetTrainer<ModelsComposition, L>() { - /** {@inheritDoc} */ - @Override public <K, V> ModelsComposition fit( - DatasetBuilder<K, V> datasetBuilder, - IgniteBiFunction<K, V, Vector> featureExtractor, - IgniteBiFunction<K, V, L> lbExtractor) { - return runOnEnsemble( - (db, i, fe) -> (() -> trainer.fit(db, fe, lbExtractor)), - datasetBuilder, - ensembleSize, - subsampleRatio, - featureVectorSize, - featuresSubspaceDim, - featureExtractor, - aggregator, - environment); - } - - /** {@inheritDoc} */ - @Override protected boolean checkState(ModelsComposition mdl) { - return mdl.getModels().stream().allMatch(m -> trainer.checkState((M)m)); - } - - /** {@inheritDoc} */ - @Override protected <K, V> ModelsComposition updateModel( - ModelsComposition mdl, - DatasetBuilder<K, V> datasetBuilder, - IgniteBiFunction<K, V, Vector> featureExtractor, - IgniteBiFunction<K, V, L> lbExtractor) { - return runOnEnsemble( - (db, i, fe) -> (() -> trainer.updateModel( - ((ModelWithMapping<Vector, Double, M>)mdl.getModels().get(i)).model(), - db, - fe, - lbExtractor)), - datasetBuilder, - ensembleSize, - subsampleRatio, - featureVectorSize, - featuresSubspaceDim, - featureExtractor, - aggregator, - environment); - } - }.withEnvironmentBuilder(trainer.envBuilder); + return new BaggedTrainer<>(trainer, + aggregator, + ensembleSize, + subsampleRatio, + featureVectorSize, + featuresSubspaceDim); } /** diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/transformers/BaggingUpstreamTransformer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/transformers/BaggingUpstreamTransformer.java index 7f45fdd..36e7867 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/transformers/BaggingUpstreamTransformer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/transformers/BaggingUpstreamTransformer.java @@ -28,11 +28,8 @@ import org.apache.ignite.ml.dataset.UpstreamTransformerBuilder; * This class encapsulates the logic needed to do bagging (bootstrap aggregating) by features. * The action of this class on a given upstream is to replicate each entry in accordance to * Poisson distribution. - * - * @param <K> Type of upstream keys. - * @param <V> Type of upstream values. */ -public class BaggingUpstreamTransformer<K, V> implements UpstreamTransformer<K, V> { +public class BaggingUpstreamTransformer implements UpstreamTransformer { /** Serial version uid. */ private static final long serialVersionUID = -913152523469994149L; @@ -51,8 +48,8 @@ public class BaggingUpstreamTransformer<K, V> implements UpstreamTransformer<K, * @param <V> Type of upstream values. * @return Builder of {@link BaggingUpstreamTransformer}. */ - public static <K, V> UpstreamTransformerBuilder<K, V> builder(double subsampleRatio, int mdlIdx) { - return env -> new BaggingUpstreamTransformer<>(env.randomNumbersGenerator().nextLong() + mdlIdx, subsampleRatio); + public static <K, V> UpstreamTransformerBuilder builder(double subsampleRatio, int mdlIdx) { + return env -> new BaggingUpstreamTransformer(env.randomNumbersGenerator().nextLong() + mdlIdx, subsampleRatio); } /** @@ -67,7 +64,7 @@ public class BaggingUpstreamTransformer<K, V> implements UpstreamTransformer<K, } /** {@inheritDoc} */ - @Override public Stream<UpstreamEntry<K, V>> transform(Stream<UpstreamEntry<K, V>> upstream) { + @Override public Stream<UpstreamEntry> transform(Stream<UpstreamEntry> upstream) { PoissonDistribution poisson = new PoissonDistribution( new Well19937c(seed), subsampleRatio, diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java index 35d1ea4..f3fc4ce 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java @@ -106,7 +106,7 @@ public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends Dataset } /** {@inheritDoc} */ - @Override protected boolean checkState(DecisionTreeNode mdl) { + @Override public boolean isUpdateable(DecisionTreeNode mdl) { return true; } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java index d9b8e30..6d92948 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java @@ -239,7 +239,7 @@ public abstract class RandomForestTrainer<L, S extends ImpurityComputer<Bootstra } /** {@inheritDoc} */ - @Override protected boolean checkState(ModelsComposition mdl) { + @Override public boolean isUpdateable(ModelsComposition mdl) { ModelsComposition fakeComposition = buildComposition(Collections.emptyList()); return mdl.getPredictionsAggregator().getClass() == fakeComposition.getPredictionsAggregator().getClass(); } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/util/generators/DataStreamGenerator.java b/modules/ml/src/main/java/org/apache/ignite/ml/util/generators/DataStreamGenerator.java index c2fd652..e57c5ba 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/util/generators/DataStreamGenerator.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/util/generators/DataStreamGenerator.java @@ -126,7 +126,7 @@ public interface DataStreamGenerator { * @return Dataset builder. */ public default DatasetBuilder<Vector, Double> asDatasetBuilder(int datasetSize, IgniteBiPredicate<Vector, Double> filter, - int partitions, UpstreamTransformerBuilder<Vector, Double> upstreamTransformerBuilder) { + int partitions, UpstreamTransformerBuilder upstreamTransformerBuilder) { return new DatasetBuilderAdapter(this, datasetSize, filter, partitions, upstreamTransformerBuilder); } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/util/generators/DatasetBuilderAdapter.java b/modules/ml/src/main/java/org/apache/ignite/ml/util/generators/DatasetBuilderAdapter.java index 189e053..7e5060e 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/util/generators/DatasetBuilderAdapter.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/util/generators/DatasetBuilderAdapter.java @@ -48,7 +48,7 @@ class DatasetBuilderAdapter extends LocalDatasetBuilder<Vector, Double> { */ public DatasetBuilderAdapter(DataStreamGenerator generator, int datasetSize, IgniteBiPredicate<Vector, Double> filter, int partitions, - UpstreamTransformerBuilder<Vector, Double> upstreamTransformerBuilder) { + UpstreamTransformerBuilder upstreamTransformerBuilder) { super(generator.asMap(datasetSize), filter, partitions, upstreamTransformerBuilder); } diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java b/modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java index fc3bf5c..ed23373 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java @@ -429,7 +429,7 @@ public class TestUtils { } /** {@inheritDoc} */ - @Override public boolean checkState(M mdl) { + @Override public boolean isUpdateable(M mdl) { return true; } diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/BaggingTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/BaggingTest.java index dd4b11e..4f8f412 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/composition/BaggingTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/composition/BaggingTest.java @@ -22,6 +22,8 @@ import java.util.Map; import org.apache.ignite.ml.IgniteModel; import org.apache.ignite.ml.TestUtils; import org.apache.ignite.ml.common.TrainerTest; +import org.apache.ignite.ml.composition.bagging.BaggedModel; +import org.apache.ignite.ml.composition.bagging.BaggedTrainer; import org.apache.ignite.ml.composition.predictionsaggregator.MeanValuePredictionsAggregator; import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator; import org.apache.ignite.ml.dataset.Dataset; @@ -77,18 +79,16 @@ public class BaggingTest extends TrainerTest { .withBatchSize(10) .withSeed(123L); - trainer.withEnvironmentBuilder(TestUtils.testEnvBuilder()); - - DatasetTrainer<ModelsComposition, Double> baggedTrainer = - TrainerTransformers.makeBagged( - trainer, - 10, - 0.7, - 2, - 2, - new OnMajorityPredictionsAggregator()); + BaggedTrainer<Double> baggedTrainer = TrainerTransformers.makeBagged( + trainer, + 10, + 0.7, + 2, + 2, + new OnMajorityPredictionsAggregator()) + .withEnvironmentBuilder(TestUtils.testEnvBuilder()); - ModelsComposition mdl = baggedTrainer.fit( + BaggedModel mdl = baggedTrainer.fit( cacheMock, parts, (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)), @@ -111,14 +111,17 @@ public class BaggingTest extends TrainerTest { double subsampleRatio = 0.3; - ModelsComposition mdl = TrainerTransformers.makeBagged( + BaggedModel mdl = TrainerTransformers.makeBagged( cntTrainer, 100, subsampleRatio, 2, 2, new MeanValuePredictionsAggregator()) - .fit(cacheMock, parts, null, null); + .fit(cacheMock, + parts, + (integer, doubles) -> VectorUtils.of(doubles), + (integer, doubles) -> doubles[doubles.length - 1]); Double res = mdl.predict(null); @@ -177,7 +180,7 @@ public class BaggingTest extends TrainerTest { } /** {@inheritDoc} */ - @Override protected boolean checkState(IgniteModel<Vector, Double> mdl) { + @Override public boolean isUpdateable(IgniteModel<Vector, Double> mdl) { return true; } diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java index d253ea0..874547f 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java @@ -103,7 +103,7 @@ public class LearningEnvironmentTest { } /** {@inheritDoc} */ - @Override protected boolean checkState(IgniteModel<Object, Vector> mdl) { + @Override public boolean isUpdateable(IgniteModel<Object, Vector> mdl) { return false; } diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/StackingTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/StackingTest.java new file mode 100644 index 0000000..9c089ce --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/StackingTest.java @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.trainers; + +import java.util.Arrays; +import org.apache.ignite.ml.IgniteModel; +import org.apache.ignite.ml.TestUtils; +import org.apache.ignite.ml.common.TrainerTest; +import org.apache.ignite.ml.composition.stacking.StackedDatasetTrainer; +import org.apache.ignite.ml.composition.stacking.StackedModel; +import org.apache.ignite.ml.composition.stacking.StackedVectorDatasetTrainer; +import org.apache.ignite.ml.math.functions.IgniteFunction; +import org.apache.ignite.ml.math.primitives.matrix.Matrix; +import org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.VectorUtils; +import org.apache.ignite.ml.nn.Activators; +import org.apache.ignite.ml.nn.MLPTrainer; +import org.apache.ignite.ml.nn.MultilayerPerceptron; +import org.apache.ignite.ml.nn.UpdatesStrategy; +import org.apache.ignite.ml.nn.architecture.MLPArchitecture; +import org.apache.ignite.ml.optimization.LossFunctions; +import org.apache.ignite.ml.optimization.SmoothParametrized; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate; +import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator; +import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer; +import org.apache.ignite.ml.regressions.linear.LinearRegressionModel; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import static junit.framework.TestCase.assertEquals; + +/** + * Tests stacked trainers. + */ +public class StackingTest extends TrainerTest { + /** Rule to check exceptions. */ + @Rule + public ExpectedException thrown = ExpectedException.none(); + + /** + * Tests simple stack training. + */ + @Test + public void testSimpleStack() { + StackedDatasetTrainer<Vector, Vector, Double, LinearRegressionModel, Double> trainer = + new StackedDatasetTrainer<>(); + + UpdatesStrategy<SmoothParametrized, SimpleGDParameterUpdate> updatesStgy = new UpdatesStrategy<>( + new SimpleGDUpdateCalculator(0.2), + SimpleGDParameterUpdate::sumLocal, + SimpleGDParameterUpdate::avg + ); + + MLPArchitecture arch = new MLPArchitecture(2). + withAddedLayer(10, true, Activators.RELU). + withAddedLayer(1, false, Activators.SIGMOID); + + MLPTrainer<SimpleGDParameterUpdate> trainer1 = new MLPTrainer<>( + arch, + LossFunctions.MSE, + updatesStgy, + 3000, + 10, + 50, + 123L + ); + + // Convert model trainer to produce Vector -> Vector model + DatasetTrainer<AdaptableDatasetModel<Vector, Vector, Matrix, Matrix, MultilayerPerceptron>, Double> mlpTrainer = + AdaptableDatasetTrainer.of(trainer1) + .beforeTrainedModel((Vector v) -> new DenseMatrix(v.asArray(), 1)) + .afterTrainedModel((Matrix mtx) -> mtx.getRow(0)) + .withConvertedLabels(VectorUtils::num2Arr); + + final double factor = 3; + + StackedModel<Vector, Vector, Double, LinearRegressionModel> mdl = trainer + .withAggregatorTrainer(new LinearRegressionLSQRTrainer().withConvertedLabels(x -> x * factor)) + .addTrainer(mlpTrainer) + .withAggregatorInputMerger(VectorUtils::concat) + .withSubmodelOutput2VectorConverter(IgniteFunction.identity()) + .withVector2SubmodelInputConverter(IgniteFunction.identity()) + .withOriginalFeaturesKept(IgniteFunction.identity()) + .withEnvironmentBuilder(TestUtils.testEnvBuilder()) + .fit(getCacheMock(xor), + parts, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), + (k, v) -> v[v.length - 1]); + + assertEquals(0.0 * factor, mdl.predict(VectorUtils.of(0.0, 0.0)), 0.3); + assertEquals(1.0 * factor, mdl.predict(VectorUtils.of(0.0, 1.0)), 0.3); + assertEquals(1.0 * factor, mdl.predict(VectorUtils.of(1.0, 0.0)), 0.3); + assertEquals(0.0 * factor, mdl.predict(VectorUtils.of(1.0, 1.0)), 0.3); + } + + /** + * Tests simple stack training. + */ + @Test + public void testSimpleVectorStack() { + StackedVectorDatasetTrainer<Double, LinearRegressionModel, Double> trainer = + new StackedVectorDatasetTrainer<>(); + + UpdatesStrategy<SmoothParametrized, SimpleGDParameterUpdate> updatesStgy = new UpdatesStrategy<>( + new SimpleGDUpdateCalculator(0.2), + SimpleGDParameterUpdate::sumLocal, + SimpleGDParameterUpdate::avg + ); + + MLPArchitecture arch = new MLPArchitecture(2). + withAddedLayer(10, true, Activators.RELU). + withAddedLayer(1, false, Activators.SIGMOID); + + DatasetTrainer<MultilayerPerceptron, Double> mlpTrainer = new MLPTrainer<>( + arch, + LossFunctions.MSE, + updatesStgy, + 3000, + 10, + 50, + 123L + ).withConvertedLabels(VectorUtils::num2Arr); + + final double factor = 3; + + StackedModel<Vector, Vector, Double, LinearRegressionModel> mdl = trainer + .withAggregatorTrainer(new LinearRegressionLSQRTrainer().withConvertedLabels(x -> x * factor)) + .addMatrix2MatrixTrainer(mlpTrainer) + .withEnvironmentBuilder(TestUtils.testEnvBuilder()) + .fit(getCacheMock(xor), + parts, + (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)), + (k, v) -> v[v.length - 1]); + + assertEquals(0.0 * factor, mdl.predict(VectorUtils.of(0.0, 0.0)), 0.3); + assertEquals(1.0 * factor, mdl.predict(VectorUtils.of(0.0, 1.0)), 0.3); + assertEquals(1.0 * factor, mdl.predict(VectorUtils.of(1.0, 0.0)), 0.3); + assertEquals(0.0 * factor, mdl.predict(VectorUtils.of(1.0, 1.0)), 0.3); + } + + /** + * Tests that if there is no any way for input of first layer to propagate to second layer, + * exception will be thrown. + */ + @Test + public void testINoWaysOfPropagation() { + StackedDatasetTrainer<Void, Void, Void, IgniteModel<Void, Void>, Void> trainer = + new StackedDatasetTrainer<>(); + thrown.expect(IllegalStateException.class); + trainer.fit(null, null, null); + } +} diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/DataStreamGeneratorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/DataStreamGeneratorTest.java index f2899c2..d711fc4 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/DataStreamGeneratorTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/DataStreamGeneratorTest.java @@ -147,8 +147,8 @@ public class DataStreamGeneratorTest { DatasetBuilder<Vector, Double> b2 = generator.asDatasetBuilder(N, (v, l) -> l == 0, 2); counter.set(0); DatasetBuilder<Vector, Double> b3 = generator.asDatasetBuilder(N, (v, l) -> l == 1, 2, - new UpstreamTransformerBuilder<Vector, Double>() { - @Override public UpstreamTransformer<Vector, Double> build(LearningEnvironment env) { + new UpstreamTransformerBuilder() { + @Override public UpstreamTransformer build(LearningEnvironment env) { return new UpstreamTransformerForTest(); } }); @@ -201,10 +201,10 @@ public class DataStreamGeneratorTest { } /** */ - private static class UpstreamTransformerForTest implements UpstreamTransformer<Vector, Double> { - @Override public Stream<UpstreamEntry<Vector, Double>> transform( - Stream<UpstreamEntry<Vector, Double>> upstream) { - return upstream.map(entry -> new UpstreamEntry<>(entry.getKey(), -entry.getValue())); + private static class UpstreamTransformerForTest implements UpstreamTransformer { + @Override public Stream<UpstreamEntry> transform( + Stream<UpstreamEntry> upstream) { + return upstream.map(entry -> new UpstreamEntry<>(entry.getKey(), -((double)entry.getValue()))); } } }