IGNITE-9393:[ML] KMeans fails on complex data in cache

this closes #4628


Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/3f184913
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/3f184913
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/3f184913

Branch: refs/heads/ignite-9273
Commit: 3f1849131a47a3450a5768c5d29dacfb7ee98923
Parents: e2ff347
Author: zaleslaw <zaleslaw....@gmail.com>
Authored: Tue Aug 28 14:03:37 2018 +0300
Committer: Yury Babak <yba...@gridgain.com>
Committed: Tue Aug 28 14:03:37 2018 +0300

----------------------------------------------------------------------
 .../ml/clustering/kmeans/KMeansTrainer.java     |  60 +++++++----
 .../ml/knn/ann/ANNClassificationTrainer.java    |  12 ++-
 .../classification/KNNClassificationModel.java  |   9 +-
 .../ignite/ml/math/isolve/lsqr/LSQROnHeap.java  |   8 +-
 .../linear/LinearRegressionSGDTrainer.java      |  10 +-
 .../binomial/LogisticRegressionSGDTrainer.java  |  12 ++-
 .../LogRegressionMultiClassTrainer.java         |  25 +++--
 .../SVMLinearBinaryClassificationTrainer.java   |  68 +++++++++++--
 ...VMLinearMultiClassClassificationTrainer.java |  34 ++++++-
 .../ignite/ml/knn/ANNClassificationTest.java    |   3 -
 .../ml/svm/SVMBinaryTrainerIntegrationTest.java | 102 -------------------
 .../ignite/ml/svm/SVMBinaryTrainerTest.java     |   3 +-
 .../ignite/ml/svm/SVMMultiClassTrainerTest.java |   3 +-
 .../org/apache/ignite/ml/svm/SVMTestSuite.java  |   1 -
 14 files changed, 191 insertions(+), 159 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/ignite/blob/3f184913/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java
index c005312..5b880fcc 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java
@@ -65,13 +65,13 @@ public class KMeansTrainer extends 
SingleLabelDatasetTrainer<KMeansModel> {
     /**
      * Trains model based on the specified data.
      *
-     * @param datasetBuilder   Dataset builder.
+     * @param datasetBuilder Dataset builder.
      * @param featureExtractor Feature extractor.
-     * @param lbExtractor      Label extractor.
+     * @param lbExtractor Label extractor.
      * @return Model.
      */
     @Override public <K, V> KMeansModel fit(DatasetBuilder<K, V> 
datasetBuilder,
-                                            IgniteBiFunction<K, V, Vector> 
featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, 
V, Double> lbExtractor) {
         assert datasetBuilder != null;
 
         PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<Double, 
LabeledVector>> partDataBuilder = new 
LabeledDatasetPartitionDataBuilderOnHeap<>(
@@ -85,7 +85,14 @@ public class KMeansTrainer extends 
SingleLabelDatasetTrainer<KMeansModel> {
             (upstream, upstreamSize) -> new EmptyContext(),
             partDataBuilder
         )) {
-            final int cols = 
dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> a 
== null ? b : a);
+            final int cols = 
dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> {
+                if (a == null)
+                    return b == null ? 0 : b;
+                if (b == null)
+                    return a;
+                return b;
+            });
+
             centers = initClusterCentersRandomly(dataset, k);
 
             boolean converged = false;
@@ -113,7 +120,8 @@ public class KMeansTrainer extends 
SingleLabelDatasetTrainer<KMeansModel> {
                         centers[i] = newCentroids[i];
                 }
             }
-        } catch (Exception e) {
+        }
+        catch (Exception e) {
             throw new RuntimeException(e);
         }
         return new KMeansModel(centers, distance);
