IGNITE-8233: KNN and SVM algorithms don't work when partition doesn't contain 
data.

this closes #3807


Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/ee9ca06a
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/ee9ca06a
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/ee9ca06a

Branch: refs/heads/ignite-7708
Commit: ee9ca06a8cbec6eec2b963fb1db1b3f383fc1837
Parents: 568c3e7
Author: dmitrievanthony <dmitrievanth...@gmail.com>
Authored: Fri Apr 13 18:02:37 2018 +0300
Committer: Yury Babak <yba...@gridgain.com>
Committed: Fri Apr 13 18:02:37 2018 +0300

----------------------------------------------------------------------
 .../dataset/impl/cache/CacheBasedDataset.java   |  14 ++-
 .../dataset/impl/cache/util/ComputeUtils.java   |   9 +-
 .../ml/dataset/impl/local/LocalDataset.java     |  16 ++-
 .../dataset/impl/local/LocalDatasetBuilder.java |   8 +-
 .../classification/KNNClassificationModel.java  |  32 +++--
 .../impl/cache/CacheBasedDatasetTest.java       |   9 +-
 .../ignite/ml/knn/KNNClassificationTest.java    | 120 ++++++++++--------
 .../apache/ignite/ml/knn/KNNRegressionTest.java | 122 +++++++++++--------
 .../ignite/ml/knn/LabeledDatasetHelper.java     |  45 ++-----
 .../ignite/ml/knn/LabeledDatasetTest.java       |  54 ++++----
 10 files changed, 226 insertions(+), 203 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/ignite/blob/ee9ca06a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java
index 463d496..7428faf 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java
@@ -101,12 +101,16 @@ public class CacheBasedDataset<K, V, C extends 
Serializable, D extends AutoClose
                 partDataBuilder
             );
 
-            R res = map.apply(ctx, data, part);
+            if (data != null) {
+                R res = map.apply(ctx, data, part);
 
-            // Saves partition context after update.
-            ComputeUtils.saveContext(Ignition.localIgnite(), datasetCacheName, 
part, ctx);
+                // Saves partition context after update.
+                ComputeUtils.saveContext(Ignition.localIgnite(), 
datasetCacheName, part, ctx);
 
-            return res;
+                return res;
+            }
+
+            return null;
         }, reduce, identity);
     }
 
@@ -125,7 +129,7 @@ public class CacheBasedDataset<K, V, C extends 
Serializable, D extends AutoClose
                 partDataBuilder
             );
 
-            return map.apply(data, part);
+            return data != null ? map.apply(data, part) : null;
         }, reduce, identity);
     }
 

http://git-wip-us.apache.org/repos/asf/ignite/blob/ee9ca06a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java
index 0785db2..ce2fcfd 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java
@@ -163,9 +163,14 @@ public class ComputeUtils {
             qry.setPartition(part);
 
             long cnt = upstreamCache.localSizeLong(part);
-            try (QueryCursor<Cache.Entry<K, V>> cursor = 
upstreamCache.query(qry)) {
-                return partDataBuilder.build(new 
UpstreamCursorAdapter<>(cursor.iterator(), cnt), cnt, ctx);
+
+            if (cnt > 0) {
+                try (QueryCursor<Cache.Entry<K, V>> cursor = 
upstreamCache.query(qry)) {
+                    return partDataBuilder.build(new 
UpstreamCursorAdapter<>(cursor.iterator(), cnt), cnt, ctx);
+                }
             }
+
+            return null;
         });
     }
 

http://git-wip-us.apache.org/repos/asf/ignite/blob/ee9ca06a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDataset.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDataset.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDataset.java
index c08b7de..e312b20 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDataset.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDataset.java
@@ -55,8 +55,12 @@ public class LocalDataset<C extends Serializable, D extends 
AutoCloseable> imple
         R identity) {
         R res = identity;
 
-        for (int part = 0; part < ctx.size(); part++)
-            res = reduce.apply(res, map.apply(ctx.get(part), data.get(part), 
part));
+        for (int part = 0; part < ctx.size(); part++) {
+            D partData = data.get(part);
+
+            if (partData != null)
+                res = reduce.apply(res, map.apply(ctx.get(part), partData, 
part));
+        }
 
         return res;
     }
@@ -65,8 +69,12 @@ public class LocalDataset<C extends Serializable, D extends 
AutoCloseable> imple
     @Override public <R> R compute(IgniteBiFunction<D, Integer, R> map, 
IgniteBinaryOperator<R> reduce, R identity) {
         R res = identity;
 
-        for (int part = 0; part < data.size(); part++)
-            res = reduce.apply(res, map.apply(data.get(part), part));
+        for (int part = 0; part < data.size(); part++) {
+            D partData = data.get(part);
+
+            if (partData != null)
+                res = reduce.apply(res, map.apply(partData, part));
+        }
 
         return res;
     }

