Repository: ignite
Updated Branches:
  refs/heads/master ed6bf5ac2 -> 6225c56ea


http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/FeatureMatrixWithLabelsOnHeapDataBuilder.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/FeatureMatrixWithLabelsOnHeapDataBuilder.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/FeatureMatrixWithLabelsOnHeapDataBuilder.java
new file mode 100644
index 0000000..be1724c
--- /dev/null
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/FeatureMatrixWithLabelsOnHeapDataBuilder.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.dataset.primitive;
+
+import java.io.Serializable;
+import java.util.Iterator;
+import org.apache.ignite.ml.dataset.PartitionDataBuilder;
+import org.apache.ignite.ml.dataset.UpstreamEntry;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.tree.data.DecisionTreeData;
+
+/**
+ * A partition {@code data} builder that makes {@link DecisionTreeData}.
+ *
+ * @param <K> Type of a key in <tt>upstream</tt> data.
+ * @param <V> Type of a value in <tt>upstream</tt> data.
+ * @param <C> Type of a partition <tt>context</tt>.
+ */
+public class FeatureMatrixWithLabelsOnHeapDataBuilder<K, V, C extends 
Serializable>
+    implements PartitionDataBuilder<K, V, C, 
FeatureMatrixWithLabelsOnHeapData> {
+    /** Serial version uid. */
+    private static final long serialVersionUID = 6273736987424171813L;
+
+    /** Function that extracts features from an {@code upstream} data. */
+    private final IgniteBiFunction<K, V, Vector> featureExtractor;
+
+    /** Function that extracts labels from an {@code upstream} data. */
+    private final IgniteBiFunction<K, V, Double> lbExtractor;
+
+    /**
+     * Constructs a new instance of decision tree data builder.
+     *
+     * @param featureExtractor Function that extracts features from an {@code 
upstream} data.
+     * @param lbExtractor Function that extracts labels from an {@code 
upstream} data.
+     */
+    public FeatureMatrixWithLabelsOnHeapDataBuilder(IgniteBiFunction<K, V, 
Vector> featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
+        this.featureExtractor = featureExtractor;
+        this.lbExtractor = lbExtractor;
+    }
+
+    /** {@inheritDoc} */
+    @Override public FeatureMatrixWithLabelsOnHeapData 
build(Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize, C ctx) 
{
+        double[][] features = new double[Math.toIntExact(upstreamDataSize)][];
+        double[] labels = new double[Math.toIntExact(upstreamDataSize)];
+
+        int ptr = 0;
+        while (upstreamData.hasNext()) {
+            UpstreamEntry<K, V> entry = upstreamData.next();
+
+            features[ptr] = featureExtractor.apply(entry.getKey(), 
entry.getValue()).asArray();
+
+            labels[ptr] = lbExtractor.apply(entry.getKey(), entry.getValue());
+
+            ptr++;
+        }
+
+        return new FeatureMatrixWithLabelsOnHeapData(features, labels);
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java
index 8589a79..6ebbda1 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java
@@ -17,12 +17,13 @@
 
 package org.apache.ignite.ml.tree.boosting;
 
-import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
 import org.apache.ignite.ml.Model;
 import org.apache.ignite.ml.composition.ModelsComposition;
 import org.apache.ignite.ml.composition.boosting.GDBLearningStrategy;
+import org.apache.ignite.ml.composition.boosting.GDBTrainer;
+import 
org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker;
 import 
org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator;
 import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
@@ -54,22 +55,30 @@ public class GDBOnTreesLearningStrategy  extends 
GDBLearningStrategy {
     }
 
     /** {@inheritDoc} */
-    @Override public <K, V> List<Model<Vector, Double>> 
learnModels(DatasetBuilder<K, V> datasetBuilder,
-        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, 
V, Double> lbExtractor) {
+    @Override public <K, V> List<Model<Vector, Double>> 
update(GDBTrainer.GDBModel mdlToUpdate,
+        DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> 
featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
 
         DatasetTrainer<? extends Model<Vector, Double>, Double> trainer = 
baseMdlTrainerBuilder.get();
         assert trainer instanceof DecisionTree;
         DecisionTree decisionTreeTrainer = (DecisionTree) trainer;
 
-        List<Model<Vector, Double>> models = new ArrayList<>();
+        List<Model<Vector, Double>> models = initLearningState(mdlToUpdate);
+
+        ConvergenceChecker<K,V> convCheck = 
checkConvergenceStgyFactory.create(sampleSize,
+            externalLbToInternalMapping, loss, datasetBuilder, 
featureExtractor, lbExtractor);
+
         try (Dataset<EmptyContext, DecisionTreeData> dataset = 
datasetBuilder.build(
             new EmptyContextBuilder<>(),
             new DecisionTreeDataBuilder<>(featureExtractor, lbExtractor, 
useIndex)
         )) {
             for (int i = 0; i < cntOfIterations; i++) {
-                double[] weights = Arrays.copyOf(compositionWeights, i);
+                double[] weights = Arrays.copyOf(compositionWeights, 
models.size());
                 WeightedPredictionsAggregator aggregator = new 
WeightedPredictionsAggregator(weights, meanLabelValue);
-                Model<Vector, Double> currComposition = new 
ModelsComposition(models, aggregator);
+                ModelsComposition currComposition = new 
ModelsComposition(models, aggregator);
+
+                if(convCheck.isConverged(dataset, currComposition))
+                    break;
 
                 dataset.compute(part -> {
                     if(part.getCopyOfOriginalLabels() == null)
@@ -78,7 +87,7 @@ public class GDBOnTreesLearningStrategy  extends 
GDBLearningStrategy {
                     for(int j = 0; j < part.getLabels().length; j++) {
                         double mdlAnswer = 
currComposition.apply(VectorUtils.of(part.getFeatures()[j]));
                         double originalLbVal = 
externalLbToInternalMapping.apply(part.getCopyOfOriginalLabels()[j]);
-                        part.getLabels()[j] = -lossGradient.apply(sampleSize, 
originalLbVal, mdlAnswer);
+                        part.getLabels()[j] = -loss.gradient(sampleSize, 
originalLbVal, mdlAnswer);
                     }
                 });
 
@@ -92,6 +101,7 @@ public class GDBOnTreesLearningStrategy  extends 
GDBLearningStrategy {
             throw new RuntimeException(e);
         }
 
+        compositionWeights = Arrays.copyOf(compositionWeights, models.size());
         return models;
     }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java 
b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java
index d5750ea..b8a16dc 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java
@@ -19,18 +19,14 @@ package org.apache.ignite.ml.tree.data;
 
 import java.util.ArrayList;
 import java.util.List;
+import 
org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData;
 import org.apache.ignite.ml.tree.TreeFilter;
 
 /**
- * A partition {@code data} of the containing matrix of features and vector of 
labels stored in heap.
+ * A partition {@code data} of the containing matrix of features and vector of 
labels stored in heap
+ * with index on features.
  */
-public class DecisionTreeData implements AutoCloseable {
-    /** Matrix with features. */
-    private final double[][] features;
-
-    /** Vector with labels. */
-    private final double[] labels;
-
+public class DecisionTreeData extends FeatureMatrixWithLabelsOnHeapData 
implements AutoCloseable {
     /** Copy of vector with original labels. Auxiliary for Gradient Boosting 
on Trees.*/
     private double[] copyOfOriginalLabels;
 
@@ -48,10 +44,7 @@ public class DecisionTreeData implements AutoCloseable {
      * @param buildIdx Build index.
      */
     public DecisionTreeData(double[][] features, double[] labels, boolean 
buildIdx) {
-        assert features.length == labels.length : "Features and labels have to 
be the same length";
-
-        this.features = features;
-        this.labels = labels;
+        super(features, labels);
         this.buildIndex = buildIdx;
 
         indexesCache = new ArrayList<>();
@@ -68,6 +61,8 @@ public class DecisionTreeData implements AutoCloseable {
     public DecisionTreeData filter(TreeFilter filter) {
         int size = 0;
 
+        double[][] features = getFeatures();
+        double[] labels = getLabels();
         for (int i = 0; i < features.length; i++)
             if (filter.test(features[i]))
                 size++;
@@ -95,12 +90,15 @@ public class DecisionTreeData implements AutoCloseable {
      * @param col Column.
      */
     public void sort(int col) {
-        sort(col, 0, features.length - 1);
+        sort(col, 0, getFeatures().length - 1);
     }
 
     /** */
     private void sort(int col, int from, int to) {
         if (from < to) {
+            double[][] features = getFeatures();
+            double[] labels = getLabels();
+
             double pivot = features[(from + to) / 2][col];
 
             int i = from, j = to;
@@ -131,19 +129,11 @@ public class DecisionTreeData implements AutoCloseable {
     }
 
     /** */
-    public double[][] getFeatures() {
-        return features;
-    }
-
-    /** */
-    public double[] getLabels() {
-        return labels;
-    }
-
     public double[] getCopyOfOriginalLabels() {
         return copyOfOriginalLabels;
     }
 
+    /** */
     public void setCopyOfOriginalLabels(double[] copyOfOriginalLabels) {
         this.copyOfOriginalLabels = copyOfOriginalLabels;
     }
@@ -170,7 +160,7 @@ public class DecisionTreeData implements AutoCloseable {
 
         if (depth == indexesCache.size()) {
             if (depth == 0)
-                indexesCache.add(new TreeDataIndex(features, labels));
+                indexesCache.add(new TreeDataIndex(getFeatures(), 
getLabels()));
             else {
                 TreeDataIndex lastIndex = indexesCache.get(depth - 1);
                 indexesCache.add(lastIndex.filter(filter));

http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java
index 3e340f6..89b8c9c 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java
@@ -22,11 +22,13 @@ import java.util.Map;
 import java.util.function.BiFunction;
 import org.apache.ignite.ml.Model;
 import org.apache.ignite.ml.composition.ModelsComposition;
+import 
org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueConvergenceCheckerFactory;
+import 
org.apache.ignite.ml.composition.boosting.convergence.simple.ConvergenceCheckerStubFactory;
 import 
org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
 import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
-import org.apache.ignite.ml.trainers.DatasetTrainer;
 import org.apache.ignite.ml.tree.DecisionTreeConditionalNode;
 import org.apache.ignite.ml.tree.boosting.GDBBinaryClassifierOnTreesTrainer;
 import org.apache.ignite.ml.tree.boosting.GDBRegressionOnTreesTrainer;
@@ -54,8 +56,8 @@ public class GDBTrainerTest {
             learningSample.put(i, new double[] {xs[i], ys[i]});
         }
 
-        DatasetTrainer<ModelsComposition, Double> trainer
-            = new GDBRegressionOnTreesTrainer(1.0, 2000, 3, 
0.0).withUseIndex(true);
+        GDBTrainer trainer = new GDBRegressionOnTreesTrainer(1.0, 2000, 3, 0.0)
+            .withUseIndex(true);
 
         Model<Vector, Double> mdl = trainer.fit(
             learningSample, 1,
@@ -74,7 +76,6 @@ public class GDBTrainerTest {
 
         assertEquals(0.0, mse, 0.0001);
 
-        assertTrue(mdl instanceof ModelsComposition);
         ModelsComposition composition = (ModelsComposition)mdl;
         assertTrue(composition.toString().length() > 0);
         assertTrue(composition.toString(true).length() > 0);
@@ -84,6 +85,13 @@ public class GDBTrainerTest {
 
         assertEquals(2000, composition.getModels().size());
         assertTrue(composition.getPredictionsAggregator() instanceof 
WeightedPredictionsAggregator);
+
+        trainer = trainer.withCheckConvergenceStgyFactory(new 
MeanAbsValueConvergenceCheckerFactory(0.1));
+        assertTrue(trainer.fit(
+            learningSample, 1,
+            (k, v) -> VectorUtils.of(v[0]),
+            (k, v) -> v[1]
+        ).getModels().size() < 2000);
     }
 
     /** */
@@ -107,7 +115,7 @@ public class GDBTrainerTest {
     }
 
     /** */
-    private void testClassifier(BiFunction<GDBBinaryClassifierOnTreesTrainer, 
Map<Integer, double[]>,
+    private void testClassifier(BiFunction<GDBTrainer, Map<Integer, double[]>,
         Model<Vector, Double>> fitter) {
         int sampleSize = 100;
         double[] xs = new double[sampleSize];
@@ -122,8 +130,9 @@ public class GDBTrainerTest {
         for (int i = 0; i < sampleSize; i++)
             learningSample.put(i, new double[] {xs[i], ys[i]});
 
-        GDBBinaryClassifierOnTreesTrainer trainer
-            = new GDBBinaryClassifierOnTreesTrainer(0.3, 500, 3, 
0.0).withUseIndex(true);
+        GDBTrainer trainer = new GDBBinaryClassifierOnTreesTrainer(0.3, 500, 
3, 0.0)
+            .withUseIndex(true)
+            .withCheckConvergenceStgyFactory(new 
MeanAbsValueConvergenceCheckerFactory(0.3));
 
         Model<Vector, Double> mdl = fitter.apply(trainer, learningSample);
 
@@ -132,7 +141,7 @@ public class GDBTrainerTest {
             double x = xs[j];
             double y = ys[j];
             double p = mdl.apply(VectorUtils.of(x));
-            if(p != y)
+            if (p != y)
                 errorsCnt++;
         }
 
@@ -142,7 +151,61 @@ public class GDBTrainerTest {
         ModelsComposition composition = (ModelsComposition)mdl;
         composition.getModels().forEach(m -> assertTrue(m instanceof 
DecisionTreeConditionalNode));
 
-        assertEquals(500, composition.getModels().size());
+        assertTrue(composition.getModels().size() < 500);
         assertTrue(composition.getPredictionsAggregator() instanceof 
WeightedPredictionsAggregator);
+
+        trainer = trainer.withCheckConvergenceStgyFactory(new 
ConvergenceCheckerStubFactory());
+        assertEquals(500, ((ModelsComposition)fitter.apply(trainer, 
learningSample)).getModels().size());
+    }
+
+    /** */
+    @Test
+    public void testUpdate() {
+        int sampleSize = 100;
+        double[] xs = new double[sampleSize];
+        double[] ys = new double[sampleSize];
+
+        for (int i = 0; i < sampleSize; i++) {
+            xs[i] = i;
+            ys[i] = ((int)(xs[i] / 10.0) % 2) == 0 ? -1.0 : 1.0;
+        }
+
+        Map<Integer, double[]> learningSample = new HashMap<>();
+        for (int i = 0; i < sampleSize; i++)
+            learningSample.put(i, new double[] {xs[i], ys[i]});
+        IgniteBiFunction<Integer, double[], Vector> fExtr = (k, v) -> 
VectorUtils.of(v[0]);
+        IgniteBiFunction<Integer, double[], Double> lExtr = (k, v) -> v[1];
+
+        GDBTrainer classifTrainer = new GDBBinaryClassifierOnTreesTrainer(0.3, 
500, 3, 0.0)
+            .withUseIndex(true)
+            .withCheckConvergenceStgyFactory(new 
MeanAbsValueConvergenceCheckerFactory(0.3));
+        GDBTrainer regressTrainer = new GDBRegressionOnTreesTrainer(0.3, 500, 
3, 0.0)
+            .withUseIndex(true)
+            .withCheckConvergenceStgyFactory(new 
MeanAbsValueConvergenceCheckerFactory(0.3));
+
+        testUpdate(learningSample, fExtr, lExtr, classifTrainer);
+        testUpdate(learningSample, fExtr, lExtr, regressTrainer);
+    }
+
+    /** */
+    private void testUpdate(Map<Integer, double[]> dataset, 
IgniteBiFunction<Integer, double[], Vector> fExtr,
+        IgniteBiFunction<Integer, double[], Double> lExtr, GDBTrainer trainer) 
{
+
+        ModelsComposition originalMdl = trainer.fit(dataset, 1, fExtr, lExtr);
+        ModelsComposition updatedOnSameDataset = trainer.update(originalMdl, 
dataset, 1, fExtr, lExtr);
+
+        LocalDatasetBuilder<Integer, double[]> epmtyDataset = new 
LocalDatasetBuilder<>(new HashMap<>(), 1);
+        ModelsComposition updatedOnEmptyDataset = 
trainer.updateModel(originalMdl, epmtyDataset, fExtr, lExtr);
+
+        dataset.forEach((k,v) -> {
+            Vector features = fExtr.apply(k, v);
+
+            Double originalAnswer = originalMdl.apply(features);
+            Double updatedMdlAnswer1 = updatedOnSameDataset.apply(features);
+            Double updatedMdlAnswer2 = updatedOnEmptyDataset.apply(features);
+
+            assertEquals(originalAnswer, updatedMdlAnswer1, 0.01);
+            assertEquals(originalAnswer, updatedMdlAnswer2, 0.01);
+        });
     }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceCheckerTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceCheckerTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceCheckerTest.java
new file mode 100644
index 0000000..50fdf8b
--- /dev/null
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceCheckerTest.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.boosting.convergence;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.ignite.ml.composition.ModelsComposition;
+import org.apache.ignite.ml.composition.boosting.loss.Loss;
+import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.junit.Before;
+
+/** */
+public abstract class ConvergenceCheckerTest {
+    /** Not converged model. */
+    protected ModelsComposition notConvergedMdl = new 
ModelsComposition(Collections.emptyList(), null) {
+        @Override public Double apply(Vector features) {
+            return 2.1 * features.get(0);
+        }
+    };
+
+    /** Converged model. */
+    protected ModelsComposition convergedMdl = new 
ModelsComposition(Collections.emptyList(), null) {
+        @Override public Double apply(Vector features) {
+            return 2 * (features.get(0) + 1);
+        }
+    };
+
+    /** Features extractor. */
+    protected IgniteBiFunction<double[], Double, Vector> fExtr = (x, y) -> 
VectorUtils.of(x);
+
+    /** Label extractor. */
+    protected IgniteBiFunction<double[], Double, Double> lbExtr = (x, y) -> y;
+
+    /** Data. */
+    protected Map<double[], Double> data;
+
+    /** */
+    @Before
+    public void setUp() throws Exception {
+        data = new HashMap<>();
+        for(int i = 0; i < 10; i ++)
+            data.put(new double[]{i, i + 1}, (double)(2 * (i + 1)));
+    }
+
+    /** */
+    public ConvergenceChecker<double[], Double> 
createChecker(ConvergenceCheckerFactory factory,
+        LocalDatasetBuilder<double[], Double> datasetBuilder) {
+
+        return factory.create(data.size(),
+            x -> x,
+            new Loss() {
+                @Override public double error(long sampleSize, double lb, 
double mdlAnswer) {
+                    return mdlAnswer - lb;
+                }
+
+                @Override public double gradient(long sampleSize, double lb, 
double mdlAnswer) {
+                    return mdlAnswer - lb;
+                }
+            },
+            datasetBuilder, fExtr, lbExtr
+        );
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceCheckerTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceCheckerTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceCheckerTest.java
new file mode 100644
index 0000000..0b42db8
--- /dev/null
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceCheckerTest.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.boosting.convergence.mean;
+
+import 
org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker;
+import 
org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerTest;
+import org.apache.ignite.ml.dataset.impl.local.LocalDataset;
+import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import 
org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData;
+import 
org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapDataBuilder;
+import 
org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+/** */
+public class MeanAbsValueConvergenceCheckerTest extends ConvergenceCheckerTest 
{
+    /** */
+    @Test
+    public void testConvergenceChecking() {
+        LocalDatasetBuilder<double[], Double> datasetBuilder = new 
LocalDatasetBuilder<>(data, 1);
+        ConvergenceChecker<double[], Double> checker = createChecker(
+            new MeanAbsValueConvergenceCheckerFactory(0.1), datasetBuilder);
+
+        double error = checker.computeError(VectorUtils.of(1, 2), 4.0, 
notConvergedMdl);
+        Assert.assertEquals(1.9, error, 0.01);
+        Assert.assertFalse(checker.isConverged(datasetBuilder, 
notConvergedMdl));
+        Assert.assertTrue(checker.isConverged(datasetBuilder, convergedMdl));
+
+        try(LocalDataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> 
dataset = datasetBuilder.build(
+            new EmptyContextBuilder<>(), new 
FeatureMatrixWithLabelsOnHeapDataBuilder<>(fExtr, lbExtr))) {
+
+            double onDSError = checker.computeMeanErrorOnDataset(dataset, 
notConvergedMdl);
+            Assert.assertEquals(1.55, onDSError, 0.01);
+        } catch (Exception e) {
+            throw new RuntimeException(e);
+        }
+    }
+
+    /** Mean error more sensitive to anomalies in data */
+    @Test
+    public void testConvergenceCheckingWithAnomaliesInData() {
+        data.put(new double[]{10, 11}, 100000.0);
+        LocalDatasetBuilder<double[], Double> datasetBuilder = new 
LocalDatasetBuilder<>(data, 1);
+        ConvergenceChecker<double[], Double> checker = createChecker(
+            new MeanAbsValueConvergenceCheckerFactory(0.1), datasetBuilder);
+
+        try(LocalDataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> 
dataset = datasetBuilder.build(
+            new EmptyContextBuilder<>(), new 
FeatureMatrixWithLabelsOnHeapDataBuilder<>(fExtr, lbExtr))) {
+
+            double onDSError = checker.computeMeanErrorOnDataset(dataset, 
notConvergedMdl);
+            Assert.assertEquals(9090.41, onDSError, 0.01);
+        } catch (Exception e) {
+            throw new RuntimeException(e);
+        }
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceCheckerTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceCheckerTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceCheckerTest.java
new file mode 100644
index 0000000..d6880b4
--- /dev/null
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceCheckerTest.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.boosting.convergence.median;
+
+import 
org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker;
+import 
org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerTest;
+import org.apache.ignite.ml.dataset.impl.local.LocalDataset;
+import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import 
org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData;
+import 
org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapDataBuilder;
+import 
org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+/** */
+public class MedianOfMedianConvergenceCheckerTest extends 
ConvergenceCheckerTest {
+    /** */
+    @Test
+    public void testConvergenceChecking() {
+        data.put(new double[]{10, 11}, 100000.0);
+        LocalDatasetBuilder<double[], Double> datasetBuilder = new 
LocalDatasetBuilder<>(data, 1);
+
+        ConvergenceChecker<double[], Double> checker = createChecker(
+            new MedianOfMedianConvergenceCheckerFactory(0.1), datasetBuilder);
+
+        double error = checker.computeError(VectorUtils.of(1, 2), 4.0, 
notConvergedMdl);
+        Assert.assertEquals(1.9, error, 0.01);
+        Assert.assertFalse(checker.isConverged(datasetBuilder, 
notConvergedMdl));
+        Assert.assertTrue(checker.isConverged(datasetBuilder, convergedMdl));
+
+        try(LocalDataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> 
dataset = datasetBuilder.build(
+            new EmptyContextBuilder<>(), new 
FeatureMatrixWithLabelsOnHeapDataBuilder<>(fExtr, lbExtr))) {
+
+            double onDSError = checker.computeMeanErrorOnDataset(dataset, 
notConvergedMdl);
+            Assert.assertEquals(1.6, onDSError, 0.01);
+        } catch (Exception e) {
+            throw new RuntimeException(e);
+        }
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java
index f88fd3e..b06fd67 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java
@@ -17,185 +17,44 @@
 
 package org.apache.ignite.ml.environment;
 
-import java.util.Arrays;
-import java.util.UUID;
-import java.util.concurrent.atomic.AtomicInteger;
-import java.util.concurrent.atomic.AtomicReference;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
-import javax.cache.Cache;
-import org.apache.ignite.Ignite;
-import org.apache.ignite.IgniteCache;
-import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
-import org.apache.ignite.cache.query.QueryCursor;
-import org.apache.ignite.cache.query.ScanQuery;
-import org.apache.ignite.configuration.CacheConfiguration;
-import org.apache.ignite.internal.util.IgniteUtils;
-import org.apache.ignite.ml.composition.ModelsComposition;
 import org.apache.ignite.ml.dataset.feature.FeatureMeta;
 import org.apache.ignite.ml.environment.logging.ConsoleLogger;
 import org.apache.ignite.ml.environment.logging.MLLogger;
+import org.apache.ignite.ml.environment.parallelism.DefaultParallelismStrategy;
 import org.apache.ignite.ml.environment.parallelism.ParallelismStrategy;
-import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
 import org.apache.ignite.ml.tree.randomforest.RandomForestRegressionTrainer;
 import 
org.apache.ignite.ml.tree.randomforest.data.FeaturesCountSelectionStrategies;
-import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
-import org.apache.ignite.thread.IgniteThread;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
 
 /**
  * Tests for {@link LearningEnvironment} that require to start the whole 
Ignite infrastructure. IMPL NOTE based on
  * RandomForestRegressionExample example.
  */
-public class LearningEnvironmentTest extends GridCommonAbstractTest {
-    /** Number of nodes in grid */
-    private static final int NODE_COUNT = 1;
-
-    /** Ignite instance. */
-    private Ignite ignite;
-
-    /** {@inheritDoc} */
-    @Override protected void beforeTestsStarted() throws Exception {
-        for (int i = 1; i <= NODE_COUNT; i++)
-            startGrid(i);
-    }
-
-    /** {@inheritDoc} */
-    @Override protected void afterTestsStopped() {
-        stopAllGrids();
-    }
-
-    /**
-     * {@inheritDoc}
-     */
-    @Override protected void beforeTest() {
-        /* Grid instance. */
-        ignite = grid(NODE_COUNT);
-        ignite.configuration().setPeerClassLoadingEnabled(true);
-        
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-    }
-
+public class LearningEnvironmentTest {
     /** */
+    @Test
     public void testBasic() throws InterruptedException {
-        AtomicReference<Integer> actualAmount = new AtomicReference<>(null);
-        AtomicReference<Double> actualMse = new AtomicReference<>(null);
-        AtomicReference<Double> actualMae = new AtomicReference<>(null);
-
-        IgniteThread igniteThread = new 
IgniteThread(ignite.configuration().getIgniteInstanceName(),
-            LearningEnvironmentTest.class.getSimpleName(), () -> {
-            IgniteCache<Integer, double[]> dataCache = getTestCache(ignite);
-
-            AtomicInteger idx = new AtomicInteger(0);
-            RandomForestRegressionTrainer trainer = new 
RandomForestRegressionTrainer(
-                IntStream.range(0, data[0].length - 1).mapToObj(
-                    x -> new FeatureMeta("", idx.getAndIncrement(), 
false)).collect(Collectors.toList())
-            ).withCountOfTrees(101)
-                
.withFeaturesCountSelectionStrgy(FeaturesCountSelectionStrategies.ONE_THIRD)
-                .withMaxDepth(4)
-                .withMinImpurityDelta(0.)
-                .withSubsampleSize(0.3)
-                .withSeed(0);
-
-            trainer.setEnvironment(LearningEnvironment.builder()
-                
.withParallelismStrategy(ParallelismStrategy.Type.ON_DEFAULT_POOL)
-                
.withLoggingFactory(ConsoleLogger.factory(MLLogger.VerboseLevel.LOW))
-                .build()
-            );
-
-            ModelsComposition randomForest = trainer.fit(ignite, dataCache,
-                (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 
1)),
-                (k, v) -> v[v.length - 1]
-            );
-
-            double mse = 0.0;
-            double mae = 0.0;
-            int totalAmount = 0;
-
-            try (QueryCursor<Cache.Entry<Integer, double[]>> observations = 
dataCache.query(new ScanQuery<>())) {
-                for (Cache.Entry<Integer, double[]> observation : 
observations) {
-                    double difference = estimatePrediction(randomForest, 
observation);
-
-                    mse += Math.pow(difference, 2.0);
-                    mae += Math.abs(difference);
-
-                    totalAmount++;
-                }
-            }
-
-            actualAmount.set(totalAmount);
-
-            mse = mse / totalAmount;
-            actualMse.set(mse);
-
-            mae = mae / totalAmount;
-            actualMae.set(mae);
-        });
-
-        igniteThread.start();
-        igniteThread.join();
-
-        assertEquals("Total amount", 23, (int)actualAmount.get());
-        assertTrue("Mean squared error (MSE)", actualMse.get() > 0);
-        assertTrue("Mean absolute error (MAE)", actualMae.get() > 0);
+        RandomForestRegressionTrainer trainer = new 
RandomForestRegressionTrainer(
+            IntStream.range(0, 0).mapToObj(
+                x -> new FeatureMeta("", 0, 
false)).collect(Collectors.toList())
+        ).withCountOfTrees(101)
+            
.withFeaturesCountSelectionStrgy(FeaturesCountSelectionStrategies.ONE_THIRD)
+            .withMaxDepth(4)
+            .withMinImpurityDelta(0.)
+            .withSubsampleSize(0.3)
+            .withSeed(0);
+
+        LearningEnvironment environment = LearningEnvironment.builder()
+            .withParallelismStrategy(ParallelismStrategy.Type.ON_DEFAULT_POOL)
+            
.withLoggingFactory(ConsoleLogger.factory(MLLogger.VerboseLevel.LOW))
+            .build();
+        trainer.setEnvironment(environment);
+        assertEquals(DefaultParallelismStrategy.class, 
environment.parallelismStrategy().getClass());
+        assertEquals(ConsoleLogger.class, environment.logger().getClass());
     }
-
-    /** */
-    private double estimatePrediction(ModelsComposition randomForest, 
Cache.Entry<Integer, double[]> observation) {
-        double[] val = observation.getValue();
-        double[] inputs = Arrays.copyOfRange(val, 0, val.length - 1);
-        double groundTruth = val[val.length - 1];
-
-        double prediction = randomForest.apply(VectorUtils.of(inputs));
-
-        return prediction - groundTruth;
-    }
-
-    /**
-     * Fills cache with data and returns it.
-     *
-     * @param ignite Ignite instance.
-     * @return Filled Ignite Cache.
-     */
-    private IgniteCache<Integer, double[]> getTestCache(Ignite ignite) {
-        CacheConfiguration<Integer, double[]> cacheConfiguration = new 
CacheConfiguration<>();
-        cacheConfiguration.setName(UUID.randomUUID().toString());
-        cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 
10));
-
-        IgniteCache<Integer, double[]> cache = 
ignite.createCache(cacheConfiguration);
-
-        for (int i = 0; i < data.length; i++)
-            cache.put(i, data[i]);
-
-        return cache;
-    }
-
-    /**
-     * Part of the Boston housing dataset.
-     */
-    private static final double[][] data = {
-        
{0.02731,0.00,7.070,0,0.4690,6.4210,78.90,4.9671,2,242.0,17.80,396.90,9.14,21.60},
-        
{0.02729,0.00,7.070,0,0.4690,7.1850,61.10,4.9671,2,242.0,17.80,392.83,4.03,34.70},
-        
{0.03237,0.00,2.180,0,0.4580,6.9980,45.80,6.0622,3,222.0,18.70,394.63,2.94,33.40},
-        
{0.06905,0.00,2.180,0,0.4580,7.1470,54.20,6.0622,3,222.0,18.70,396.90,5.33,36.20},
-        
{0.02985,0.00,2.180,0,0.4580,6.4300,58.70,6.0622,3,222.0,18.70,394.12,5.21,28.70},
-        
{0.08829,12.50,7.870,0,0.5240,6.0120,66.60,5.5605,5,311.0,15.20,395.60,12.43,22.90},
-        
{0.14455,12.50,7.870,0,0.5240,6.1720,96.10,5.9505,5,311.0,15.20,396.90,19.15,27.10},
-        
{0.21124,12.50,7.870,0,0.5240,5.6310,100.00,6.0821,5,311.0,15.20,386.63,29.93,16.50},
-        
{0.17004,12.50,7.870,0,0.5240,6.0040,85.90,6.5921,5,311.0,15.20,386.71,17.10,18.90},
-        
{0.22489,12.50,7.870,0,0.5240,6.3770,94.30,6.3467,5,311.0,15.20,392.52,20.45,15.00},
-        
{0.11747,12.50,7.870,0,0.5240,6.0090,82.90,6.2267,5,311.0,15.20,396.90,13.27,18.90},
-        
{0.09378,12.50,7.870,0,0.5240,5.8890,39.00,5.4509,5,311.0,15.20,390.50,15.71,21.70},
-        
{0.62976,0.00,8.140,0,0.5380,5.9490,61.80,4.7075,4,307.0,21.00,396.90,8.26,20.40},
-        
{0.63796,0.00,8.140,0,0.5380,6.0960,84.50,4.4619,4,307.0,21.00,380.02,10.26,18.20},
-        
{0.62739,0.00,8.140,0,0.5380,5.8340,56.50,4.4986,4,307.0,21.00,395.62,8.47,19.90},
-        
{1.05393,0.00,8.140,0,0.5380,5.9350,29.30,4.4986,4,307.0,21.00,386.85,6.58,23.10},
-        
{0.78420,0.00,8.140,0,0.5380,5.9900,81.70,4.2579,4,307.0,21.00,386.75,14.67,17.50},
-        
{0.80271,0.00,8.140,0,0.5380,5.4560,36.60,3.7965,4,307.0,21.00,288.99,11.69,20.20},
-        
{0.72580,0.00,8.140,0,0.5380,5.7270,69.50,3.7965,4,307.0,21.00,390.95,11.28,18.20},
-        
{1.25179,0.00,8.140,0,0.5380,5.5700,98.10,3.7979,4,307.0,21.00,376.57,21.02,13.60},
-        
{0.85204,0.00,8.140,0,0.5380,5.9650,89.20,4.0123,4,307.0,21.00,392.53,13.83,19.60},
-        
{1.23247,0.00,8.140,0,0.5380,6.1420,91.70,3.9769,4,307.0,21.00,396.90,18.72,15.20},
-        
{0.98843,0.00,8.140,0,0.5380,5.8130,100.00,4.0952,4,307.0,21.00,394.54,19.88,14.50}
-    };
-
 }
 

http://git-wip-us.apache.org/repos/asf/ignite/blob/6225c56e/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java 
b/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java
index d8fb620..199644b 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java
@@ -93,17 +93,21 @@ public class ANNClassificationTest extends TrainerTest {
             .withDistanceMeasure(new EuclideanDistance())
             .withStrategy(NNStrategy.SIMPLE);
 
-        ANNClassificationModel updatedOnSameDataset = 
trainer.withSeed(1234L).update(originalMdl,
+        ANNClassificationModel updatedOnSameDataset = (ANNClassificationModel) 
trainer.withSeed(1234L).update(originalMdl,
             cacheMock, parts,
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
             (k, v) -> v[2]
-        );
+        ).withK(3)
+            .withDistanceMeasure(new EuclideanDistance())
+            .withStrategy(NNStrategy.SIMPLE);
 
-        ANNClassificationModel updatedOnEmptyDataset = 
trainer.withSeed(1234L).update(originalMdl,
+        ANNClassificationModel updatedOnEmptyDataset = 
(ANNClassificationModel) trainer.withSeed(1234L).update(originalMdl,
             new HashMap<Integer, double[]>(), parts,
             (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
             (k, v) -> v[2]
-        );
+        ).withK(3)
+            .withDistanceMeasure(new EuclideanDistance())
+            .withStrategy(NNStrategy.SIMPLE);
 
         Vector v1 = VectorUtils.of(550, 550);
         Vector v2 = VectorUtils.of(-550, -550);

Reply via email to