@@ -124,15 +132,14 @@ public class KMeansTrainer extends 
SingleLabelDatasetTrainer<KMeansModel> {
      *
      * @param centers Current centers on the current iteration.
      * @param dataset Dataset.
-     * @param cols    Amount of columns.
+     * @param cols Amount of columns.
      * @return Helper data to calculate the new centroids.
      */
     private TotalCostAndCounts calcDataForNewCentroids(Vector[] centers,
-                                                       Dataset<EmptyContext, 
LabeledVectorSet<Double, LabeledVector>> dataset, int cols) {
+        Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> 
dataset, int cols) {
         final Vector[] finalCenters = centers;
 
         return dataset.compute(data -> {
-
             TotalCostAndCounts res = new TotalCostAndCounts();
 
             for (int i = 0; i < data.rowSize(); i++) {
@@ -147,20 +154,29 @@ public class KMeansTrainer extends 
SingleLabelDatasetTrainer<KMeansModel> {
 
                 int finalI = i;
                 res.sums.compute(centroidIdx,
-                    (IgniteBiFunction<Integer, Vector, Vector>) (ind, v) -> 
v.plus(data.getRow(finalI).features()));
+                    (IgniteBiFunction<Integer, Vector, Vector>)(ind, v) -> {
+                        Vector features = data.getRow(finalI).features();
+                        return v == null ? features : v.plus(features);
+                    });
 
                 res.counts.merge(centroidIdx, 1,
-                    (IgniteBiFunction<Integer, Integer, Integer>) (i1, i2) -> 
i1 + i2);
+                    (IgniteBiFunction<Integer, Integer, Integer>)(i1, i2) -> 
i1 + i2);
             }
             return res;
-        }, (a, b) -> a == null ? b : a.merge(b));
+        }, (a, b) -> {
+            if (a == null)
+                return b == null ? new TotalCostAndCounts() : b;
+            if (b == null)
+                return a;
+            return a.merge(b);
+        });
     }
 
     /**
      * Find the closest cluster center index and distance to it from a given 
point.
      *
      * @param centers Centers to look in.
-     * @param pnt     Point.
+     * @param pnt Point.
      */
     private IgniteBiTuple<Integer, Double> findClosestCentroid(Vector[] 
centers, LabeledVector pnt) {
         double bestDistance = Double.POSITIVE_INFINITY;
@@ -180,12 +196,11 @@ public class KMeansTrainer extends 
SingleLabelDatasetTrainer<KMeansModel> {
      * K cluster centers are initialized randomly.
      *
      * @param dataset The dataset to pick up random centers.
-     * @param k       Amount of clusters.
+     * @param k Amount of clusters.
      * @return K cluster centers.
      */
     private Vector[] initClusterCentersRandomly(Dataset<EmptyContext, 
LabeledVectorSet<Double, LabeledVector>> dataset,
-                                                int k) {
-
+        int k) {
         Vector[] initCenters = new DenseVector[k];
 
         // Gets k or less vectors from each partition.
@@ -211,12 +226,19 @@ public class KMeansTrainer extends 
SingleLabelDatasetTrainer<KMeansModel> {
 
                         rndPnt.add(data.getRow(nextIdx));
                     }
-                } else // If it's not enough vectors to pick k vectors.
+                }
+                else // If it's not enough vectors to pick k vectors.
                     for (int i = 0; i < data.rowSize(); i++)
                         rndPnt.add(data.getRow(i));
             }
             return rndPnt;
-        }, (a, b) -> a == null ? b : Stream.concat(a.stream(), 
b.stream()).collect(Collectors.toList()));
+        }, (a, b) -> {
+            if (a == null)
+                return b == null ? new ArrayList<>() : b;
+            if (b == null)
+                return a;
+            return Stream.concat(a.stream(), 
b.stream()).collect(Collectors.toList());
+        });
 
         // Shuffle them.
         Collections.shuffle(rndPnts);
@@ -228,7 +250,8 @@ public class KMeansTrainer extends 
SingleLabelDatasetTrainer<KMeansModel> {
                 rndPnts.remove(rndPnt);
                 initCenters[i] = rndPnt.features();
             }
