Repository: ignite
Updated Branches:
  refs/heads/master 02afb437f -> f4c18f114


http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java 
b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java
index 2f5d5d6..fb34c93 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java
@@ -26,8 +26,10 @@ import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
 import org.apache.ignite.ml.environment.LearningEnvironment;
+import org.apache.ignite.ml.environment.logging.MLLogger;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.jetbrains.annotations.NotNull;
 
 /**
  * Interface for trainers. Trainer is just a function which produces model 
from the data.
@@ -53,6 +55,71 @@ public abstract class DatasetTrainer<M extends Model, L> {
         IgniteBiFunction<K, V, L> lbExtractor);
 
     /**
+     * Gets state of model in arguments, compare it with training parameters 
of trainer and if they are fit then
+     * trainer updates model in according to new data and return new model. In 
other case trains new model.
+     *
+     * @param mdl Learned model.
+     * @param datasetBuilder Dataset builder.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     * @param <K> Type of a key in {@code upstream} data.
+     * @param <V> Type of a value in {@code upstream} data.
+     * @return Updated model.
+     */
+    public <K,V> M update(M mdl, DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, 
V, L> lbExtractor) {
+
+        if(mdl != null) {
+            if(checkState(mdl)) {
+                return updateModel(mdl, datasetBuilder, featureExtractor, 
lbExtractor);
+            } else {
+                environment.logger(getClass()).log(
+                    MLLogger.VerboseLevel.HIGH,
+                    "Model cannot be updated because of initial state of " +
+                        "it doesn't corresponds to trainer parameters"
+                );
+            }
+        }
+
+        return fit(datasetBuilder, featureExtractor, lbExtractor);
+    }
+
+    /**
+     * @param mdl Model.
+     * @return true if current critical for training parameters correspond to 
parameters from last training.
+     */
+    protected abstract boolean checkState(M mdl);
+
+    /**
+     * Used on update phase when given dataset is empty.
+     * If last trained model exist then method returns it. In other case 
throws IllegalArgumentException.
+     *
+     * @param lastTrainedMdl Model.
+     */
+    @NotNull protected M getLastTrainedModelOrThrowEmptyDatasetException(M 
lastTrainedMdl) {
+        String msg = "Cannot train model on empty dataset";
+        if (lastTrainedMdl != null) {
+            environment.logger(getClass()).log(MLLogger.VerboseLevel.HIGH, 
msg);
+            return lastTrainedMdl;
+        } else
+            throw new EmptyDatasetException();
+    }
+
+    /**
+     * Gets state of model in arguments, update in according to new data and 
return new model.
+     *
+     * @param mdl Learned model.
+     * @param datasetBuilder Dataset builder.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     * @param <K> Type of a key in {@code upstream} data.
+     * @param <V> Type of a value in {@code upstream} data.
+     * @return Updated model.
+     */
+    protected abstract <K, V> M updateModel(M mdl, DatasetBuilder<K, V> 
datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, 
V, L> lbExtractor);
+
+    /**
      * Trains model based on the specified data.
      *
      * @param ignite Ignite instance.
@@ -73,6 +140,27 @@ public abstract class DatasetTrainer<M extends Model, L> {
     }
 
     /**
+     * Gets state of model in arguments, update in according to new data and 
return new model.
+     *
+     * @param mdl Learned model.
+     * @param ignite Ignite instance.
+     * @param cache Ignite cache.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     * @param <K> Type of a key in {@code upstream} data.
+     * @param <V> Type of a value in {@code upstream} data.
+     * @return Updated model.
+     */
+    public <K, V> M update(M mdl, Ignite ignite, IgniteCache<K, V> cache,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, 
V, L> lbExtractor) {
+        return update(
+            mdl, new CacheBasedDatasetBuilder<>(ignite, cache),
+            featureExtractor,
+            lbExtractor
+        );
+    }
+
+    /**
      * Trains model based on the specified data.
      *
      * @param ignite Ignite instance.
@@ -94,6 +182,28 @@ public abstract class DatasetTrainer<M extends Model, L> {
     }
 
     /**
+     * Gets state of model in arguments, update in according to new data and 
return new model.
+     *
+     * @param mdl Learned model.
+     * @param ignite Ignite instance.
+     * @param cache Ignite cache.
+     * @param filter Filter for {@code upstream} data.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     * @param <K> Type of a key in {@code upstream} data.
+     * @param <V> Type of a value in {@code upstream} data.
+     * @return Updated model.
+     */
+    public <K, V> M update(M mdl, Ignite ignite, IgniteCache<K, V> cache, 
IgniteBiPredicate<K, V> filter,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, 
V, L> lbExtractor) {
+        return update(
+            mdl, new CacheBasedDatasetBuilder<>(ignite, cache, filter),
+            featureExtractor,
+            lbExtractor
+        );
+    }
+
+    /**
      * Trains model based on the specified data.
      *
      * @param data Data.
@@ -114,6 +224,27 @@ public abstract class DatasetTrainer<M extends Model, L> {
     }
 
     /**
+     * Gets state of model in arguments, update in according to new data and 
return new model.
+     *
+     * @param mdl Learned model.
+     * @param data Data.
+     * @param parts Number of partitions.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     * @param <K> Type of a key in {@code upstream} data.
+     * @param <V> Type of a value in {@code upstream} data.
+     * @return Updated model.
+     */
+    public <K, V> M update(M mdl, Map<K, V> data, int parts, 
IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, L> lbExtractor) {
+        return update(
+            mdl, new LocalDatasetBuilder<>(data, parts),
+            featureExtractor,
+            lbExtractor
+        );
+    }
+
+    /**
      * Trains model based on the specified data.
      *
      * @param data Data.
@@ -136,10 +267,45 @@ public abstract class DatasetTrainer<M extends Model, L> {
     }
 
     /**
+     * Gets state of model in arguments, update in according to new data and 
return new model.
+     *
+     * @param data Data.
+     * @param filter Filter for {@code upstream} data.
+     * @param parts Number of partitions.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     * @param <K> Type of a key in {@code upstream} data.
+     * @param <V> Type of a value in {@code upstream} data.
+     * @return Updated model.
+     */
+    public <K, V> M update(M mdl, Map<K, V> data, IgniteBiPredicate<K, V> 
filter, int parts,
+        IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, L> lbExtractor) {
+        return update(
+            mdl, new LocalDatasetBuilder<>(data, filter, parts),
+            featureExtractor,
+            lbExtractor
+        );
+    }
+
+    /**
      * Sets learning Environment
      * @param environment Environment.
      */
     public void setEnvironment(LearningEnvironment environment) {
         this.environment = environment;
     }
