IGNITE-9387: [ML] Model updating this closes #4659
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/f4c18f11 Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/f4c18f11 Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/f4c18f11 Branch: refs/heads/master Commit: f4c18f114a435ab818e5c9b179e4854df611dd74 Parents: 02afb43 Author: Alexey Platonov <[email protected]> Authored: Tue Sep 4 18:11:48 2018 +0300 Committer: Yury Babak <[email protected]> Committed: Tue Sep 4 18:11:48 2018 +0300 ---------------------------------------------------------------------- .../GDBOnTreesClassificationTrainerExample.java | 3 +- .../GDBOnTreesRegressionTrainerExample.java | 3 +- .../ml/clustering/kmeans/KMeansTrainer.java | 27 ++- .../ml/composition/BaggingModelTrainer.java | 20 +++ .../ml/composition/ModelsComposition.java | 10 +- .../ml/composition/ModelsCompositionFormat.java | 61 +++++++ .../ml/composition/boosting/GDBTrainer.java | 19 ++- .../ignite/ml/knn/NNClassificationModel.java | 16 ++ .../ml/knn/ann/ANNClassificationModel.java | 15 +- .../ml/knn/ann/ANNClassificationTrainer.java | 93 +++++++---- .../ignite/ml/knn/ann/ANNModelFormat.java | 12 +- .../classification/KNNClassificationModel.java | 40 ++++- .../KNNClassificationTrainer.java | 20 ++- .../ml/knn/regression/KNNRegressionTrainer.java | 19 ++- .../ml/math/isolve/lsqr/AbstractLSQR.java | 6 +- .../ignite/ml/math/isolve/lsqr/LSQROnHeap.java | 2 +- .../org/apache/ignite/ml/nn/MLPTrainer.java | 25 ++- .../ml/preprocessing/PreprocessingTrainer.java | 3 +- .../linear/LinearRegressionLSQRTrainer.java | 27 ++- .../linear/LinearRegressionSGDTrainer.java | 97 ++++++++++- .../binomial/LogisticRegressionSGDTrainer.java | 42 ++++- .../LogRegressionMultiClassModel.java | 9 + .../LogRegressionMultiClassTrainer.java | 33 +++- .../SVMLinearBinaryClassificationTrainer.java | 64 +++++-- .../SVMLinearMultiClassClassificationModel.java | 9 + ...VMLinearMultiClassClassificationTrainer.java | 82 +++++++-- .../ignite/ml/trainers/DatasetTrainer.java | 166 +++++++++++++++++++ .../org/apache/ignite/ml/tree/DecisionTree.java | 23 +++ .../RandomForestClassifierTrainer.java | 8 +- .../tree/randomforest/RandomForestTrainer.java | 28 +++- .../ignite/ml/clustering/KMeansTrainerTest.java | 82 ++++++--- .../ignite/ml/common/CollectionsTest.java | 9 +- .../ignite/ml/common/LocalModelsTest.java | 5 +- .../ml/composition/boosting/GDBTrainerTest.java | 2 +- .../ignite/ml/knn/ANNClassificationTest.java | 44 +++++ .../ignite/ml/knn/KNNClassificationTest.java | 39 +++++ .../apache/ignite/ml/knn/KNNRegressionTest.java | 40 +++++ .../org/apache/ignite/ml/nn/MLPTrainerTest.java | 64 +++++++ .../linear/LinearRegressionLSQRTrainerTest.java | 51 ++++++ .../linear/LinearRegressionSGDTrainerTest.java | 62 +++++++ .../logistic/LogRegMultiClassTrainerTest.java | 58 +++++++ .../LogisticRegressionSGDTrainerTest.java | 46 +++++ .../ignite/ml/svm/SVMBinaryTrainerTest.java | 41 +++++ .../ignite/ml/svm/SVMMultiClassTrainerTest.java | 43 +++++ .../RandomForestClassifierTrainerTest.java | 33 +++- .../RandomForestRegressionTrainerTest.java | 31 ++++ 46 files changed, 1503 insertions(+), 129 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesClassificationTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesClassificationTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesClassificationTrainerExample.java index 130b91a..075eab2 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesClassificationTrainerExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesClassificationTrainerExample.java @@ -23,6 +23,7 @@ import org.apache.ignite.Ignition; import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; import org.apache.ignite.configuration.CacheConfiguration; import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.composition.ModelsComposition; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.ml.trainers.DatasetTrainer; @@ -58,7 +59,7 @@ public class GDBOnTreesClassificationTrainerExample { IgniteCache<Integer, double[]> trainingSet = fillTrainingData(ignite, trainingSetCfg); // Create regression trainer. - DatasetTrainer<Model<Vector, Double>, Double> trainer = new GDBBinaryClassifierOnTreesTrainer(1.0, 300, 2, 0.); + DatasetTrainer<ModelsComposition, Double> trainer = new GDBBinaryClassifierOnTreesTrainer(1.0, 300, 2, 0.); // Train decision tree model. Model<Vector, Double> mdl = trainer.fit( http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java index 31dd2b0..b2b08d0 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java @@ -23,6 +23,7 @@ import org.apache.ignite.Ignition; import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; import org.apache.ignite.configuration.CacheConfiguration; import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.composition.ModelsComposition; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; import org.apache.ignite.ml.trainers.DatasetTrainer; @@ -58,7 +59,7 @@ public class GDBOnTreesRegressionTrainerExample { IgniteCache<Integer, double[]> trainingSet = fillTrainingData(ignite, trainingSetCfg); // Create regression trainer. - DatasetTrainer<Model<Vector, Double>, Double> trainer = new GDBRegressionOnTreesTrainer(1.0, 2000, 1, 0.); + DatasetTrainer<ModelsComposition, Double> trainer = new GDBRegressionOnTreesTrainer(1.0, 2000, 1, 0.); // Train decision tree model. Model<Vector, Double> mdl = trainer.fit( http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java index 5b880fcc..2596dbc 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; import java.util.List; +import java.util.Optional; import java.util.Random; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; @@ -72,6 +73,14 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> { */ @Override public <K, V> KMeansModel fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { + + return updateModel(null, datasetBuilder, featureExtractor, lbExtractor); + } + + /** {@inheritDoc} */ + @Override protected <K, V> KMeansModel updateModel(KMeansModel mdl, DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { + assert datasetBuilder != null; PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<Double, LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>( @@ -85,7 +94,7 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> { (upstream, upstreamSize) -> new EmptyContext(), partDataBuilder )) { - final int cols = dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> { + final Integer cols = dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> { if (a == null) return b == null ? 0 : b; if (b == null) @@ -93,7 +102,12 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> { return b; }); - centers = initClusterCentersRandomly(dataset, k); + if (cols == null) + return getLastTrainedModelOrThrowEmptyDatasetException(mdl); + + centers = Optional.ofNullable(mdl) + .map(KMeansModel::centers) + .orElseGet(() -> initClusterCentersRandomly(dataset, k)); boolean converged = false; int iteration = 0; @@ -127,6 +141,11 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> { return new KMeansModel(centers, distance); } + /** {@inheritDoc} */ + @Override protected boolean checkState(KMeansModel mdl) { + return mdl.centers().length == k && mdl.distanceMeasure().equals(distance); + } + /** * Prepares the data to define new centroids on current iteration. * @@ -281,10 +300,12 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> { return this; } + /** + * @return centroid statistics. + */ public ConcurrentHashMap<Integer, ConcurrentHashMap<Double, Integer>> getCentroidStat() { return centroidStat; } - } /** http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/composition/BaggingModelTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/BaggingModelTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/BaggingModelTrainer.java index f439789..493c1da 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/BaggingModelTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/BaggingModelTrainer.java @@ -177,4 +177,24 @@ public abstract class BaggingModelTrainer extends DatasetTrainer<ModelsCompositi return VectorUtils.of(newFeaturesValues); }); } + + /** + * Learn new models on dataset and create new Compositions over them and already learned models. + * + * @param mdl Learned model. + * @param datasetBuilder Dataset builder. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + * @param <K> Type of a key in {@code upstream} data. + * @param <V> Type of a value in {@code upstream} data. + * @return New models composition. + */ + @Override public <K, V> ModelsComposition updateModel(ModelsComposition mdl, DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { + + ArrayList<Model<Vector, Double>> newModels = new ArrayList<>(mdl.getModels()); + newModels.addAll(fit(datasetBuilder, featureExtractor, lbExtractor).getModels()); + + return new ModelsComposition(newModels, predictionsAggregator); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsComposition.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsComposition.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsComposition.java index e14fa6d..36ee626 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsComposition.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsComposition.java @@ -19,6 +19,8 @@ package org.apache.ignite.ml.composition; import java.util.Collections; import java.util.List; +import org.apache.ignite.ml.Exportable; +import org.apache.ignite.ml.Exporter; import org.apache.ignite.ml.Model; import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator; import org.apache.ignite.ml.math.primitives.vector.Vector; @@ -27,7 +29,7 @@ import org.apache.ignite.ml.util.ModelTrace; /** * Model consisting of several models and prediction aggregation strategy. */ -public class ModelsComposition implements Model<Vector, Double> { +public class ModelsComposition implements Model<Vector, Double>, Exportable<ModelsCompositionFormat> { /** * Predictions aggregator. */ @@ -78,6 +80,12 @@ public class ModelsComposition implements Model<Vector, Double> { } /** {@inheritDoc} */ + @Override public <P> void saveModel(Exporter<ModelsCompositionFormat, P> exporter, P path) { + ModelsCompositionFormat format = new ModelsCompositionFormat(models, predictionsAggregator); + exporter.save(format, path); + } + + /** {@inheritDoc} */ @Override public String toString() { return toString(false); } http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsCompositionFormat.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsCompositionFormat.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsCompositionFormat.java new file mode 100644 index 0000000..68af0a9 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsCompositionFormat.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.composition; + +import java.io.Serializable; +import java.util.List; +import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator; +import org.apache.ignite.ml.math.primitives.vector.Vector; + +/** + * ModelsComposition representation. + * + * @see ModelsComposition + */ +public class ModelsCompositionFormat implements Serializable { + /** Serial version uid. */ + private static final long serialVersionUID = 9115341364082681837L; + + /** Models. */ + private List<Model<Vector, Double>> models; + + /** Predictions aggregator. */ + private PredictionsAggregator predictionsAggregator; + + /** + * Creates an instance of ModelsCompositionFormat. + * + * @param models Models. + * @param predictionsAggregator Predictions aggregator. + */ + public ModelsCompositionFormat(List<Model<Vector, Double>> models,PredictionsAggregator predictionsAggregator) { + this.models = models; + this.predictionsAggregator = predictionsAggregator; + } + + /** */ + public List<Model<Vector, Double>> models() { + return models; + } + + /** */ + public PredictionsAggregator predictionsAggregator() { + return predictionsAggregator; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java index 5a0f52a..c7f21dd 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java @@ -52,7 +52,7 @@ import org.jetbrains.annotations.NotNull; * * But in practice Decision Trees is most used regressors (see: {@link DecisionTreeRegressionTrainer}). */ -public abstract class GDBTrainer extends DatasetTrainer<Model<Vector, Double>, Double> { +public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Double> { /** Gradient step. */ private final double gradientStep; @@ -81,7 +81,7 @@ public abstract class GDBTrainer extends DatasetTrainer<Model<Vector, Double>, D } /** {@inheritDoc} */ - @Override public <K, V> Model<Vector, Double> fit(DatasetBuilder<K, V> datasetBuilder, + @Override public <K, V> ModelsComposition fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { @@ -119,6 +119,21 @@ public abstract class GDBTrainer extends DatasetTrainer<Model<Vector, Double>, D }; } + + //TODO: This method will be implemented in IGNITE-9412 + /** {@inheritDoc} */ + @Override public <K, V> ModelsComposition updateModel(ModelsComposition mdl, DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { + + throw new UnsupportedOperationException(); + } + + //TODO: This method will be implemented in IGNITE-9412 + /** {@inheritDoc} */ + @Override protected boolean checkState(ModelsComposition mdl) { + throw new UnsupportedOperationException(); + } + /** * Defines unique labels in dataset if need (useful in case of classification). * http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/knn/NNClassificationModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/NNClassificationModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/NNClassificationModel.java index b7a57f5..d435f91 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/NNClassificationModel.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/NNClassificationModel.java @@ -174,6 +174,11 @@ public abstract class NNClassificationModel implements Model<Vector, Double>, Ex return 1.0; // strategy.SIMPLE } + /** */ + public DistanceMeasure getDistanceMeasure() { + return distanceMeasure; + } + /** {@inheritDoc} */ @Override public int hashCode() { int res = 1; @@ -212,6 +217,17 @@ public abstract class NNClassificationModel implements Model<Vector, Double>, Ex .toString(); } + /** + * Sets parameters from other model to this model. + * + * @param mdl Model. + */ + protected void copyParametersFrom(NNClassificationModel mdl) { + this.k = mdl.k; + this.distanceMeasure = mdl.distanceMeasure; + this.stgy = mdl.stgy; + } + /** */ public abstract <P> void saveModel(Exporter<KNNModelFormat, P> exporter, P path); } http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationModel.java index e8c0b4a..bec82a9 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationModel.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationModel.java @@ -44,12 +44,18 @@ public class ANNClassificationModel extends NNClassificationModel { /** The labeled set of candidates. */ private final LabeledVectorSet<ProbableLabel, LabeledVector> candidates; + /** Centroid statistics. */ + private final ANNClassificationTrainer.CentroidStat centroindsStat; + /** * Build the model based on a candidates set. * @param centers The candidates set. + * @param centroindsStat */ - public ANNClassificationModel(LabeledVectorSet<ProbableLabel, LabeledVector> centers) { + public ANNClassificationModel(LabeledVectorSet<ProbableLabel, LabeledVector> centers, + ANNClassificationTrainer.CentroidStat centroindsStat) { this.candidates = centers; + this.centroindsStat = centroindsStat; } /** */ @@ -57,6 +63,11 @@ public class ANNClassificationModel extends NNClassificationModel { return candidates; } + /** */ + public ANNClassificationTrainer.CentroidStat getCentroindsStat() { + return centroindsStat; + } + /** {@inheritDoc} */ @Override public Double apply(Vector v) { List<LabeledVector> neighbors = findKNearestNeighbors(v); @@ -65,7 +76,7 @@ public class ANNClassificationModel extends NNClassificationModel { /** */ @Override public <P> void saveModel(Exporter<KNNModelFormat, P> exporter, P path) { - ANNModelFormat mdlData = new ANNModelFormat(k, distanceMeasure, stgy, candidates); + ANNModelFormat mdlData = new ANNModelFormat(k, distanceMeasure, stgy, candidates, centroindsStat); exporter.save(mdlData, path); } http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java index 1c45812..3e32b67 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java @@ -17,9 +17,13 @@ package org.apache.ignite.ml.knn.ann; +import java.io.Serializable; +import java.util.Arrays; +import java.util.List; import java.util.TreeMap; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentSkipListSet; +import java.util.stream.Collectors; import org.apache.ignite.lang.IgniteBiTuple; import org.apache.ignite.ml.clustering.kmeans.KMeansModel; import org.apache.ignite.ml.clustering.kmeans.KMeansTrainer; @@ -39,8 +43,8 @@ import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer; import org.jetbrains.annotations.NotNull; /** - * ANN algorithm trainer to solve multi-class classification task. - * This trainer is based on ACD strategy and KMeans clustering algorithm to find centroids. + * ANN algorithm trainer to solve multi-class classification task. This trainer is based on ACD strategy and KMeans + * clustering algorithm to find centroids. */ public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClassificationModel> { /** Amount of clusters. */ @@ -61,29 +65,55 @@ public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClass /** * Trains model based on the specified data. * - * @param datasetBuilder Dataset builder. + * @param datasetBuilder Dataset builder. * @param featureExtractor Feature extractor. - * @param lbExtractor Label extractor. + * @param lbExtractor Label extractor. * @return Model. */ - @Override public <K, V> ANNClassificationModel fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { - final Vector[] centers = getCentroids(featureExtractor, lbExtractor, datasetBuilder); + @Override public <K, V> ANNClassificationModel fit(DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { - final CentroidStat centroidStat = getCentroidStat(datasetBuilder, featureExtractor, lbExtractor, centers); + return updateModel(null, datasetBuilder, featureExtractor, lbExtractor); + } + + /** {@inheritDoc} */ + @Override protected <K, V> ANNClassificationModel updateModel(ANNClassificationModel mdl, + DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor) { + + List<Vector> centers; + CentroidStat centroidStat; + if (mdl != null) { + centers = Arrays.stream(mdl.getCandidates().data()).map(x -> x.features()).collect(Collectors.toList()); + CentroidStat newStat = getCentroidStat(datasetBuilder, featureExtractor, lbExtractor, centers); + if(newStat == null) + return mdl; + CentroidStat oldStat = mdl.getCentroindsStat(); + centroidStat = newStat.merge(oldStat); + } else { + centers = getCentroids(featureExtractor, lbExtractor, datasetBuilder); + centroidStat = getCentroidStat(datasetBuilder, featureExtractor, lbExtractor, centers); + } final LabeledVectorSet<ProbableLabel, LabeledVector> dataset = buildLabelsForCandidates(centers, centroidStat); - return new ANNClassificationModel(dataset); + return new ANNClassificationModel(dataset, centroidStat); + } + + /** {@inheritDoc} */ + @Override protected boolean checkState(ANNClassificationModel mdl) { + return mdl.getDistanceMeasure().equals(distance) && mdl.getCandidates().rowSize() == k; } /** */ - @NotNull private LabeledVectorSet<ProbableLabel, LabeledVector> buildLabelsForCandidates(Vector[] centers, CentroidStat centroidStat) { + @NotNull private LabeledVectorSet<ProbableLabel, LabeledVector> buildLabelsForCandidates(List<Vector> centers, + CentroidStat centroidStat) { // init - final LabeledVector<Vector, ProbableLabel>[] arr = new LabeledVector[centers.length]; + final LabeledVector<Vector, ProbableLabel>[] arr = new LabeledVector[centers.size()]; // fill label for each centroid - for (int i = 0; i < centers.length; i++) - arr[i] = new LabeledVector<>(centers[i], fillProbableLabel(i, centroidStat)); + for (int i = 0; i < centers.size(); i++) + arr[i] = new LabeledVector<>(centers.get(i), fillProbableLabel(i, centroidStat)); return new LabeledVectorSet<>(arr); } @@ -92,13 +122,14 @@ public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClass * Perform KMeans clusterization algorithm to find centroids. * * @param featureExtractor Feature extractor. - * @param lbExtractor Label extractor. - * @param datasetBuilder The dataset builder. - * @param <K> Type of a key in {@code upstream} data. - * @param <V> Type of a value in {@code upstream} data. + * @param lbExtractor Label extractor. + * @param datasetBuilder The dataset builder. + * @param <K> Type of a key in {@code upstream} data. + * @param <V> Type of a value in {@code upstream} data. * @return The arrays of vectors. */ - private <K, V> Vector[] getCentroids(IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor, DatasetBuilder<K, V> datasetBuilder) { + private <K, V> List<Vector> getCentroids(IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor, DatasetBuilder<K, V> datasetBuilder) { KMeansTrainer trainer = new KMeansTrainer() .withK(k) .withMaxIterations(maxIterations) @@ -112,7 +143,7 @@ public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClass lbExtractor ); - return mdl.centers(); + return Arrays.asList(mdl.centers()); } /** */ @@ -125,21 +156,24 @@ public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClass ConcurrentHashMap<Double, Integer> centroidLbDistribution = centroidStat.centroidStat().get(centroidIdx); - if(centroidStat.counts.containsKey(centroidIdx)){ + if (centroidStat.counts.containsKey(centroidIdx)) { int clusterSize = centroidStat .counts .get(centroidIdx); clsLbls.keySet().forEach( - (label) -> clsLbls.put(label, centroidLbDistribution.containsKey(label) ? ((double) (centroidLbDistribution.get(label)) / clusterSize) : 0.0) + (label) -> clsLbls.put(label, centroidLbDistribution.containsKey(label) ? ((double)(centroidLbDistribution.get(label)) / clusterSize) : 0.0) ); } return new ProbableLabel(clsLbls); } /** */ - private <K, V> CentroidStat getCentroidStat(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor, Vector[] centers) { + private <K, V> CentroidStat getCentroidStat(DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor, List<Vector> centers) { + PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<Double, LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>( featureExtractor, lbExtractor @@ -174,7 +208,7 @@ public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClass } res.counts.merge(centroidIdx, 1, - (IgniteBiFunction<Integer, Integer, Integer>) (i1, i2) -> i1 + i2); + (IgniteBiFunction<Integer, Integer, Integer>)(i1, i2) -> i1 + i2); } return res; }, (a, b) -> { @@ -194,15 +228,15 @@ public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClass * Find the closest cluster center index and distance to it from a given point. * * @param centers Centers to look in. - * @param pnt Point. + * @param pnt Point. */ - private IgniteBiTuple<Integer, Double> findClosestCentroid(Vector[] centers, LabeledVector pnt) { + private IgniteBiTuple<Integer, Double> findClosestCentroid(List<Vector> centers, LabeledVector pnt) { double bestDistance = Double.POSITIVE_INFINITY; int bestInd = 0; - for (int i = 0; i < centers.length; i++) { - if (centers[i] != null) { - double dist = distance.compute(centers[i], pnt.features()); + for (int i = 0; i < centers.size(); i++) { + if (centers.get(i) != null) { + double dist = distance.compute(centers.get(i), pnt.features()); if (dist < bestDistance) { bestDistance = dist; bestInd = i; @@ -212,7 +246,6 @@ public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClass return new IgniteBiTuple<>(bestInd, bestDistance); } - /** * Gets the amount of clusters. * @@ -314,7 +347,9 @@ public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClass } /** Service class used for statistics. */ - public static class CentroidStat { + public static class CentroidStat implements Serializable { + /** Serial version uid. */ + private static final long serialVersionUID = 7624883170532045144L; /** Count of points closest to the center with a given index. */ ConcurrentHashMap<Integer, ConcurrentHashMap<Double, Integer>> centroidStat = new ConcurrentHashMap<>(); http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNModelFormat.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNModelFormat.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNModelFormat.java index e10f3b2..be09828 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNModelFormat.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNModelFormat.java @@ -30,6 +30,9 @@ import org.apache.ignite.ml.structures.LabeledVectorSet; * @see ANNClassificationModel */ public class ANNModelFormat extends KNNModelFormat implements Serializable { + /** Centroid statistics. */ + private final ANNClassificationTrainer.CentroidStat candidatesStat; + /** The labeled set of candidates. */ private LabeledVectorSet<ProbableLabel, LabeledVector> candidates; @@ -38,15 +41,18 @@ public class ANNModelFormat extends KNNModelFormat implements Serializable { * @param k Amount of nearest neighbors. * @param measure Distance measure. * @param stgy kNN strategy. + * @param candidatesStat */ public ANNModelFormat(int k, - DistanceMeasure measure, - NNStrategy stgy, - LabeledVectorSet<ProbableLabel, LabeledVector> candidates) { + DistanceMeasure measure, + NNStrategy stgy, + LabeledVectorSet<ProbableLabel, LabeledVector> candidates, + ANNClassificationTrainer.CentroidStat candidatesStat) { this.k = k; this.distanceMeasure = measure; this.stgy = stgy; this.candidates = candidates; + this.candidatesStat = candidatesStat; } /** */ http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java index 0b88f81..0d03ee5 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java @@ -19,6 +19,7 @@ package org.apache.ignite.ml.knn.classification; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -42,25 +43,29 @@ public class KNNClassificationModel extends NNClassificationModel implements Exp /** */ private static final long serialVersionUID = -127386523291350345L; - /** Dataset. */ - private Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset; + /** Datasets. */ + private List<Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>>> datasets; /** * Builds the model via prepared dataset. + * * @param dataset Specially prepared object to run algorithm over it. */ public KNNClassificationModel(Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset) { - this.dataset = dataset; + this.datasets = new ArrayList<>(); + if (dataset != null) + datasets.add(dataset); } /** {@inheritDoc} */ @Override public Double apply(Vector v) { - if(dataset != null) { + if (!datasets.isEmpty()) { List<LabeledVector> neighbors = findKNearestNeighbors(v); return classify(neighbors, v, stgy); - } else + } else { throw new IllegalStateException("The train kNN dataset is null"); + } } /** */ @@ -77,6 +82,17 @@ public class KNNClassificationModel extends NNClassificationModel implements Exp * @return K-nearest neighbors. */ protected List<LabeledVector> findKNearestNeighbors(Vector v) { + List<LabeledVector> neighborsFromPartitions = datasets.stream() + .flatMap(dataset -> findKNearestNeighborsInDataset(v, dataset).stream()) + .collect(Collectors.toList()); + + LabeledVectorSet<Double, LabeledVector> neighborsToFilter = buildLabeledDatasetOnListOfVectors(neighborsFromPartitions); + + return Arrays.asList(getKClosestVectors(neighborsToFilter, getDistances(v, neighborsToFilter))); + } + + private List<LabeledVector> findKNearestNeighborsInDataset(Vector v, + Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset) { List<LabeledVector> neighborsFromPartitions = dataset.compute(data -> { TreeMap<Double, Set<Integer>> distanceIdxPairs = getDistances(v, data); return Arrays.asList(getKClosestVectors(data, distanceIdxPairs)); @@ -88,12 +104,14 @@ public class KNNClassificationModel extends NNClassificationModel implements Exp return Stream.concat(a.stream(), b.stream()).collect(Collectors.toList()); }); + if(neighborsFromPartitions == null) + return Collections.emptyList(); + LabeledVectorSet<Double, LabeledVector> neighborsToFilter = buildLabeledDatasetOnListOfVectors(neighborsFromPartitions); return Arrays.asList(getKClosestVectors(neighborsToFilter, getDistances(v, neighborsToFilter))); } - /** */ private double classify(List<LabeledVector> neighbors, Vector v, NNStrategy stgy) { Map<Double, Double> clsVotes = new HashMap<>(); @@ -116,5 +134,13 @@ public class KNNClassificationModel extends NNClassificationModel implements Exp return getClassWithMaxVotes(clsVotes); } - + /** + * Copy parameters from other model and save all datasets from it. + * + * @param model Model. + */ + public void copyStateFrom(KNNClassificationModel model) { + this.copyParametersFrom(model); + datasets.addAll(model.datasets); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationTrainer.java index e0a81f9..1a3ff73 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationTrainer.java @@ -37,6 +37,24 @@ public class KNNClassificationTrainer extends SingleLabelDatasetTrainer<KNNClass */ @Override public <K, V> KNNClassificationModel fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { - return new KNNClassificationModel(KNNUtils.buildDataset(datasetBuilder, featureExtractor, lbExtractor)); + + return updateModel(null, datasetBuilder, featureExtractor, lbExtractor); + } + + /** {@inheritDoc} */ + @Override public <K, V> KNNClassificationModel updateModel(KNNClassificationModel mdl, + DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor) { + + KNNClassificationModel res = new KNNClassificationModel(KNNUtils.buildDataset(datasetBuilder, + featureExtractor, lbExtractor)); + if (mdl != null) + res.copyStateFrom(mdl); + return res; + } + + /** {@inheritDoc} */ + @Override protected boolean checkState(KNNClassificationModel mdl) { + return true; } } http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionTrainer.java index 395ce61..7a42dc8 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionTrainer.java @@ -37,6 +37,23 @@ public class KNNRegressionTrainer extends SingleLabelDatasetTrainer<KNNRegressio */ public <K, V> KNNRegressionModel fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { - return new KNNRegressionModel(KNNUtils.buildDataset(datasetBuilder, featureExtractor, lbExtractor)); + + return updateModel(null, datasetBuilder, featureExtractor, lbExtractor); + } + + /** {@inheritDoc} */ + @Override public <K, V> KNNRegressionModel updateModel(KNNRegressionModel mdl, DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { + + KNNRegressionModel res = new KNNRegressionModel(KNNUtils.buildDataset(datasetBuilder, + featureExtractor, lbExtractor)); + if (mdl != null) + res.copyStateFrom(mdl); + return res; + } + + /** {@inheritDoc} */ + @Override protected boolean checkState(KNNRegressionModel mdl) { + return true; } } http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/AbstractLSQR.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/AbstractLSQR.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/AbstractLSQR.java index 7a362f7..c9281c0 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/AbstractLSQR.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/AbstractLSQR.java @@ -78,7 +78,9 @@ public abstract class AbstractLSQR { */ public LSQRResult solve(double damp, double atol, double btol, double conlim, double iterLim, boolean calcVar, double[] x0) { - int n = getColumns(); + Integer n = getColumns(); + if(n == null) + return null; if (iterLim < 0) iterLim = 2 * n; @@ -313,7 +315,7 @@ public abstract class AbstractLSQR { protected abstract double[] iter(double bnorm, double[] target); /** */ - protected abstract int getColumns(); + protected abstract Integer getColumns(); /** */ private static double[] symOrtho(double a, double b) { http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java index f75caef..14356e1 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java @@ -100,7 +100,7 @@ public class LSQROnHeap<K, V> extends AbstractLSQR implements AutoCloseable { * * @return number of columns */ - @Override protected int getColumns() { + @Override protected Integer getColumns() { return dataset.compute( data -> data.getFeatures() == null ? null : data.getFeatures().length / data.getRows(), (a, b) -> { http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java index 6727ba9..8f1a4cb 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java @@ -111,12 +111,25 @@ public class MLPTrainer<P extends Serializable> extends MultiLabelDatasetTrainer public <K, V> MultilayerPerceptron fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, double[]> lbExtractor) { + return updateModel(null, datasetBuilder, featureExtractor, lbExtractor); + } + + /** {@inheritDoc} */ + @Override protected <K, V> MultilayerPerceptron updateModel(MultilayerPerceptron lastLearnedModel, + DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, double[]> lbExtractor) { + try (Dataset<EmptyContext, SimpleLabeledDatasetData> dataset = datasetBuilder.build( new EmptyContextBuilder<>(), new SimpleLabeledDatasetDataBuilder<>(featureExtractor, lbExtractor) )) { - MLPArchitecture arch = archSupplier.apply(dataset); - MultilayerPerceptron mdl = new MultilayerPerceptron(arch, new RandomInitializer(seed)); + MultilayerPerceptron mdl; + if (lastLearnedModel != null) { + mdl = lastLearnedModel; + } else { + MLPArchitecture arch = archSupplier.apply(dataset); + mdl = new MultilayerPerceptron(arch, new RandomInitializer(seed)); + } ParameterUpdateCalculator<? super MultilayerPerceptron, P> updater = updatesStgy.getUpdatesCalculator(); for (int i = 0; i < maxIterations; i += locIterations) { @@ -178,6 +191,9 @@ public class MLPTrainer<P extends Serializable> extends MultiLabelDatasetTrainer } ); + if (totUp == null) + return getLastTrainedModelOrThrowEmptyDatasetException(lastLearnedModel); + P update = updatesStgy.allUpdatesReducer().apply(totUp); mdl = updater.update(mdl, update); } @@ -189,6 +205,11 @@ public class MLPTrainer<P extends Serializable> extends MultiLabelDatasetTrainer } } + /** {@inheritDoc} */ + @Override protected boolean checkState(MultilayerPerceptron mdl) { + return true; + } + /** * Builds a batch of the data by fetching specified rows. * http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/PreprocessingTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/PreprocessingTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/PreprocessingTrainer.java index 1886ee5..b977864 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/PreprocessingTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/PreprocessingTrainer.java @@ -17,6 +17,7 @@ package org.apache.ignite.ml.preprocessing; +import java.util.Map; import org.apache.ignite.Ignite; import org.apache.ignite.IgniteCache; import org.apache.ignite.ml.dataset.DatasetBuilder; @@ -24,8 +25,6 @@ import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder; import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; import org.apache.ignite.ml.math.functions.IgniteBiFunction; -import java.util.Map; - /** * Trainer for preprocessor. * http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java index 8197779..5497177 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java @@ -38,16 +38,34 @@ public class LinearRegressionLSQRTrainer extends SingleLabelDatasetTrainer<Linea @Override public <K, V> LinearRegressionModel fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { + return updateModel(null, datasetBuilder, featureExtractor, lbExtractor); + } + + /** {@inheritDoc} */ + @Override protected <K, V> LinearRegressionModel updateModel(LinearRegressionModel mdl, + DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { + LSQRResult res; try (LSQROnHeap<K, V> lsqr = new LSQROnHeap<>( datasetBuilder, new SimpleLabeledDatasetDataBuilder<>( new FeatureExtractorWrapper<>(featureExtractor), - lbExtractor.andThen(e -> new double[]{e}) + lbExtractor.andThen(e -> new double[] {e}) ) )) { - res = lsqr.solve(0, 1e-12, 1e-12, 1e8, -1, false, null); + double[] x0 = null; + if (mdl != null) { + int x0Size = mdl.getWeights().size() + 1; + Vector weights = mdl.getWeights().like(x0Size); + mdl.getWeights().nonZeroes().forEach(ith -> weights.set(ith.index(), ith.get())); + weights.set(weights.size() - 1, mdl.getIntercept()); + x0 = weights.asArray(); + } + res = lsqr.solve(0, 1e-12, 1e-12, 1e8, -1, false, x0); + if (res == null) + return getLastTrainedModelOrThrowEmptyDatasetException(mdl); } catch (Exception e) { throw new RuntimeException(e); @@ -58,4 +76,9 @@ public class LinearRegressionLSQRTrainer extends SingleLabelDatasetTrainer<Linea return new LinearRegressionModel(weights, x[x.length - 1]); } + + /** {@inheritDoc} */ + @Override protected boolean checkState(LinearRegressionModel mdl) { + return true; + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java index 44f60d1..125ed24 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java @@ -19,6 +19,7 @@ package org.apache.ignite.ml.regressions.linear; import java.io.Serializable; import java.util.Arrays; +import java.util.Optional; import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; @@ -34,6 +35,7 @@ import org.apache.ignite.ml.nn.UpdatesStrategy; import org.apache.ignite.ml.nn.architecture.MLPArchitecture; import org.apache.ignite.ml.optimization.LossFunctions; import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer; +import org.jetbrains.annotations.NotNull; /** * Trainer of the linear regression model based on stochastic gradient descent algorithm. @@ -43,16 +45,16 @@ public class LinearRegressionSGDTrainer<P extends Serializable> extends SingleLa private final UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy; /** Max number of iteration. */ - private final int maxIterations; + private int maxIterations = 1000; /** Batch size. */ - private final int batchSize; + private int batchSize = 10; /** Number of local iterations. */ - private final int locIterations; + private int locIterations = 100; /** Seed for random generator. */ - private final long seed; + private long seed = System.currentTimeMillis(); /** * Constructs a new instance of linear regression SGD trainer. @@ -72,10 +74,24 @@ public class LinearRegressionSGDTrainer<P extends Serializable> extends SingleLa this.seed = seed; } + /** + * Constructs a new instance of linear regression SGD trainer. + */ + public LinearRegressionSGDTrainer(UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy) { + this.updatesStgy = updatesStgy; + } + /** {@inheritDoc} */ @Override public <K, V> LinearRegressionModel fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { + return updateModel(null, datasetBuilder, featureExtractor, lbExtractor); + } + + /** {@inheritDoc} */ + @Override protected <K, V> LinearRegressionModel updateModel(LinearRegressionModel mdl, DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { + IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier = dataset -> { int cols = dataset.compute(data -> { @@ -108,7 +124,10 @@ public class LinearRegressionSGDTrainer<P extends Serializable> extends SingleLa IgniteBiFunction<K, V, double[]> lbE = (IgniteBiFunction<K, V, double[]>)(k, v) -> new double[] {lbExtractor.apply(k, v)}; - MultilayerPerceptron mlp = trainer.fit(datasetBuilder, featureExtractor, lbE); + MultilayerPerceptron mlp = Optional.ofNullable(mdl) + .map(this::restoreMLPState) + .map(m -> trainer.update(m, datasetBuilder, featureExtractor, lbE)) + .orElseGet(() -> trainer.fit(datasetBuilder, featureExtractor, lbE)); double[] p = mlp.parameters().getStorage().data(); @@ -117,4 +136,72 @@ public class LinearRegressionSGDTrainer<P extends Serializable> extends SingleLa p[p.length - 1] ); } + + /** + * @param mdl Model. + * @return state of MLP from last learning. + */ + @NotNull private MultilayerPerceptron restoreMLPState(LinearRegressionModel mdl) { + Vector weights = mdl.getWeights(); + double intercept = mdl.getIntercept(); + MLPArchitecture architecture1 = new MLPArchitecture(weights.size()); + architecture1 = architecture1.withAddedLayer(1, true, Activators.LINEAR); + MLPArchitecture architecture = architecture1; + MultilayerPerceptron perceptron = new MultilayerPerceptron(architecture); + + Vector mlpState = weights.like(weights.size() + 1); + weights.nonZeroes().forEach(ith -> mlpState.set(ith.index(), ith.get())); + mlpState.set(mlpState.size() - 1, intercept); + perceptron.setParameters(mlpState); + return perceptron; + } + + /** {@inheritDoc} */ + @Override protected boolean checkState(LinearRegressionModel mdl) { + return true; + } + + /** + * Set up the max number of iterations before convergence. + * + * @param maxIterations The parameter value. + * @return Model with new max number of iterations before convergence parameter value. + */ + public LinearRegressionSGDTrainer<P> withMaxIterations(int maxIterations) { + this.maxIterations = maxIterations; + return this; + } + + /** + * Set up the batchSize parameter. + * + * @param batchSize The size of learning batch. + * @return Trainer with new batch size parameter value. + */ + public LinearRegressionSGDTrainer<P> withBatchSize(int batchSize) { + this.batchSize = batchSize; + return this; + } + + /** + * Set up the amount of local iterations of SGD algorithm. + * + * @param amountOfLocIterations The parameter value. + * @return Trainer with new locIterations parameter value. + */ + public LinearRegressionSGDTrainer<P> withLocIterations(int amountOfLocIterations) { + this.locIterations = amountOfLocIterations; + return this; + } + + /** + * Set up the random seed parameter. + * + * @param seed Seed for random generator. + * @return Trainer with new seed parameter value. + */ + public LinearRegressionSGDTrainer<P> withSeed(long seed) { + this.seed = seed; + return this; + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java index 6396279..839dab5 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java @@ -34,6 +34,7 @@ import org.apache.ignite.ml.nn.UpdatesStrategy; import org.apache.ignite.ml.nn.architecture.MLPArchitecture; import org.apache.ignite.ml.optimization.LossFunctions; import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer; +import org.jetbrains.annotations.NotNull; /** * Trainer of the logistic regression model based on stochastic gradient descent algorithm. @@ -76,8 +77,15 @@ public class LogisticRegressionSGDTrainer<P extends Serializable> extends Single @Override public <K, V> LogisticRegressionModel fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { - IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier = dataset -> { + return updateModel(null, datasetBuilder, featureExtractor, lbExtractor); + } + + /** {@inheritDoc} */ + @Override protected <K, V> LogisticRegressionModel updateModel(LogisticRegressionModel mdl, + DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor) { + IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier = dataset -> { int cols = dataset.compute(data -> { if (data.getFeatures() == null) return null; @@ -106,7 +114,13 @@ public class LogisticRegressionSGDTrainer<P extends Serializable> extends Single seed ); - MultilayerPerceptron mlp = trainer.fit(datasetBuilder, featureExtractor, (k, v) -> new double[] {lbExtractor.apply(k, v)}); + IgniteBiFunction<K, V, double[]> lbExtractorWrapper = (k, v) -> new double[] {lbExtractor.apply(k, v)}; + MultilayerPerceptron mlp; + if(mdl != null) { + mlp = restoreMLPState(mdl); + mlp = trainer.update(mlp, datasetBuilder, featureExtractor, lbExtractorWrapper); + } else + mlp = trainer.fit(datasetBuilder, featureExtractor, lbExtractorWrapper); double[] params = mlp.parameters().getStorage().data(); @@ -114,4 +128,28 @@ public class LogisticRegressionSGDTrainer<P extends Serializable> extends Single params[params.length - 1] ); } + + /** + * @param mdl Model. + * @return state of MLP from last learning. + */ + @NotNull private MultilayerPerceptron restoreMLPState(LogisticRegressionModel mdl) { + Vector weights = mdl.weights(); + double intercept = mdl.intercept(); + MLPArchitecture architecture1 = new MLPArchitecture(weights.size()); + architecture1 = architecture1.withAddedLayer(1, true, Activators.SIGMOID); + MLPArchitecture architecture = architecture1; + MultilayerPerceptron perceptron = new MultilayerPerceptron(architecture); + + Vector mlpState = weights.like(weights.size() + 1); + weights.nonZeroes().forEach(ith -> mlpState.set(ith.index(), ith.get())); + mlpState.set(mlpState.size() - 1, intercept); + perceptron.setParameters(mlpState); + return perceptron; + } + + /** {@inheritDoc} */ + @Override protected boolean checkState(LogisticRegressionModel mdl) { + return true; + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassModel.java index 56d2d29..a7c9118 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassModel.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassModel.java @@ -21,6 +21,7 @@ import java.io.Serializable; import java.util.HashMap; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.TreeMap; import org.apache.ignite.ml.Exportable; import org.apache.ignite.ml.Exporter; @@ -103,4 +104,12 @@ public class LogRegressionMultiClassModel implements Model<Vector, Double>, Expo public void add(double clsLb, LogisticRegressionModel mdl) { models.put(clsLb, mdl); } + + /** + * @param clsLb Class label. + * @return model for class label if it exists. + */ + public Optional<LogisticRegressionModel> getModel(Double clsLb) { + return Optional.ofNullable(models.get(clsLb)); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java index 4885373..eb44301 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.Collection; import java.util.HashSet; import java.util.List; +import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -33,6 +34,7 @@ import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.nn.MultilayerPerceptron; import org.apache.ignite.ml.nn.UpdatesStrategy; +import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel; import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer; import org.apache.ignite.ml.structures.partition.LabelPartitionDataBuilderOnHeap; import org.apache.ignite.ml.structures.partition.LabelPartitionDataOnHeap; @@ -71,6 +73,19 @@ public class LogRegressionMultiClassTrainer<P extends Serializable> IgniteBiFunction<K, V, Double> lbExtractor) { List<Double> classes = extractClassLabels(datasetBuilder, lbExtractor); + return updateModel(null, datasetBuilder, featureExtractor, lbExtractor); + } + + /** {@inheritDoc} */ + @Override public <K, V> LogRegressionMultiClassModel updateModel(LogRegressionMultiClassModel mdl, + DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor) { + + List<Double> classes = extractClassLabels(datasetBuilder, lbExtractor); + + if(classes.isEmpty()) + return getLastTrainedModelOrThrowEmptyDatasetException(mdl); + LogRegressionMultiClassModel multiClsMdl = new LogRegressionMultiClassModel(); classes.forEach(clsLb -> { @@ -85,12 +100,23 @@ public class LogRegressionMultiClassTrainer<P extends Serializable> else return 0.0; }; - multiClsMdl.add(clsLb, trainer.fit(datasetBuilder, featureExtractor, lbTransformer)); + + LogisticRegressionModel model = Optional.ofNullable(mdl) + .flatMap(multiClassModel -> multiClassModel.getModel(clsLb)) + .map(learnedModel -> trainer.update(learnedModel, datasetBuilder, featureExtractor, lbTransformer)) + .orElseGet(() -> trainer.fit(datasetBuilder, featureExtractor, lbTransformer)); + + multiClsMdl.add(clsLb, model); }); return multiClsMdl; } + /** {@inheritDoc} */ + @Override protected boolean checkState(LogRegressionMultiClassModel mdl) { + return true; + } + /** Iterates among dataset and collects class labels. */ private <K, V> List<Double> extractClassLabels(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Double> lbExtractor) { @@ -121,7 +147,8 @@ public class LogRegressionMultiClassTrainer<P extends Serializable> return Stream.of(a, b).flatMap(Collection::stream).collect(Collectors.toSet()); }); - res.addAll(clsLabels); + if (clsLabels != null) + res.addAll(clsLabels); } catch (Exception e) { @@ -191,7 +218,7 @@ public class LogRegressionMultiClassTrainer<P extends Serializable> } /** - * Set up the regularization parameter. + * Set up the random seed parameter. * * @param seed Seed for random generator. * @return Trainer with new seed parameter value. http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java index 933a712..573df1a 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java @@ -22,9 +22,11 @@ import org.apache.ignite.ml.dataset.Dataset; import org.apache.ignite.ml.dataset.DatasetBuilder; import org.apache.ignite.ml.dataset.PartitionDataBuilder; import org.apache.ignite.ml.dataset.primitive.context.EmptyContext; +import org.apache.ignite.ml.math.StorageConstants; import org.apache.ignite.ml.math.functions.IgniteBiFunction; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; +import org.apache.ignite.ml.math.primitives.vector.impl.SparseVector; import org.apache.ignite.ml.structures.LabeledVector; import org.apache.ignite.ml.structures.LabeledVectorSet; import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap; @@ -61,6 +63,14 @@ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrai @Override public <K, V> SVMLinearBinaryClassificationModel fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { + return updateModel(null, datasetBuilder, featureExtractor, lbExtractor); + } + + /** {@inheritDoc} */ + @Override protected <K, V> SVMLinearBinaryClassificationModel updateModel(SVMLinearBinaryClassificationModel mdl, + DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor) { + assert datasetBuilder != null; PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<Double, LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>( @@ -74,29 +84,57 @@ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrai (upstream, upstreamSize) -> new EmptyContext(), partDataBuilder )) { - final int cols = dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> { - if (a == null) - return b == null ? 0 : b; - if (b == null) - return a; - return b; - }); - - final int weightVectorSizeWithIntercept = cols + 1; - - weights = initializeWeightsWithZeros(weightVectorSizeWithIntercept); + if (mdl == null) { + final int cols = dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> { + if (a == null) + return b == null ? 0 : b; + if (b == null) + return a; + return b; + }); + + final int weightVectorSizeWithIntercept = cols + 1; + weights = initializeWeightsWithZeros(weightVectorSizeWithIntercept); + } else { + weights = getStateVector(mdl); + } for (int i = 0; i < this.getAmountOfIterations(); i++) { Vector deltaWeights = calculateUpdates(weights, dataset); + if (deltaWeights == null) + return getLastTrainedModelOrThrowEmptyDatasetException(mdl); + weights = weights.plus(deltaWeights); // creates new vector } - } - catch (Exception e) { + } catch (Exception e) { throw new RuntimeException(e); } return new SVMLinearBinaryClassificationModel(weights.viewPart(1, weights.size() - 1), weights.get(0)); } + /** {@inheritDoc} */ + @Override protected boolean checkState(SVMLinearBinaryClassificationModel mdl) { + return true; + } + + /** + * @param mdl Model. + * @return vector of model weights with intercept. + */ + private Vector getStateVector(SVMLinearBinaryClassificationModel mdl) { + double intercept = mdl.intercept(); + Vector weights = mdl.weights(); + + int stateVectorSize = weights.size() + 1; + Vector result = weights.isDense() ? + new DenseVector(stateVectorSize) : + new SparseVector(stateVectorSize, StorageConstants.RANDOM_ACCESS_MODE); + + result.set(0, intercept); + weights.nonZeroes().forEach(ith -> result.set(ith.index(), ith.get())); + return result; + } + /** */ @NotNull private Vector initializeWeightsWithZeros(int vectorSize) { return new DenseVector(vectorSize); http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationModel.java index 4b04824..46bf4b2 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationModel.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationModel.java @@ -21,6 +21,7 @@ import java.io.Serializable; import java.util.HashMap; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.TreeMap; import org.apache.ignite.ml.Exportable; import org.apache.ignite.ml.Exporter; @@ -102,4 +103,12 @@ public class SVMLinearMultiClassClassificationModel implements Model<Vector, Dou public void add(double clsLb, SVMLinearBinaryClassificationModel mdl) { models.put(clsLb, mdl); } + + /** + * @param clsLb Class label. + * @return model trained for target class if it exists. + */ + public Optional<SVMLinearBinaryClassificationModel> getModelForClass(double clsLb) { + return Optional.of(models.get(clsLb)); + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java index 4b7cc95..b77baa2 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java @@ -57,15 +57,26 @@ public class SVMLinearMultiClassClassificationTrainer /** * Trains model based on the specified data. * - * @param datasetBuilder Dataset builder. + * @param datasetBuilder Dataset builder. * @param featureExtractor Feature extractor. - * @param lbExtractor Label extractor. + * @param lbExtractor Label extractor. * @return Model. */ @Override public <K, V> SVMLinearMultiClassClassificationModel fit(DatasetBuilder<K, V> datasetBuilder, - IgniteBiFunction<K, V, Vector> featureExtractor, - IgniteBiFunction<K, V, Double> lbExtractor) { + IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor) { + return updateModel(null, datasetBuilder, featureExtractor, lbExtractor); + } + + /** {@inheritDoc} */ + @Override public <K, V> SVMLinearMultiClassClassificationModel updateModel( + SVMLinearMultiClassClassificationModel mdl, + DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor) { + List<Double> classes = extractClassLabels(datasetBuilder, lbExtractor); + if (classes.isEmpty()) + return getLastTrainedModelOrThrowEmptyDatasetException(mdl); SVMLinearMultiClassClassificationModel multiClsMdl = new SVMLinearMultiClassClassificationModel(); @@ -84,14 +95,60 @@ public class SVMLinearMultiClassClassificationTrainer else return -1.0; }; - multiClsMdl.add(clsLb, trainer.fit(datasetBuilder, featureExtractor, lbTransformer)); + + SVMLinearBinaryClassificationModel model; + if (mdl == null) + model = learnNewModel(trainer, datasetBuilder, featureExtractor, lbTransformer); + else + model = updateModel(mdl, clsLb, trainer, datasetBuilder, featureExtractor, lbTransformer); + multiClsMdl.add(clsLb, model); }); return multiClsMdl; } + /** {@inheritDoc} */ + @Override protected boolean checkState(SVMLinearMultiClassClassificationModel mdl) { + return true; + } + + /** + * Trains model based on the specified data. + * + * @param svmTrainer Prepared SVM trainer. + * @param datasetBuilder Dataset builder. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + */ + private <K, V> SVMLinearBinaryClassificationModel learnNewModel(SVMLinearBinaryClassificationTrainer svmTrainer, + DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, + IgniteBiFunction<K, V, Double> lbExtractor) { + + return svmTrainer.fit(datasetBuilder, featureExtractor, lbExtractor); + } + + /** + * Updates already learned model or fit new model if there is no model for current class label. + * + * @param multiClsMdl Learning multi-class model. + * @param clsLb Current class label. + * @param svmTrainer Prepared SVM trainer. + * @param datasetBuilder Dataset builder. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + */ + private <K, V> SVMLinearBinaryClassificationModel updateModel(SVMLinearMultiClassClassificationModel multiClsMdl, + Double clsLb, SVMLinearBinaryClassificationTrainer svmTrainer, DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) { + + return multiClsMdl.getModelForClass(clsLb) + .map(learnedModel -> svmTrainer.update(learnedModel, datasetBuilder, featureExtractor, lbExtractor)) + .orElseGet(() -> svmTrainer.fit(datasetBuilder, featureExtractor, lbExtractor)); + } + /** Iterates among dataset and collects class labels. */ - private <K, V> List<Double> extractClassLabels(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Double> lbExtractor) { + private <K, V> List<Double> extractClassLabels(DatasetBuilder<K, V> datasetBuilder, + IgniteBiFunction<K, V, Double> lbExtractor) { assert datasetBuilder != null; PartitionDataBuilder<K, V, EmptyContext, LabelPartitionDataOnHeap> partDataBuilder = new LabelPartitionDataBuilderOnHeap<>(lbExtractor); @@ -107,7 +164,8 @@ public class SVMLinearMultiClassClassificationTrainer final double[] lbs = data.getY(); - for (double lb : lbs) locClsLabels.add(lb); + for (double lb : lbs) + locClsLabels.add(lb); return locClsLabels; }, (a, b) -> { @@ -118,8 +176,8 @@ public class SVMLinearMultiClassClassificationTrainer return Stream.of(a, b).flatMap(Collection::stream).collect(Collectors.toSet()); }); - res.addAll(clsLabels); - + if (clsLabels != null) + res.addAll(clsLabels); } catch (Exception e) { throw new RuntimeException(e); } @@ -132,7 +190,7 @@ public class SVMLinearMultiClassClassificationTrainer * @param lambda The regularization parameter. Should be more than 0.0. * @return Trainer with new lambda parameter value. */ - public SVMLinearMultiClassClassificationTrainer withLambda(double lambda) { + public SVMLinearMultiClassClassificationTrainer withLambda(double lambda) { assert lambda > 0.0; this.lambda = lambda; return this; @@ -162,7 +220,7 @@ public class SVMLinearMultiClassClassificationTrainer * @param amountOfIterations The parameter value. * @return Trainer with new amountOfIterations parameter value. */ - public SVMLinearMultiClassClassificationTrainer withAmountOfIterations(int amountOfIterations) { + public SVMLinearMultiClassClassificationTrainer withAmountOfIterations(int amountOfIterations) { this.amountOfIterations = amountOfIterations; return this; } @@ -182,7 +240,7 @@ public class SVMLinearMultiClassClassificationTrainer * @param amountOfLocIterations The parameter value. * @return Trainer with new amountOfLocIterations parameter value. */ - public SVMLinearMultiClassClassificationTrainer withAmountOfLocIterations(int amountOfLocIterations) { + public SVMLinearMultiClassClassificationTrainer withAmountOfLocIterations(int amountOfLocIterations) { this.amountOfLocIterations = amountOfLocIterations; return this; }