-        } else
+        }
+        else
             throw new RuntimeException("The KMeans Trainer required more than 
" + k + " vectors to find " + k + " clusters");
 
         return initCenters;
@@ -245,7 +268,6 @@ public class KMeansTrainer extends 
SingleLabelDatasetTrainer<KMeansModel> {
         /** Count of points closest to the center with a given index. */
         ConcurrentHashMap<Integer, Integer> counts = new ConcurrentHashMap<>();
 
-
         /** Count of points closest to the center with a given index. */
         ConcurrentHashMap<Integer, ConcurrentHashMap<Double, Integer>> 
centroidStat = new ConcurrentHashMap<>();
 

http://git-wip-us.apache.org/repos/asf/ignite/blob/3f184913/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java
index 282be3c..1c45812 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java
@@ -149,9 +149,7 @@ public class ANNClassificationTrainer extends 
SingleLabelDatasetTrainer<ANNClass
             (upstream, upstreamSize) -> new EmptyContext(),
             partDataBuilder
         )) {
-
             return dataset.compute(data -> {
-
                 CentroidStat res = new CentroidStat();
 
                 for (int i = 0; i < data.rowSize(); i++) {
@@ -171,7 +169,7 @@ public class ANNClassificationTrainer extends 
SingleLabelDatasetTrainer<ANNClass
                         centroidStat.put(lb, 1);
                         res.centroidStat.put(centroidIdx, centroidStat);
                     } else {
-                        int cnt = centroidStat.containsKey(lb) ? 
centroidStat.get(lb) : 0;
+                        int cnt = centroidStat.getOrDefault(lb, 0);
                         centroidStat.put(lb, cnt + 1);
                     }
 
@@ -179,7 +177,13 @@ public class ANNClassificationTrainer extends 
SingleLabelDatasetTrainer<ANNClass
                         (IgniteBiFunction<Integer, Integer, Integer>) (i1, i2) 
-> i1 + i2);
                 }
                 return res;
-            }, (a, b) -> a == null ? b : a.merge(b));
+            }, (a, b) -> {
+                if (a == null)
+                    return b == null ? new CentroidStat() : b;
+                if (b == null)
+                    return a;
+                return a.merge(b);
+            });
 
         } catch (Exception e) {
             throw new RuntimeException(e);

http://git-wip-us.apache.org/repos/asf/ignite/blob/3f184913/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java
index 3404ae8..0b88f81 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java
@@ -17,6 +17,7 @@
 
 package org.apache.ignite.ml.knn.classification;
 
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.HashMap;
 import java.util.List;
@@ -79,7 +80,13 @@ public class KNNClassificationModel extends 
NNClassificationModel implements Exp
         List<LabeledVector> neighborsFromPartitions = dataset.compute(data -> {
             TreeMap<Double, Set<Integer>> distanceIdxPairs = getDistances(v, 
data);
             return Arrays.asList(getKClosestVectors(data, distanceIdxPairs));
-        }, (a, b) -> a == null ? b : Stream.concat(a.stream(), 
b.stream()).collect(Collectors.toList()));
+        }, (a, b) -> {
+            if (a == null)
+                return b == null ? new ArrayList<>() : b;
+            if (b == null)
+                return a;
+            return Stream.concat(a.stream(), 
b.stream()).collect(Collectors.toList());
+        });
 
         LabeledVectorSet<Double, LabeledVector> neighborsToFilter = 
buildLabeledDatasetOnListOfVectors(neighborsFromPartitions);
 

http://git-wip-us.apache.org/repos/asf/ignite/blob/3f184913/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java
index e138cf3..f75caef 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java
@@ -103,7 +103,13 @@ public class LSQROnHeap<K, V> extends AbstractLSQR 
implements AutoCloseable {
     @Override protected int getColumns() {
         return dataset.compute(
             data -> data.getFeatures() == null ? null : 
data.getFeatures().length / data.getRows(),
-            (a, b) -> a == null ? b : a
+            (a, b) -> {
+                if (a == null)
+                    return b == null ? 0 : b;
+                if (b == null)
+                    return a;
+                return b;
+            }
         );
     }
 

http://git-wip-us.apache.org/repos/asf/ignite/blob/3f184913/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java
index 2237c95..44f60d1 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java
@@ -82,7 +82,13 @@ public class LinearRegressionSGDTrainer<P extends 
Serializable> extends SingleLa
                 if (data.getFeatures() == null)
                     return null;
                 return data.getFeatures().length / data.getRows();
-            }, (a, b) -> a == null ? b : a);
+            }, (a, b) -> {
+                if (a == null)
+                    return b == null ? 0 : b;
+                if (b == null)
+                    return a;
+                return b;
+            });
 
             MLPArchitecture architecture = new MLPArchitecture(cols);
             architecture = architecture.withAddedLayer(1, true, 
Activators.LINEAR);
@@ -100,7 +106,7 @@ public class LinearRegressionSGDTrainer<P extends 
Serializable> extends SingleLa
             seed
         );
 
