IGNITE-9387: [ML] Model updating

this closes #4659


Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/f4c18f11
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/f4c18f11
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/f4c18f11

Branch: refs/heads/master
Commit: f4c18f114a435ab818e5c9b179e4854df611dd74
Parents: 02afb43
Author: Alexey Platonov <[email protected]>
Authored: Tue Sep 4 18:11:48 2018 +0300
Committer: Yury Babak <[email protected]>
Committed: Tue Sep 4 18:11:48 2018 +0300

----------------------------------------------------------------------
 .../GDBOnTreesClassificationTrainerExample.java |   3 +-
 .../GDBOnTreesRegressionTrainerExample.java     |   3 +-
 .../ml/clustering/kmeans/KMeansTrainer.java     |  27 ++-
 .../ml/composition/BaggingModelTrainer.java     |  20 +++
 .../ml/composition/ModelsComposition.java       |  10 +-
 .../ml/composition/ModelsCompositionFormat.java |  61 +++++++
 .../ml/composition/boosting/GDBTrainer.java     |  19 ++-
 .../ignite/ml/knn/NNClassificationModel.java    |  16 ++
 .../ml/knn/ann/ANNClassificationModel.java      |  15 +-
 .../ml/knn/ann/ANNClassificationTrainer.java    |  93 +++++++----
 .../ignite/ml/knn/ann/ANNModelFormat.java       |  12 +-
 .../classification/KNNClassificationModel.java  |  40 ++++-
 .../KNNClassificationTrainer.java               |  20 ++-
 .../ml/knn/regression/KNNRegressionTrainer.java |  19 ++-
 .../ml/math/isolve/lsqr/AbstractLSQR.java       |   6 +-
 .../ignite/ml/math/isolve/lsqr/LSQROnHeap.java  |   2 +-
 .../org/apache/ignite/ml/nn/MLPTrainer.java     |  25 ++-
 .../ml/preprocessing/PreprocessingTrainer.java  |   3 +-
 .../linear/LinearRegressionLSQRTrainer.java     |  27 ++-
 .../linear/LinearRegressionSGDTrainer.java      |  97 ++++++++++-
 .../binomial/LogisticRegressionSGDTrainer.java  |  42 ++++-
 .../LogRegressionMultiClassModel.java           |   9 +
 .../LogRegressionMultiClassTrainer.java         |  33 +++-
 .../SVMLinearBinaryClassificationTrainer.java   |  64 +++++--
 .../SVMLinearMultiClassClassificationModel.java |   9 +
 ...VMLinearMultiClassClassificationTrainer.java |  82 +++++++--
 .../ignite/ml/trainers/DatasetTrainer.java      | 166 +++++++++++++++++++
 .../org/apache/ignite/ml/tree/DecisionTree.java |  23 +++
 .../RandomForestClassifierTrainer.java          |   8 +-
 .../tree/randomforest/RandomForestTrainer.java  |  28 +++-
 .../ignite/ml/clustering/KMeansTrainerTest.java |  82 ++++++---
 .../ignite/ml/common/CollectionsTest.java       |   9 +-
 .../ignite/ml/common/LocalModelsTest.java       |   5 +-
 .../ml/composition/boosting/GDBTrainerTest.java |   2 +-
 .../ignite/ml/knn/ANNClassificationTest.java    |  44 +++++
 .../ignite/ml/knn/KNNClassificationTest.java    |  39 +++++
 .../apache/ignite/ml/knn/KNNRegressionTest.java |  40 +++++
 .../org/apache/ignite/ml/nn/MLPTrainerTest.java |  64 +++++++
 .../linear/LinearRegressionLSQRTrainerTest.java |  51 ++++++
 .../linear/LinearRegressionSGDTrainerTest.java  |  62 +++++++
 .../logistic/LogRegMultiClassTrainerTest.java   |  58 +++++++
 .../LogisticRegressionSGDTrainerTest.java       |  46 +++++
 .../ignite/ml/svm/SVMBinaryTrainerTest.java     |  41 +++++
 .../ignite/ml/svm/SVMMultiClassTrainerTest.java |  43 +++++
 .../RandomForestClassifierTrainerTest.java      |  33 +++-
 .../RandomForestRegressionTrainerTest.java      |  31 ++++
 46 files changed, 1503 insertions(+), 129 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesClassificationTrainerExample.java
----------------------------------------------------------------------
diff --git 
a/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesClassificationTrainerExample.java
 
b/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesClassificationTrainerExample.java
index 130b91a..075eab2 100644
--- 
a/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesClassificationTrainerExample.java
+++ 
b/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesClassificationTrainerExample.java
@@ -23,6 +23,7 @@ import org.apache.ignite.Ignition;
 import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
 import org.apache.ignite.configuration.CacheConfiguration;
 import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.composition.ModelsComposition;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.apache.ignite.ml.trainers.DatasetTrainer;