+
+    /** */
+    public static class EmptyDatasetException extends IllegalArgumentException 
{
+        /** Serial version uid. */
+        private static final long serialVersionUID = 6914650522523293521L;
+
+        /**
+         * Constructs an instance of EmptyDatasetException.
+         */
+        public EmptyDatasetException() {
+            super("Cannot train model on empty dataset");
+        }
+    }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java 
b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java
index de8994a..355048a 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java
@@ -86,6 +86,29 @@ public abstract class DecisionTree<T extends 
ImpurityMeasure<T>> extends Dataset
         }
     }
 
+    /**
+     * Trains new model based on dataset because there is no valid approach to 
update decision trees.
+     *
+     * @param mdl Learned model.
+     * @param datasetBuilder Dataset builder.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     * @param <K> Type of a key in {@code upstream} data.
+     * @param <V> Type of a value in {@code upstream} data.
+     * @return New model based on new dataset.
+     */
+    @Override public <K, V> DecisionTreeNode updateModel(DecisionTreeNode mdl, 
DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, 
V, Double> lbExtractor) {
+
+        return fit(datasetBuilder, featureExtractor, lbExtractor);
+    }
+
+    /** {@inheritDoc} */
+    @Override protected boolean checkState(DecisionTreeNode mdl) {
+        return true;
+    }
+
+    /** */
     public <K,V> DecisionTreeNode fit(Dataset<EmptyContext, DecisionTreeData> 
dataset) {
         return split(dataset, e -> true, 0, 
getImpurityMeasureCalculator(dataset));
     }

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java
index 559dfff..7832584 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java
@@ -64,8 +64,9 @@ public class RandomForestClassifierTrainer
      * This id can be used as index in arrays or lists.
      *
      * @param dataset Dataset.
+     * @return true if initialization was done.
      */