http://git-wip-us.apache.org/repos/asf/ignite/blob/ee9ca06a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java
index 0dc1ed6..cfc1801 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java
@@ -69,16 +69,16 @@ public class LocalDatasetBuilder<K, V> implements 
DatasetBuilder<K, V> {
         for (int part = 0; part < partitions; part++) {
             int cnt = part == partitions - 1 ? upstreamMap.size() - ptr : 
Math.min(partSize, upstreamMap.size() - ptr);
 
-            C ctx = partCtxBuilder.build(
+            C ctx = cnt > 0 ? partCtxBuilder.build(
                 new IteratorWindow<>(firstKeysIter, k -> new 
UpstreamEntry<>(k, upstreamMap.get(k)), cnt),
                 cnt
-            );
+            ) : null;
 
-            D data = partDataBuilder.build(
+            D data = cnt > 0 ? partDataBuilder.build(
                 new IteratorWindow<>(secondKeysIter, k -> new 
UpstreamEntry<>(k, upstreamMap.get(k)), cnt),
                 cnt,
                 ctx
-            );
+            ) : null;
 
             ctxList.add(ctx);
             dataList.add(data);

http://git-wip-us.apache.org/repos/asf/ignite/blob/ee9ca06a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java
index 693b81d..0f0cc9f 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java
@@ -151,19 +151,29 @@ public class KNNClassificationModel<K, V> implements 
Model<Vector, Double>, Expo
      */
     @NotNull private LabeledVector[] getKClosestVectors(LabeledDataset<Double, 
LabeledVector> trainingData,
         TreeMap<Double, Set<Integer>> distanceIdxPairs) {
-        LabeledVector[] res = new LabeledVector[k];
-        int i = 0;
-        final Iterator<Double> iter = distanceIdxPairs.keySet().iterator();
-        while (i < k) {
-            double key = iter.next();
-            Set<Integer> idxs = distanceIdxPairs.get(key);
-            for (Integer idx : idxs) {
-                res[i] = trainingData.getRow(idx);
-                i++;
-                if (i >= k)
-                    break; // go to next while-loop iteration
+        LabeledVector[] res;
+
+        if (trainingData.rowSize() <= k) {
+            res = new LabeledVector[trainingData.rowSize()];
+            for (int i = 0; i < trainingData.rowSize(); i++)
+                res[i] = trainingData.getRow(i);
+        }
+        else {
+            res = new LabeledVector[k];
+            int i = 0;
+            final Iterator<Double> iter = distanceIdxPairs.keySet().iterator();
+            while (i < k) {
+                double key = iter.next();
+                Set<Integer> idxs = distanceIdxPairs.get(key);
+                for (Integer idx : idxs) {
+                    res[i] = trainingData.getRow(idx);
+                    i++;
+                    if (i >= k)
+                        break; // go to next while-loop iteration
+                }
             }
         }
+
         return res;
     }
 

http://git-wip-us.apache.org/repos/asf/ignite/blob/ee9ca06a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetTest.java
index dc0e160..16ba044 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetTest.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetTest.java
@@ -38,6 +38,7 @@ import 
org.apache.ignite.internal.processors.cache.distributed.dht.GridDhtPartit
 import org.apache.ignite.internal.util.IgniteUtils;
 import org.apache.ignite.internal.util.typedef.G;
 import org.apache.ignite.lang.IgnitePredicate;
+import org.apache.ignite.ml.dataset.primitive.data.SimpleDatasetData;
 import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
 
 /**
@@ -81,9 +82,9 @@ public class CacheBasedDatasetTest extends 
GridCommonAbstractTest {
 
         CacheBasedDatasetBuilder<Integer, String> builder = new 
CacheBasedDatasetBuilder<>(ignite, upstreamCache);
 
-        CacheBasedDataset<Integer, String, Long, AutoCloseable> dataset = 
builder.build(
+        CacheBasedDataset<Integer, String, Long, SimpleDatasetData> dataset = 
builder.build(
             (upstream, upstreamSize) -> upstreamSize,
-            (upstream, upstreamSize, ctx) -> null
+            (upstream, upstreamSize, ctx) -> new SimpleDatasetData(new 
double[0], 0)
         );
 
         assertTrue("Before computation all partitions should not be reserved",
@@ -133,9 +134,9 @@ public class CacheBasedDatasetTest extends 
GridCommonAbstractTest {
 
         CacheBasedDatasetBuilder<Integer, String> builder = new 
CacheBasedDatasetBuilder<>(ignite, upstreamCache);
 
-        CacheBasedDataset<Integer, String, Long, AutoCloseable> dataset = 
builder.build(
+        CacheBasedDataset<Integer, String, Long, SimpleDatasetData> dataset = 
builder.build(
             (upstream, upstreamSize) -> upstreamSize,
-            (upstream, upstreamSize, ctx) -> null
+            (upstream, upstreamSize, ctx) -> new SimpleDatasetData(new 
double[0], 0)
         );
 
         assertTrue("Before computation all partitions should not be reserved",

http://git-wip-us.apache.org/repos/asf/ignite/blob/ee9ca06a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java 
b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
index 0877fc0..004718e 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
@@ -17,11 +17,11 @@
 
 package org.apache.ignite.ml.knn;
 
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
-import org.junit.Assert;
-import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
 import org.apache.ignite.ml.knn.classification.KNNClassificationModel;
 import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer;
 import org.apache.ignite.ml.knn.classification.KNNStrategy;
@@ -29,121 +29,137 @@ import org.apache.ignite.ml.math.Vector;
 import org.apache.ignite.ml.math.distances.EuclideanDistance;
 import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
 import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import static junit.framework.TestCase.assertEquals;
 
 /** Tests behaviour of KNNClassificationTest. */
+@RunWith(Parameterized.class)
 public class KNNClassificationTest {
-    /** Precision in test checks. */
-    private static final double PRECISION = 1e-2;
+    /** Number of parts to be tested. */
+    private static final int[] partsToBeTested = new int[] {1, 2, 3, 4, 5, 7, 
100};
+
+    /** Number of partitions. */
+    @Parameterized.Parameter
+    public int parts;
+
+    /** Parameters. */
+    @Parameterized.Parameters(name = "Data divided on {0} partitions, training 
with batch size {1}")
+    public static Iterable<Integer[]> data() {
+        List<Integer[]> res = new ArrayList<>();
+
+        for (int part : partsToBeTested)
+            res.add(new Integer[] {part});
+
+        return res;
+    }
 
     /** */
     @Test
-    public void binaryClassificationTest() {
-
+    public void testBinaryClassificationTest() {
         Map<Integer, double[]> data = new HashMap<>();
-        data.put(0, new double[]{1.0, 1.0, 1.0});
-        data.put(1, new double[]{1.0, 2.0, 1.0});
-        data.put(2, new double[]{2.0, 1.0, 1.0});
-        data.put(3, new double[]{-1.0, -1.0, 2.0});
-        data.put(4, new double[]{-1.0, -2.0, 2.0});
-        data.put(5, new double[]{-2.0, -1.0, 2.0});
+        data.put(0, new double[] {1.0, 1.0, 1.0});
+        data.put(1, new double[] {1.0, 2.0, 1.0});
+        data.put(2, new double[] {2.0, 1.0, 1.0});
+        data.put(3, new double[] {-1.0, -1.0, 2.0});
+        data.put(4, new double[] {-1.0, -2.0, 2.0});
+        data.put(5, new double[] {-2.0, -1.0, 2.0});
 
         KNNClassificationTrainer trainer = new KNNClassificationTrainer();
 
         KNNClassificationModel knnMdl = trainer.fit(
             data,
-            2,
+            parts,
             (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1),
             (k, v) -> v[2]
         ).withK(3)
             .withDistanceMeasure(new EuclideanDistance())
             .withStrategy(KNNStrategy.SIMPLE);
 
-        Vector firstVector = new DenseLocalOnHeapVector(new double[]{2.0, 
2.0});
-        Assert.assertEquals(knnMdl.apply(firstVector), 1.0, PRECISION);
-        Vector secondVector = new DenseLocalOnHeapVector(new double[]{-2.0, 
-2.0});
-        Assert.assertEquals(knnMdl.apply(secondVector), 2.0, PRECISION);
+        Vector firstVector = new DenseLocalOnHeapVector(new double[] {2.0, 
2.0});
+        assertEquals(knnMdl.apply(firstVector), 1.0);
+        Vector secondVector = new DenseLocalOnHeapVector(new double[] {-2.0, 
-2.0});
+        assertEquals(knnMdl.apply(secondVector), 2.0);
     }
 
     /** */
     @Test
-    public void binaryClassificationWithSmallestKTest() {
+    public void testBinaryClassificationWithSmallestKTest() {
         Map<Integer, double[]> data = new HashMap<>();
-
-        data.put(0, new double[]{1.0, 1.0, 1.0});
-        data.put(1, new double[]{1.0, 2.0, 1.0});
-        data.put(2, new double[]{2.0, 1.0, 1.0});
-        data.put(3, new double[]{-1.0, -1.0, 2.0});
-        data.put(4, new double[]{-1.0, -2.0, 2.0});
-        data.put(5, new double[]{-2.0, -1.0, 2.0});
+        data.put(0, new double[] {1.0, 1.0, 1.0});
+        data.put(1, new double[] {1.0, 2.0, 1.0});
+        data.put(2, new double[] {2.0, 1.0, 1.0});
+        data.put(3, new double[] {-1.0, -1.0, 2.0});
+        data.put(4, new double[] {-1.0, -2.0, 2.0});
+        data.put(5, new double[] {-2.0, -1.0, 2.0});
 
         KNNClassificationTrainer trainer = new KNNClassificationTrainer();
 
         KNNClassificationModel knnMdl = trainer.fit(
             data,
-            2,
+            parts,
             (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1),
             (k, v) -> v[2]
         ).withK(1)
             .withDistanceMeasure(new EuclideanDistance())
             .withStrategy(KNNStrategy.SIMPLE);
 
-        Vector firstVector = new DenseLocalOnHeapVector(new double[]{2.0, 
2.0});
-        Assert.assertEquals(knnMdl.apply(firstVector), 1.0, PRECISION);
-        Vector secondVector = new DenseLocalOnHeapVector(new double[]{-2.0, 
-2.0});
-        Assert.assertEquals(knnMdl.apply(secondVector), 2.0, PRECISION);
+        Vector firstVector = new DenseLocalOnHeapVector(new double[] {2.0, 
2.0});
+        assertEquals(knnMdl.apply(firstVector), 1.0);
+        Vector secondVector = new DenseLocalOnHeapVector(new double[] {-2.0, 
-2.0});
+        assertEquals(knnMdl.apply(secondVector), 2.0);
     }
 
     /** */
     @Test
-    public void binaryClassificationFarPointsWithSimpleStrategy() {
+    public void testBinaryClassificationFarPointsWithSimpleStrategy() {
         Map<Integer, double[]> data = new HashMap<>();
-
-        data.put(0, new double[]{10.0, 10.0, 1.0});
-        data.put(1, new double[]{10.0, 20.0, 1.0});
-        data.put(2, new double[]{-1, -1, 1.0});
-        data.put(3, new double[]{-2, -2, 2.0});
-        data.put(4, new double[]{-1.0, -2.0, 2.0});
-        data.put(5, new double[]{-2.0, -1.0, 2.0});
+        data.put(0, new double[] {10.0, 10.0, 1.0});
+        data.put(1, new double[] {10.0, 20.0, 1.0});
+        data.put(2, new double[] {-1, -1, 1.0});
+        data.put(3, new double[] {-2, -2, 2.0});
+        data.put(4, new double[] {-1.0, -2.0, 2.0});
+        data.put(5, new double[] {-2.0, -1.0, 2.0});
 
         KNNClassificationTrainer trainer = new KNNClassificationTrainer();
 
         KNNClassificationModel knnMdl = trainer.fit(
             data,
-            2,
+            parts,
             (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1),
             (k, v) -> v[2]
         ).withK(3)
             .withDistanceMeasure(new EuclideanDistance())
             .withStrategy(KNNStrategy.SIMPLE);
 
-        Vector vector = new DenseLocalOnHeapVector(new double[]{-1.01, -1.01});
-        Assert.assertEquals(knnMdl.apply(vector), 2.0, PRECISION);
+        Vector vector = new DenseLocalOnHeapVector(new double[] {-1.01, 
-1.01});
+        assertEquals(knnMdl.apply(vector), 2.0);
     }
 
     /** */
     @Test
-    public void binaryClassificationFarPointsWithWeightedStrategy() {
+    public void testBinaryClassificationFarPointsWithWeightedStrategy() {
         Map<Integer, double[]> data = new HashMap<>();
-
-        data.put(0, new double[]{10.0, 10.0, 1.0});
-        data.put(1, new double[]{10.0, 20.0, 1.0});
-        data.put(2, new double[]{-1, -1, 1.0});
-        data.put(3, new double[]{-2, -2, 2.0});
-        data.put(4, new double[]{-1.0, -2.0, 2.0});
-        data.put(5, new double[]{-2.0, -1.0, 2.0});
+        data.put(0, new double[] {10.0, 10.0, 1.0});
+        data.put(1, new double[] {10.0, 20.0, 1.0});
+        data.put(2, new double[] {-1, -1, 1.0});
+        data.put(3, new double[] {-2, -2, 2.0});
+        data.put(4, new double[] {-1.0, -2.0, 2.0});
+        data.put(5, new double[] {-2.0, -1.0, 2.0});
 
         KNNClassificationTrainer trainer = new KNNClassificationTrainer();
 
         KNNClassificationModel knnMdl = trainer.fit(
             data,
-            2,
+            parts,
             (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1),
             (k, v) -> v[2]
         ).withK(3)
             .withDistanceMeasure(new EuclideanDistance())
             .withStrategy(KNNStrategy.WEIGHTED);
 
-        Vector vector = new DenseLocalOnHeapVector(new double[]{-1.01, -1.01});
-        Assert.assertEquals(knnMdl.apply(vector), 1.0, PRECISION);
+        Vector vector = new DenseLocalOnHeapVector(new double[] {-1.01, 
-1.01});
+        assertEquals(knnMdl.apply(vector), 1.0);
     }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/ee9ca06a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java 
b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java
index ce9cae5..0c26ba9 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java
@@ -17,6 +17,11 @@
 
 package org.apache.ignite.ml.knn;
 
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
 import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
 import org.apache.ignite.ml.knn.classification.KNNStrategy;
 import org.apache.ignite.ml.knn.regression.KNNRegressionModel;
@@ -25,110 +30,125 @@ import org.apache.ignite.ml.math.Vector;
 import org.apache.ignite.ml.math.distances.EuclideanDistance;
 import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
 import org.junit.Assert;
-
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.Map;
 import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
 
 /**
  * Tests for {@link KNNRegressionTrainer}.
  */
+@RunWith(Parameterized.class)
 public class KNNRegressionTest {
+    /** Number of parts to be tested. */
+    private static final int[] partsToBeTested = new int[] {1, 2, 3, 4, 5, 7, 
100};
+
+    /** Number of partitions. */
+    @Parameterized.Parameter
+    public int parts;
+
+    /** Parameters. */
+    @Parameterized.Parameters(name = "Data divided on {0} partitions, training 
with batch size {1}")
+    public static Iterable<Integer[]> data() {
+        List<Integer[]> res = new ArrayList<>();
+
+        for (int part : partsToBeTested)
+            res.add(new Integer[] {part});
+
+        return res;
+    }
+
     /** */
     @Test
-    public void simpleRegressionWithOneNeighbour() {
+    public void testSimpleRegressionWithOneNeighbour() {
         Map<Integer, double[]> data = new HashMap<>();
-
-        data.put(0, new double[]{11.0, 0, 0, 0, 0, 0});
-        data.put(1, new double[]{12.0, 2.0, 0, 0, 0, 0});
-        data.put(2, new double[]{13.0, 0, 3.0, 0, 0, 0});
-        data.put(3, new double[]{14.0, 0, 0, 4.0, 0, 0});
-        data.put(4, new double[]{15.0, 0, 0, 0, 5.0, 0});
-        data.put(5, new double[]{16.0, 0, 0, 0, 0, 6.0});
+        data.put(0, new double[] {11.0, 0, 0, 0, 0, 0});
+        data.put(1, new double[] {12.0, 2.0, 0, 0, 0, 0});
+        data.put(2, new double[] {13.0, 0, 3.0, 0, 0, 0});
+        data.put(3, new double[] {14.0, 0, 0, 4.0, 0, 0});
+        data.put(4, new double[] {15.0, 0, 0, 0, 5.0, 0});
+        data.put(5, new double[] {16.0, 0, 0, 0, 0, 6.0});
 
         KNNRegressionTrainer trainer = new KNNRegressionTrainer();
 
         KNNRegressionModel knnMdl = (KNNRegressionModel) trainer.fit(
-            new LocalDatasetBuilder<>(data, 2),
+            new LocalDatasetBuilder<>(data, parts),
             (k, v) -> Arrays.copyOfRange(v, 1, v.length),
             (k, v) -> v[0]
         ).withK(1)
             .withDistanceMeasure(new EuclideanDistance())
             .withStrategy(KNNStrategy.SIMPLE);
 
-        Vector vector = new DenseLocalOnHeapVector(new double[]{0, 0, 0, 5.0, 
0.0});
+        Vector vector = new DenseLocalOnHeapVector(new double[] {0, 0, 0, 5.0, 
0.0});
         System.out.println(knnMdl.apply(vector));
         Assert.assertEquals(15, knnMdl.apply(vector), 1E-12);
     }
 
     /** */
     @Test
-    public void longly() {
+    public void testLongly() {
         Map<Integer, double[]> data = new HashMap<>();
-
-        data.put(0, new double[]{60323, 83.0, 234289, 2356, 1590, 107608, 
1947});
-        data.put(1, new double[]{61122, 88.5, 259426, 2325, 1456, 108632, 
1948});
-        data.put(2, new double[]{60171, 88.2, 258054, 3682, 1616, 109773, 
1949});
-        data.put(3, new double[]{61187, 89.5, 284599, 3351, 1650, 110929, 
1950});
-        data.put(4, new double[]{63221, 96.2, 328975, 2099, 3099, 112075, 
1951});
-        data.put(5, new double[]{63639, 98.1, 346999, 1932, 3594, 113270, 
1952});
-        data.put(6, new double[]{64989, 99.0, 365385, 1870, 3547, 115094, 
1953});
-        data.put(7, new double[]{63761, 100.0, 363112, 3578, 3350, 116219, 
1954});
-        data.put(8, new double[]{66019, 101.2, 397469, 2904, 3048, 117388, 
1955});
-        data.put(9, new double[]{68169, 108.4, 442769, 2936, 2798, 120445, 
1957});
-        data.put(10, new double[]{66513, 110.8, 444546, 4681, 2637, 121950, 
1958});
-        data.put(11, new double[]{68655, 112.6, 482704, 3813, 2552, 123366, 
1959});
-        data.put(12, new double[]{69564, 114.2, 502601, 3931, 2514, 125368, 
1960});
-        data.put(13, new double[]{69331, 115.7, 518173, 4806, 2572, 127852, 
1961});
-        data.put(14, new double[]{70551, 116.9, 554894, 4007, 2827, 130081, 
1962});
+        data.put(0, new double[] {60323, 83.0, 234289, 2356, 1590, 107608, 
1947});
+        data.put(1, new double[] {61122, 88.5, 259426, 2325, 1456, 108632, 
1948});
+        data.put(2, new double[] {60171, 88.2, 258054, 3682, 1616, 109773, 
1949});
+        data.put(3, new double[] {61187, 89.5, 284599, 3351, 1650, 110929, 
1950});
+        data.put(4, new double[] {63221, 96.2, 328975, 2099, 3099, 112075, 
1951});
+        data.put(5, new double[] {63639, 98.1, 346999, 1932, 3594, 113270, 
1952});
+        data.put(6, new double[] {64989, 99.0, 365385, 1870, 3547, 115094, 
1953});
+        data.put(7, new double[] {63761, 100.0, 363112, 3578, 3350, 116219, 
1954});
+        data.put(8, new double[] {66019, 101.2, 397469, 2904, 3048, 117388, 
1955});
+        data.put(9, new double[] {68169, 108.4, 442769, 2936, 2798, 120445, 
1957});
+        data.put(10, new double[] {66513, 110.8, 444546, 4681, 2637, 121950, 
1958});
+        data.put(11, new double[] {68655, 112.6, 482704, 3813, 2552, 123366, 
1959});
+        data.put(12, new double[] {69564, 114.2, 502601, 3931, 2514, 125368, 
1960});
+        data.put(13, new double[] {69331, 115.7, 518173, 4806, 2572, 127852, 
1961});
+        data.put(14, new double[] {70551, 116.9, 554894, 4007, 2827, 130081, 
1962});
 
         KNNRegressionTrainer trainer = new KNNRegressionTrainer();
 
         KNNRegressionModel knnMdl = (KNNRegressionModel) trainer.fit(
-            new LocalDatasetBuilder<>(data, 2),
+            new LocalDatasetBuilder<>(data, parts),
             (k, v) -> Arrays.copyOfRange(v, 1, v.length),
             (k, v) -> v[0]
         ).withK(3)
             .withDistanceMeasure(new EuclideanDistance())
             .withStrategy(KNNStrategy.SIMPLE);
 
-        Vector vector = new DenseLocalOnHeapVector(new double[]{104.6, 419180, 
2822, 2857, 118734, 1956});
+        Vector vector = new DenseLocalOnHeapVector(new double[] {104.6, 
419180, 2822, 2857, 118734, 1956});
         System.out.println(knnMdl.apply(vector));
         Assert.assertEquals(67857, knnMdl.apply(vector), 2000);
     }
 
     /** */
+    @Test
     public void testLonglyWithWeightedStrategy() {
         Map<Integer, double[]> data = new HashMap<>();
-
-        data.put(0, new double[]{60323, 83.0, 234289, 2356, 1590, 107608, 
1947});
-        data.put(1, new double[]{61122, 88.5, 259426, 2325, 1456, 108632, 
1948});
-        data.put(2, new double[]{60171, 88.2, 258054, 3682, 1616, 109773, 
1949});
-        data.put(3, new double[]{61187, 89.5, 284599, 3351, 1650, 110929, 
1950});
-        data.put(4, new double[]{63221, 96.2, 328975, 2099, 3099, 112075, 
1951});
-        data.put(5, new double[]{63639, 98.1, 346999, 1932, 3594, 113270, 
1952});
-        data.put(6, new double[]{64989, 99.0, 365385, 1870, 3547, 115094, 
1953});
-        data.put(7, new double[]{63761, 100.0, 363112, 3578, 3350, 116219, 
1954});
-        data.put(8, new double[]{66019, 101.2, 397469, 2904, 3048, 117388, 
1955});
-        data.put(9, new double[]{68169, 108.4, 442769, 2936, 2798, 120445, 
1957});
-        data.put(10, new double[]{66513, 110.8, 444546, 4681, 2637, 121950, 
1958});
-        data.put(11, new double[]{68655, 112.6, 482704, 3813, 2552, 123366, 
1959});
-        data.put(12, new double[]{69564, 114.2, 502601, 3931, 2514, 125368, 
1960});
-        data.put(13, new double[]{69331, 115.7, 518173, 4806, 2572, 127852, 
1961});
-        data.put(14, new double[]{70551, 116.9, 554894, 4007, 2827, 130081, 
1962});
+        data.put(0, new double[] {60323, 83.0, 234289, 2356, 1590, 107608, 
1947});
+        data.put(1, new double[] {61122, 88.5, 259426, 2325, 1456, 108632, 
1948});
+        data.put(2, new double[] {60171, 88.2, 258054, 3682, 1616, 109773, 
1949});
+        data.put(3, new double[] {61187, 89.5, 284599, 3351, 1650, 110929, 
1950});
+        data.put(4, new double[] {63221, 96.2, 328975, 2099, 3099, 112075, 
1951});
+        data.put(5, new double[] {63639, 98.1, 346999, 1932, 3594, 113270, 
1952});
+        data.put(6, new double[] {64989, 99.0, 365385, 1870, 3547, 115094, 
1953});
+        data.put(7, new double[] {63761, 100.0, 363112, 3578, 3350, 116219, 
1954});
+        data.put(8, new double[] {66019, 101.2, 397469, 2904, 3048, 117388, 
1955});
+        data.put(9, new double[] {68169, 108.4, 442769, 2936, 2798, 120445, 
1957});
+        data.put(10, new double[] {66513, 110.8, 444546, 4681, 2637, 121950, 
1958});
+        data.put(11, new double[] {68655, 112.6, 482704, 3813, 2552, 123366, 
1959});
+        data.put(12, new double[] {69564, 114.2, 502601, 3931, 2514, 125368, 
1960});
+        data.put(13, new double[] {69331, 115.7, 518173, 4806, 2572, 127852, 
1961});
+        data.put(14, new double[] {70551, 116.9, 554894, 4007, 2827, 130081, 
1962});
 
         KNNRegressionTrainer trainer = new KNNRegressionTrainer();
 
         KNNRegressionModel knnMdl = (KNNRegressionModel) trainer.fit(
-            new LocalDatasetBuilder<>(data, 2),
+            new LocalDatasetBuilder<>(data, parts),
             (k, v) -> Arrays.copyOfRange(v, 1, v.length),
             (k, v) -> v[0]
         ).withK(3)
             .withDistanceMeasure(new EuclideanDistance())
             .withStrategy(KNNStrategy.SIMPLE);
 
-        Vector vector = new DenseLocalOnHeapVector(new double[]{104.6, 419180, 
2822, 2857, 118734, 1956});
+        Vector vector = new DenseLocalOnHeapVector(new double[] {104.6, 
419180, 2822, 2857, 118734, 1956});
         System.out.println(knnMdl.apply(vector));
         Assert.assertEquals(67857, knnMdl.apply(vector), 2000);
     }

http://git-wip-us.apache.org/repos/asf/ignite/blob/ee9ca06a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetHelper.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetHelper.java 
b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetHelper.java
index a25b303..dbcdb99 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetHelper.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetHelper.java
@@ -21,64 +21,33 @@ import java.io.IOException;
 import java.net.URISyntaxException;
 import java.nio.file.Path;
 import java.nio.file.Paths;
-import org.apache.ignite.Ignite;
 import org.apache.ignite.ml.structures.LabeledDataset;
 import org.apache.ignite.ml.structures.preprocessing.LabeledDatasetLoader;
-import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
 
 /**
  * Base class for decision trees test.
  */
-public class LabeledDatasetHelper extends GridCommonAbstractTest {
-    /** Count of nodes. */
-    private static final int NODE_COUNT = 4;
-
+public class LabeledDatasetHelper {
     /** Separator. */
     private static final String SEPARATOR = "\t";
 
-    /** Grid instance. */
-    protected Ignite ignite;
-
-    /**
-     * Default constructor.
-     */
-    public LabeledDatasetHelper() {
-        super(false);
-    }
-
-    /**
-     * {@inheritDoc}
-     */
-    @Override protected void beforeTest() throws Exception {
-        ignite = grid(NODE_COUNT);
-    }
-
-    /** {@inheritDoc} */
-    @Override protected void beforeTestsStarted() throws Exception {
-        for (int i = 1; i <= NODE_COUNT; i++)
-            startGrid(i);
-    }
-
-    /** {@inheritDoc} */
-    @Override protected void afterTestsStopped() throws Exception {
-        stopAllGrids();
-    }
-
     /**
      * Loads labeled dataset from file with .txt extension.
      *
      * @param rsrcPath path to dataset.
      * @return null if path is incorrect.
      */
-    LabeledDataset loadDatasetFromTxt(String rsrcPath, boolean 
isFallOnBadData) {
+    public static LabeledDataset loadDatasetFromTxt(String rsrcPath, boolean 
isFallOnBadData) {
         try {
-            Path path = 
Paths.get(this.getClass().getClassLoader().getResource(rsrcPath).toURI());
+            Path path = 
Paths.get(LabeledDatasetHelper.class.getClassLoader().getResource(rsrcPath).toURI());
             try {
                 return LabeledDatasetLoader.loadFromTxtFile(path, SEPARATOR, 
false, isFallOnBadData);
-            } catch (IOException e) {
+            }
+            catch (IOException e) {
                 e.printStackTrace();
             }
-        } catch (URISyntaxException e) {
+        }
+        catch (URISyntaxException e) {
             e.printStackTrace();
             return null;
         }

http://git-wip-us.apache.org/repos/asf/ignite/blob/ee9ca06a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java 
b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java
index 77d40a6..e986740 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java
@@ -21,7 +21,6 @@ import java.io.IOException;
 import java.net.URISyntaxException;
 import java.nio.file.Path;
 import java.nio.file.Paths;
-import org.apache.ignite.internal.util.IgniteUtils;
 import org.apache.ignite.ml.math.ExternalizableTest;
 import org.apache.ignite.ml.math.Vector;
 import org.apache.ignite.ml.math.exceptions.CardinalityException;
@@ -32,9 +31,13 @@ import org.apache.ignite.ml.structures.LabeledDataset;
 import org.apache.ignite.ml.structures.LabeledDatasetTestTrainPair;
 import org.apache.ignite.ml.structures.LabeledVector;
 import org.apache.ignite.ml.structures.preprocessing.LabeledDatasetLoader;
+import org.junit.Test;
+
+import static junit.framework.TestCase.assertEquals;
+import static junit.framework.TestCase.fail;
 
 /** Tests behaviour of KNNClassificationTest. */
-public class LabeledDatasetTest extends LabeledDatasetHelper implements 
ExternalizableTest<LabeledDataset> {
+public class LabeledDatasetTest implements ExternalizableTest<LabeledDataset> {
     /** */
     private static final String KNN_IRIS_TXT = "datasets/knn/iris.txt";
 
@@ -51,9 +54,8 @@ public class LabeledDatasetTest extends LabeledDatasetHelper 
implements External
     private static final String IRIS_MISSED_DATA = 
"datasets/knn/missed_data.txt";
 
     /** */
+    @Test
     public void testFeatureNames() {
-        
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
         double[][] mtx =
             new double[][] {
                 {1.0, 1.0},
@@ -71,9 +73,8 @@ public class LabeledDatasetTest extends LabeledDatasetHelper 
implements External
     }
 
     /** */
+    @Test
     public void testAccessMethods() {
-        
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
         double[][] mtx =
             new double[][] {
                 {1.0, 1.0},
@@ -98,9 +99,8 @@ public class LabeledDatasetTest extends LabeledDatasetHelper 
implements External
     }
 
     /** */
+    @Test
     public void testFailOnYNull() {
-        
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
         double[][] mtx =
             new double[][] {
                 {1.0, 1.0},
@@ -122,9 +122,8 @@ public class LabeledDatasetTest extends 
LabeledDatasetHelper implements External
     }
 
     /** */
+    @Test
     public void testFailOnXNull() {
-        
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
         double[][] mtx =
             new double[][] {};
         double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
@@ -140,18 +139,17 @@ public class LabeledDatasetTest extends 
LabeledDatasetHelper implements External
     }
 
     /** */
+    @Test
     public void testLoadingCorrectTxtFile() {
-        
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-        LabeledDataset training = loadDatasetFromTxt(KNN_IRIS_TXT, false);
+        LabeledDataset training = 
LabeledDatasetHelper.loadDatasetFromTxt(KNN_IRIS_TXT, false);
         assertEquals(training.rowSize(), 150);
     }
 
     /** */
+    @Test
     public void testLoadingEmptyFile() {
-        
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
         try {
-            loadDatasetFromTxt(EMPTY_TXT, false);
+            LabeledDatasetHelper.loadDatasetFromTxt(EMPTY_TXT, false);
             fail("EmptyFileException");
         }
         catch (EmptyFileException e) {
@@ -161,11 +159,10 @@ public class LabeledDatasetTest extends 
LabeledDatasetHelper implements External
     }
 
     /** */
+    @Test
     public void testLoadingFileWithFirstEmptyRow() {
-        
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
         try {
-            loadDatasetFromTxt(NO_DATA_TXT, false);
+            LabeledDatasetHelper.loadDatasetFromTxt(NO_DATA_TXT, false);
             fail("NoDataException");
         }
         catch (NoDataException e) {
@@ -175,19 +172,17 @@ public class LabeledDatasetTest extends 
LabeledDatasetHelper implements External
     }
 
     /** */
+    @Test
     public void testLoadingFileWithIncorrectData() {
-        
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
-        LabeledDataset training = loadDatasetFromTxt(IRIS_INCORRECT_TXT, 
false);
+        LabeledDataset training = 
LabeledDatasetHelper.loadDatasetFromTxt(IRIS_INCORRECT_TXT, false);
         assertEquals(149, training.rowSize());
     }
 
     /** */
+    @Test
     public void testFailOnLoadingFileWithIncorrectData() {
-        
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
         try {
-            loadDatasetFromTxt(IRIS_INCORRECT_TXT, true);
+            LabeledDatasetHelper.loadDatasetFromTxt(IRIS_INCORRECT_TXT, true);
             fail("FileParsingException");
         }
         catch (FileParsingException e) {
@@ -198,9 +193,8 @@ public class LabeledDatasetTest extends 
LabeledDatasetHelper implements External
     }
 
     /** */
+    @Test
     public void testLoadingFileWithMissedData() throws URISyntaxException, 
IOException {
-        
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
         Path path = 
Paths.get(this.getClass().getClassLoader().getResource(IRIS_MISSED_DATA).toURI());
 
         LabeledDataset training = LabeledDatasetLoader.loadFromTxtFile(path, 
",", false, false);
@@ -209,9 +203,8 @@ public class LabeledDatasetTest extends 
LabeledDatasetHelper implements External
     }
 
     /** */
+    @Test
     public void testSplitting() {
-        
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
         double[][] mtx =
             new double[][] {
                 {1.0, 1.0},
@@ -246,9 +239,8 @@ public class LabeledDatasetTest extends 
LabeledDatasetHelper implements External
     }
 
     /** */
+    @Test
     public void testLabels() {
-        
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
         double[][] mtx =
             new double[][] {
                 {1.0, 1.0},
@@ -267,8 +259,6 @@ public class LabeledDatasetTest extends 
LabeledDatasetHelper implements External
 
     /** */
     @Override public void testExternalization() {
-        
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
         double[][] mtx =
             new double[][] {
                 {1.0, 1.0},

Reply via email to