@@ -58,7 +59,7 @@ public class GDBOnTreesClassificationTrainerExample {
                 IgniteCache<Integer, double[]> trainingSet = 
fillTrainingData(ignite, trainingSetCfg);
 
                 // Create regression trainer.
-                DatasetTrainer<Model<Vector, Double>, Double> trainer = new 
GDBBinaryClassifierOnTreesTrainer(1.0, 300, 2, 0.);
+                DatasetTrainer<ModelsComposition, Double> trainer = new 
GDBBinaryClassifierOnTreesTrainer(1.0, 300, 2, 0.);
 
                 // Train decision tree model.
                 Model<Vector, Double> mdl = trainer.fit(

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java
----------------------------------------------------------------------
diff --git 
a/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java
 
b/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java
index 31dd2b0..b2b08d0 100644
--- 
a/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java
+++ 
b/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java
@@ -23,6 +23,7 @@ import org.apache.ignite.Ignition;
 import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
 import org.apache.ignite.configuration.CacheConfiguration;
 import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.composition.ModelsComposition;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.apache.ignite.ml.trainers.DatasetTrainer;
@@ -58,7 +59,7 @@ public class GDBOnTreesRegressionTrainerExample {
                 IgniteCache<Integer, double[]> trainingSet = 
fillTrainingData(ignite, trainingSetCfg);
 
                 // Create regression trainer.
-                DatasetTrainer<Model<Vector, Double>, Double> trainer = new 
GDBRegressionOnTreesTrainer(1.0, 2000, 1, 0.);
+                DatasetTrainer<ModelsComposition, Double> trainer = new 
GDBRegressionOnTreesTrainer(1.0, 2000, 1, 0.);
 
                 // Train decision tree model.
                 Model<Vector, Double> mdl = trainer.fit(

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java
index 5b880fcc..2596dbc 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java
@@ -21,6 +21,7 @@ import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashSet;
 import java.util.List;
+import java.util.Optional;
 import java.util.Random;
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
@@ -72,6 +73,14 @@ public class KMeansTrainer extends 
SingleLabelDatasetTrainer<KMeansModel> {
      */
     @Override public <K, V> KMeansModel fit(DatasetBuilder<K, V> 
datasetBuilder,
         IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, 
V, Double> lbExtractor) {
+
+        return updateModel(null, datasetBuilder, featureExtractor, 
lbExtractor);
+    }
+
+    /** {@inheritDoc} */
+    @Override protected <K, V> KMeansModel updateModel(KMeansModel mdl, 
DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, 
V, Double> lbExtractor) {
+
         assert datasetBuilder != null;
 
         PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<Double, 
LabeledVector>> partDataBuilder = new 
LabeledDatasetPartitionDataBuilderOnHeap<>(
@@ -85,7 +94,7 @@ public class KMeansTrainer extends 
SingleLabelDatasetTrainer<KMeansModel> {
             (upstream, upstreamSize) -> new EmptyContext(),
             partDataBuilder
         )) {
-            final int cols = 
dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> {
+            final Integer cols = 
dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> {
                 if (a == null)
                     return b == null ? 0 : b;
                 if (b == null)
@@ -93,7 +102,12 @@ public class KMeansTrainer extends 
SingleLabelDatasetTrainer<KMeansModel> {
                 return b;
             });
 
-            centers = initClusterCentersRandomly(dataset, k);
+            if (cols == null)
+                return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
+
+            centers = Optional.ofNullable(mdl)
+                .map(KMeansModel::centers)
+                .orElseGet(() -> initClusterCentersRandomly(dataset, k));
 
             boolean converged = false;
             int iteration = 0;
@@ -127,6 +141,11 @@ public class KMeansTrainer extends 
SingleLabelDatasetTrainer<KMeansModel> {
         return new KMeansModel(centers, distance);
     }
 
+    /** {@inheritDoc} */
+    @Override protected boolean checkState(KMeansModel mdl) {
+        return mdl.centers().length == k && 
mdl.distanceMeasure().equals(distance);
+    }
+
     /**
      * Prepares the data to define new centroids on current iteration.
      *
@@ -281,10 +300,12 @@ public class KMeansTrainer extends 
SingleLabelDatasetTrainer<KMeansModel> {
             return this;
         }
 
+        /**
+         * @return centroid statistics.
+         */
         public ConcurrentHashMap<Integer, ConcurrentHashMap<Double, Integer>> 
getCentroidStat() {
             return centroidStat;
         }
-
     }
 
     /**

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/composition/BaggingModelTrainer.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/composition/BaggingModelTrainer.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/composition/BaggingModelTrainer.java
index f439789..493c1da 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/composition/BaggingModelTrainer.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/composition/BaggingModelTrainer.java
@@ -177,4 +177,24 @@ public abstract class BaggingModelTrainer extends 
DatasetTrainer<ModelsCompositi
             return VectorUtils.of(newFeaturesValues);
         });
     }
+
+    /**
+     * Learn new models on dataset and create new Compositions over them and 
already learned models.
+     *
+     * @param mdl Learned model.
+     * @param datasetBuilder Dataset builder.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     * @param <K> Type of a key in {@code upstream} data.
+     * @param <V> Type of a value in {@code upstream} data.
+     * @return New models composition.
+     */
+    @Override public <K, V> ModelsComposition updateModel(ModelsComposition 
mdl, DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, 
V, Double> lbExtractor) {
+
+        ArrayList<Model<Vector, Double>> newModels = new 
ArrayList<>(mdl.getModels());
+        newModels.addAll(fit(datasetBuilder, featureExtractor, 
lbExtractor).getModels());
+
+        return new ModelsComposition(newModels, predictionsAggregator);
+    }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsComposition.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsComposition.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsComposition.java
index e14fa6d..36ee626 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsComposition.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsComposition.java
@@ -19,6 +19,8 @@ package org.apache.ignite.ml.composition;
 
 import java.util.Collections;
 import java.util.List;
+import org.apache.ignite.ml.Exportable;
+import org.apache.ignite.ml.Exporter;
 import org.apache.ignite.ml.Model;
 import 
org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
@@ -27,7 +29,7 @@ import org.apache.ignite.ml.util.ModelTrace;
 /**
  * Model consisting of several models and prediction aggregation strategy.
  */
-public class ModelsComposition implements Model<Vector, Double> {
+public class ModelsComposition implements Model<Vector, Double>, 
Exportable<ModelsCompositionFormat> {
     /**
      * Predictions aggregator.
      */
@@ -78,6 +80,12 @@ public class ModelsComposition implements Model<Vector, 
Double> {
     }
 
     /** {@inheritDoc} */
+    @Override public <P> void saveModel(Exporter<ModelsCompositionFormat, P> 
exporter, P path) {
+        ModelsCompositionFormat format = new ModelsCompositionFormat(models, 
predictionsAggregator);
+        exporter.save(format, path);
+    }
+
+    /** {@inheritDoc} */
     @Override public String toString() {
         return toString(false);
     }

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsCompositionFormat.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsCompositionFormat.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsCompositionFormat.java
new file mode 100644
index 0000000..68af0a9
--- /dev/null
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsCompositionFormat.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition;
+
+import java.io.Serializable;
+import java.util.List;
+import org.apache.ignite.ml.Model;
+import 
org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+
+/**
+ * ModelsComposition representation.
+ *
+ * @see ModelsComposition
+ */
+public class ModelsCompositionFormat implements Serializable {
+    /** Serial version uid. */
+    private static final long serialVersionUID = 9115341364082681837L;
+
+    /** Models. */
+    private List<Model<Vector, Double>> models;
+
+    /** Predictions aggregator. */
+    private PredictionsAggregator predictionsAggregator;
+
+    /**
+     * Creates an instance of ModelsCompositionFormat.
+     *
+     * @param models Models.
+     * @param predictionsAggregator Predictions aggregator.
+     */
+    public ModelsCompositionFormat(List<Model<Vector, Double>> 
models,PredictionsAggregator predictionsAggregator) {
+        this.models = models;
+        this.predictionsAggregator = predictionsAggregator;
+    }
+
+    /** */
+    public List<Model<Vector, Double>> models() {
+        return models;
+    }
+
+    /** */
+    public PredictionsAggregator predictionsAggregator() {
+        return predictionsAggregator;
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java
index 5a0f52a..c7f21dd 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java
@@ -52,7 +52,7 @@ import org.jetbrains.annotations.NotNull;
  *
  * But in practice Decision Trees is most used regressors (see: {@link 
DecisionTreeRegressionTrainer}).
  */
-public abstract class GDBTrainer extends DatasetTrainer<Model<Vector, Double>, 
Double> {
+public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, 
Double> {
     /** Gradient step. */
     private final double gradientStep;
 
@@ -81,7 +81,7 @@ public abstract class GDBTrainer extends 
DatasetTrainer<Model<Vector, Double>, D
     }
 
     /** {@inheritDoc} */
-    @Override public <K, V> Model<Vector, Double> fit(DatasetBuilder<K, V> 
datasetBuilder,
+    @Override public <K, V> ModelsComposition fit(DatasetBuilder<K, V> 
datasetBuilder,
         IgniteBiFunction<K, V, Vector> featureExtractor,
         IgniteBiFunction<K, V, Double> lbExtractor) {
 
@@ -119,6 +119,21 @@ public abstract class GDBTrainer extends 
DatasetTrainer<Model<Vector, Double>, D
         };
     }
 
+
+    //TODO: This method will be implemented in IGNITE-9412
+    /** {@inheritDoc} */
+    @Override public <K, V> ModelsComposition updateModel(ModelsComposition 
mdl, DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, 
V, Double> lbExtractor) {
+
+        throw new UnsupportedOperationException();
+    }
+
+    //TODO: This method will be implemented in IGNITE-9412
+    /** {@inheritDoc} */
+    @Override protected boolean checkState(ModelsComposition mdl) {
+        throw new UnsupportedOperationException();
+    }
+
     /**
      * Defines unique labels in dataset if need (useful in case of 
classification).
      *

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/knn/NNClassificationModel.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/knn/NNClassificationModel.java 
b/modules/ml/src/main/java/org/apache/ignite/ml/knn/NNClassificationModel.java
index b7a57f5..d435f91 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/knn/NNClassificationModel.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/knn/NNClassificationModel.java
@@ -174,6 +174,11 @@ public abstract class NNClassificationModel implements 
Model<Vector, Double>, Ex
             return 1.0; // strategy.SIMPLE
     }
 
+    /** */
+    public DistanceMeasure getDistanceMeasure() {
+        return distanceMeasure;
+    }
+
     /** {@inheritDoc} */
     @Override public int hashCode() {
         int res = 1;
@@ -212,6 +217,17 @@ public abstract class NNClassificationModel implements 
Model<Vector, Double>, Ex
             .toString();
     }
 
+    /**
+     * Sets parameters from other model to this model.
+     *
+     * @param mdl Model.
+     */
+    protected void copyParametersFrom(NNClassificationModel mdl) {
+        this.k = mdl.k;
+        this.distanceMeasure = mdl.distanceMeasure;
+        this.stgy = mdl.stgy;
+    }
+
     /** */
     public abstract <P> void saveModel(Exporter<KNNModelFormat, P> exporter, P 
path);
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationModel.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationModel.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationModel.java
index e8c0b4a..bec82a9 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationModel.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationModel.java
@@ -44,12 +44,18 @@ public class ANNClassificationModel extends 
NNClassificationModel  {
     /** The labeled set of candidates. */
     private final LabeledVectorSet<ProbableLabel, LabeledVector> candidates;
 
+    /** Centroid statistics. */
+    private final ANNClassificationTrainer.CentroidStat centroindsStat;
+
     /**
      * Build the model based on a candidates set.
      * @param centers The candidates set.
+     * @param centroindsStat
      */
-    public ANNClassificationModel(LabeledVectorSet<ProbableLabel, 
LabeledVector> centers) {
+    public ANNClassificationModel(LabeledVectorSet<ProbableLabel, 
LabeledVector> centers,
+        ANNClassificationTrainer.CentroidStat centroindsStat) {
        this.candidates = centers;
+       this.centroindsStat = centroindsStat;
     }
 
     /** */
@@ -57,6 +63,11 @@ public class ANNClassificationModel extends 
NNClassificationModel  {
         return candidates;
     }
 
+    /** */
+    public ANNClassificationTrainer.CentroidStat getCentroindsStat() {
+        return centroindsStat;
+    }
+
     /** {@inheritDoc} */
     @Override public Double apply(Vector v) {
             List<LabeledVector> neighbors = findKNearestNeighbors(v);
@@ -65,7 +76,7 @@ public class ANNClassificationModel extends 
NNClassificationModel  {
 
     /** */
     @Override public <P> void saveModel(Exporter<KNNModelFormat, P> exporter, 
P path) {
-        ANNModelFormat mdlData = new ANNModelFormat(k, distanceMeasure, stgy, 
candidates);
+        ANNModelFormat mdlData = new ANNModelFormat(k, distanceMeasure, stgy, 
candidates, centroindsStat);
         exporter.save(mdlData, path);
     }
 

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java
index 1c45812..3e32b67 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java
@@ -17,9 +17,13 @@
 
 package org.apache.ignite.ml.knn.ann;
 
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
 import java.util.TreeMap;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentSkipListSet;
+import java.util.stream.Collectors;
 import org.apache.ignite.lang.IgniteBiTuple;
 import org.apache.ignite.ml.clustering.kmeans.KMeansModel;
 import org.apache.ignite.ml.clustering.kmeans.KMeansTrainer;
@@ -39,8 +43,8 @@ import 
org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
 import org.jetbrains.annotations.NotNull;
 
 /**
- * ANN algorithm trainer to solve multi-class classification task.
- * This trainer is based on ACD strategy and KMeans clustering algorithm to 
find centroids.
+ * ANN algorithm trainer to solve multi-class classification task. This 
trainer is based on ACD strategy and KMeans
+ * clustering algorithm to find centroids.
  */
 public class ANNClassificationTrainer extends 
SingleLabelDatasetTrainer<ANNClassificationModel> {
     /** Amount of clusters. */
@@ -61,29 +65,55 @@ public class ANNClassificationTrainer extends 
SingleLabelDatasetTrainer<ANNClass
     /**
      * Trains model based on the specified data.
      *
-     * @param datasetBuilder   Dataset builder.
+     * @param datasetBuilder Dataset builder.
      * @param featureExtractor Feature extractor.
-     * @param lbExtractor      Label extractor.
+     * @param lbExtractor Label extractor.
      * @return Model.
      */
-    @Override public <K, V> ANNClassificationModel fit(DatasetBuilder<K, V> 
datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, 
IgniteBiFunction<K, V, Double> lbExtractor) {
-        final Vector[] centers = getCentroids(featureExtractor, lbExtractor, 
datasetBuilder);
+    @Override public <K, V> ANNClassificationModel fit(DatasetBuilder<K, V> 
datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, 
V, Double> lbExtractor) {
 
-        final CentroidStat centroidStat = getCentroidStat(datasetBuilder, 
featureExtractor, lbExtractor, centers);
+        return updateModel(null, datasetBuilder, featureExtractor, 
lbExtractor);
+    }
+
+    /** {@inheritDoc} */
+    @Override protected <K, V> ANNClassificationModel 
updateModel(ANNClassificationModel mdl,
+        DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> 
featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
+
+        List<Vector> centers;
+        CentroidStat centroidStat;
+        if (mdl != null) {
+            centers = Arrays.stream(mdl.getCandidates().data()).map(x -> 
x.features()).collect(Collectors.toList());
+            CentroidStat newStat = getCentroidStat(datasetBuilder, 
featureExtractor, lbExtractor, centers);
+            if(newStat == null)
+                return mdl;
+            CentroidStat oldStat = mdl.getCentroindsStat();
+            centroidStat = newStat.merge(oldStat);
+        } else {
+            centers = getCentroids(featureExtractor, lbExtractor, 
datasetBuilder);
+            centroidStat = getCentroidStat(datasetBuilder, featureExtractor, 
lbExtractor, centers);
+        }
 
         final LabeledVectorSet<ProbableLabel, LabeledVector> dataset = 
buildLabelsForCandidates(centers, centroidStat);
 
-        return new ANNClassificationModel(dataset);
+        return new ANNClassificationModel(dataset, centroidStat);
+    }
+
+    /** {@inheritDoc} */
+    @Override protected boolean checkState(ANNClassificationModel mdl) {
+        return mdl.getDistanceMeasure().equals(distance) && 
mdl.getCandidates().rowSize() == k;
     }
 
     /** */
-    @NotNull private LabeledVectorSet<ProbableLabel, LabeledVector> 
buildLabelsForCandidates(Vector[] centers, CentroidStat centroidStat) {
+    @NotNull private LabeledVectorSet<ProbableLabel, LabeledVector> 
buildLabelsForCandidates(List<Vector> centers,
+        CentroidStat centroidStat) {
         // init
-        final LabeledVector<Vector, ProbableLabel>[] arr = new 
LabeledVector[centers.length];
+        final LabeledVector<Vector, ProbableLabel>[] arr = new 
LabeledVector[centers.size()];
 
         // fill label for each centroid
-        for (int i = 0; i < centers.length; i++)
-            arr[i] = new LabeledVector<>(centers[i], fillProbableLabel(i, 
centroidStat));
+        for (int i = 0; i < centers.size(); i++)
+            arr[i] = new LabeledVector<>(centers.get(i), fillProbableLabel(i, 
centroidStat));
 
         return new LabeledVectorSet<>(arr);
     }
@@ -92,13 +122,14 @@ public class ANNClassificationTrainer extends 
SingleLabelDatasetTrainer<ANNClass
      * Perform KMeans clusterization algorithm to find centroids.
      *
      * @param featureExtractor Feature extractor.
-     * @param lbExtractor      Label extractor.
-     * @param datasetBuilder   The dataset builder.
-     * @param <K>              Type of a key in {@code upstream} data.
-     * @param <V>              Type of a value in {@code upstream} data.
+     * @param lbExtractor Label extractor.
+     * @param datasetBuilder The dataset builder.
+     * @param <K> Type of a key in {@code upstream} data.
+     * @param <V> Type of a value in {@code upstream} data.
      * @return The arrays of vectors.
      */
-    private <K, V> Vector[] getCentroids(IgniteBiFunction<K, V, Vector> 
featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor, DatasetBuilder<K, 
V> datasetBuilder) {
+    private <K, V> List<Vector> getCentroids(IgniteBiFunction<K, V, Vector> 
featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor, DatasetBuilder<K, V> 
datasetBuilder) {
         KMeansTrainer trainer = new KMeansTrainer()
             .withK(k)
             .withMaxIterations(maxIterations)
@@ -112,7 +143,7 @@ public class ANNClassificationTrainer extends 
SingleLabelDatasetTrainer<ANNClass
             lbExtractor
         );
 
-        return mdl.centers();
+        return Arrays.asList(mdl.centers());
     }
 
     /** */
@@ -125,21 +156,24 @@ public class ANNClassificationTrainer extends 
SingleLabelDatasetTrainer<ANNClass
         ConcurrentHashMap<Double, Integer> centroidLbDistribution
             = centroidStat.centroidStat().get(centroidIdx);
 
-        if(centroidStat.counts.containsKey(centroidIdx)){
+        if (centroidStat.counts.containsKey(centroidIdx)) {
 
             int clusterSize = centroidStat
                 .counts
                 .get(centroidIdx);
 
             clsLbls.keySet().forEach(
-                (label) -> clsLbls.put(label, 
centroidLbDistribution.containsKey(label) ? ((double) 
(centroidLbDistribution.get(label)) / clusterSize) : 0.0)
+                (label) -> clsLbls.put(label, 
centroidLbDistribution.containsKey(label) ? 
((double)(centroidLbDistribution.get(label)) / clusterSize) : 0.0)
             );
         }
         return new ProbableLabel(clsLbls);
     }
 
     /** */
-    private <K, V> CentroidStat getCentroidStat(DatasetBuilder<K, V> 
datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, 
IgniteBiFunction<K, V, Double> lbExtractor, Vector[] centers) {
+    private <K, V> CentroidStat getCentroidStat(DatasetBuilder<K, V> 
datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor, List<Vector> centers) {
+
         PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<Double, 
LabeledVector>> partDataBuilder = new 
LabeledDatasetPartitionDataBuilderOnHeap<>(
             featureExtractor,
             lbExtractor
@@ -174,7 +208,7 @@ public class ANNClassificationTrainer extends 
SingleLabelDatasetTrainer<ANNClass
                     }
 
                     res.counts.merge(centroidIdx, 1,
-                        (IgniteBiFunction<Integer, Integer, Integer>) (i1, i2) 
-> i1 + i2);
+                        (IgniteBiFunction<Integer, Integer, Integer>)(i1, i2) 
-> i1 + i2);
                 }
                 return res;
             }, (a, b) -> {
@@ -194,15 +228,15 @@ public class ANNClassificationTrainer extends 
SingleLabelDatasetTrainer<ANNClass
      * Find the closest cluster center index and distance to it from a given 
point.
      *
      * @param centers Centers to look in.
-     * @param pnt     Point.
+     * @param pnt Point.
      */
-    private IgniteBiTuple<Integer, Double> findClosestCentroid(Vector[] 
centers, LabeledVector pnt) {
+    private IgniteBiTuple<Integer, Double> findClosestCentroid(List<Vector> 
centers, LabeledVector pnt) {
         double bestDistance = Double.POSITIVE_INFINITY;
         int bestInd = 0;
 
-        for (int i = 0; i < centers.length; i++) {
-            if (centers[i] != null) {
-                double dist = distance.compute(centers[i], pnt.features());
+        for (int i = 0; i < centers.size(); i++) {
+            if (centers.get(i) != null) {
+                double dist = distance.compute(centers.get(i), pnt.features());
                 if (dist < bestDistance) {
                     bestDistance = dist;
                     bestInd = i;
@@ -212,7 +246,6 @@ public class ANNClassificationTrainer extends 
SingleLabelDatasetTrainer<ANNClass
         return new IgniteBiTuple<>(bestInd, bestDistance);
     }
 
-
     /**
      * Gets the amount of clusters.
      *
@@ -314,7 +347,9 @@ public class ANNClassificationTrainer extends 
SingleLabelDatasetTrainer<ANNClass
     }
 
     /** Service class used for statistics. */
-    public static class CentroidStat {
+    public static class CentroidStat implements Serializable {
+        /** Serial version uid. */
+        private static final long serialVersionUID = 7624883170532045144L;
 
         /** Count of points closest to the center with a given index. */
         ConcurrentHashMap<Integer, ConcurrentHashMap<Double, Integer>> 
centroidStat = new ConcurrentHashMap<>();

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNModelFormat.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNModelFormat.java 
b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNModelFormat.java
index e10f3b2..be09828 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNModelFormat.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNModelFormat.java
@@ -30,6 +30,9 @@ import org.apache.ignite.ml.structures.LabeledVectorSet;
  * @see ANNClassificationModel
  */
 public class ANNModelFormat extends KNNModelFormat implements Serializable {
+    /** Centroid statistics. */
+    private final ANNClassificationTrainer.CentroidStat candidatesStat;
+
     /** The labeled set of candidates. */
     private LabeledVectorSet<ProbableLabel, LabeledVector> candidates;
 
@@ -38,15 +41,18 @@ public class ANNModelFormat extends KNNModelFormat 
implements Serializable {
      * @param k Amount of nearest neighbors.
      * @param measure Distance measure.
      * @param stgy kNN strategy.
+     * @param candidatesStat
      */
     public ANNModelFormat(int k,
-                          DistanceMeasure measure,
-                          NNStrategy stgy,
-                          LabeledVectorSet<ProbableLabel, LabeledVector> 
candidates) {
+        DistanceMeasure measure,
+        NNStrategy stgy,
+        LabeledVectorSet<ProbableLabel, LabeledVector> candidates,
+        ANNClassificationTrainer.CentroidStat candidatesStat) {
         this.k = k;
         this.distanceMeasure = measure;
         this.stgy = stgy;
         this.candidates = candidates;
+        this.candidatesStat = candidatesStat;
     }
 
     /** */

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java
index 0b88f81..0d03ee5 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java
@@ -19,6 +19,7 @@ package org.apache.ignite.ml.knn.classification;
 
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -42,25 +43,29 @@ public class KNNClassificationModel extends 
NNClassificationModel implements Exp
     /** */
     private static final long serialVersionUID = -127386523291350345L;
 
-    /** Dataset. */
-    private Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> 
dataset;
+    /** Datasets. */
+    private List<Dataset<EmptyContext, LabeledVectorSet<Double, 
LabeledVector>>> datasets;
 
     /**
      * Builds the model via prepared dataset.
+     *
      * @param dataset Specially prepared object to run algorithm over it.
      */
     public KNNClassificationModel(Dataset<EmptyContext, 
LabeledVectorSet<Double, LabeledVector>> dataset) {
-        this.dataset = dataset;
+        this.datasets = new ArrayList<>();
+        if (dataset != null)
+            datasets.add(dataset);
     }
 
     /** {@inheritDoc} */
     @Override public Double apply(Vector v) {
-        if(dataset != null) {
+        if (!datasets.isEmpty()) {
             List<LabeledVector> neighbors = findKNearestNeighbors(v);
 
             return classify(neighbors, v, stgy);
-        } else
+        } else {
             throw new IllegalStateException("The train kNN dataset is null");
+        }
     }
 
     /** */
@@ -77,6 +82,17 @@ public class KNNClassificationModel extends 
NNClassificationModel implements Exp
      * @return K-nearest neighbors.
      */
     protected List<LabeledVector> findKNearestNeighbors(Vector v) {
+        List<LabeledVector> neighborsFromPartitions = datasets.stream()
+            .flatMap(dataset -> findKNearestNeighborsInDataset(v, 
dataset).stream())
+            .collect(Collectors.toList());
+
+        LabeledVectorSet<Double, LabeledVector> neighborsToFilter = 
buildLabeledDatasetOnListOfVectors(neighborsFromPartitions);
+
+        return Arrays.asList(getKClosestVectors(neighborsToFilter, 
getDistances(v, neighborsToFilter)));
+    }
+
+    private List<LabeledVector> findKNearestNeighborsInDataset(Vector v,
+        Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> 
dataset) {
         List<LabeledVector> neighborsFromPartitions = dataset.compute(data -> {
             TreeMap<Double, Set<Integer>> distanceIdxPairs = getDistances(v, 
data);
             return Arrays.asList(getKClosestVectors(data, distanceIdxPairs));
@@ -88,12 +104,14 @@ public class KNNClassificationModel extends 
NNClassificationModel implements Exp
             return Stream.concat(a.stream(), 
b.stream()).collect(Collectors.toList());
         });
 
+        if(neighborsFromPartitions == null)
+            return Collections.emptyList();
+
         LabeledVectorSet<Double, LabeledVector> neighborsToFilter = 
buildLabeledDatasetOnListOfVectors(neighborsFromPartitions);
 
         return Arrays.asList(getKClosestVectors(neighborsToFilter, 
getDistances(v, neighborsToFilter)));
     }
 
-
     /** */
     private double classify(List<LabeledVector> neighbors, Vector v, 
NNStrategy stgy) {
         Map<Double, Double> clsVotes = new HashMap<>();
@@ -116,5 +134,13 @@ public class KNNClassificationModel extends 
NNClassificationModel implements Exp
         return getClassWithMaxVotes(clsVotes);
     }
 
-
+    /**
+     * Copy parameters from other model and save all datasets from it.
+     *
+     * @param model Model.
+     */
+    public void copyStateFrom(KNNClassificationModel model) {
+        this.copyParametersFrom(model);
+        datasets.addAll(model.datasets);
+    }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationTrainer.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationTrainer.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationTrainer.java
index e0a81f9..1a3ff73 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationTrainer.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationTrainer.java
@@ -37,6 +37,24 @@ public class KNNClassificationTrainer extends 
SingleLabelDatasetTrainer<KNNClass
      */
     @Override public <K, V> KNNClassificationModel fit(DatasetBuilder<K, V> 
datasetBuilder,
         IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, 
V, Double> lbExtractor) {
-        return new 
KNNClassificationModel(KNNUtils.buildDataset(datasetBuilder, featureExtractor, 
lbExtractor));
+
+        return updateModel(null, datasetBuilder, featureExtractor, 
lbExtractor);
+    }
+
+    /** {@inheritDoc} */
+    @Override public <K, V> KNNClassificationModel 
updateModel(KNNClassificationModel mdl,
+        DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> 
featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
+
+        KNNClassificationModel res = new 
KNNClassificationModel(KNNUtils.buildDataset(datasetBuilder,
+            featureExtractor, lbExtractor));
+        if (mdl != null)
+            res.copyStateFrom(mdl);
+        return res;
+    }
+
+    /** {@inheritDoc} */
+    @Override protected boolean checkState(KNNClassificationModel mdl) {
+        return true;
     }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionTrainer.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionTrainer.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionTrainer.java
index 395ce61..7a42dc8 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionTrainer.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionTrainer.java
@@ -37,6 +37,23 @@ public class KNNRegressionTrainer extends 
SingleLabelDatasetTrainer<KNNRegressio
      */
     public <K, V> KNNRegressionModel fit(DatasetBuilder<K, V> datasetBuilder,
         IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, 
V, Double> lbExtractor) {
-        return new KNNRegressionModel(KNNUtils.buildDataset(datasetBuilder, 
featureExtractor, lbExtractor));
+
+        return updateModel(null, datasetBuilder, featureExtractor, 
lbExtractor);
+    }
+
+    /** {@inheritDoc} */
+    @Override public <K, V> KNNRegressionModel updateModel(KNNRegressionModel 
mdl, DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, 
V, Double> lbExtractor) {
+
+        KNNRegressionModel res = new 
KNNRegressionModel(KNNUtils.buildDataset(datasetBuilder,
+            featureExtractor, lbExtractor));
+        if (mdl != null)
+            res.copyStateFrom(mdl);
+        return res;
+    }
+
+    /** {@inheritDoc} */
+    @Override protected boolean checkState(KNNRegressionModel mdl) {
+        return true;
     }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/AbstractLSQR.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/AbstractLSQR.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/AbstractLSQR.java
index 7a362f7..c9281c0 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/AbstractLSQR.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/AbstractLSQR.java
@@ -78,7 +78,9 @@ public abstract class AbstractLSQR {
      */
     public LSQRResult solve(double damp, double atol, double btol, double 
conlim, double iterLim, boolean calcVar,
         double[] x0) {
-        int n = getColumns();
+        Integer n = getColumns();
+        if(n == null)
+            return null;
 
         if (iterLim < 0)
             iterLim = 2 * n;
@@ -313,7 +315,7 @@ public abstract class AbstractLSQR {
     protected abstract double[] iter(double bnorm, double[] target);
 
     /** */
-    protected abstract int getColumns();
+    protected abstract Integer getColumns();
 
     /** */
     private static double[] symOrtho(double a, double b) {

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java
index f75caef..14356e1 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java
@@ -100,7 +100,7 @@ public class LSQROnHeap<K, V> extends AbstractLSQR 
implements AutoCloseable {
      *
      * @return number of columns
      */
-    @Override protected int getColumns() {
+    @Override protected Integer getColumns() {
         return dataset.compute(
             data -> data.getFeatures() == null ? null : 
data.getFeatures().length / data.getRows(),
             (a, b) -> {

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java 
b/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java
index 6727ba9..8f1a4cb 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java
@@ -111,12 +111,25 @@ public class MLPTrainer<P extends Serializable> extends 
MultiLabelDatasetTrainer
     public <K, V> MultilayerPerceptron fit(DatasetBuilder<K, V> datasetBuilder,
         IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, 
V, double[]> lbExtractor) {
 
+        return updateModel(null, datasetBuilder, featureExtractor, 
lbExtractor);
+    }
+
+    /** {@inheritDoc} */
+    @Override protected <K, V> MultilayerPerceptron 
updateModel(MultilayerPerceptron lastLearnedModel,
+        DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, 
V, double[]> lbExtractor) {
+
         try (Dataset<EmptyContext, SimpleLabeledDatasetData> dataset = 
datasetBuilder.build(
             new EmptyContextBuilder<>(),
             new SimpleLabeledDatasetDataBuilder<>(featureExtractor, 
lbExtractor)
         )) {
-            MLPArchitecture arch = archSupplier.apply(dataset);
-            MultilayerPerceptron mdl = new MultilayerPerceptron(arch, new 
RandomInitializer(seed));
+            MultilayerPerceptron mdl;
+            if (lastLearnedModel != null) {
+                mdl = lastLearnedModel;
+            } else {
+                MLPArchitecture arch = archSupplier.apply(dataset);
+                mdl = new MultilayerPerceptron(arch, new 
RandomInitializer(seed));
+            }
             ParameterUpdateCalculator<? super MultilayerPerceptron, P> updater 
= updatesStgy.getUpdatesCalculator();
 
             for (int i = 0; i < maxIterations; i += locIterations) {
@@ -178,6 +191,9 @@ public class MLPTrainer<P extends Serializable> extends 
MultiLabelDatasetTrainer
                     }
                 );
 
+                if (totUp == null)
+                    return 
getLastTrainedModelOrThrowEmptyDatasetException(lastLearnedModel);
+
                 P update = updatesStgy.allUpdatesReducer().apply(totUp);
                 mdl = updater.update(mdl, update);
             }
@@ -189,6 +205,11 @@ public class MLPTrainer<P extends Serializable> extends 
MultiLabelDatasetTrainer
         }
     }
 
+    /** {@inheritDoc} */
+    @Override protected boolean checkState(MultilayerPerceptron mdl) {
+        return true;
+    }
+
     /**
      * Builds a batch of the data by fetching specified rows.
      *

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/PreprocessingTrainer.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/PreprocessingTrainer.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/PreprocessingTrainer.java
index 1886ee5..b977864 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/PreprocessingTrainer.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/PreprocessingTrainer.java
@@ -17,6 +17,7 @@
 
 package org.apache.ignite.ml.preprocessing;
 
+import java.util.Map;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
@@ -24,8 +25,6 @@ import 
org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 
-import java.util.Map;
-
 /**
  * Trainer for preprocessor.
  *

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java
index 8197779..5497177 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java
@@ -38,16 +38,34 @@ public class LinearRegressionLSQRTrainer extends 
SingleLabelDatasetTrainer<Linea
     @Override public <K, V> LinearRegressionModel fit(DatasetBuilder<K, V> 
datasetBuilder,
         IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, 
V, Double> lbExtractor) {
 
+        return updateModel(null, datasetBuilder, featureExtractor, 
lbExtractor);
+    }
+
+    /** {@inheritDoc} */
+    @Override protected <K, V> LinearRegressionModel 
updateModel(LinearRegressionModel mdl,
+        DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, 
V, Double> lbExtractor) {
+
         LSQRResult res;
 
         try (LSQROnHeap<K, V> lsqr = new LSQROnHeap<>(
             datasetBuilder,
             new SimpleLabeledDatasetDataBuilder<>(
                 new FeatureExtractorWrapper<>(featureExtractor),
-                lbExtractor.andThen(e -> new double[]{e})
+                lbExtractor.andThen(e -> new double[] {e})
             )
         )) {
-            res = lsqr.solve(0, 1e-12, 1e-12, 1e8, -1, false, null);
+            double[] x0 = null;
+            if (mdl != null) {
+                int x0Size = mdl.getWeights().size() + 1;
+                Vector weights = mdl.getWeights().like(x0Size);
+                mdl.getWeights().nonZeroes().forEach(ith -> 
weights.set(ith.index(), ith.get()));
+                weights.set(weights.size() - 1, mdl.getIntercept());
+                x0 = weights.asArray();
+            }
+            res = lsqr.solve(0, 1e-12, 1e-12, 1e8, -1, false, x0);
+            if (res == null)
+                return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
         }
         catch (Exception e) {
             throw new RuntimeException(e);
@@ -58,4 +76,9 @@ public class LinearRegressionLSQRTrainer extends 
SingleLabelDatasetTrainer<Linea
 
         return new LinearRegressionModel(weights, x[x.length - 1]);
     }
+
+    /** {@inheritDoc} */
+    @Override protected boolean checkState(LinearRegressionModel mdl) {
+        return true;
+    }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java
index 44f60d1..125ed24 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java
@@ -19,6 +19,7 @@ package org.apache.ignite.ml.regressions.linear;
 
 import java.io.Serializable;
 import java.util.Arrays;
+import java.util.Optional;
 import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
@@ -34,6 +35,7 @@ import org.apache.ignite.ml.nn.UpdatesStrategy;
 import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
 import org.apache.ignite.ml.optimization.LossFunctions;
 import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
+import org.jetbrains.annotations.NotNull;
 
 /**
  * Trainer of the linear regression model based on stochastic gradient descent 
algorithm.
@@ -43,16 +45,16 @@ public class LinearRegressionSGDTrainer<P extends 
Serializable> extends SingleLa
     private final UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy;
 
     /** Max number of iteration. */
-    private final int maxIterations;
+    private int maxIterations = 1000;
 
     /** Batch size. */
-    private final int batchSize;
+    private int batchSize = 10;
 
     /** Number of local iterations. */
-    private final int locIterations;
+    private int locIterations = 100;
 
     /** Seed for random generator. */
-    private final long seed;
+    private long seed = System.currentTimeMillis();
 
     /**
      * Constructs a new instance of linear regression SGD trainer.
@@ -72,10 +74,24 @@ public class LinearRegressionSGDTrainer<P extends 
Serializable> extends SingleLa
         this.seed = seed;
     }
 
+    /**
+     * Constructs a new instance of linear regression SGD trainer.
+     */
+    public LinearRegressionSGDTrainer(UpdatesStrategy<? super 
MultilayerPerceptron, P> updatesStgy) {
+        this.updatesStgy = updatesStgy;
+    }
+
     /** {@inheritDoc} */
     @Override public <K, V> LinearRegressionModel fit(DatasetBuilder<K, V> 
datasetBuilder,
         IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, 
V, Double> lbExtractor) {
 
+        return updateModel(null, datasetBuilder, featureExtractor, 
lbExtractor);
+    }
+
+    /** {@inheritDoc} */
+    @Override protected <K, V> LinearRegressionModel 
updateModel(LinearRegressionModel mdl, DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, 
V, Double> lbExtractor) {
+
         IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, 
MLPArchitecture> archSupplier = dataset -> {
 
             int cols = dataset.compute(data -> {
@@ -108,7 +124,10 @@ public class LinearRegressionSGDTrainer<P extends 
Serializable> extends SingleLa
 
         IgniteBiFunction<K, V, double[]> lbE = (IgniteBiFunction<K, V, 
double[]>)(k, v) -> new double[] {lbExtractor.apply(k, v)};
 
-        MultilayerPerceptron mlp = trainer.fit(datasetBuilder, 
featureExtractor, lbE);
+        MultilayerPerceptron mlp = Optional.ofNullable(mdl)
+            .map(this::restoreMLPState)
+            .map(m -> trainer.update(m, datasetBuilder, featureExtractor, lbE))
+            .orElseGet(() -> trainer.fit(datasetBuilder, featureExtractor, 
lbE));
 
         double[] p = mlp.parameters().getStorage().data();
 
@@ -117,4 +136,72 @@ public class LinearRegressionSGDTrainer<P extends 
Serializable> extends SingleLa
             p[p.length - 1]
         );
     }
+
+    /**
+     * @param mdl Model.
+     * @return state of MLP from last learning.
+     */
+    @NotNull private MultilayerPerceptron 
restoreMLPState(LinearRegressionModel mdl) {
+        Vector weights = mdl.getWeights();
+        double intercept = mdl.getIntercept();
+        MLPArchitecture architecture1 = new MLPArchitecture(weights.size());
+        architecture1 = architecture1.withAddedLayer(1, true, 
Activators.LINEAR);
+        MLPArchitecture architecture = architecture1;
+        MultilayerPerceptron perceptron = new 
MultilayerPerceptron(architecture);
+
+        Vector mlpState = weights.like(weights.size() + 1);
+        weights.nonZeroes().forEach(ith -> mlpState.set(ith.index(), 
ith.get()));
+        mlpState.set(mlpState.size() - 1, intercept);
+        perceptron.setParameters(mlpState);
+        return perceptron;
+    }
+
+    /** {@inheritDoc} */
+    @Override protected boolean checkState(LinearRegressionModel mdl) {
+        return true;
+    }
+
+    /**
+     * Set up the max number of iterations before convergence.
+     *
+     * @param maxIterations The parameter value.
+     * @return Model with new max number of iterations before convergence 
parameter value.
+     */
+    public LinearRegressionSGDTrainer<P> withMaxIterations(int maxIterations) {
+        this.maxIterations = maxIterations;
+        return this;
+    }
+
+    /**
+     * Set up the batchSize parameter.
+     *
+     * @param batchSize The size of learning batch.
+     * @return Trainer with new batch size parameter value.
+     */
+    public LinearRegressionSGDTrainer<P> withBatchSize(int batchSize) {
+        this.batchSize = batchSize;
+        return this;
+    }
+
+    /**
+     * Set up the amount of local iterations of SGD algorithm.
+     *
+     * @param amountOfLocIterations The parameter value.
+     * @return Trainer with new locIterations parameter value.
+     */
+    public LinearRegressionSGDTrainer<P> withLocIterations(int 
amountOfLocIterations) {
+        this.locIterations = amountOfLocIterations;
+        return this;
+    }
+
+    /**
+     * Set up the random seed parameter.
+     *
+     * @param seed Seed for random generator.
+     * @return Trainer with new seed parameter value.
+     */
+    public LinearRegressionSGDTrainer<P> withSeed(long seed) {
+        this.seed = seed;
+        return this;
+    }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java
index 6396279..839dab5 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java
@@ -34,6 +34,7 @@ import org.apache.ignite.ml.nn.UpdatesStrategy;
 import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
 import org.apache.ignite.ml.optimization.LossFunctions;
 import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
+import org.jetbrains.annotations.NotNull;
 
 /**
  * Trainer of the logistic regression model based on stochastic gradient 
descent algorithm.
@@ -76,8 +77,15 @@ public class LogisticRegressionSGDTrainer<P extends 
Serializable> extends Single
     @Override public <K, V> LogisticRegressionModel fit(DatasetBuilder<K, V> 
datasetBuilder,
         IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, 
V, Double> lbExtractor) {
 
-        IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, 
MLPArchitecture> archSupplier = dataset -> {
+        return updateModel(null, datasetBuilder, featureExtractor, 
lbExtractor);
+    }
+
+    /** {@inheritDoc} */
+    @Override protected <K, V> LogisticRegressionModel 
updateModel(LogisticRegressionModel mdl,
+        DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> 
featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
 
+        IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, 
MLPArchitecture> archSupplier = dataset -> {
             int cols = dataset.compute(data -> {
                 if (data.getFeatures() == null)
                     return null;
@@ -106,7 +114,13 @@ public class LogisticRegressionSGDTrainer<P extends 
Serializable> extends Single
             seed
         );
 
-        MultilayerPerceptron mlp = trainer.fit(datasetBuilder, 
featureExtractor, (k, v) -> new double[] {lbExtractor.apply(k, v)});
+        IgniteBiFunction<K, V, double[]> lbExtractorWrapper = (k, v) -> new 
double[] {lbExtractor.apply(k, v)};
+        MultilayerPerceptron mlp;
+        if(mdl != null) {
+            mlp = restoreMLPState(mdl);
+            mlp = trainer.update(mlp, datasetBuilder, featureExtractor, 
lbExtractorWrapper);
+        } else
+            mlp = trainer.fit(datasetBuilder, featureExtractor, 
lbExtractorWrapper);
 
         double[] params = mlp.parameters().getStorage().data();
 
@@ -114,4 +128,28 @@ public class LogisticRegressionSGDTrainer<P extends 
Serializable> extends Single
             params[params.length - 1]
         );
     }
+
+    /**
+     * @param mdl Model.
+     * @return state of MLP from last learning.
+     */
+    @NotNull private MultilayerPerceptron 
restoreMLPState(LogisticRegressionModel mdl) {
+        Vector weights = mdl.weights();
+        double intercept = mdl.intercept();
+        MLPArchitecture architecture1 = new MLPArchitecture(weights.size());
+        architecture1 = architecture1.withAddedLayer(1, true, 
Activators.SIGMOID);
+        MLPArchitecture architecture = architecture1;
+        MultilayerPerceptron perceptron = new 
MultilayerPerceptron(architecture);
+
+        Vector mlpState = weights.like(weights.size() + 1);
+        weights.nonZeroes().forEach(ith -> mlpState.set(ith.index(), 
ith.get()));
+        mlpState.set(mlpState.size() - 1, intercept);
+        perceptron.setParameters(mlpState);
+        return perceptron;
+    }
+
+    /** {@inheritDoc} */
+    @Override protected boolean checkState(LogisticRegressionModel mdl) {
+        return true;
+    }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassModel.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassModel.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassModel.java
index 56d2d29..a7c9118 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassModel.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassModel.java
@@ -21,6 +21,7 @@ import java.io.Serializable;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Objects;
+import java.util.Optional;
 import java.util.TreeMap;
 import org.apache.ignite.ml.Exportable;
 import org.apache.ignite.ml.Exporter;
@@ -103,4 +104,12 @@ public class LogRegressionMultiClassModel implements 
Model<Vector, Double>, Expo
     public void add(double clsLb, LogisticRegressionModel mdl) {
         models.put(clsLb, mdl);
     }
+
+    /**
+     * @param clsLb Class label.
+     * @return model for class label if it exists.
+     */
+    public Optional<LogisticRegressionModel> getModel(Double clsLb) {
+        return Optional.ofNullable(models.get(clsLb));
+    }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java
index 4885373..eb44301 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java
@@ -22,6 +22,7 @@ import java.util.ArrayList;
 import java.util.Collection;
 import java.util.HashSet;
 import java.util.List;
+import java.util.Optional;
 import java.util.Set;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
@@ -33,6 +34,7 @@ import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.nn.MultilayerPerceptron;
 import org.apache.ignite.ml.nn.UpdatesStrategy;
+import 
org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel;
 import 
org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer;
 import 
org.apache.ignite.ml.structures.partition.LabelPartitionDataBuilderOnHeap;
 import org.apache.ignite.ml.structures.partition.LabelPartitionDataOnHeap;
@@ -71,6 +73,19 @@ public class LogRegressionMultiClassTrainer<P extends 
Serializable>
         IgniteBiFunction<K, V, Double> lbExtractor) {
         List<Double> classes = extractClassLabels(datasetBuilder, lbExtractor);
 
+        return updateModel(null, datasetBuilder, featureExtractor, 
lbExtractor);
+    }
+
+    /** {@inheritDoc} */
+    @Override public <K, V> LogRegressionMultiClassModel 
updateModel(LogRegressionMultiClassModel mdl,
+        DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> 
featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
+
+        List<Double> classes = extractClassLabels(datasetBuilder, lbExtractor);
+
+        if(classes.isEmpty())
+            return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
+
         LogRegressionMultiClassModel multiClsMdl = new 
LogRegressionMultiClassModel();
 
         classes.forEach(clsLb -> {
@@ -85,12 +100,23 @@ public class LogRegressionMultiClassTrainer<P extends 
Serializable>
                 else
                     return 0.0;
             };
-            multiClsMdl.add(clsLb, trainer.fit(datasetBuilder, 
featureExtractor, lbTransformer));
+
+            LogisticRegressionModel model = Optional.ofNullable(mdl)
+                .flatMap(multiClassModel -> multiClassModel.getModel(clsLb))
+                .map(learnedModel -> trainer.update(learnedModel, 
datasetBuilder, featureExtractor, lbTransformer))
+                .orElseGet(() -> trainer.fit(datasetBuilder, featureExtractor, 
lbTransformer));
+
+            multiClsMdl.add(clsLb, model);
         });
 
         return multiClsMdl;
     }
 
+    /** {@inheritDoc} */
+    @Override protected boolean checkState(LogRegressionMultiClassModel mdl) {
+        return true;
+    }
+
     /** Iterates among dataset and collects class labels. */
     private <K, V> List<Double> extractClassLabels(DatasetBuilder<K, V> 
datasetBuilder,
         IgniteBiFunction<K, V, Double> lbExtractor) {
@@ -121,7 +147,8 @@ public class LogRegressionMultiClassTrainer<P extends 
Serializable>
                 return Stream.of(a, 
b).flatMap(Collection::stream).collect(Collectors.toSet());
             });
 
-            res.addAll(clsLabels);
+            if (clsLabels != null)
+                res.addAll(clsLabels);
 
         }
         catch (Exception e) {
@@ -191,7 +218,7 @@ public class LogRegressionMultiClassTrainer<P extends 
Serializable>
     }
 
     /**
-     * Set up the regularization parameter.
+     * Set up the random seed parameter.
      *
      * @param seed Seed for random generator.
      * @return Trainer with new seed parameter value.

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
index 933a712..573df1a 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
@@ -22,9 +22,11 @@ import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.PartitionDataBuilder;
 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.math.StorageConstants;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
+import org.apache.ignite.ml.math.primitives.vector.impl.SparseVector;
 import org.apache.ignite.ml.structures.LabeledVector;
 import org.apache.ignite.ml.structures.LabeledVectorSet;
 import 
org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap;
@@ -61,6 +63,14 @@ public class SVMLinearBinaryClassificationTrainer extends 
SingleLabelDatasetTrai
     @Override public <K, V> SVMLinearBinaryClassificationModel 
fit(DatasetBuilder<K, V> datasetBuilder,
         IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, 
V, Double> lbExtractor) {
 
+        return updateModel(null, datasetBuilder, featureExtractor, 
lbExtractor);
+    }
+
+    /** {@inheritDoc} */
+    @Override protected <K, V> SVMLinearBinaryClassificationModel 
updateModel(SVMLinearBinaryClassificationModel mdl,
+        DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> 
featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
+
         assert datasetBuilder != null;
 
         PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<Double, 
LabeledVector>> partDataBuilder = new 
LabeledDatasetPartitionDataBuilderOnHeap<>(
@@ -74,29 +84,57 @@ public class SVMLinearBinaryClassificationTrainer extends 
SingleLabelDatasetTrai
             (upstream, upstreamSize) -> new EmptyContext(),
             partDataBuilder
         )) {
-            final int cols = 
dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> {
-                if (a == null)
-                    return b == null ? 0 : b;
-                if (b == null)
-                    return a;
-                return b;
-            });
-
-            final int weightVectorSizeWithIntercept = cols + 1;
-
-            weights = 
initializeWeightsWithZeros(weightVectorSizeWithIntercept);
+            if (mdl == null) {
+                final int cols = 
dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> {
+                    if (a == null)
+                        return b == null ? 0 : b;
+                    if (b == null)
+                        return a;
+                    return b;
+                });
+
+                final int weightVectorSizeWithIntercept = cols + 1;
+                weights = 
initializeWeightsWithZeros(weightVectorSizeWithIntercept);
+            } else {
+                weights = getStateVector(mdl);
+            }
 
             for (int i = 0; i < this.getAmountOfIterations(); i++) {
                 Vector deltaWeights = calculateUpdates(weights, dataset);
+                if (deltaWeights == null)
+                    return 
getLastTrainedModelOrThrowEmptyDatasetException(mdl);
+
                 weights = weights.plus(deltaWeights); // creates new vector
             }
-        }
-        catch (Exception e) {
+        } catch (Exception e) {
             throw new RuntimeException(e);
         }
         return new SVMLinearBinaryClassificationModel(weights.viewPart(1, 
weights.size() - 1), weights.get(0));
     }
 
+    /** {@inheritDoc} */
+    @Override protected boolean checkState(SVMLinearBinaryClassificationModel 
mdl) {
+        return true;
+    }
+
+    /**
+     * @param mdl Model.
+     * @return vector of model weights with intercept.
+     */
+    private Vector getStateVector(SVMLinearBinaryClassificationModel mdl) {
+        double intercept = mdl.intercept();
+        Vector weights = mdl.weights();
+
+        int stateVectorSize = weights.size() + 1;
+        Vector result = weights.isDense() ?
+            new DenseVector(stateVectorSize) :
+            new SparseVector(stateVectorSize, 
StorageConstants.RANDOM_ACCESS_MODE);
+
+        result.set(0, intercept);
+        weights.nonZeroes().forEach(ith -> result.set(ith.index(), ith.get()));
+        return result;
+    }
+
     /** */
     @NotNull private Vector initializeWeightsWithZeros(int vectorSize) {
         return new DenseVector(vectorSize);

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationModel.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationModel.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationModel.java
index 4b04824..46bf4b2 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationModel.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationModel.java
@@ -21,6 +21,7 @@ import java.io.Serializable;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Objects;
+import java.util.Optional;
 import java.util.TreeMap;
 import org.apache.ignite.ml.Exportable;
 import org.apache.ignite.ml.Exporter;
@@ -102,4 +103,12 @@ public class SVMLinearMultiClassClassificationModel 
implements Model<Vector, Dou
     public void add(double clsLb, SVMLinearBinaryClassificationModel mdl) {
         models.put(clsLb, mdl);
     }
+
+    /**
+     * @param clsLb Class label.
+     * @return model trained for target class if it exists.
+     */
+    public Optional<SVMLinearBinaryClassificationModel> 
getModelForClass(double clsLb) {
+        return Optional.of(models.get(clsLb));
+    }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
index 4b7cc95..b77baa2 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
@@ -57,15 +57,26 @@ public class SVMLinearMultiClassClassificationTrainer
     /**
      * Trains model based on the specified data.
      *
-     * @param datasetBuilder   Dataset builder.
+     * @param datasetBuilder Dataset builder.
      * @param featureExtractor Feature extractor.
-     * @param lbExtractor      Label extractor.
+     * @param lbExtractor Label extractor.
      * @return Model.
      */
     @Override public <K, V> SVMLinearMultiClassClassificationModel 
fit(DatasetBuilder<K, V> datasetBuilder,
-                                                                
IgniteBiFunction<K, V, Vector> featureExtractor,
-                                                                
IgniteBiFunction<K, V, Double> lbExtractor) {
+        IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
+        return updateModel(null, datasetBuilder, featureExtractor, 
lbExtractor);
+    }
+
+    /** {@inheritDoc} */
+    @Override public <K, V> SVMLinearMultiClassClassificationModel updateModel(
+        SVMLinearMultiClassClassificationModel mdl,
+        DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> 
featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
+
         List<Double> classes = extractClassLabels(datasetBuilder, lbExtractor);
+        if (classes.isEmpty())
+            return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
 
         SVMLinearMultiClassClassificationModel multiClsMdl = new 
SVMLinearMultiClassClassificationModel();
 
@@ -84,14 +95,60 @@ public class SVMLinearMultiClassClassificationTrainer
                 else
                     return -1.0;
             };
-            multiClsMdl.add(clsLb, trainer.fit(datasetBuilder, 
featureExtractor, lbTransformer));
+
+            SVMLinearBinaryClassificationModel model;
+            if (mdl == null)
+                model = learnNewModel(trainer, datasetBuilder, 
featureExtractor, lbTransformer);
+            else
+                model = updateModel(mdl, clsLb, trainer, datasetBuilder, 
featureExtractor, lbTransformer);
+            multiClsMdl.add(clsLb, model);
         });
 
         return multiClsMdl;
     }
 
+    /** {@inheritDoc} */
+    @Override protected boolean 
checkState(SVMLinearMultiClassClassificationModel mdl) {
+        return true;
+    }
+
+    /**
+     * Trains model based on the specified data.
+     *
+     * @param svmTrainer Prepared SVM trainer.
+     * @param datasetBuilder Dataset builder.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     */
+    private <K, V> SVMLinearBinaryClassificationModel 
learnNewModel(SVMLinearBinaryClassificationTrainer svmTrainer,
+        DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> 
featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
+
+        return svmTrainer.fit(datasetBuilder, featureExtractor, lbExtractor);
+    }
+
+    /**
+     * Updates already learned model or fit new model if there is no model for 
current class label.
+     *
+     * @param multiClsMdl Learning multi-class model.
+     * @param clsLb Current class label.
+     * @param svmTrainer Prepared SVM trainer.
+     * @param datasetBuilder Dataset builder.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     */
+    private <K, V> SVMLinearBinaryClassificationModel 
updateModel(SVMLinearMultiClassClassificationModel multiClsMdl,
+        Double clsLb, SVMLinearBinaryClassificationTrainer svmTrainer, 
DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, 
V, Double> lbExtractor) {
+
+        return multiClsMdl.getModelForClass(clsLb)
+            .map(learnedModel -> svmTrainer.update(learnedModel, 
datasetBuilder, featureExtractor, lbExtractor))
+            .orElseGet(() -> svmTrainer.fit(datasetBuilder, featureExtractor, 
lbExtractor));
+    }
+
     /** Iterates among dataset and collects class labels. */
-    private <K, V> List<Double> extractClassLabels(DatasetBuilder<K, V> 
datasetBuilder, IgniteBiFunction<K, V, Double> lbExtractor) {
+    private <K, V> List<Double> extractClassLabels(DatasetBuilder<K, V> 
datasetBuilder,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
         assert datasetBuilder != null;
 
         PartitionDataBuilder<K, V, EmptyContext, LabelPartitionDataOnHeap> 
partDataBuilder = new LabelPartitionDataBuilderOnHeap<>(lbExtractor);
@@ -107,7 +164,8 @@ public class SVMLinearMultiClassClassificationTrainer
 
                 final double[] lbs = data.getY();
 
-                for (double lb : lbs) locClsLabels.add(lb);
+                for (double lb : lbs)
+                    locClsLabels.add(lb);
 
                 return locClsLabels;
             }, (a, b) -> {
@@ -118,8 +176,8 @@ public class SVMLinearMultiClassClassificationTrainer
                 return Stream.of(a, 
b).flatMap(Collection::stream).collect(Collectors.toSet());
             });
 
-            res.addAll(clsLabels);
-
+            if (clsLabels != null)
+                res.addAll(clsLabels);
         } catch (Exception e) {
             throw new RuntimeException(e);
         }
@@ -132,7 +190,7 @@ public class SVMLinearMultiClassClassificationTrainer
      * @param lambda The regularization parameter. Should be more than 0.0.
      * @return Trainer with new lambda parameter value.
      */
-    public SVMLinearMultiClassClassificationTrainer  withLambda(double lambda) 
{
+    public SVMLinearMultiClassClassificationTrainer withLambda(double lambda) {
         assert lambda > 0.0;
         this.lambda = lambda;
         return this;
@@ -162,7 +220,7 @@ public class SVMLinearMultiClassClassificationTrainer
      * @param amountOfIterations The parameter value.
      * @return Trainer with new amountOfIterations parameter value.
      */
-    public SVMLinearMultiClassClassificationTrainer  
withAmountOfIterations(int amountOfIterations) {
+    public SVMLinearMultiClassClassificationTrainer withAmountOfIterations(int 
amountOfIterations) {
         this.amountOfIterations = amountOfIterations;
         return this;
     }
@@ -182,7 +240,7 @@ public class SVMLinearMultiClassClassificationTrainer
      * @param amountOfLocIterations The parameter value.
      * @return Trainer with new amountOfLocIterations parameter value.
      */
-    public SVMLinearMultiClassClassificationTrainer  
withAmountOfLocIterations(int amountOfLocIterations) {
+    public SVMLinearMultiClassClassificationTrainer 
withAmountOfLocIterations(int amountOfLocIterations) {
         this.amountOfLocIterations = amountOfLocIterations;
         return this;
     }

Reply via email to