-    @Override protected void init(Dataset<EmptyContext, 
BootstrappedDatasetPartition> dataset) {
+    @Override protected boolean init(Dataset<EmptyContext, 
BootstrappedDatasetPartition> dataset) {
         Set<Double> uniqLabels = dataset.compute(
             x -> {
                 Set<Double> labels = new HashSet<>();
@@ -85,11 +86,14 @@ public class RandomForestClassifierTrainer
             }
         );
 
+        if(uniqLabels == null)
+            return false;
+
         int i = 0;
         for (Double label : uniqLabels)
             lblMapping.put(label, i++);
 
-        super.init(dataset);
+        return super.init(dataset);
     }
 
     /** {@inheritDoc} */

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java
index cb25aa3..91fcf0a 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java
@@ -30,6 +30,7 @@ import java.util.Set;
 import java.util.function.Function;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
+import org.apache.ignite.ml.Model;
 import org.apache.ignite.ml.composition.ModelsComposition;
 import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
@@ -116,7 +117,8 @@ public abstract class RandomForestTrainer<L, S extends 
ImpurityComputer<Bootstra
             new EmptyContextBuilder<>(),
             new BootstrappedDatasetBuilder<>(featureExtractor, lbExtractor, 
cntOfTrees, subsampleSize))) {
 
-            init(dataset);
+            if(!init(dataset))
+                return buildComposition(Collections.emptyList());
             models = fit(dataset);
         }
         catch (Exception e) {
@@ -202,7 +204,8 @@ public abstract class RandomForestTrainer<L, S extends 
ImpurityComputer<Bootstra
      *
      * @param dataset Dataset.
      */
-    protected void init(Dataset<EmptyContext, BootstrappedDatasetPartition> 
dataset) {
+    protected boolean init(Dataset<EmptyContext, BootstrappedDatasetPartition> 
dataset) {
+        return true;
     }
 
     /**
@@ -215,6 +218,8 @@ public abstract class RandomForestTrainer<L, S extends 
ImpurityComputer<Bootstra
         Queue<TreeNode> treesQueue = createRootsQueue();
         ArrayList<TreeRoot> roots = initTrees(treesQueue);
         Map<Integer, BucketMeta> histMeta = computeHistogramMeta(meta, 
dataset);
+        if(histMeta.isEmpty())
+            return Collections.emptyList();
 
         ImpurityHistogramsComputer<S> histogramsComputer = 
createImpurityHistogramsComputer();
         while (!treesQueue.isEmpty()) {
@@ -232,6 +237,23 @@ public abstract class RandomForestTrainer<L, S extends 
ImpurityComputer<Bootstra
         return roots;
     }
 
+    /** {@inheritDoc} */
+    @Override protected boolean checkState(ModelsComposition mdl) {
+        ModelsComposition fakeComposition = 
buildComposition(Collections.emptyList());
+        return mdl.getPredictionsAggregator().getClass() == 
fakeComposition.getPredictionsAggregator().getClass();
+    }
+
+    /** {@inheritDoc} */
+    @Override protected <K, V> ModelsComposition updateModel(ModelsComposition 
mdl, DatasetBuilder<K, V> datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, 
V, Double> lbExtractor) {
+
+        ArrayList<Model<Vector, Double>> oldModels = new 
ArrayList<>(mdl.getModels());
+        ModelsComposition newModels = fit(datasetBuilder, featureExtractor, 
lbExtractor);
+        oldModels.addAll(newModels.getModels());
+
+        return new ModelsComposition(oldModels, 
mdl.getPredictionsAggregator());
+    }
+
     /**
      * Split node with NodeId if need.
      *
@@ -302,6 +324,8 @@ public abstract class RandomForestTrainer<L, S extends 
ImpurityComputer<Bootstra
 
         List<NormalDistributionStatistics> stats = new 
NormalDistributionStatisticsComputer()
             .computeStatistics(meta, dataset);
+        if(stats == null)
+            return Collections.emptyMap();
 
         Map<Integer, BucketMeta> bucketsMeta = new HashMap<>();
         for (int i = 0; i < stats.size(); i++) {

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java
index aae5af1..03f044a 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java
@@ -27,6 +27,7 @@ import org.apache.ignite.ml.math.distances.EuclideanDistance;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
+import org.jetbrains.annotations.NotNull;
 import org.junit.Test;
 
 import static org.junit.Assert.assertEquals;
@@ -39,19 +40,70 @@ public class KMeansTrainerTest {
     /** Precision in test checks. */
     private static final double PRECISION = 1e-2;
 
+    /** Data. */
+    private static final Map<Integer, double[]> data = new HashMap<>();
+
+    static {
+        data.put(0, new double[] {1.0, 1.0, 1.0});
+        data.put(1, new double[] {1.0, 2.0, 1.0});
+        data.put(2, new double[] {2.0, 1.0, 1.0});
+        data.put(3, new double[] {-1.0, -1.0, 2.0});
+        data.put(4, new double[] {-1.0, -2.0, 2.0});
+        data.put(5, new double[] {-2.0, -1.0, 2.0});
+    }
+
     /**
      * A few points, one cluster, one iteration
      */
     @Test
     public void findOneClusters() {
-        Map<Integer, double[]> data = new HashMap<>();
-        data.put(0, new double[]{1.0, 1.0, 1.0});
-        data.put(1, new double[]{1.0, 2.0, 1.0});
-        data.put(2, new double[]{2.0, 1.0, 1.0});
-        data.put(3, new double[]{-1.0, -1.0, 2.0});
-        data.put(4, new double[]{-1.0, -2.0, 2.0});
-        data.put(5, new double[]{-2.0, -1.0, 2.0});
+        KMeansTrainer trainer = createAndCheckTrainer();
+        KMeansModel knnMdl = trainer.withK(1).fit(
+            new LocalDatasetBuilder<>(data, 2),
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[2]
+        );
+
+        Vector firstVector = new DenseVector(new double[] {2.0, 2.0});
+        assertEquals(knnMdl.apply(firstVector), 0.0, PRECISION);
+        Vector secondVector = new DenseVector(new double[] {-2.0, -2.0});
+        assertEquals(knnMdl.apply(secondVector), 0.0, PRECISION);
+        assertEquals(trainer.getMaxIterations(), 1);
+        assertEquals(trainer.getEpsilon(), PRECISION, PRECISION);
+    }
 
+    /** */
+    @Test
+    public void testUpdateMdl() {
+        KMeansTrainer trainer = createAndCheckTrainer();
+        KMeansModel originalMdl = trainer.withK(1).fit(
+            new LocalDatasetBuilder<>(data, 2),
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[2]
+        );
+        KMeansModel updatedMdlOnSameDataset = trainer.update(
+            originalMdl,
+            new LocalDatasetBuilder<>(data, 2),
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[2]
+        );
+        KMeansModel updatedMdlOnEmptyDataset = trainer.update(
+            originalMdl,
+            new LocalDatasetBuilder<>(new HashMap<Integer, double[]>(), 2),
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[2]
+        );
+
+        Vector firstVector = new DenseVector(new double[] {2.0, 2.0});
+        Vector secondVector = new DenseVector(new double[] {-2.0, -2.0});
+        assertEquals(originalMdl.apply(firstVector), 
updatedMdlOnSameDataset.apply(firstVector), PRECISION);
+        assertEquals(originalMdl.apply(secondVector), 
updatedMdlOnSameDataset.apply(secondVector), PRECISION);
+        assertEquals(originalMdl.apply(firstVector), 
updatedMdlOnEmptyDataset.apply(firstVector), PRECISION);
+        assertEquals(originalMdl.apply(secondVector), 
updatedMdlOnEmptyDataset.apply(secondVector), PRECISION);
+    }
+
+    /** */
+    @NotNull private KMeansTrainer createAndCheckTrainer() {
         KMeansTrainer trainer = new KMeansTrainer()
             .withDistance(new EuclideanDistance())
             .withK(10)
@@ -61,20 +113,6 @@ public class KMeansTrainerTest {
         assertEquals(10, trainer.getK());
         assertEquals(2, trainer.getSeed());
         assertTrue(trainer.getDistance() instanceof EuclideanDistance);
-
-        KMeansModel knnMdl = trainer
-            .withK(1)
-            .fit(
-                new LocalDatasetBuilder<>(data, 2),
-                (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 
1)),
-                (k, v) -> v[2]
-            );
-
-        Vector firstVector = new DenseVector(new double[]{2.0, 2.0});
-        assertEquals(knnMdl.apply(firstVector), 0.0, PRECISION);
-        Vector secondVector = new DenseVector(new double[]{-2.0, -2.0});
-        assertEquals(knnMdl.apply(secondVector), 0.0, PRECISION);
-        assertEquals(trainer.getMaxIterations(), 1);
-        assertEquals(trainer.getEpsilon(), PRECISION, PRECISION);
+        return trainer;
     }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java 
b/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java
index acf28e9..745eac9 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java
@@ -22,6 +22,7 @@ import java.util.Set;
 import org.apache.ignite.ml.clustering.kmeans.KMeansModel;
 import org.apache.ignite.ml.clustering.kmeans.KMeansModelFormat;
 import org.apache.ignite.ml.knn.ann.ANNClassificationModel;
+import org.apache.ignite.ml.knn.ann.ANNClassificationTrainer;
 import org.apache.ignite.ml.knn.ann.ANNModelFormat;
 import org.apache.ignite.ml.knn.classification.KNNClassificationModel;
 import org.apache.ignite.ml.knn.classification.KNNModelFormat;
@@ -103,11 +104,11 @@ public class CollectionsTest {
 
         test(new SVMLinearBinaryClassificationModel(null, 1.0), new 
SVMLinearBinaryClassificationModel(null, 0.5));
 
-        test(new ANNClassificationModel(new LabeledVectorSet<>()),
-            new ANNClassificationModel(new LabeledVectorSet<>(1, 1, true)));
+        test(new ANNClassificationModel(new LabeledVectorSet<>(), new 
ANNClassificationTrainer.CentroidStat()),
+            new ANNClassificationModel(new LabeledVectorSet<>(1, 1, true), new 
ANNClassificationTrainer.CentroidStat()));
 
-        test(new ANNModelFormat(1, new ManhattanDistance(), NNStrategy.SIMPLE, 
new LabeledVectorSet<>()),
-            new ANNModelFormat(2, new ManhattanDistance(), NNStrategy.SIMPLE, 
new LabeledVectorSet<>()));
+        test(new ANNModelFormat(1, new ManhattanDistance(), NNStrategy.SIMPLE, 
new LabeledVectorSet<>(), new ANNClassificationTrainer.CentroidStat()),
+            new ANNModelFormat(2, new ManhattanDistance(), NNStrategy.SIMPLE, 
new LabeledVectorSet<>(), new ANNClassificationTrainer.CentroidStat()));
     }
 
     /** Test classes that have all instances equal (eg, metrics). */

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/test/java/org/apache/ignite/ml/common/LocalModelsTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/common/LocalModelsTest.java 
b/modules/ml/src/test/java/org/apache/ignite/ml/common/LocalModelsTest.java
index 17d9c1a..9315850 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/common/LocalModelsTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/common/LocalModelsTest.java
@@ -32,6 +32,7 @@ import org.apache.ignite.ml.clustering.kmeans.KMeansTrainer;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
 import org.apache.ignite.ml.knn.NNClassificationModel;
 import org.apache.ignite.ml.knn.ann.ANNClassificationModel;
+import org.apache.ignite.ml.knn.ann.ANNClassificationTrainer;
 import org.apache.ignite.ml.knn.ann.ANNModelFormat;
 import org.apache.ignite.ml.knn.ann.ProbableLabel;
 import org.apache.ignite.ml.knn.classification.KNNClassificationModel;
@@ -237,7 +238,7 @@ public class LocalModelsTest {
         executeModelTest(mdlFilePath -> {
             final LabeledVectorSet<ProbableLabel, LabeledVector> centers = new 
LabeledVectorSet<>();
 
-            NNClassificationModel mdl = new ANNClassificationModel(centers)
+            NNClassificationModel mdl = new ANNClassificationModel(centers, 
new ANNClassificationTrainer.CentroidStat())
                 .withK(4)
                 .withDistanceMeasure(new ManhattanDistance())
                 .withStrategy(NNStrategy.WEIGHTED);
@@ -250,7 +251,7 @@ public class LocalModelsTest {
             Assert.assertNotNull(load);
 
 
-            NNClassificationModel importedMdl = new 
ANNClassificationModel(load.getCandidates())
+            NNClassificationModel importedMdl = new 
ANNClassificationModel(load.getCandidates(), new 
ANNClassificationTrainer.CentroidStat())
                 .withK(load.getK())
                 .withDistanceMeasure(load.getDistanceMeasure())
                 .withStrategy(load.getStgy());

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java
index 4452668..3e340f6 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java
@@ -54,7 +54,7 @@ public class GDBTrainerTest {
             learningSample.put(i, new double[] {xs[i], ys[i]});
         }
 
-        DatasetTrainer<Model<Vector, Double>, Double> trainer
+        DatasetTrainer<ModelsComposition, Double> trainer
             = new GDBRegressionOnTreesTrainer(1.0, 2000, 3, 
0.0).withUseIndex(true);
 
         Model<Vector, Double> mdl = trainer.fit(

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java 
b/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java
index 7289b1d..d8fb620 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java
@@ -26,6 +26,7 @@ import org.apache.ignite.ml.knn.ann.ANNClassificationModel;
 import org.apache.ignite.ml.knn.ann.ANNClassificationTrainer;
 import org.apache.ignite.ml.knn.classification.NNStrategy;
 import org.apache.ignite.ml.math.distances.EuclideanDistance;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.junit.Assert;
 import org.junit.Test;
@@ -68,4 +69,47 @@ public class ANNClassificationTest extends TrainerTest {
         
Assert.assertTrue(mdl.toString(true).contains(NNStrategy.SIMPLE.name()));
         
Assert.assertTrue(mdl.toString(false).contains(NNStrategy.SIMPLE.name()));
     }
+
+    /** */
+    @Test
+    public void testUpdate() {
+        Map<Integer, double[]> cacheMock = new HashMap<>();
+
+        for (int i = 0; i < twoClusters.length; i++)
+            cacheMock.put(i, twoClusters[i]);
+
+        ANNClassificationTrainer trainer = new ANNClassificationTrainer()
+            .withK(10)
+            .withMaxIterations(10)
+            .withEpsilon(1e-4)
+            .withDistance(new EuclideanDistance());
+
+        ANNClassificationModel originalMdl = (ANNClassificationModel) 
trainer.withSeed(1234L).fit(
+            cacheMock,
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        ).withK(3)
+            .withDistanceMeasure(new EuclideanDistance())
+            .withStrategy(NNStrategy.SIMPLE);
+
+        ANNClassificationModel updatedOnSameDataset = 
trainer.withSeed(1234L).update(originalMdl,
+            cacheMock, parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[2]
+        );
+
+        ANNClassificationModel updatedOnEmptyDataset = 
trainer.withSeed(1234L).update(originalMdl,
+            new HashMap<Integer, double[]>(), parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[2]
+        );
+
+        Vector v1 = VectorUtils.of(550, 550);
+        Vector v2 = VectorUtils.of(-550, -550);
+        TestUtils.assertEquals(originalMdl.apply(v1), 
updatedOnSameDataset.apply(v1), PRECISION);
+        TestUtils.assertEquals(originalMdl.apply(v2), 
updatedOnSameDataset.apply(v2), PRECISION);
+        TestUtils.assertEquals(originalMdl.apply(v1), 
updatedOnEmptyDataset.apply(v1), PRECISION);
+        TestUtils.assertEquals(originalMdl.apply(v2), 
updatedOnEmptyDataset.apply(v2), PRECISION);
+    }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java 
b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
index c5a5c1c..748123a 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
@@ -174,4 +174,43 @@ public class KNNClassificationTest {
         Vector vector = new DenseVector(new double[] {-1.01, -1.01});
         assertEquals(knnMdl.apply(vector), 1.0);
     }
+
+    /** */
+    @Test
+    public void testUpdate() {
+        Map<Integer, double[]> data = new HashMap<>();
+        data.put(0, new double[] {10.0, 10.0, 1.0});
+        data.put(1, new double[] {10.0, 20.0, 1.0});
+        data.put(2, new double[] {-1, -1, 1.0});
+        data.put(3, new double[] {-2, -2, 2.0});
+        data.put(4, new double[] {-1.0, -2.0, 2.0});
+        data.put(5, new double[] {-2.0, -1.0, 2.0});
+
+        KNNClassificationTrainer trainer = new KNNClassificationTrainer();
+
+        KNNClassificationModel originalMdl = 
(KNNClassificationModel)trainer.fit(
+            data,
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[2]
+        ).withK(3)
+            .withDistanceMeasure(new EuclideanDistance())
+            .withStrategy(NNStrategy.WEIGHTED);
+
+        KNNClassificationModel updatedOnSameDataset = 
trainer.update(originalMdl,
+            data, parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[2]
+        );
+
+        KNNClassificationModel updatedOnEmptyDataset = 
trainer.update(originalMdl,
+            new HashMap<Integer, double[]>(), parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[2]
+        );
+
+        Vector vector = new DenseVector(new double[] {-1.01, -1.01});
+        assertEquals(originalMdl.apply(vector), 
updatedOnSameDataset.apply(vector));
+        assertEquals(originalMdl.apply(vector), 
updatedOnEmptyDataset.apply(vector));
+    }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java 
b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java
index 5504e1a..52ff1ec 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java
@@ -35,6 +35,8 @@ import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
 
+import static junit.framework.TestCase.assertEquals;
+
 /**
  * Tests for {@link KNNRegressionTrainer}.
  */
@@ -135,4 +137,42 @@ public class KNNRegressionTest {
         Assert.assertTrue(knnMdl.toString(true).contains(stgy.name()));
         Assert.assertTrue(knnMdl.toString(false).contains(stgy.name()));
     }
+
+    /** */
+    @Test
+    public void testUpdate() {
+        Map<Integer, double[]> data = new HashMap<>();
+        data.put(0, new double[] {11.0, 0, 0, 0, 0, 0});
+        data.put(1, new double[] {12.0, 2.0, 0, 0, 0, 0});
+        data.put(2, new double[] {13.0, 0, 3.0, 0, 0, 0});
+        data.put(3, new double[] {14.0, 0, 0, 4.0, 0, 0});
+        data.put(4, new double[] {15.0, 0, 0, 0, 5.0, 0});
+        data.put(5, new double[] {16.0, 0, 0, 0, 0, 6.0});
+
+        KNNRegressionTrainer trainer = new KNNRegressionTrainer();
+
+        KNNRegressionModel originalMdl = (KNNRegressionModel) trainer.fit(
+            new LocalDatasetBuilder<>(data, parts),
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        ).withK(1)
+            .withDistanceMeasure(new EuclideanDistance())
+            .withStrategy(NNStrategy.SIMPLE);
+
+        KNNRegressionModel updatedOnSameDataset = trainer.update(originalMdl,
+            data, parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[2]
+        );
+
+        KNNRegressionModel updatedOnEmptyDataset = trainer.update(originalMdl,
+            new HashMap<Integer, double[]>(), parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[2]
+        );
+
+        Vector vector = new DenseVector(new double[] {0, 0, 0, 5.0, 0.0});
+        assertEquals(originalMdl.apply(vector), 
updatedOnSameDataset.apply(vector));
+        assertEquals(originalMdl.apply(vector), 
updatedOnEmptyDataset.apply(vector));
+    }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java 
b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java
index a1d601c..6a6555e 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTrainerTest.java
@@ -29,6 +29,7 @@ import 
org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
 import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
 import org.apache.ignite.ml.optimization.LossFunctions;
+import org.apache.ignite.ml.optimization.SmoothParametrized;
 import 
org.apache.ignite.ml.optimization.updatecalculators.NesterovParameterUpdate;
 import 
org.apache.ignite.ml.optimization.updatecalculators.NesterovUpdateCalculator;
 import 
org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate;
@@ -154,6 +155,69 @@ public class MLPTrainerTest {
 
             TestUtils.checkIsInEpsilonNeighbourhood(new DenseVector(new 
double[]{0.0}), predict.getRow(0), 1E-1);
         }
+
+        /** */
+        @Test
+        public void testUpdate() {
+            UpdatesStrategy<SmoothParametrized, SimpleGDParameterUpdate> 
updatesStgy = new UpdatesStrategy<>(
+                new SimpleGDUpdateCalculator(0.2),
+                SimpleGDParameterUpdate::sumLocal,
+                SimpleGDParameterUpdate::avg
+            );
+
+            Map<Integer, double[][]> xorData = new HashMap<>();
+            xorData.put(0, new double[][]{{0.0, 0.0}, {0.0}});
+            xorData.put(1, new double[][]{{0.0, 1.0}, {1.0}});
+            xorData.put(2, new double[][]{{1.0, 0.0}, {1.0}});
+            xorData.put(3, new double[][]{{1.0, 1.0}, {0.0}});
+
+            MLPArchitecture arch = new MLPArchitecture(2).
+                withAddedLayer(10, true, Activators.RELU).
+                withAddedLayer(1, false, Activators.SIGMOID);
+
+            MLPTrainer<SimpleGDParameterUpdate> trainer = new MLPTrainer<>(
+                arch,
+                LossFunctions.MSE,
+                updatesStgy,
+                3000,
+                batchSize,
+                50,
+                123L
+            );
+
+            MultilayerPerceptron originalMdl = trainer.fit(
+                xorData,
+                parts,
+                (k, v) -> VectorUtils.of(v[0]),
+                (k, v) -> v[1]
+            );
+
+            MultilayerPerceptron updatedOnSameDS = trainer.update(
+                originalMdl,
+                xorData,
+                parts,
+                (k, v) -> VectorUtils.of(v[0]),
+                (k, v) -> v[1]
+            );
+
+            MultilayerPerceptron updatedOnEmptyDS = trainer.update(
+                originalMdl,
+                new HashMap<Integer, double[][]>(),
+                parts,
+                (k, v) -> VectorUtils.of(v[0]),
+                (k, v) -> v[1]
+            );
+
+            DenseMatrix matrix = new DenseMatrix(new double[][] {
+                {0.0, 0.0},
+                {0.0, 1.0},
+                {1.0, 0.0},
+                {1.0, 1.0}
+            });
+
+            
TestUtils.checkIsInEpsilonNeighbourhood(originalMdl.apply(matrix).getRow(0), 
updatedOnSameDS.apply(matrix).getRow(0), 1E-1);
+            
TestUtils.checkIsInEpsilonNeighbourhood(originalMdl.apply(matrix).getRow(0), 
updatedOnEmptyDS.apply(matrix).getRow(0), 1E-1);
+        }
     }
 
     /**

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java
index d16ae72..9c35ac7 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java
@@ -101,4 +101,55 @@ public class LinearRegressionLSQRTrainerTest extends 
TrainerTest {
 
         assertEquals(intercept, mdl.getIntercept(), 1e-6);
     }
+
+    /** */
+    @Test
+    public void testUpdate() {
+        Random rnd = new Random(0);
+        Map<Integer, double[]> data = new HashMap<>();
+        double[] coef = new double[100];
+        double intercept = rnd.nextDouble() * 10;
+
+        for (int i = 0; i < 100000; i++) {
+            double[] x = new double[coef.length + 1];
+
+            for (int j = 0; j < coef.length; j++)
+                x[j] = rnd.nextDouble() * 10;
+
+            x[coef.length] = intercept;
+
+            data.put(i, x);
+        }
+
+        LinearRegressionLSQRTrainer trainer = new 
LinearRegressionLSQRTrainer();
+
+        LinearRegressionModel originalModel = trainer.fit(
+            data,
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[coef.length]
+        );
+
+        LinearRegressionModel updatedOnSameDS = trainer.update(
+            originalModel,
+            data,
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[coef.length]
+        );
+
+        LinearRegressionModel updatedOnEmpyDS = trainer.update(
+            originalModel,
+            new HashMap<Integer, double[]>(),
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[coef.length]
+        );
+
+        assertArrayEquals(originalModel.getWeights().getStorage().data(), 
updatedOnSameDS.getWeights().getStorage().data(), 1e-6);
+        assertEquals(originalModel.getIntercept(), 
updatedOnSameDS.getIntercept(), 1e-6);
+
+        assertArrayEquals(originalModel.getWeights().getStorage().data(), 
updatedOnEmpyDS.getWeights().getStorage().data(), 1e-6);
+        assertEquals(originalModel.getIntercept(), 
updatedOnEmpyDS.getIntercept(), 1e-6);
+    }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java
index 349e712..86b0f27 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java
@@ -72,4 +72,66 @@ public class LinearRegressionSGDTrainerTest extends 
TrainerTest {
 
         assertEquals(2.8421709430404007e-14, mdl.getIntercept(), 1e-1);
     }
+
+    /** */
+    @Test
+    public void testUpdate() {
+        Map<Integer, double[]> data = new HashMap<>();
+        data.put(0, new double[]{-1.0915526, 1.81983527, -0.91409478, 
0.70890712, -24.55724107});
+        data.put(1, new double[]{-0.61072904, 0.37545517, 0.21705352, 
0.09516495, -26.57226867});
+        data.put(2, new double[]{0.05485406, 0.88219898, -0.80584547, 
0.94668307, 61.80919728});
+        data.put(3, new double[]{-0.24835094, -0.34000053, -1.69984651, 
-1.45902635, -161.65525991});
+        data.put(4, new double[]{0.63675392, 0.31675535, 0.38837437, 
-1.1221971, -14.46432611});
+        data.put(5, new double[]{0.14194017, 2.18158997, -0.28397346, 
-0.62090588, -3.2122197});
+        data.put(6, new double[]{-0.53487507, 1.4454797, 0.21570443, 
-0.54161422, -46.5469012});
+        data.put(7, new double[]{-1.58812173, -0.73216803, -2.15670676, 
-1.03195988, -247.23559889});
+        data.put(8, new double[]{0.20702671, 0.92864654, 0.32721202, 
-0.09047503, 31.61484949});
+        data.put(9, new double[]{-0.37890345, -0.04846179, -0.84122753, 
-1.14667474, -124.92598583});
+
+        LinearRegressionSGDTrainer<?> trainer = new 
LinearRegressionSGDTrainer<>(new UpdatesStrategy<>(
+            new RPropUpdateCalculator(),
+            RPropParameterUpdate::sumLocal,
+            RPropParameterUpdate::avg
+        ), 100000, 10, 100, 0L);
+
+        LinearRegressionModel originalModel = trainer.withSeed(0).fit(
+            data,
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[4]
+        );
+
+
+        LinearRegressionModel updatedOnSameDS = trainer.withSeed(0).update(
+            originalModel,
+            data,
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[4]
+        );
+
+        LinearRegressionModel updatedOnEmptyDS = trainer.withSeed(0).update(
+            originalModel,
+            new HashMap<Integer, double[]>(),
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[4]
+        );
+
+        assertArrayEquals(
+            originalModel.getWeights().getStorage().data(),
+            updatedOnSameDS.getWeights().getStorage().data(),
+            1.0
+        );
+
+        assertEquals(originalModel.getIntercept(), 
updatedOnSameDS.getIntercept(), 1.0);
+
+        assertArrayEquals(
+            originalModel.getWeights().getStorage().data(),
+            updatedOnEmptyDS.getWeights().getStorage().data(),
+            1e-1
+        );
+
+        assertEquals(originalModel.getIntercept(), 
updatedOnEmptyDS.getIntercept(), 1e-1);
+    }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java
index 1f8c5d1..f08501c 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java
@@ -19,9 +19,11 @@ package org.apache.ignite.ml.regressions.logistic;
 
 import java.util.Arrays;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 import org.apache.ignite.ml.TestUtils;
 import org.apache.ignite.ml.common.TrainerTest;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.apache.ignite.ml.nn.UpdatesStrategy;
 import org.apache.ignite.ml.optimization.SmoothParametrized;
@@ -81,4 +83,60 @@ public class LogRegMultiClassTrainerTest extends TrainerTest 
{
         TestUtils.assertEquals(2, mdl.apply(VectorUtils.of(-10, -10)), 
PRECISION);
         TestUtils.assertEquals(3, mdl.apply(VectorUtils.of(10, -10)), 
PRECISION);
     }
+
+    /** */
+    @Test
+    public void testUpdate() {
+        Map<Integer, double[]> cacheMock = new HashMap<>();
+
+        for (int i = 0; i < fourSetsInSquareVertices.length; i++)
+            cacheMock.put(i, fourSetsInSquareVertices[i]);
+
+        LogRegressionMultiClassTrainer<?> trainer = new 
LogRegressionMultiClassTrainer<>()
+            .withUpdatesStgy(new UpdatesStrategy<>(
+                new SimpleGDUpdateCalculator(0.2),
+                SimpleGDParameterUpdate::sumLocal,
+                SimpleGDParameterUpdate::avg
+            ))
+            .withAmountOfIterations(1000)
+            .withAmountOfLocIterations(10)
+            .withBatchSize(100)
+            .withSeed(123L);
+
+        LogRegressionMultiClassModel originalModel = trainer.fit(
+            cacheMock,
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        );
+
+        LogRegressionMultiClassModel updatedOnSameDS = trainer.update(
+            originalModel,
+            cacheMock,
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        );
+
+        LogRegressionMultiClassModel updatedOnEmptyDS = trainer.update(
+            originalModel,
+            new HashMap<Integer, double[]>(),
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        );
+
+        List<Vector> vectors = Arrays.asList(
+            VectorUtils.of(10, 10),
+            VectorUtils.of(-10, 10),
+            VectorUtils.of(-10, -10),
+            VectorUtils.of(10, -10)
+        );
+
+
+        for (Vector vec : vectors) {
+            TestUtils.assertEquals(originalModel.apply(vec), 
updatedOnSameDS.apply(vec), PRECISION);
+            TestUtils.assertEquals(originalModel.apply(vec), 
updatedOnEmptyDS.apply(vec), PRECISION);
+        }
+    }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java
index 5bd2dbd..1da0d1a 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java
@@ -22,6 +22,7 @@ import java.util.HashMap;
 import java.util.Map;
 import org.apache.ignite.ml.TestUtils;
 import org.apache.ignite.ml.common.TrainerTest;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.apache.ignite.ml.nn.UpdatesStrategy;
 import 
org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
@@ -60,4 +61,49 @@ public class LogisticRegressionSGDTrainerTest extends 
TrainerTest {
         TestUtils.assertEquals(0, mdl.apply(VectorUtils.of(100, 10)), 
PRECISION);
         TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(10, 100)), 
PRECISION);
     }
+
+    /** */
+    @Test
+    public void testUpdate() {
+        Map<Integer, double[]> cacheMock = new HashMap<>();
+
+        for (int i = 0; i < twoLinearlySeparableClasses.length; i++)
+            cacheMock.put(i, twoLinearlySeparableClasses[i]);
+
+        LogisticRegressionSGDTrainer<?> trainer = new 
LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>(
+            new SimpleGDUpdateCalculator().withLearningRate(0.2),
+            SimpleGDParameterUpdate::sumLocal,
+            SimpleGDParameterUpdate::avg
+        ), 100000, 10, 100, 123L);
+
+        LogisticRegressionModel originalModel = trainer.fit(
+            cacheMock,
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        );
+
+        LogisticRegressionModel updatedOnSameDS = trainer.update(
+            originalModel,
+            cacheMock,
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        );
+
+        LogisticRegressionModel updatedOnEmptyDS = trainer.update(
+            originalModel,
+            new HashMap<Integer, double[]>(),
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        );
+
+        Vector v1 = VectorUtils.of(100, 10);
+        Vector v2 = VectorUtils.of(10, 100);
+        TestUtils.assertEquals(originalModel.apply(v1), 
updatedOnSameDS.apply(v1), PRECISION);
+        TestUtils.assertEquals(originalModel.apply(v2), 
updatedOnSameDS.apply(v2), PRECISION);
+        TestUtils.assertEquals(originalModel.apply(v2), 
updatedOnEmptyDS.apply(v2), PRECISION);
+        TestUtils.assertEquals(originalModel.apply(v1), 
updatedOnEmptyDS.apply(v1), PRECISION);
+    }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java 
b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java
index 5630bee..263bb6d 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java
@@ -22,6 +22,7 @@ import java.util.HashMap;
 import java.util.Map;
 import org.apache.ignite.ml.TestUtils;
 import org.apache.ignite.ml.common.TrainerTest;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.junit.Test;
 
@@ -52,4 +53,44 @@ public class SVMBinaryTrainerTest extends TrainerTest {
         TestUtils.assertEquals(-1, mdl.apply(VectorUtils.of(100, 10)), 
PRECISION);
         TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(10, 100)), 
PRECISION);
     }
+
+    /** */
+    @Test
+    public void testUpdate() {
+        Map<Integer, double[]> cacheMock = new HashMap<>();
+
+        for (int i = 0; i < twoLinearlySeparableClasses.length; i++)
+            cacheMock.put(i, twoLinearlySeparableClasses[i]);
+
+        SVMLinearBinaryClassificationTrainer trainer = new 
SVMLinearBinaryClassificationTrainer()
+            .withAmountOfIterations(1000)
+            .withSeed(1234L);
+
+        SVMLinearBinaryClassificationModel originalModel = trainer.fit(
+            cacheMock,
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        );
+
+        SVMLinearBinaryClassificationModel updatedOnSameDS = trainer.update(
+            originalModel,
+            cacheMock,
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        );
+
+        SVMLinearBinaryClassificationModel updatedOnEmptyDS = trainer.update(
+            originalModel,
+            new HashMap<Integer, double[]>(),
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        );
+
+        Vector v = VectorUtils.of(100, 10);
+        TestUtils.assertEquals(originalModel.apply(v), 
updatedOnSameDS.apply(v), PRECISION);
+        TestUtils.assertEquals(originalModel.apply(v), 
updatedOnEmptyDS.apply(v), PRECISION);
+    }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java
index 7ea28c2..e0c62af 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java
@@ -22,6 +22,7 @@ import java.util.HashMap;
 import java.util.Map;
 import org.apache.ignite.ml.TestUtils;
 import org.apache.ignite.ml.common.TrainerTest;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.junit.Test;
 
@@ -54,4 +55,46 @@ public class SVMMultiClassTrainerTest extends TrainerTest {
         TestUtils.assertEquals(0, mdl.apply(VectorUtils.of(100, 10)), 
PRECISION);
         TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(10, 100)), 
PRECISION);
     }