-        IgniteBiFunction<K, V, double[]> lbE = (IgniteBiFunction<K, V, 
double[]>)(k, v) -> new double[]{lbExtractor.apply(k, v)};
+        IgniteBiFunction<K, V, double[]> lbE = (IgniteBiFunction<K, V, 
double[]>)(k, v) -> new double[] {lbExtractor.apply(k, v)};
 
         MultilayerPerceptron mlp = trainer.fit(datasetBuilder, 
featureExtractor, lbE);
 

http://git-wip-us.apache.org/repos/asf/ignite/blob/3f184913/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java
index 840a18d..6396279 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java
@@ -64,7 +64,7 @@ public class LogisticRegressionSGDTrainer<P extends 
Serializable> extends Single
      * @param seed Seed for random generator.
      */
     public LogisticRegressionSGDTrainer(UpdatesStrategy<? super 
MultilayerPerceptron, P> updatesStgy, int maxIterations,
-                                        int batchSize, int locIterations, long 
seed) {
+        int batchSize, int locIterations, long seed) {
         this.updatesStgy = updatesStgy;
         this.maxIterations = maxIterations;
         this.batchSize = batchSize;
@@ -82,7 +82,13 @@ public class LogisticRegressionSGDTrainer<P extends 
Serializable> extends Single
                 if (data.getFeatures() == null)
                     return null;
                 return data.getFeatures().length / data.getRows();
-            }, (a, b) -> a == null ? b : a);
+            }, (a, b) -> {
+                if (a == null)
+                    return b == null ? 0 : b;
+                if (b == null)
+                    return a;
+                return b;
+            });
 
             MLPArchitecture architecture = new MLPArchitecture(cols);
             architecture = architecture.withAddedLayer(1, true, 
Activators.SIGMOID);
@@ -100,7 +106,7 @@ public class LogisticRegressionSGDTrainer<P extends 
Serializable> extends Single
             seed
         );
 
-        MultilayerPerceptron mlp = trainer.fit(datasetBuilder, 
featureExtractor, (k, v) -> new double[]{lbExtractor.apply(k, v)});
+        MultilayerPerceptron mlp = trainer.fit(datasetBuilder, 
featureExtractor, (k, v) -> new double[] {lbExtractor.apply(k, v)});
 
         double[] params = mlp.parameters().getStorage().data();
 

http://git-wip-us.apache.org/repos/asf/ignite/blob/3f184913/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java
index 1ed938a..4885373 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java
@@ -61,14 +61,14 @@ public class LogRegressionMultiClassTrainer<P extends 
Serializable>
     /**
      * Trains model based on the specified data.
      *
-     * @param datasetBuilder   Dataset builder.
+     * @param datasetBuilder Dataset builder.
      * @param featureExtractor Feature extractor.
-     * @param lbExtractor      Label extractor.
+     * @param lbExtractor Label extractor.
      * @return Model.
      */
     @Override public <K, V> LogRegressionMultiClassModel fit(DatasetBuilder<K, 
V> datasetBuilder,
-                                                                
IgniteBiFunction<K, V, Vector> featureExtractor,
-                                                                
IgniteBiFunction<K, V, Double> lbExtractor) {
+        IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
         List<Double> classes = extractClassLabels(datasetBuilder, lbExtractor);
 
         LogRegressionMultiClassModel multiClsMdl = new 
LogRegressionMultiClassModel();
@@ -92,7 +92,8 @@ public class LogRegressionMultiClassTrainer<P extends 
Serializable>
     }
 
     /** Iterates among dataset and collects class labels. */
-    private <K, V> List<Double> extractClassLabels(DatasetBuilder<K, V> 
datasetBuilder, IgniteBiFunction<K, V, Double> lbExtractor) {
+    private <K, V> List<Double> extractClassLabels(DatasetBuilder<K, V> 
datasetBuilder,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
         assert datasetBuilder != null;
 
         PartitionDataBuilder<K, V, EmptyContext, LabelPartitionDataOnHeap> 
partDataBuilder = new LabelPartitionDataBuilderOnHeap<>(lbExtractor);
@@ -108,14 +109,22 @@ public class LogRegressionMultiClassTrainer<P extends 
Serializable>
 
                 final double[] lbs = data.getY();
 
-                for (double lb : lbs) locClsLabels.add(lb);
+                for (double lb : lbs)
+                    locClsLabels.add(lb);
 
                 return locClsLabels;
-            }, (a, b) -> a == null ? b : Stream.of(a, 
b).flatMap(Collection::stream).collect(Collectors.toSet()));
+            }, (a, b) -> {
+                if (a == null)
+                    return b == null ? new HashSet<>() : b;
+                if (b == null)
+                    return a;
+                return Stream.of(a, 
b).flatMap(Collection::stream).collect(Collectors.toSet());
+            });
 
             res.addAll(clsLabels);
 
-        } catch (Exception e) {
+        }
+        catch (Exception e) {
             throw new RuntimeException(e);
         }
         return res;

http://git-wip-us.apache.org/repos/asf/ignite/blob/3f184913/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
index 4f11318..933a712 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
@@ -17,7 +17,7 @@
 
 package org.apache.ignite.ml.svm;
 
-import java.util.concurrent.ThreadLocalRandom;
+import java.util.Random;
 import org.apache.ignite.ml.dataset.Dataset;
 import org.apache.ignite.ml.dataset.DatasetBuilder;
 import org.apache.ignite.ml.dataset.PartitionDataBuilder;
@@ -47,12 +47,15 @@ public class SVMLinearBinaryClassificationTrainer extends 
SingleLabelDatasetTrai
     /** Regularization parameter. */
     private double lambda = 0.4;
 
+    /** The seed number. */
+    private long seed;
+
     /**
      * Trains model based on the specified data.
      *
-     * @param datasetBuilder   Dataset builder.
+     * @param datasetBuilder Dataset builder.
      * @param featureExtractor Feature extractor.
-     * @param lbExtractor      Label extractor.
+     * @param lbExtractor Label extractor.
      * @return Model.
      */
     @Override public <K, V> SVMLinearBinaryClassificationModel 