+
+    /** */
+    @Test
+    public void testUpdate() {
+        Map<Integer, double[]> cacheMock = new HashMap<>();
+
+        for (int i = 0; i < twoLinearlySeparableClasses.length; i++)
+            cacheMock.put(i, twoLinearlySeparableClasses[i]);
+
+        SVMLinearMultiClassClassificationTrainer trainer = new 
SVMLinearMultiClassClassificationTrainer()
+            .withLambda(0.3)
+            .withAmountOfLocIterations(10)
+            .withAmountOfIterations(100)
+            .withSeed(1234L);
+
+        SVMLinearMultiClassClassificationModel originalModel = trainer.fit(
+            cacheMock,
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        );
+
+        SVMLinearMultiClassClassificationModel updatedOnSameDS = 
trainer.update(
+            originalModel,
+            cacheMock,
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        );
+
+        SVMLinearMultiClassClassificationModel updatedOnEmptyDS = 
trainer.update(
+            originalModel,
+            new HashMap<Integer, double[]>(),
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        );
+
+        Vector v = VectorUtils.of(100, 10);
+        TestUtils.assertEquals(originalModel.apply(v), 
updatedOnSameDS.apply(v), PRECISION);
+        TestUtils.assertEquals(originalModel.apply(v), 
updatedOnEmptyDS.apply(v), PRECISION);
+    }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java
index 4abf508..087f4e8 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java
@@ -24,6 +24,7 @@ import java.util.Map;
 import org.apache.ignite.ml.composition.ModelsComposition;
 import 
org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator;
 import org.apache.ignite.ml.dataset.feature.FeatureMeta;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -75,7 +76,7 @@ public class RandomForestClassifierTrainerTest {
         }
 
         ArrayList<FeatureMeta> meta = new ArrayList<>();
-        for(int i = 0; i < 4; i++)
+        for (int i = 0; i < 4; i++)
             meta.add(new FeatureMeta("", i, false));
         RandomForestClassifierTrainer trainer = new 
RandomForestClassifierTrainer(meta)
             .withCountOfTrees(5)
@@ -86,4 +87,34 @@ public class RandomForestClassifierTrainerTest {
         assertTrue(mdl.getPredictionsAggregator() instanceof 
OnMajorityPredictionsAggregator);
         assertEquals(5, mdl.getModels().size());
     }
+
+    /** */
+    @Test
+    public void testUpdate() {
+        int sampleSize = 1000;
+        Map<double[], Double> sample = new HashMap<>();
+        for (int i = 0; i < sampleSize; i++) {
+            double x1 = i;
+            double x2 = x1 / 10.0;
+            double x3 = x2 / 10.0;
+            double x4 = x3 / 10.0;
+
+            sample.put(new double[] {x1, x2, x3, x4}, (double)(i % 2));
+        }
+
+        ArrayList<FeatureMeta> meta = new ArrayList<>();
+        for (int i = 0; i < 4; i++)
+            meta.add(new FeatureMeta("", i, false));
+        RandomForestClassifierTrainer trainer = new 
RandomForestClassifierTrainer(meta)
+            .withCountOfTrees(100)
+            .withFeaturesCountSelectionStrgy(x -> 2);
+
+        ModelsComposition originalModel = trainer.fit(sample, parts, (k, v) -> 
VectorUtils.of(k), (k, v) -> v);
+        ModelsComposition updatedOnSameDS = trainer.update(originalModel, 
sample, parts, (k, v) -> VectorUtils.of(k), (k, v) -> v);
+        ModelsComposition updatedOnEmptyDS = trainer.update(originalModel, new 
HashMap<double[], Double>(), parts, (k, v) -> VectorUtils.of(k), (k, v) -> v);
+
+        Vector v = VectorUtils.of(5, 0.5, 0.05, 0.005);
+        assertEquals(originalModel.apply(v), updatedOnSameDS.apply(v), 0.01);
+        assertEquals(originalModel.apply(v), updatedOnEmptyDS.apply(v), 0.01);
+    }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/f4c18f11/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java
index c4a4a75..fcc20bd 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java
@@ -24,6 +24,7 @@ import java.util.Map;
 import org.apache.ignite.ml.composition.ModelsComposition;
 import 