fit(DatasetBuilder<K, V> datasetBuilder,
@@ -67,19 +70,28 @@ public class SVMLinearBinaryClassificationTrainer extends 
SingleLabelDatasetTrai
 
         Vector weights;
 
-        try(Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> 
dataset = datasetBuilder.build(
+        try (Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> 
dataset = datasetBuilder.build(
             (upstream, upstreamSize) -> new EmptyContext(),
             partDataBuilder
         )) {
-            final int cols = 
dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> a 
== null ? b : a);
+            final int cols = 
dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> {
+                if (a == null)
+                    return b == null ? 0 : b;
+                if (b == null)
+                    return a;
+                return b;
+            });
+
             final int weightVectorSizeWithIntercept = cols + 1;
+
             weights = 
initializeWeightsWithZeros(weightVectorSizeWithIntercept);
 
             for (int i = 0; i < this.getAmountOfIterations(); i++) {
                 Vector deltaWeights = calculateUpdates(weights, dataset);
                 weights = weights.plus(deltaWeights); // creates new vector
             }
-        } catch (Exception e) {
+        }
+        catch (Exception e) {
             throw new RuntimeException(e);
         }
         return new SVMLinearBinaryClassificationModel(weights.viewPart(1, 
weights.size() - 1), weights.get(0));
@@ -87,11 +99,12 @@ public class SVMLinearBinaryClassificationTrainer extends 
SingleLabelDatasetTrai
 
     /** */
     @NotNull private Vector initializeWeightsWithZeros(int vectorSize) {
-            return new DenseVector(vectorSize);
+        return new DenseVector(vectorSize);
     }
 
     /** */
-    private Vector calculateUpdates(Vector weights, Dataset<EmptyContext, 
LabeledVectorSet<Double, LabeledVector>> dataset) {
+    private Vector calculateUpdates(Vector weights,
+        Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> 
dataset) {
         return dataset.compute(data -> {
             Vector copiedWeights = weights.copy();
             Vector deltaWeights = initializeWeightsWithZeros(weights.size());
@@ -100,8 +113,10 @@ public class SVMLinearBinaryClassificationTrainer extends 
SingleLabelDatasetTrai
             Vector tmpAlphas = initializeWeightsWithZeros(amountOfObservation);
             Vector deltaAlphas = 
initializeWeightsWithZeros(amountOfObservation);
 
+            Random random = new Random(seed);
+
             for (int i = 0; i < this.getAmountOfLocIterations(); i++) {
-                int randomIdx = 
ThreadLocalRandom.current().nextInt(amountOfObservation);
+                int randomIdx = random.nextInt(amountOfObservation);
 
                 Deltas deltas = getDeltas(data, copiedWeights, 
amountOfObservation, tmpAlphas, randomIdx);
 
@@ -112,12 +127,18 @@ public class SVMLinearBinaryClassificationTrainer extends 
SingleLabelDatasetTrai
                 deltaAlphas.set(randomIdx, deltaAlphas.get(randomIdx) + 
deltas.deltaAlpha);
             }
             return deltaWeights;
-        }, (a, b) -> a == null ? b : a.plus(b));
+        }, (a, b) -> {
+            if (a == null)
+                return b == null ? new DenseVector() : b;
+            if (b == null)
+                return a;
+            return a.plus(b);
+        });
     }
 
     /** */
     private Deltas getDeltas(LabeledVectorSet data, Vector copiedWeights, int 
amountOfObservation, Vector tmpAlphas,
-                             int randomIdx) {
+        int randomIdx) {
         LabeledVector row = (LabeledVector)data.getRow(randomIdx);
         Double lb = (Double)row.label();
         Vector v = makeVectorWithInterceptElement(row);
@@ -191,6 +212,7 @@ public class SVMLinearBinaryClassificationTrainer extends 
SingleLabelDatasetTrai
 
     /**
      * Set up the regularization parameter.
+     *
      * @param lambda The regularization parameter. Should be more than 0.0.
      * @return Trainer with new lambda parameter value.
      */
@@ -202,6 +224,7 @@ public class SVMLinearBinaryClassificationTrainer extends 
SingleLabelDatasetTrai
 
     /**
      * Gets the regularization lambda.
+     *
      * @return The parameter value.
      */
     public double lambda() {
@@ -210,6 +233,7 @@ public class SVMLinearBinaryClassificationTrainer extends 
SingleLabelDatasetTrai
 
     /**
      * Gets the amount of outer iterations of SCDA algorithm.
+     *
      * @return The parameter value.
      */
     public int getAmountOfIterations() {
@@ -218,6 +242,7 @@ public class SVMLinearBinaryClassificationTrainer extends 
SingleLabelDatasetTrai
 
     /**
      * Set up the amount of outer iterations of SCDA algorithm.
+     *
      * @param amountOfIterations The parameter value.
      * @return Trainer with new amountOfIterations parameter value.
      */
@@ -228,6 +253,7 @@ public class SVMLinearBinaryClassificationTrainer extends 
SingleLabelDatasetTrai
 
     /**
      * Gets the amount of local iterations of SCDA algorithm.
+     *
      * @return The parameter value.
      */
     public int getAmountOfLocIterations() {
@@ -236,6 +262,7 @@ public class SVMLinearBinaryClassificationTrainer extends 
SingleLabelDatasetTrai
 
     /**
      * Set up the amount of local iterations of SCDA algorithm.
+     *
      * @param amountOfLocIterations The parameter value.
      * @return Trainer with new amountOfLocIterations parameter value.
      */
@@ -244,6 +271,25 @@ public class SVMLinearBinaryClassificationTrainer extends 
SingleLabelDatasetTrai
         return this;
     }
 
+    /**
+     * Gets the seed number.
+     *
+     * @return The parameter value.
+     */
+    public long getSeed() {
+        return seed;
+    }
+
+    /**
+     * Set up the seed.
+     *
+     * @param seed The parameter value.
+     * @return Model with new seed parameter value.
+     */
+    public SVMLinearBinaryClassificationTrainer withSeed(long seed) {
+        this.seed = seed;
+        return this;
+    }
 }
 
 /** This is a helper class to handle pair results which are returned from the 
calculation method. */

http://git-wip-us.apache.org/repos/asf/ignite/blob/3f184913/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
index 7069c4d..4b7cc95 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
@@ -51,6 +51,9 @@ public class SVMLinearMultiClassClassificationTrainer
     /** Regularization parameter. */
     private double lambda = 0.2;
 
+    /** The seed number. */
+    private long seed;
+
     /**
      * Trains model based on the specified data.
      *
@@ -70,7 +73,8 @@ public class SVMLinearMultiClassClassificationTrainer
             SVMLinearBinaryClassificationTrainer trainer = new 
SVMLinearBinaryClassificationTrainer()
                 .withAmountOfIterations(this.amountOfIterations())
                 .withAmountOfLocIterations(this.amountOfLocIterations())
-                .withLambda(this.lambda());
+                .withLambda(this.lambda())
+                .withSeed(this.seed);
 
             IgniteBiFunction<K, V, Double> lbTransformer = (k, v) -> {
                 Double lb = lbExtractor.apply(k, v);
@@ -106,7 +110,13 @@ public class SVMLinearMultiClassClassificationTrainer
                 for (double lb : lbs) locClsLabels.add(lb);
 
                 return locClsLabels;
-            }, (a, b) -> a == null ? b : Stream.of(a, 
b).flatMap(Collection::stream).collect(Collectors.toSet()));
+            }, (a, b) -> {
+                if (a == null)
+                    return b == null ? new HashSet<>() : b;
+                if (b == null)
+                    return a;
+                return Stream.of(a, 
b).flatMap(Collection::stream).collect(Collectors.toSet());
+            });
 
             res.addAll(clsLabels);
 
@@ -176,4 +186,24 @@ public class SVMLinearMultiClassClassificationTrainer
         this.amountOfLocIterations = amountOfLocIterations;
         return this;
     }
+
+    /**
+     * Gets the seed number.
+     *
+     * @return The parameter value.
+     */
+    public long getSeed() {
+        return seed;
+    }
+
+    /**
+     * Set up the seed.
+     *
+     * @param seed The parameter value.
+     * @return Model with new seed parameter value.
+     */
+    public SVMLinearMultiClassClassificationTrainer withSeed(long seed) {
+        this.seed = seed;
+        return this;
+    }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/3f184913/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java 
b/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java
index aed6387..7289b1d 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java
@@ -62,9 +62,6 @@ public class ANNClassificationTest extends TrainerTest {
             .withDistanceMeasure(new EuclideanDistance())
             .withStrategy(NNStrategy.SIMPLE);
 
-        TestUtils.assertEquals(0, mdl.apply(VectorUtils.of(550, 550)), 
PRECISION);
-        TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(-550, -550)), 
PRECISION);
-
         Assert.assertNotNull(((ANNClassificationModel) mdl).getCandidates());
 
         Assert.assertTrue(mdl.toString().contains(NNStrategy.SIMPLE.name()));

http://git-wip-us.apache.org/repos/asf/ignite/blob/3f184913/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerIntegrationTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerIntegrationTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerIntegrationTest.java
deleted file mode 100644
index d227de7..0000000
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerIntegrationTest.java
+++ /dev/null
@@ -1,102 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.svm;
-
-import java.util.Arrays;
-import java.util.UUID;
-import java.util.concurrent.ThreadLocalRandom;
-import org.apache.ignite.Ignite;
-import org.apache.ignite.IgniteCache;
-import org.apache.ignite.internal.util.IgniteUtils;
-import org.apache.ignite.ml.TestUtils;
-import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
-import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
-import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
-
-/**
- * Tests for {@link SVMLinearBinaryClassificationTrainer} that require to 
start the whole Ignite infrastructure.
- */
-public class SVMBinaryTrainerIntegrationTest extends GridCommonAbstractTest {
-    /** Fixed size of Dataset. */
-    private static final int AMOUNT_OF_OBSERVATIONS = 1000;
-
-    /** Fixed size of columns in Dataset. */
-    private static final int AMOUNT_OF_FEATURES = 2;
-
-    /** Precision in test checks. */
-    private static final double PRECISION = 1e-2;
-
-    /** Number of nodes in grid */
-    private static final int NODE_COUNT = 3;
-
-    /** Ignite instance. */
-    private Ignite ignite;
-
-    /** {@inheritDoc} */
-    @Override protected void beforeTestsStarted() throws Exception {
-        for (int i = 1; i <= NODE_COUNT; i++)
-            startGrid(i);
-    }
-
-    /** {@inheritDoc} */
-    @Override protected void afterTestsStopped() {
-        stopAllGrids();
-    }
-
-    /**
-     * {@inheritDoc}
-     */
-    @Override protected void beforeTest() throws Exception {
-        /* Grid instance. */
-        ignite = grid(NODE_COUNT);
-        ignite.configuration().setPeerClassLoadingEnabled(true);
-        
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-    }
-
-    /**
-     * Test trainer on classification model y = x.
-     */
-    public void testTrainWithTheLinearlySeparableCase() {
-        IgniteCache<Integer, double[]> data = 
ignite.getOrCreateCache(UUID.randomUUID().toString());
-
-        ThreadLocalRandom rndX = ThreadLocalRandom.current();
-        ThreadLocalRandom rndY = ThreadLocalRandom.current();
-
-        for (int i = 0; i < AMOUNT_OF_OBSERVATIONS; i++) {
-            double x = rndX.nextDouble(-1000, 1000);
-            double y = rndY.nextDouble(-1000, 1000);
-            double[] vec = new double[AMOUNT_OF_FEATURES + 1];
-            vec[0] = y - x > 0 ? 1 : -1; // assign label.
-            vec[1] = x;
-            vec[2] = y;
-            data.put(i, vec);
-        }
-
-        SVMLinearBinaryClassificationTrainer trainer = new 
SVMLinearBinaryClassificationTrainer();
-
-        SVMLinearBinaryClassificationModel mdl = trainer.fit(
-            ignite,
-            data,
-            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
-            (k, v) -> v[0]
-        );
-
-        TestUtils.assertEquals(-1, mdl.apply(new DenseVector(new double[]{100, 
10})), PRECISION);
-        TestUtils.assertEquals(1, mdl.apply(new DenseVector(new double[]{10, 
100})), PRECISION);
-    }
-}

http://git-wip-us.apache.org/repos/asf/ignite/blob/3f184913/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java 
b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java
index b772177..5630bee 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMBinaryTrainerTest.java
@@ -39,7 +39,8 @@ public class SVMBinaryTrainerTest extends TrainerTest {
         for (int i = 0; i < twoLinearlySeparableClasses.length; i++)
             cacheMock.put(i, twoLinearlySeparableClasses[i]);
 
-        SVMLinearBinaryClassificationTrainer trainer = new 
SVMLinearBinaryClassificationTrainer();
+        SVMLinearBinaryClassificationTrainer trainer = new 
SVMLinearBinaryClassificationTrainer()
+            .withSeed(1234L);
 
         SVMLinearBinaryClassificationModel mdl = trainer.fit(
             cacheMock,

http://git-wip-us.apache.org/repos/asf/ignite/blob/3f184913/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java
index f2328f8..7ea28c2 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMMultiClassTrainerTest.java
@@ -42,7 +42,8 @@ public class SVMMultiClassTrainerTest extends TrainerTest {
         SVMLinearMultiClassClassificationTrainer trainer = new 
SVMLinearMultiClassClassificationTrainer()
             .withLambda(0.3)
             .withAmountOfLocIterations(10)
-            .withAmountOfIterations(20);
+            .withAmountOfIterations(20)
+            .withSeed(1234L);
 
         SVMLinearMultiClassClassificationModel mdl = trainer.fit(
             cacheMock,

http://git-wip-us.apache.org/repos/asf/ignite/blob/3f184913/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java
----------------------------------------------------------------------
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java 
b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java
index 822ad18..df7263f 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/svm/SVMTestSuite.java
@@ -28,7 +28,6 @@ import org.junit.runners.Suite;
     SVMModelTest.class,
     SVMBinaryTrainerTest.class,
     SVMMultiClassTrainerTest.class,
-    SVMBinaryTrainerIntegrationTest.class
 })
 public class SVMTestSuite {
     // No-op.

Reply via email to