org.apache.ignite.ml.composition.predictionsaggregator.MeanValuePredictionsAggregator;
 import org.apache.ignite.ml.dataset.feature.FeatureMeta;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -82,4 +83,34 @@ public class RandomForestRegressionTrainerTest {
         assertTrue(mdl.getPredictionsAggregator() instanceof 
MeanValuePredictionsAggregator);
         assertEquals(5, mdl.getModels().size());
     }
+
+    /** */
+    @Test
+    public void testUpdate() {
+        int sampleSize = 1000;
+        Map<double[], Double> sample = new HashMap<>();
+        for (int i = 0; i < sampleSize; i++) {
+            double x1 = i;
+            double x2 = x1 / 10.0;
+            double x3 = x2 / 10.0;
+            double x4 = x3 / 10.0;
+
+            sample.put(new double[] {x1, x2, x3, x4}, (double)(i % 2));
+        }
+
+        ArrayList<FeatureMeta> meta = new ArrayList<>();
+        for (int i = 0; i < 4; i++)
+            meta.add(new FeatureMeta("", i, false));
+        RandomForestRegressionTrainer trainer = new 
RandomForestRegressionTrainer(meta)
+            .withCountOfTrees(100)
+            .withFeaturesCountSelectionStrgy(x -> 2);
+
+        ModelsComposition originalModel = trainer.fit(sample, parts, (k, v) -> 
VectorUtils.of(k), (k, v) -> v);
+        ModelsComposition updatedOnSameDS = trainer.update(originalModel, 
sample, parts, (k, v) -> VectorUtils.of(k), (k, v) -> v);
+        ModelsComposition updatedOnEmptyDS = trainer.update(originalModel, new 
HashMap<double[], Double>(), parts, (k, v) -> VectorUtils.of(k), (k, v) -> v);
+
+        Vector v = VectorUtils.of(5, 0.5, 0.05, 0.005);
+        assertEquals(originalModel.apply(v), updatedOnSameDS.apply(v), 0.1);
+        assertEquals(originalModel.apply(v), updatedOnEmptyDS.apply(v), 0.1);
+    }
 }

Reply via email to